<a href="https://colab.research.google.com/github/Papa-Panda/industry_algo/blob/main/Deep_Interest_Network_(DIN)_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://gemini.google.com/app/72f602f1dcb0baa4

# 屠龙少年与龙：漫谈深度学习驱动的广告推荐技术发展周期 - 朱小强的文章 - 知乎
# https://zhuanlan.zhihu.com/p/398041971
# 推荐系统中的注意力机制——阿里深度兴趣网络（DIN） - 王喆的文章 - 知乎
# https://zhuanlan.zhihu.com/p/51623339

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd

print(f"TensorFlow Version: {tf.__version__}")

# --- 1. Synthetic Data Generation ---
# This dataset simulates user interactions with items.
# Each user has a history of clicked items, and we'll predict if they click a new candidate item.

def generate_synthetic_data(num_samples=10000, num_users=1000, num_items=500, history_length=10):
    """
    Generates synthetic data for a Deep Interest Network.

    Args:
        num_samples (int): Total number of data points (user-candidate pairs).
        num_users (int): Number of unique users.
        num_items (int): Number of unique items.
        history_length (int): Maximum length of a user's historical clicked items.

    Returns:
        tuple: A tuple containing:
            - np.array: User IDs.
            - np.array: Candidate Item IDs.
            - np.array: 2D array of Historical Item IDs (padded with 0s).
            - np.array: Labels (1 if clicked, 0 otherwise).
            - np.array: Item feature embeddings (simulated).
    """
    user_ids = np.random.randint(1, num_users + 1, num_samples) # User IDs start from 1
    candidate_item_ids = np.random.randint(1, num_items + 1, num_samples) # Item IDs start from 1

    # Simulate historical clicked items for each user
    historical_item_ids = []
    for _ in range(num_samples):
        # Random number of historical items for each sample, up to history_length
        current_history_len = np.random.randint(1, history_length + 1)
        history = np.random.randint(1, num_items + 1, current_history_len).tolist()
        # Pad history with 0s if less than history_length
        history.extend([0] * (history_length - len(history)))
        historical_item_ids.append(history)
    historical_item_ids = np.array(historical_item_ids)

    # Simulate labels (e.g., 1 if candidate item is "similar" to history, 0 otherwise)
    # This is a very simple simulation of a click, based on random chance
    labels = np.random.randint(0, 2, num_samples)

    # Simulate item feature embeddings (e.g., each item has a 16-dim embedding)
    # We will use this for the DIN's attention mechanism
    embedding_dim = 16
    item_features = np.random.rand(num_items + 1, embedding_dim) # +1 for 0-padding, item 0 is dummy

    print(f"Generated synthetic data: {num_samples} samples.")
    print(f"  User IDs shape: {user_ids.shape}")
    print(f"  Candidate Item IDs shape: {candidate_item_ids.shape}")
    print(f"  Historical Item IDs shape: {historical_item_ids.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Item Features shape: {item_features.shape}")

    return user_ids, candidate_item_ids, historical_item_ids, labels, item_features

# Generate the data
num_users = 1000
num_items = 500
history_length = 10 # Max length of user behavior sequence
embedding_dim = 16 # Dimensionality of item and user embeddings

user_ids_data, candidate_item_ids_data, historical_item_ids_data, labels_data, item_features_data = \
    generate_synthetic_data(num_samples=50000, num_users=num_users, num_items=num_items, history_length=history_length)

# --- 2. DIN Model Architecture ---

# Custom Attention Layer for DIN
class Dice(layers.Layer):
    """
    Data Adaptive Activation Function (DICE) for DIN.
    It's a variant of PReLu that learns a dynamic 'p' parameter.
    """
    def __init__(self, axis=-1, epsilon=1e-9, **kwargs):
        super().__init__(**kwargs)
        self.axis = axis
        self.epsilon = epsilon

    def build(self, input_shape):
        self.alphas = self.add_weight(
            shape=(input_shape[-1],),
            initializer='zeros',
            trainable=True,
            name='dice_alpha'
        )
        self.beta = self.add_weight(
            shape=(input_shape[-1],),
            initializer='zeros',
            trainable=True,
            name='dice_beta'
        )
        super().build(input_shape)

    def call(self, inputs):
        # Calculate mean and variance along the specified axis
        reduc_axis = list(range(len(inputs.shape)))
        if self.axis != -1:
            reduc_axis.pop(self.axis)
        mean = tf.reduce_mean(inputs, axis=reduc_axis, keepdims=True)
        variance = tf.reduce_mean(tf.square(inputs - mean), axis=reduc_axis, keepdims=True)

        # Normalize the input
        x_normed = (inputs - mean) / tf.sqrt(variance + self.epsilon)

        # Calculate p (the parameter for PReLu-like activation)
        p = tf.sigmoid(self.alphas * x_normed + self.beta)

        # Apply DICE activation
        return p * inputs + (1 - p) * self.alphas * inputs

    def get_config(self):
        config = super().get_config()
        config.update({
            "axis": self.axis,
            "epsilon": self.epsilon,
        })
        return config

class AttentionPoolingLayer(layers.Layer):
    """
    Attention pooling layer for Deep Interest Network (DIN).
    Calculates attention scores between candidate item and historical items,
    then weights the historical item embeddings.
    """
    def __init__(self, embedding_dim, hidden_units=[80, 40], **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.hidden_units = hidden_units

        # Attention network (e.g., MLP)
        self.dense_layers = []
        for units in hidden_units:
            self.dense_layers.append(layers.Dense(units, activation=None)) # No activation initially
            self.dense_layers.append(Dice()) # Use Dice activation after each dense layer

        self.output_layer = layers.Dense(1, activation=None) # Output attention score (scalar)

    def call(self, inputs):
        # inputs: [candidate_item_embedding, historical_item_embeddings]
        # candidate_item_embedding: (batch_size, embedding_dim)
        # historical_item_embeddings: (batch_size, history_length, embedding_dim)

        candidate_item_embedding, historical_item_embeddings = inputs

        # Expand candidate_item_embedding to match history_length dimension for concatenation
        # (batch_size, 1, embedding_dim) -> (batch_size, history_length, embedding_dim)
        candidate_item_embedding_expanded = tf.expand_dims(candidate_item_embedding, 1)
        candidate_item_embedding_tiled = tf.tile(candidate_item_embedding_expanded, [1, tf.shape(historical_item_embeddings)[1], 1])

        # Concatenate candidate item, historical item, their product, and their difference
        # This is a common practice in attention mechanisms for DIN
        # (batch_size, history_length, embedding_dim * 4)
        concatenated_features = tf.concat([
            candidate_item_embedding_tiled,
            historical_item_embeddings,
            candidate_item_embedding_tiled * historical_item_embeddings,
            candidate_item_embedding_tiled - historical_item_embeddings
        ], axis=-1)

        # Pass through attention network
        attention_logits = concatenated_features
        for layer in self.dense_layers:
            attention_logits = layer(attention_logits)

        attention_logits = self.output_layer(attention_logits) # (batch_size, history_length, 1)

        # Apply softmax to get attention weights.
        # Mask out padded items (where embedding is 0) to prevent them from influencing attention.
        # For simplicity, we assume historical_item_embeddings with all zeros corresponds to padding.
        # A more robust approach would be to pass a mask explicitly.
        mask = tf.cast(tf.reduce_sum(tf.abs(historical_item_embeddings), axis=-1, keepdims=True) > 0, tf.float32)
        attention_logits = attention_logits - (1.0 - mask) * 1e9 # Mask padded items with large negative value

        attention_weights = tf.nn.softmax(attention_logits, axis=1) # (batch_size, history_length, 1)

        # Weighted sum of historical item embeddings
        # (batch_size, history_length, embedding_dim) * (batch_size, history_length, 1)
        # -> (batch_size, history_length, embedding_dim) -> (batch_size, embedding_dim)
        weighted_history_embedding = tf.reduce_sum(attention_weights * historical_item_embeddings, axis=1)

        return weighted_history_embedding

    def get_config(self):
        config = super().get_config()
        config.update({
            "embedding_dim": self.embedding_dim,
            "hidden_units": self.hidden_units,
        })
        return config


def build_din_model(num_users, num_items, history_length, embedding_dim, item_features_matrix):
    """
    Builds the Deep Interest Network (DIN) model.

    Args:
        num_users (int): Total number of unique users.
        num_items (int): Total number of unique items.
        history_length (int): Maximum length of user historical behavior sequence.
        embedding_dim (int): Dimensionality of item and user embeddings.
        item_features_matrix (np.array): Pre-trained or initial item feature embeddings.

    Returns:
        keras.Model: Compiled DIN model.
    """
    # Input Layers
    user_id_input = keras.Input(shape=(1,), name='user_id_input', dtype='int32')
    candidate_item_id_input = keras.Input(shape=(1,), name='candidate_item_id_input', dtype='int32')
    historical_item_ids_input = keras.Input(shape=(history_length,), name='historical_item_ids_input', dtype='int32')

    # Embedding Layers
    # User embeddings: simple lookup
    user_embedding_layer = layers.Embedding(
        input_dim=num_users + 1, # +1 for 0-padding if user_id 0 exists
        output_dim=embedding_dim,
        name='user_embedding'
    )
    user_embedding = user_embedding_layer(user_id_input) # (batch_size, 1, embedding_dim)
    user_embedding = layers.Reshape((embedding_dim,))(user_embedding) # (batch_size, embedding_dim)

    # Item embeddings: use pre-defined item_features_matrix (e.g., from pre-training or here simulated)
    # Set trainable=False if these embeddings are fixed, True if they should be fine-tuned.
    item_embedding_layer = layers.Embedding(
        input_dim=num_items + 1, # +1 for 0-padding
        output_dim=embedding_dim,
        weights=[item_features_matrix], # Initialize with the simulated item features
        trainable=True, # Allow fine-tuning these embeddings during training
        name='item_embedding'
    )

    # Candidate item embedding
    candidate_item_embedding = item_embedding_layer(candidate_item_id_input) # (batch_size, 1, embedding_dim)
    candidate_item_embedding = layers.Reshape((embedding_dim,))(candidate_item_embedding) # (batch_size, embedding_dim)

    # Historical items embeddings
    historical_item_embeddings = item_embedding_layer(historical_item_ids_input) # (batch_size, history_length, embedding_dim)

    # DIN Attention Mechanism
    # The AttentionPoolingLayer computes a weighted sum of historical item embeddings
    # based on their relevance to the candidate item.
    attention_output = AttentionPoolingLayer(
        embedding_dim=embedding_dim,
        hidden_units=[80, 40], # Attention MLP hidden units
        name='din_attention_pooling'
    )([candidate_item_embedding, historical_item_embeddings])

    # Concatenate all features for the final prediction layer
    # These are: user_embedding, candidate_item_embedding, and the attention-weighted historical embedding
    concatenated_features = layers.concatenate([
        user_embedding,
        candidate_item_embedding,
        attention_output
    ], axis=-1)

    # Prediction MLP (Deep Network)
    mlp_output = layers.Dense(128, activation='relu')(concatenated_features)
    mlp_output = layers.Dropout(0.3)(mlp_output)
    mlp_output = layers.Dense(64, activation='relu')(mlp_output)
    mlp_output = layers.Dropout(0.3)(mlp_output)
    mlp_output = layers.Dense(32, activation='relu')(mlp_output)

    # Output layer (sigmoid for binary classification)
    output = layers.Dense(1, activation='sigmoid', name='output')(mlp_output)

    # Create the model
    model = keras.Model(
        inputs=[user_id_input, candidate_item_id_input, historical_item_ids_input],
        outputs=output,
        name='deep_interest_network'
    )

    # Compile the model
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    return model

# Build the DIN model
din_model = build_din_model(
    num_users=num_users,
    num_items=num_items,
    history_length=history_length,
    embedding_dim=embedding_dim,
    item_features_matrix=item_features_data
)

din_model.summary()

# --- 3. Prepare Data for Training ---
# Create a dictionary for model inputs
model_inputs = {
    'user_id_input': user_ids_data,
    'candidate_item_id_input': candidate_item_ids_data,
    'historical_item_ids_input': historical_item_ids_data
}

# --- 4. Train the Model ---
print("\n--- Training the DIN Model ---")
history = din_model.fit(
    model_inputs,
    labels_data,
    batch_size=256,
    epochs=5, # Using a small number of epochs for demonstration
    validation_split=0.2, # Use 20% of data for validation
    verbose=1
)

print("\nTraining complete.")
print(f"Final training accuracy: {history.history['accuracy'][-1]:.4f}")
print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
print(f"Final training AUC: {history.history['auc'][-1]:.4f}")
print(f"Final validation AUC: {history.history['val_auc'][-1]:.4f}")

# --- 5. Make Predictions (Example) ---
print("\n--- Making Predictions (Example) ---")

# Select a few random samples for prediction
num_predict_samples = 5
random_indices = np.random.choice(len(user_ids_data), num_predict_samples, replace=False)

sample_user_ids = user_ids_data[random_indices]
sample_candidate_item_ids = candidate_item_ids_data[random_indices]
sample_historical_item_ids = historical_item_ids_data[random_indices]
sample_labels = labels_data[random_indices]

sample_inputs = {
    'user_id_input': sample_user_ids,
    'candidate_item_id_input': sample_candidate_item_ids,
    'historical_item_ids_input': sample_historical_item_ids
}

predictions = din_model.predict(sample_inputs)

print("\nSample Predictions:")
for i in range(num_predict_samples):
    print(f"  Sample {i+1}:")
    print(f"    User ID: {sample_user_ids[i][0] if sample_user_ids[i].ndim > 0 else sample_user_ids[i]}")
    print(f"    Candidate Item ID: {sample_candidate_item_ids[i][0] if sample_candidate_item_ids[i].ndim > 0 else sample_candidate_item_ids[i]}")
    print(f"    Historical Item IDs: {sample_historical_item_ids[i]}")
    print(f"    True Label: {sample_labels[i]}")
    print(f"    Predicted Probability: {predictions[i][0]:.4f}")
    print("-" * 30)

TensorFlow Version: 2.18.0
Generated synthetic data: 50000 samples.
  User IDs shape: (50000,)
  Candidate Item IDs shape: (50000,)
  Historical Item IDs shape: (50000, 10)
  Labels shape: (50000,)
  Item Features shape: (501, 16)



--- Training the DIN Model ---
Epoch 1/5
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.5030 - auc: 0.5048 - loss: 0.6949 - val_accuracy: 0.5039 - val_auc: 0.5019 - val_loss: 0.6931
Epoch 2/5
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 37ms/step - accuracy: 0.5007 - auc: 0.4985 - loss: 0.6935 - val_accuracy: 0.4991 - val_auc: 0.5024 - val_loss: 0.6934
Epoch 3/5
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 30ms/step - accuracy: 0.4997 - auc: 0.4982 - loss: 0.6934 - val_accuracy: 0.4978 - val_auc: 0.4962 - val_loss: 0.6933
Epoch 4/5
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 37ms/step - accuracy: 0.5053 - auc: 0.5100 - loss: 0.6931 - val_accuracy: 0.4974 - val_auc: 0.4996 - val_loss: 0.6940
Epoch 5/5
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 33ms/step - accuracy: 0.5233 - auc: 0.5308 - loss: 0.6917 - val_accuracy: 0.4938 - val_auc: 0.4925 - val_loss: 0.6

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

print(f"PyTorch Version: {torch.__version__}")

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. Synthetic Data Generation ---
# This dataset simulates user interactions with items, adapted for PyTorch.

def generate_synthetic_data(num_samples=10000, num_users=1000, num_items=500, history_length=10, embedding_dim=16):
    """
    Generates synthetic data for a Deep Interest Network.

    Args:
        num_samples (int): Total number of data points (user-candidate pairs).
        num_users (int): Number of unique users.
        num_items (int): Number of unique items.
        history_length (int): Maximum length of a user's historical clicked items.
        embedding_dim (int): Dimensionality of item and user embeddings.

    Returns:
        tuple: A tuple containing PyTorch Tensors:
            - user_ids (torch.Tensor): User IDs.
            - candidate_item_ids (torch.Tensor): Candidate Item IDs.
            - historical_item_ids (torch.Tensor): 2D array of Historical Item IDs (padded with 0s).
            - labels (torch.Tensor): Labels (1 if clicked, 0 otherwise).
            - item_features_matrix (np.array): Item feature embeddings (simulated, numpy for embedding init).
    """
    user_ids = np.random.randint(1, num_users + 1, num_samples) # User IDs start from 1
    candidate_item_ids = np.random.randint(1, num_items + 1, num_samples) # Item IDs start from 1

    historical_item_ids = []
    for _ in range(num_samples):
        current_history_len = np.random.randint(1, history_length + 1)
        history = np.random.randint(1, num_items + 1, current_history_len).tolist()
        history.extend([0] * (history_length - len(history))) # Pad with 0s
        historical_item_ids.append(history)
    historical_item_ids = np.array(historical_item_ids)

    labels = np.random.randint(0, 2, num_samples)

    item_features_matrix = np.random.rand(num_items + 1, embedding_dim).astype(np.float32) # +1 for 0-padding, item 0 is dummy

    print(f"Generated synthetic data: {num_samples} samples.")
    print(f"  User IDs shape: {user_ids.shape}")
    print(f"  Candidate Item IDs shape: {candidate_item_ids.shape}")
    print(f"  Historical Item IDs shape: {historical_item_ids.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Item Features matrix shape: {item_features_matrix.shape}")

    # Convert to PyTorch Tensors
    user_ids_t = torch.LongTensor(user_ids)
    candidate_item_ids_t = torch.LongTensor(candidate_item_ids)
    historical_item_ids_t = torch.LongTensor(historical_item_ids)
    labels_t = torch.FloatTensor(labels).unsqueeze(1) # Add a dimension for BCEWithLogitsLoss

    return user_ids_t, candidate_item_ids_t, historical_item_ids_t, labels_t, item_features_matrix

# Generate the data
num_users = 1000
num_items = 500
history_length = 10 # Max length of user behavior sequence
embedding_dim = 16 # Dimensionality of item and user embeddings

user_ids_data, candidate_item_ids_data, historical_item_ids_data, labels_data, item_features_data_np = \
    generate_synthetic_data(num_samples=50000, num_users=num_users, num_items=num_items, history_length=history_length, embedding_dim=embedding_dim)

# --- 2. DIN Model Architecture in PyTorch ---

# Custom Activation Function (Dice)
class Dice(nn.Module):
    """
    Data Adaptive Activation Function (DICE) for DIN in PyTorch.
    """
    def __init__(self, input_dim, epsilon=1e-9):
        super().__init__()
        self.epsilon = epsilon
        # alphas and beta are learnable parameters, one for each feature dimension
        self.alphas = nn.Parameter(torch.zeros(input_dim))
        self.beta = nn.Parameter(torch.zeros(input_dim))

    def forward(self, x):
        # Calculate mean and variance along batch and sequence dimensions,
        # keeping the feature dimension intact.
        # Example: if x is (batch_size, seq_len, features_dim),
        # mean and variance will be (1, 1, features_dim)
        reduction_axes = tuple(range(x.dim() - 1))
        mean = torch.mean(x, dim=reduction_axes, keepdim=True)
        variance = torch.mean(torch.pow(x - mean, 2), dim=reduction_axes, keepdim=True)

        x_normed = (x - mean) / torch.sqrt(variance + self.epsilon)

        # p is calculated element-wise across the feature dimension
        p = torch.sigmoid(self.alphas * x_normed + self.beta)

        # Apply DICE activation formula
        return p * x + (1 - p) * self.alphas * x

# Attention Pooling Layer for DIN
class AttentionPoolingLayer(nn.Module):
    """
    Attention pooling layer for Deep Interest Network (DIN) in PyTorch.
    Calculates attention scores between candidate item and historical items,
    then weights the historical item embeddings.
    """
    def __init__(self, embedding_dim, hidden_units=[80, 40]):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_units = hidden_units

        # Attention network (MLP)
        attention_mlp_layers = []
        input_dim_mlp = embedding_dim * 4 # Concatenated features: candidate, history, product, difference

        for i, units in enumerate(hidden_units):
            attention_mlp_layers.append(nn.Linear(input_dim_mlp, units))
            attention_mlp_layers.append(Dice(units)) # Apply Dice activation
            input_dim_mlp = units # Update input_dim for the next layer

        self.attention_mlp = nn.Sequential(*attention_mlp_layers)
        self.output_layer = nn.Linear(input_dim_mlp, 1) # Outputs a single attention score

    def forward(self, candidate_item_embedding, historical_item_embeddings):
        # candidate_item_embedding: (batch_size, embedding_dim)
        # historical_item_embeddings: (batch_size, history_length, embedding_dim)

        batch_size, history_length, _ = historical_item_embeddings.shape

        # Expand candidate_item_embedding to match history_length dimension for concatenation
        # (batch_size, 1, embedding_dim) -> (batch_size, history_length, embedding_dim)
        candidate_item_embedding_tiled = candidate_item_embedding.unsqueeze(1).expand(-1, history_length, -1)

        # Concatenate candidate item, historical item, their product, and their difference
        # Resulting shape: (batch_size, history_length, embedding_dim * 4)
        concatenated_features = torch.cat([
            candidate_item_embedding_tiled,
            historical_item_embeddings,
            candidate_item_embedding_tiled * historical_item_embeddings,
            candidate_item_embedding_tiled - historical_item_embeddings
        ], dim=-1)

        # Pass through attention network
        attention_logits = self.attention_mlp(concatenated_features) # (batch_size, history_length, hidden_units[-1])
        attention_logits = self.output_layer(attention_logits) # (batch_size, history_length, 1)

        # Mask out padded items (where the historical_item_embeddings are all zeros)
        # Create a mask: sum of absolute values along embedding dim will be zero for padded items
        mask = (historical_item_embeddings.abs().sum(dim=-1, keepdim=True) > 0).float()
        attention_logits = attention_logits - (1.0 - mask) * 1e9 # Apply large negative value to masked logits

        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_logits, dim=1) # (batch_size, history_length, 1)

        # Weighted sum of historical item embeddings
        # (batch_size, history_length, embedding_dim) * (batch_size, history_length, 1)
        # -> (batch_size, history_length, embedding_dim) -> sum over history_length -> (batch_size, embedding_dim)
        weighted_history_embedding = torch.sum(attention_weights * historical_item_embeddings, dim=1)

        return weighted_history_embedding

# Deep Interest Network (DIN) Model
class DIN(nn.Module):
    def __init__(self, num_users, num_items, history_length, embedding_dim, item_features_matrix):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.history_length = history_length
        self.embedding_dim = embedding_dim

        # Embedding layers
        self.user_embedding = nn.Embedding(num_users + 1, embedding_dim)
        # Initialize item embedding with pre-defined matrix (e.g., from pre-training)
        # Ensure it's a float tensor for embedding weights
        self.item_embedding = nn.Embedding.from_pretrained(
            torch.from_numpy(item_features_matrix).float(),
            freeze=False # Allow fine-tuning during training
        )

        # DIN Attention Pooling Layer
        self.attention_pooling_layer = AttentionPoolingLayer(embedding_dim=embedding_dim)

        # Final Prediction MLP
        # Input dim for MLP: user_embedding (embedding_dim) + candidate_item_embedding (embedding_dim)
        #                     + attention_output (embedding_dim)
        input_mlp_dim = embedding_dim * 3
        self.mlp = nn.Sequential(
            nn.Linear(input_mlp_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
        )
        # Output layer with sigmoid for binary classification probability
        self.output_layer = nn.Linear(32, 1)
        # Note: We will use BCEWithLogitsLoss which combines sigmoid and BCELoss for numerical stability,
        # so the final activation here will *not* be sigmoid, it's implicitly handled by the loss.

    def forward(self, user_id, candidate_item_id, historical_item_ids):
        # Ensure inputs are 1D for embedding lookup
        user_id = user_id.squeeze(-1) # (batch_size, 1) -> (batch_size,)
        candidate_item_id = candidate_item_id.squeeze(-1) # (batch_size, 1) -> (batch_size,)

        # Embeddings lookup
        user_emb = self.user_embedding(user_id) # (batch_size, embedding_dim)
        candidate_item_emb = self.item_embedding(candidate_item_id) # (batch_size, embedding_dim)
        historical_item_embs = self.item_embedding(historical_item_ids) # (batch_size, history_length, embedding_dim)

        # Pass through Attention Pooling Layer
        attention_output = self.attention_pooling_layer(candidate_item_emb, historical_item_embs)

        # Concatenate all features
        concatenated_features = torch.cat([
            user_emb,
            candidate_item_emb,
            attention_output
        ], dim=-1)

        # Pass through prediction MLP
        mlp_output = self.mlp(concatenated_features)
        # Final output (logits)
        logits = self.output_layer(mlp_output)

        return logits

# Instantiate and move model to device
din_model = DIN(
    num_users=num_users,
    num_items=num_items,
    history_length=history_length,
    embedding_dim=embedding_dim,
    item_features_matrix=item_features_data_np
).to(device)

print("\nDIN Model Summary:")
# A simple way to print model structure, similar to Keras summary
print(din_model)
# You might need to pass a dummy input to get details for each layer.
# print(din_model(torch.zeros(2,1).long().to(device), torch.zeros(2,1).long().to(device), torch.zeros(2, history_length).long().to(device)))


# --- 3. Prepare Data for Training (PyTorch DataLoader) ---
dataset = TensorDataset(user_ids_data, candidate_item_ids_data, historical_item_ids_data, labels_data)

# Split data into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# --- 4. Train the Model ---
optimizer = torch.optim.Adam(din_model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss() # Combines Sigmoid and Binary Cross-Entropy for numerical stability

num_epochs = 5
print(f"\n--- Training the DIN Model for {num_epochs} epochs ---")

for epoch in range(num_epochs):
    din_model.train() # Set model to training mode
    total_loss = 0
    predictions_train = []
    labels_train = []

    for batch_idx, (user_ids, candidate_ids, historical_ids, labels) in enumerate(train_loader):
        user_ids, candidate_ids, historical_ids, labels = \
            user_ids.to(device), candidate_ids.to(device), historical_ids.to(device), labels.to(device)

        optimizer.zero_grad() # Clear gradients
        outputs = din_model(user_ids, candidate_ids, historical_ids) # Forward pass
        loss = criterion(outputs, labels) # Calculate loss
        loss.backward() # Backward pass
        optimizer.step() # Update weights

        total_loss += loss.item()

        # Store predictions and labels for metrics
        predictions_train.extend(outputs.sigmoid().detach().cpu().numpy().flatten())
        labels_train.extend(labels.detach().cpu().numpy().flatten())

    avg_train_loss = total_loss / len(train_loader)
    train_accuracy = accuracy_score(labels_train, np.round(predictions_train))
    train_auc = roc_auc_score(labels_train, predictions_train)

    # --- Validation ---
    din_model.eval() # Set model to evaluation mode
    val_total_loss = 0
    predictions_val = []
    labels_val = []

    with torch.no_grad(): # Disable gradient calculations
        for user_ids, candidate_ids, historical_ids, labels in val_loader:
            user_ids, candidate_ids, historical_ids, labels = \
                user_ids.to(device), candidate_ids.to(device), historical_ids.to(device), labels.to(device)

            outputs = din_model(user_ids, candidate_ids, historical_ids)
            loss = criterion(outputs, labels)
            val_total_loss += loss.item()

            predictions_val.extend(outputs.sigmoid().cpu().numpy().flatten())
            labels_val.extend(labels.cpu().numpy().flatten())

    avg_val_loss = val_total_loss / len(val_loader)
    val_accuracy = accuracy_score(labels_val, np.round(predictions_val))
    val_auc = roc_auc_score(labels_val, predictions_val)

    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Train AUC: {train_auc:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}, Val AUC: {val_auc:.4f}")

print("\nTraining complete.")

# --- 5. Make Predictions (Example) ---
print("\n--- Making Predictions (Example) ---")

din_model.eval() # Set model to evaluation mode

# Select a few random samples for prediction
num_predict_samples = 5
random_indices = np.random.choice(len(user_ids_data), num_predict_samples, replace=False)

sample_user_ids = user_ids_data[random_indices].to(device)
sample_candidate_item_ids = candidate_item_ids_data[random_indices].to(device)
sample_historical_item_ids = historical_item_ids_data[random_indices].to(device)
sample_labels = labels_data[random_indices].to(device)

with torch.no_grad():
    sample_outputs = din_model(sample_user_ids, sample_candidate_item_ids, sample_historical_item_ids)
    sample_predictions = sample_outputs.sigmoid().cpu().numpy().flatten() # Apply sigmoid to get probabilities

print("\nSample Predictions:")
for i in range(num_predict_samples):
    print(f"  Sample {i+1}:")
    print(f"    User ID: {sample_user_ids[i].item()}")
    print(f"    Candidate Item ID: {sample_candidate_item_ids[i].item()}")
    print(f"    Historical Item IDs: {sample_historical_item_ids[i].cpu().numpy().tolist()}")
    print(f"    True Label: {sample_labels[i].item()}")
    print(f"    Predicted Probability: {sample_predictions[i]:.4f}")
    print("-" * 30)



PyTorch Version: 2.6.0+cu124
Using device: cpu
Generated synthetic data: 50000 samples.
  User IDs shape: (50000,)
  Candidate Item IDs shape: (50000,)
  Historical Item IDs shape: (50000, 10)
  Labels shape: (50000,)
  Item Features matrix shape: (501, 16)

DIN Model Summary:
DIN(
  (user_embedding): Embedding(1001, 16)
  (item_embedding): Embedding(501, 16)
  (attention_pooling_layer): AttentionPoolingLayer(
    (attention_mlp): Sequential(
      (0): Linear(in_features=64, out_features=80, bias=True)
      (1): Dice()
      (2): Linear(in_features=80, out_features=40, bias=True)
      (3): Dice()
    )
    (output_layer): Linear(in_features=40, out_features=1, bias=True)
  )
  (mlp): Sequential(
    (0): Linear(in_features=48, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=64, out_features=32, bias=True)
