In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from collections import defaultdict
import math

# Device configuration
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# Hyperparameters
embedding_dim = 128
num_heads = 4
num_layers = 2
dropout = 0.1
max_seq_length = 200
window_size = 10
chunk_size = 10
batch_size = 256
num_epochs = 500
learning_rate = 0.0025
mask_prob = 0.15
num_masks_per_batch = 5
num_negatives = 99

# Custom collation function
def custom_collate_fn(batch):
    seqs = torch.stack([item["seq"] for item in batch])
    return {"seq": seqs}

# Dataset class
class MovieLensDataset(Dataset):
    def __init__(self, user_dict, num_items, max_seq_length):
        self.user_dict = user_dict
        self.num_items = num_items
        self.max_seq_length = max_seq_length
        self.users = list(user_dict.keys())
        self.precomputed = self.precompute_sequences()

    def precompute_sequences(self):
        precomputed = {}
        for user in self.users:
            seq = self.user_dict[user][:self.max_seq_length]
            if len(seq) < 2:
                seq = [0] * self.max_seq_length
            else:
                seq = seq + [0] * (self.max_seq_length - len(seq)) if len(seq) < self.max_seq_length else seq
            precomputed[user] = torch.tensor(seq, dtype=torch.long)
        return precomputed

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        user = self.users[idx]
        seq = self.precomputed[user].clone()
        return {"seq": seq}

# Model
class SequentialRecommender(nn.Module):
    def __init__(self, num_items, embedding_dim, num_heads, num_layers, dropout, window_size):
        super(SequentialRecommender, self).__init__()
        self.num_items = num_items
        self.window_size = window_size
        self.embedding = nn.Embedding(num_items + 10000, embedding_dim, padding_idx=0)
        self.pos_encoding = nn.Parameter(self.create_pos_encoding(5000, embedding_dim))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embedding_dim,
                nhead=num_heads,
                dim_feedforward=embedding_dim * 4,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.fc = nn.Linear(embedding_dim, num_items)

    def create_pos_encoding(self, max_len, dim):
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe

    def compress_sequence(self, seq, mask_positions):
        batch_size, seq_len = seq.shape
        windows = []
        for pos in mask_positions:
            start = max(0, pos - self.window_size)
            end = min(seq_len, pos + self.window_size + 1)
            windows.append((start, end))

        merged = []
        if windows:
            current_start, current_end = windows[0]
            for start, end in windows[1:]:
                if start <= current_end:
                    current_end = max(current_end, end)
                else:
                    merged.append((current_start, current_end))
                    current_start, current_end = start, end
            merged.append((current_start, current_end))

        total_len = 0
        last_end = 0
        for start, end in merged:
            if last_end < start:
                total_len += 1
            total_len += end - start
            last_end = end
        if last_end < seq_len:
            total_len += 1

        compressed_seq = torch.zeros(batch_size, total_len, dtype=torch.long, device=seq.device)
        chunk_map = [[] for _ in range(batch_size)]
        mask_indices = []
        chunk_id = self.num_items + 1
        pos = 0

        last_end = 0
        for start, end in merged:
            if last_end < start:
                compressed_seq[:, pos] = chunk_id
                for b in range(batch_size):
                    chunk_map[b].append((last_end, start))
                chunk_id += 1
                pos += 1
            window_len = end - start
            compressed_seq[:, pos:pos + window_len] = seq[:, start:end].clone()
            for mp in mask_positions:
                if start <= mp < end and pos + (mp - start) not in mask_indices:
                    mask_indices.append(pos + (mp - start))
            for b in range(batch_size):
                chunk_map[b].extend(list(range(start, end)))
            pos += window_len
            last_end = end

        if last_end < seq_len:
            compressed_seq[:, pos] = chunk_id
            for b in range(batch_size):
                chunk_map[b].append((last_end, seq_len))

        return compressed_seq, chunk_map, torch.tensor(mask_indices, device=seq.device)

    def forward(self, seq, mask_positions, is_predict=False):
        batch_size, seq_len = seq.shape
        masked_seq = seq.clone()
        masked_seq[torch.arange(batch_size).unsqueeze(1), mask_positions] = self.num_items

        compressed_seq, chunk_map, mask_indices = self.compress_sequence(masked_seq, mask_positions)
        batch_size, comp_len = compressed_seq.shape
        embeddings = self.embedding(compressed_seq)

        is_chunk = (compressed_seq > self.num_items)
        chunk_indices = torch.where(is_chunk)
        if chunk_indices[0].numel() > 0:
            chunk_embeddings = []
            for b, i in zip(chunk_indices[0], chunk_indices[1]):
                chunk_start, chunk_end = chunk_map[b][i]
                chunk = seq[b, chunk_start:chunk_end].unsqueeze(0)
                chunk_emb = self.embedding(chunk)
                chunk_mask = (chunk != 0).float().unsqueeze(-1)
                chunk_sum = (chunk_emb * chunk_mask).sum(dim=1)
                chunk_count = chunk_mask.sum(dim=1).clamp(min=1)
                chunk_embeddings.append(chunk_sum / chunk_count)
            embeddings[chunk_indices] = torch.cat(chunk_embeddings)

        embeddings = embeddings + self.pos_encoding[:comp_len].unsqueeze(0)
        mask = (compressed_seq == 0).to(device)
        output = self.transformer(embeddings, src_key_padding_mask=mask)

        mask_output = output[torch.arange(batch_size).unsqueeze(1), mask_indices]
        logits = self.fc(mask_output)  # Shape: (batch_size, num_masks, num_items)
        
        # For predict, ensure only the last mask is used
        if is_predict and len(mask_indices) > 1:
            mask_output = mask_output[:, -1:, :]  # Take only the last mask
            logits = self.fc(mask_output)         # Shape: (batch_size, 1, num_items)
        
        return logits, mask_indices

    def predict(self, input_seq):
        seq = input_seq[:, :max_seq_length]
        seq = torch.where(seq >= self.num_items, torch.zeros_like(seq), seq)
        next_pos = seq.shape[1]
        seq = torch.cat([seq, torch.full((seq.shape[0], 1), self.num_items, device=seq.device)], dim=1)
        mask_positions = torch.tensor([next_pos], device=seq.device)
        logits, _ = self.forward(seq, mask_positions, is_predict=True)
        return logits

# Evaluation
def evaluate(model, user_dict, num_items, max_seq_length, device):
    model.eval()
    NDCG, HR, valid_users = 0.0, 0.0, 0

    for user, items in user_dict.items():
        if len(items) < 2:
            continue

        seq = items[:max_seq_length]
        input_seq = torch.tensor(seq[:-1], dtype=torch.long).unsqueeze(0).to(device)
        target = seq[-1]
        candidates = [target] + random.sample(list(set(range(1, num_items)) - set(items)), num_negatives)

        with torch.no_grad():
            logits = model.predict(input_seq)
            scores = logits[0, 0, candidates]  # Single mask in predict
            
            ranked = torch.argsort(scores, descending=True).cpu().numpy()
            rank = np.where(ranked == 0)[0][0] + 1

        valid_users += 1
        HR += int(rank <= 10)
        NDCG += 1 / np.log2(rank + 1) if rank <= 10 else 0

        # if valid_users % 100 == 0:
        #     print(f"Validated users: {valid_users}, HR@10: {HR / valid_users:.4f}, NDCG@10: {NDCG / valid_users:.4f}")

    # print(f"Final HR@10: {HR / valid_users:.4f}, NDCG@10: {NDCG / valid_users:.4f}")
    return HR / valid_users, NDCG / valid_users


# Load dataset
def load_movielens(file_path):
    user_dict = defaultdict(list)
    item_set = set()
    with open(file_path, 'r') as f:
        for line in f:
            user_id, item_id = map(int, line.strip().split())
            user_dict[user_id].append(item_id)
            item_set.add(item_id)
    num_items = max(item_set)
    return user_dict, num_items

# Main
if __name__ == "__main__":
    file_path = "data/ml-1m.txt"
    user_dict, num_items = load_movielens(file_path)
    print(f"Number of users: {len(user_dict)}, Number of items: {num_items}")

    dataset = MovieLensDataset(user_dict, num_items, max_seq_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

    model = SequentialRecommender(num_items, embedding_dim, num_heads, num_layers, dropout, window_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    print(f"Length of dataloader: {len(dataloader)}")
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            seq = batch["seq"].to(device)
            batch_size = seq.shape[0]
            valid_positions = (seq != 0).sum(dim=0).nonzero(as_tuple=True)[0]
            if len(valid_positions) < num_masks_per_batch:
                continue
            mask_positions = random.sample(valid_positions.tolist(), num_masks_per_batch)
            mask_positions = torch.tensor(mask_positions, device=device)

            logits, mask_indices = model(seq, mask_positions)  # Shape: (batch_size, num_masks, num_items)
            if logits.numel() == 0:
                continue

            targets = seq[torch.arange(batch_size).unsqueeze(1), mask_positions]  # Shape: (batch_size, num_masks_per_batch)
            valid_mask = (targets != num_items) & (targets != 0)  # Shape: (batch_size, num_masks_per_batch)

            num_masks = mask_indices.shape[0]
            logits = logits.view(batch_size * num_masks, num_items)  # Shape: (batch_size * num_masks, num_items)
            targets = targets.view(batch_size * num_masks_per_batch)  # Shape: (batch_size * num_masks_per_batch)
            valid_mask = valid_mask.view(batch_size * num_masks_per_batch)

            targets = targets[valid_mask]
            logits = logits[:targets.shape[0]]  # Align with filtered targets

            if logits.shape[0] != targets.shape[0] or targets.numel() == 0:
                continue

            loss = criterion(logits, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # if batch_idx % 10 == 0:
            #     print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

        print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader):.4f}")
        if epoch % 10 == 0:
            # Evaluate every 10 epochs
            print("Evaluating...")
            HR, NDCG = evaluate(model, user_dict, num_items, max_seq_length, device)
            print(f"Epoch {epoch+1}, HR@10: {HR:.4f}, NDCG@10: {NDCG:.4f}")

Number of users: 6040, Number of items: 3416
Length of dataloader: 24
Epoch 1, Loss: 3.6467
Evaluating...


  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


Epoch 1, HR@10: 0.3978, NDCG@10: 0.2055
Epoch 2, Loss: 3.7897
Epoch 3, Loss: 4.3596
Epoch 4, Loss: 3.1085
Epoch 5, Loss: 4.0256
Epoch 6, Loss: 3.7150
Epoch 7, Loss: 2.4765
Epoch 8, Loss: 2.1553
Epoch 9, Loss: 4.0136
Epoch 10, Loss: 1.5421
Epoch 11, Loss: 3.7231
Evaluating...
Epoch 11, HR@10: 0.4579, NDCG@10: 0.2494
Epoch 12, Loss: 3.3866
Epoch 13, Loss: 4.0010
Epoch 14, Loss: 4.6089
Epoch 15, Loss: 5.5257
Epoch 16, Loss: 5.2468
Epoch 17, Loss: 4.2982
Epoch 18, Loss: 4.6237
Epoch 19, Loss: 3.7032
Epoch 20, Loss: 3.7152
Epoch 21, Loss: 4.6244
Evaluating...
Epoch 21, HR@10: 0.4576, NDCG@10: 0.2528
Epoch 22, Loss: 3.6947
Epoch 23, Loss: 4.3166
Epoch 24, Loss: 3.9880
Epoch 25, Loss: 3.0766
Epoch 26, Loss: 3.3759
Epoch 27, Loss: 4.9337
Epoch 28, Loss: 1.5473
Epoch 29, Loss: 2.1576
Epoch 30, Loss: 3.3889
Epoch 31, Loss: 3.0728
Evaluating...
Epoch 31, HR@10: 0.4679, NDCG@10: 0.2583
Epoch 32, Loss: 3.6999
Epoch 33, Loss: 4.3043
Epoch 34, Loss: 3.6822
Epoch 35, Loss: 3.7041
Epoch 36, Loss: 3.996