In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder


class MovieLensSequenceDataset(Dataset):
    def __init__(self, ratings_file, seq_len=5, binary_classification=1):
        self.seq_len = seq_len

        self.ratings = pd.read_csv(ratings_file,
                                   header=None,
                                   sep='::',
                                   names=['userId', 'movieId', 'rating', 'timestamp'],
                                   engine='python',
                                   encoding='latin-1')

        self.user_encoder = LabelEncoder()
        self.movie_encoder = LabelEncoder()

        self.ratings['user'] = self.user_encoder.fit_transform(self.ratings['userId'])
        self.ratings['movie'] = self.movie_encoder.fit_transform(self.ratings['movieId'])

        if binary_classification == 1:
            self.ratings['label'] = (self.ratings['rating'] > 3).astype(int)
        else:
            self.ratings['label'] = (self.ratings['rating'] - 1).astype(int)

        self.samples = []
        for user_id, group in self.ratings.groupby('user'):
            group = group.sort_values('timestamp')
            movies = group['movie'].tolist()
            labels = group['label'].tolist()

            for i in range(1, len(movies)):
                behavior = movies[max(0, i - self.seq_len):i]
                if len(behavior) < self.seq_len:
                    behavior = [0] * (self.seq_len - len(behavior)) + behavior

                target = movies[i]
                label = labels[i]
                self.samples.append((user_id, behavior, target, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        user_id, behavior_seq, target_movie, label = self.samples[idx]
        return (
            torch.tensor(user_id, dtype=torch.long),
            torch.tensor(behavior_seq, dtype=torch.long),
            torch.tensor(target_movie, dtype=torch.long),
            torch.tensor(label, dtype=torch.float)
        )


def get_data_loaders(ratings_file, batch_size=64, seq_len=5):
    dataset = MovieLensSequenceDataset(ratings_file, seq_len)
    val_size = int(0.2 * len(dataset))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, dataset


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl


class MINDModel(pl.LightningModule):
    def __init__(self, num_users, num_movies, embedding_dim=32, num_interests=4, lr=1e-3, use_mha=False, num_heads=2):
        super().__init__()
        self.save_hyperparameters()

        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.movie_embedding = nn.Embedding(num_movies, embedding_dim)
        self.S = nn.Parameter(torch.randn(embedding_dim, embedding_dim))

        self.num_interests = num_interests
        self.lr = lr
        if use_mha:
            self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True)

    def dynamic_routing(self, behavior_embeds):
        batch_size, seq_len, embed_dim = behavior_embeds.shape
        K = self.num_interests
        b = torch.randn(batch_size, seq_len, K, device=self.device)

        for _ in range(3):
            w = F.softmax(b, dim=2).unsqueeze(-1)
            z = torch.sum(w * torch.matmul(behavior_embeds.unsqueeze(2), self.S), dim=1)
            norm = torch.norm(z, dim=2, keepdim=True)
            interests = (norm / (1 + norm)) * (z / (norm + 1e-8))
            b = b + torch.matmul(behavior_embeds, torch.matmul(self.S, interests.transpose(1, 2)))

        return interests

    def label_aware_attention(self, interests, target_movie_embed, p=2):
        if self.hparams.use_mha:
            # target_movie_embed: [batch, embed_dim] -> [batch, 1, embed_dim]
            query = target_movie_embed.unsqueeze(1)
            attn_output, _ = self.mha(query=query, key=interests, value=interests)
            return attn_output.squeeze(1)   # [batch, embed_dim]
        else:
            scores = torch.matmul(interests, target_movie_embed.unsqueeze(-1)).squeeze(-1)
            attn = F.softmax(scores.pow(p), dim=1)
            return torch.sum(attn.unsqueeze(-1) * interests, dim=1)

    def forward(self, user_ids, behavior_seq, target_movie_ids):
        behavior_embeds = self.movie_embedding(behavior_seq)
        target_movie_embed = self.movie_embedding(target_movie_ids)
        interests = self.dynamic_routing(behavior_embeds)
        user_vector = self.label_aware_attention(interests, target_movie_embed)
        logits = torch.sum(user_vector * target_movie_embed, dim=1)
        return logits

    def training_step(self, batch, batch_idx):
        user_ids, behavior_seq, target_movie, label = batch
        logits = self.forward(user_ids, behavior_seq, target_movie)
        loss = F.binary_cross_entropy_with_logits(logits, label)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        user_ids, behavior_seq, target_movie, label = batch
        logits = self.forward(user_ids, behavior_seq, target_movie)
        loss = F.binary_cross_entropy_with_logits(logits, label)
        
        # Log train_loss
        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        return loss


    def validation_step(self, batch, batch_idx):
        user_ids, behavior_seq, target_movie, label = batch
        logits = self.forward(user_ids, behavior_seq, target_movie)
        loss = F.binary_cross_entropy_with_logits(logits, label)
        
        # Log val_loss
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        return loss


In [3]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping

# Giả sử bạn đã import get_data_loaders và MINDModel

ratings_path = "/kaggle/input/movielens-1m-dataset/ratings.dat"
batch_size = 1024
seq_len = 500

train_loader, val_loader, dataset = get_data_loaders(ratings_path, batch_size, seq_len)

num_users = len(dataset.user_encoder.classes_)
num_movies = len(dataset.movie_encoder.classes_)

model = MINDModel(
    num_users=num_users,
    num_movies=num_movies,
    embedding_dim=32,
    num_interests=4
)

logger = TensorBoardLogger("tb_logs", name="MIND_movielens")

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=20,
    mode='min',
    verbose=True
)

trainer = pl.Trainer(
    max_epochs=100,
    logger=logger,
    callbacks=[early_stop_callback]
)

trainer.fit(model, train_loader, val_loader)


2025-04-23 03:57:18.093223: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745380638.314206      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745380638.379210      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [4]:
# from dataset import get_data_loaders
# from model import MINDModel
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

ratings_path = "/kaggle/input/movielens-1m-dataset/ratings.dat"
batch_size = 1024
seq_len = 500

train_loader, val_loader, dataset = get_data_loaders(ratings_path, batch_size, seq_len)

num_users = len(dataset.user_encoder.classes_)
num_movies = len(dataset.movie_encoder.classes_)

model = MINDModel(num_users=num_users, num_movies=num_movies, embedding_dim=32, num_interests=4,
                 use_mha=False,
                  num_heads=2)

logger = TensorBoardLogger("tb_logs", name="MIND_movielens")

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=20,
    mode='min',
    verbose=True
)

trainer = pl.Trainer(
    max_epochs=100,
    logger=logger,
    callbacks=[early_stop_callback]
)

trainer.fit(model, train_loader, val_loader)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]