<a href="https://colab.research.google.com/github/Papa-Panda/industry_algo/blob/main/Multi_channel_Interest_Memory_Network_(MIMN)_Implementation_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

print(f"PyTorch Version: {torch.__version__}")

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. Synthetic Data Generation ---
# Adapted to potentially generate longer history sequences for MIMN.

def generate_synthetic_data(num_samples=10000, num_users=1000, num_items=500, max_history_length=50, embedding_dim=16):
    """
    Generates synthetic data for a Multi-channel Interest Memory Network.

    Args:
        num_samples (int): Total number of data points (user-candidate pairs).
        num_users (int): Number of unique users.
        num_items (int): Number of unique items.
        max_history_length (int): Maximum length of a user's historical clicked items.
                                  Actual history length will vary.
        embedding_dim (int): Dimensionality of item and user embeddings.

    Returns:
        tuple: A tuple containing PyTorch Tensors:
            - user_ids (torch.Tensor): User IDs.
            - candidate_item_ids (torch.Tensor): Candidate Item IDs.
            - historical_item_ids (torch.Tensor): 2D array of Historical Item IDs (padded with 0s).
            - labels (torch.Tensor): Labels (1 if clicked, 0 otherwise).
            - item_features_matrix (np.array): Item feature embeddings (simulated, numpy for embedding init).
    """
    user_ids = np.random.randint(1, num_users + 1, num_samples)
    candidate_item_ids = np.random.randint(1, num_items + 1, num_samples)

    historical_item_ids = []
    for _ in range(num_samples):
        current_history_len = np.random.randint(1, max_history_length + 1) # Varying history length
        history = np.random.randint(1, num_items + 1, current_history_len).tolist()
        history.extend([0] * (max_history_length - len(history))) # Pad with 0s
        historical_item_ids.append(history)
    historical_item_ids = np.array(historical_item_ids)

    labels = np.random.randint(0, 2, num_samples)

    item_features_matrix = np.random.rand(num_items + 1, embedding_dim).astype(np.float32)

    print(f"Generated synthetic data: {num_samples} samples.")
    print(f"  User IDs shape: {user_ids.shape}")
    print(f"  Candidate Item IDs shape: {candidate_item_ids.shape}")
    print(f"  Historical Item IDs shape: {historical_item_ids.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Item Features matrix shape: {item_features_matrix.shape}")

    user_ids_t = torch.LongTensor(user_ids)
    candidate_item_ids_t = torch.LongTensor(candidate_item_ids)
    historical_item_ids_t = torch.LongTensor(historical_item_ids)
    labels_t = torch.FloatTensor(labels).unsqueeze(1)

    return user_ids_t, candidate_item_ids_t, historical_item_ids_t, labels_t, item_features_matrix

# Generate the data
num_users = 1000
num_items = 500
max_history_length = 50 # Longer history for MIMN
embedding_dim = 16
memory_slots = 8 # Number of memory channels/slots
memory_dim = 32 # Dimension of each memory slot (key and value)

user_ids_data, candidate_item_ids_data, historical_item_ids_data, labels_data, item_features_data_np = \
    generate_synthetic_data(num_samples=50000, num_users=num_users, num_items=num_items, max_history_length=max_history_length, embedding_dim=embedding_dim)

# --- 2. MIMN Model Architecture in PyTorch ---

# Custom Activation Function (Dice) - Reused
class Dice(nn.Module):
    """
    Data Adaptive Activation Function (DICE).
    """
    def __init__(self, input_dim, epsilon=1e-9):
        super().__init__()
        self.epsilon = epsilon
        self.alphas = nn.Parameter(torch.zeros(input_dim))
        self.beta = nn.Parameter(torch.zeros(input_dim))

    def forward(self, x):
        reduction_axes = tuple(range(x.dim() - 1))
        mean = torch.mean(x, dim=reduction_axes, keepdim=True)
        variance = torch.mean(torch.pow(x - mean, 2), dim=reduction_axes, keepdim=True)
        x_normed = (x - mean) / torch.sqrt(variance + self.epsilon)
        p = torch.sigmoid(self.alphas * x_normed + self.beta)
        return p * x + (1 - p) * self.alphas * x

# Memory Unit for MIMN
class MemoryUnit(nn.Module):
    """
    Represents the memory network in MIMN.
    Handles reading from and writing to memory slots.
    """
    def __init__(self, memory_slots, memory_dim, embedding_dim):
        super().__init__()
        self.memory_slots = memory_slots
        self.memory_dim = memory_dim
        self.embedding_dim = embedding_dim

        # Memory keys (M_K) and values (M_V)
        # Initialized as learnable parameters
        self.memory_keys = nn.Parameter(torch.randn(memory_slots, memory_dim))
        self.memory_values = nn.Parameter(torch.randn(memory_slots, memory_dim))

        # MLP for attention calculation (query to memory keys)
        self.attention_mlp = nn.Sequential(
            nn.Linear(memory_dim + embedding_dim, 64), # Concatenate query and memory key
            Dice(64),
            nn.Linear(64, 1) # Output a single score
        )

        # GRU for memory update (simplified)
        # Input to GRU: historical item embedding (embedding_dim) + read_memory (memory_dim)
        self.memory_update_gru = nn.GRUCell(input_size=embedding_dim + memory_dim, hidden_size=memory_dim)

    def read_memory(self, query_embedding):
        # query_embedding: (batch_size, embedding_dim) - typically candidate item embedding

        batch_size = query_embedding.shape[0]

        # Expand query and memory keys for concatenation
        # query_embedding_expanded: (batch_size, 1, embedding_dim)
        query_embedding_expanded = query_embedding.unsqueeze(1)
        # memory_keys_expanded: (1, memory_slots, memory_dim)
        memory_keys_expanded = self.memory_keys.unsqueeze(0)

        # Tile query across memory slots for each batch
        # query_embedding_tiled: (batch_size, memory_slots, embedding_dim)
        query_embedding_tiled = query_embedding_expanded.expand(-1, self.memory_slots, -1)
        # memory_keys_tiled: (batch_size, memory_slots, memory_dim)
        memory_keys_tiled = memory_keys_expanded.expand(batch_size, -1, -1)

        # Concatenate query and memory keys for attention MLP
        # (batch_size, memory_slots, embedding_dim + memory_dim)
        attention_input = torch.cat([query_embedding_tiled, memory_keys_tiled], dim=-1)

        # Calculate attention scores
        # (batch_size, memory_slots, 1)
        attention_logits = self.attention_mlp(attention_input)
        attention_weights = F.softmax(attention_logits, dim=1) # (batch_size, memory_slots, 1)

        # Read memory values: weighted sum of memory_values
        # (batch_size, memory_slots, 1) * (1, memory_slots, memory_dim)
        # -> (batch_size, memory_slots, memory_dim) -> sum over memory_slots -> (batch_size, memory_dim)
        read_memory_output = torch.sum(attention_weights * self.memory_values.unsqueeze(0), dim=1)

        return read_memory_output, attention_weights # Return weights for potential write operation

    def write_memory(self, historical_item_embeddings, read_memory_output, attention_weights):
        # historical_item_embeddings: (batch_size, history_length, embedding_dim)
        # read_memory_output: (batch_size, memory_dim) - from read_memory
        # attention_weights: (batch_size, memory_slots, 1) - from read_memory

        batch_size, history_length, _ = historical_item_embeddings.shape

        # For simplicity, we'll update memory based on the *last* historical item
        # and the read memory output, guided by attention.
        # A more complex MIMN would iterate through the history and update memory.
        # Here, we'll use a simplified "aggregate update" for demonstration.

        # Aggregate historical items (e.g., average)
        # Mask padded items before averaging
        mask = (historical_item_embeddings.abs().sum(dim=-1, keepdim=True) > 0).float()
        # Sum non-padded embeddings and count non-padded items
        sum_history_embs = (historical_item_embeddings * mask).sum(dim=1)
        num_valid_items = mask.sum(dim=1)
        # Avoid division by zero
        avg_history_emb = sum_history_embs / (num_valid_items + 1e-9)
        avg_history_emb = avg_history_emb.squeeze(1) # (batch_size, embedding_dim)

        # Input to GRUCell: Concatenate avg_history_emb and read_memory_output
        gru_input = torch.cat([avg_history_emb, read_memory_output], dim=-1) # (batch_size, embedding_dim + memory_dim)

        # Initial hidden state for GRUCell is the current memory value.
        # We need to update each memory slot individually based on attention weights.
        # This is a simplified, broadcasted update.
        # A full MIMN would have a more nuanced update per slot.

        # For a simplified update:
        # We'll update memory_values based on the GRUCell output,
        # weighted by the attention scores.
        # This is a conceptual write operation for demonstration.
        # In a full MIMN, each memory slot might have its own GRUCell.

        # Let's consider a simple update where the GRUCell updates based on the aggregated history
        # and its previous state (which is a memory value).
        # This is a very simplified write, not a full MIMN write.
        # A more accurate write would involve a GRU for each memory slot,
        # or a complex addressing mechanism.

        # Let's simplify: the GRUCell takes the aggregated historical input
        # and updates a single "current interest" state.
        # This is more like a DIEN's evolving interest, but using the memory concept.

        # A more direct MIMN-like write:
        # For each memory slot, update its value based on the attention it received.
        # This requires iterating through memory slots or using broadcasting carefully.

        # Let's implement a simplified update:
        # The memory_values are updated by a weighted sum of the GRU output.
        # The GRU's input is the aggregated history.
        # The GRU's hidden state is the memory_values themselves.

        # This is a conceptual update. A true MIMN write is more complex.
        # We'll simulate a write by updating the memory_values based on the input.
        # This is a placeholder for a more sophisticated memory update mechanism.
        # For demonstration, let's just make `memory_values` learnable and let the optimizer handle it.
        # The read operation is the primary focus for this example.
        pass # No explicit write logic in this simplified forward pass for now.
             # The memory_keys and memory_values are learnable parameters.

    def forward(self, query_embedding, historical_item_embeddings):
        # query_embedding: (batch_size, embedding_dim)
        # historical_item_embeddings: (batch_size, history_length, embedding_dim)

        # Read from memory
        read_output, attention_weights = self.read_memory(query_embedding)

        # In a full MIMN, a write operation would happen here,
        # potentially updating self.memory_values based on historical_item_embeddings
        # and attention_weights. For this simplified example, we're focusing on the read.
        # The memory_keys and memory_values are learnable parameters that get updated
        # through backpropagation based on the prediction loss.

        return read_output # (batch_size, memory_dim)

class MIMN(nn.Module):
    def __init__(self, num_users, num_items, max_history_length, embedding_dim, memory_slots, memory_dim, item_features_matrix):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.max_history_length = max_history_length
        self.embedding_dim = embedding_dim
        self.memory_slots = memory_slots
        self.memory_dim = memory_dim

        # Embedding layers
        self.user_embedding = nn.Embedding(num_users + 1, embedding_dim)
        self.item_embedding = nn.Embedding.from_pretrained(
            torch.from_numpy(item_features_matrix).float(),
            freeze=False
        )

        # Memory Network
        self.memory_unit = MemoryUnit(memory_slots, memory_dim, embedding_dim)

        # Final Prediction MLP
        # Input dim: user_emb (embedding_dim) + candidate_item_emb (embedding_dim)
        #            + read_memory_output (memory_dim)
        input_mlp_dim = embedding_dim + embedding_dim + memory_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_mlp_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
        )
        self.output_layer = nn.Linear(32, 1)

    def forward(self, user_id, candidate_item_id, historical_item_ids):
        user_id = user_id.squeeze(-1)
        candidate_item_id = candidate_item_id.squeeze(-1)

        user_emb = self.user_embedding(user_id)
        candidate_item_emb = self.item_embedding(candidate_item_id)
        historical_item_embs = self.item_embedding(historical_item_ids)

        # Read from memory using candidate_item_emb as query
        # The historical_item_embs are passed to the MemoryUnit's forward
        # but in this simplified version, the explicit 'write_memory' is not called
        # in the forward pass. The memory parameters are updated via backprop.
        read_memory_output = self.memory_unit(candidate_item_emb, historical_item_embs)

        # Concatenate all features
        concatenated_features = torch.cat([
            user_emb,
            candidate_item_emb,
            read_memory_output
        ], dim=-1)

        # Pass through prediction MLP
        mlp_output = self.mlp(concatenated_features)
        logits = self.output_layer(mlp_output)

        return logits

# Instantiate and move model to device
mimn_model = MIMN(
    num_users=num_users,
    num_items=num_items,
    max_history_length=max_history_length,
    embedding_dim=embedding_dim,
    memory_slots=memory_slots,
    memory_dim=memory_dim,
    item_features_matrix=item_features_data_np
).to(device)

print("\nMIMN Model Summary:")
print(mimn_model)


# --- 3. Prepare Data for Training (PyTorch DataLoader) ---
dataset = TensorDataset(user_ids_data, candidate_item_ids_data, historical_item_ids_data, labels_data)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# --- 4. Train the Model ---
optimizer = torch.optim.Adam(mimn_model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

num_epochs = 5
print(f"\n--- Training the MIMN Model for {num_epochs} epochs ---")

for epoch in range(num_epochs):
    mimn_model.train()
    total_loss = 0
    predictions_train = []
    labels_train = []

    for batch_idx, (user_ids, candidate_ids, historical_ids, labels) in enumerate(train_loader):
        user_ids, candidate_ids, historical_ids, labels = \
            user_ids.to(device), candidate_ids.to(device), historical_ids.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = mimn_model(user_ids, candidate_ids, historical_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        predictions_train.extend(outputs.sigmoid().detach().cpu().numpy().flatten())
        labels_train.extend(labels.detach().cpu().numpy().flatten())

    avg_train_loss = total_loss / len(train_loader)
    train_accuracy = accuracy_score(labels_train, np.round(predictions_train))
    train_auc = roc_auc_score(labels_train, predictions_train)

    # --- Validation ---
    mimn_model.eval()
    val_total_loss = 0
    predictions_val = []
    labels_val = []

    with torch.no_grad():
        for user_ids, candidate_ids, historical_ids, labels in val_loader:
            user_ids, candidate_ids, historical_ids, labels = \
                user_ids.to(device), candidate_ids.to(device), historical_ids.to(device), labels.to(device)

            outputs = mimn_model(user_ids, candidate_ids, historical_ids)
            loss = criterion(outputs, labels)
            val_total_loss += loss.item()

            predictions_val.extend(outputs.sigmoid().cpu().numpy().flatten())
            labels_val.extend(labels.cpu().numpy().flatten())

    avg_val_loss = val_total_loss / len(val_loader)
    val_accuracy = accuracy_score(labels_val, np.round(predictions_val))
    val_auc = roc_auc_score(labels_val, predictions_val)

    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Train AUC: {train_auc:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}, Val AUC: {val_auc:.4f}")

print("\nTraining complete.")

# --- 5. Make Predictions (Example) ---
print("\n--- Making Predictions (Example) ---")

mimn_model.eval()

num_predict_samples = 5
random_indices = np.random.choice(len(user_ids_data), num_predict_samples, replace=False)

sample_user_ids = user_ids_data[random_indices].to(device)
sample_candidate_item_ids = candidate_item_ids_data[random_indices].to(device)
sample_historical_item_ids = historical_item_ids_data[random_indices].to(device)
sample_labels = labels_data[random_indices].to(device)

with torch.no_grad():
    sample_outputs = mimn_model(sample_user_ids, sample_candidate_item_ids, sample_historical_item_ids)
    sample_predictions = sample_outputs.sigmoid().cpu().numpy().flatten()

print("\nSample Predictions:")
for i in range(num_predict_samples):
    print(f"  Sample {i+1}:")
    print(f"    User ID: {sample_user_ids[i].item()}")
    print(f"    Candidate Item ID: {sample_candidate_item_ids[i].item()}")
    print(f"    Historical Item IDs: {sample_historical_item_ids[i].cpu().numpy().tolist()}")
    print(f"    True Label: {sample_labels[i].item()}")
    print(f"    Predicted Probability: {sample_predictions[i]:.4f}")
    print("-" * 30)

PyTorch Version: 2.6.0+cu124
Using device: cpu
Generated synthetic data: 50000 samples.
  User IDs shape: (50000,)
  Candidate Item IDs shape: (50000,)
  Historical Item IDs shape: (50000, 50)
  Labels shape: (50000,)
  Item Features matrix shape: (501, 16)

MIMN Model Summary:
MIMN(
  (user_embedding): Embedding(1001, 16)
  (item_embedding): Embedding(501, 16)
  (memory_unit): MemoryUnit(
    (attention_mlp): Sequential(
      (0): Linear(in_features=48, out_features=64, bias=True)
      (1): Dice()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
    (memory_update_gru): GRUCell(48, 32)
  )
  (mlp): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=64, out_features=32, bias=True)
    (7): ReLU()
  )
  (output_layer): Linear(in_features=32, out_feat