In [1]:
"""
Code for Recommender System (Matrix Factorization), using pytorch 2.3 and python 3.12 and lightning 2.5
"""
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchmetrics import MeanSquaredError

import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split

from pytorch_lightning.loggers import TensorBoardLogger

In [2]:
def read_data_ml100k(data_dir:str="../../Data/ml-100k") -> pd.DataFrame:
    names = ['user_id', 'item_id', 'rating', 'timestamp']
    data = pd.read_csv(os.path.join(data_dir, 'u.data'), sep='\t', names=names, engine='python')
    num_users = data.user_id.unique().shape[0]
    num_items = data.item_id.unique().shape[0]
    return data, num_users, num_items

class MFData(Dataset):
    def __init__(self, data_dir:str="../../Data/ml-100k", normalize_rating:bool=False):
        self.data_dir = data_dir
        self.normalize_rating = normalize_rating
        self.df, self.num_users, self.num_items = read_data_ml100k(data_dir)
        self.user_id = self.df.user_id.values - 1
        self.item_id = self.df.item_id.values - 1
        self.rating = self.df.rating.values.astype(np.float32)
        
    def split(self, train_ratio=0.8):
        train_len = int(train_ratio * len(self))
        test_len = len(self) - train_len
        return random_split(self, [train_len, test_len])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx:int):
        return self.user_id[idx], self.item_id[idx], self.rating[idx]

class LitMFData(L.LightningDataModule):
    def __init__(
        self, 
        dataset:Dataset, 
        train_ratio:float=0.8, 
        batch_size:int=32, 
        num_workers:int=4
    ):
        self.dataset = dataset
        self.train_ratio = train_ratio
        self.dataloader_kwargs = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "persistent_workers": True if num_workers > 0 else False
        }
        self._log_hyperparams = True
        self.allow_zero_length_dataloader_with_multiple_devices = False

    def setup(self, stage:str):
        self.num_users = getattr(self.dataset, "num_users", None)
        self.num_items = getattr(self.dataset, "num_items", None)
        self.train_split, self.test_split = self.dataset.split(
            self.train_ratio)

    def train_dataloader(self):
        return DataLoader(self.train_split, **self.dataloader_kwargs, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.test_split, **self.dataloader_kwargs, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_split, **self.dataloader_kwargs, shuffle=False)

In [3]:
class MF(nn.Module):
    def __init__(self, num_factors:int, num_users:int, num_items:int, **kwargs):
        super().__init__(**kwargs)
        self.P = nn.Embedding(num_users, num_factors)
        self.Q = nn.Embedding(num_items, num_factors)
        self.user_bias = nn.Embedding(num_users, 1)
        self.item_bias = nn.Embedding(num_items, 1)

    def forward(self, user_id:torch.Tensor, item_id:torch.Tensor):
        P_u = self.P(user_id)
        Q_i = self.Q(item_id)
        b_u = self.user_bias(user_id).flatten()
        b_i = self.item_bias(item_id).flatten()
        outputs = (P_u * Q_i).sum(axis=1) + b_u + b_i
        return outputs

class LitMF(L.LightningModule):
    def __init__(self, model:nn.Module, lr:float=0.002, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model = model(**kwargs)
        self.lr = lr
        self.rmse = MeanSquaredError()
        self.training_step_outputs = []
        self.validation_step_outputs = []

    def get_loss(self, pred_ratings:torch.Tensor, batch:tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        return F.mse_loss(pred_ratings, batch[-1])

    def forward(self, batch:tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        user_ids, item_ids, _ = batch
        return self.model(user_ids, item_ids)
        
    def training_step(self, batch:tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx:int):
        outputs = self(batch)
        loss = self.get_loss(outputs, batch)
        self.training_step_outputs.append(loss)
        return loss
        
    def validation_step(self, batch:tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx:int):
        outputs = self(batch)
        loss = self.get_loss(outputs, batch)
        self.validation_step_outputs.append(loss)
        self.update_metric(outputs, batch)
        return loss

    def update_metric(self, outputs:torch.Tensor, batch:tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        _, _, gt = batch
        self.rmse.update(outputs, gt)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.lr, weight_decay=1e-5)

    def on_train_epoch_end(self):
        epoch_average = torch.stack(self.training_step_outputs).mean()
        self.logger.experiment.add_scalar(
            "train/loss", epoch_average, self.current_epoch)
        self.training_step_outputs.clear()

    def on_validation_epoch_end(self):
        epoch_average = torch.stack(self.validation_step_outputs).mean()
        self.logger.experiment.add_scalar(
            "val/loss", epoch_average, self.current_epoch)
        self.logger.experiment.add_scalar(
            "val/mse", self.rmse.compute(), self.current_epoch)
        self.rmse.reset()
        self.validation_step_outputs.clear()

In [4]:
def matirx_factorization():
    embedding_dims, max_epochs, batch_size = 30, 40, 512
    data = LitMFData(MFData(), batch_size=batch_size, num_workers=0) # we can not use num_workers larger than 0
    # in windows and jupyter lab enviroment, please use python script to run this with num_workers > 0
    data.setup("fit")
    model = LitMF(MF, num_factors=embedding_dims, num_users=data.num_users, num_items=data.num_items)
    logger = TensorBoardLogger("log", name=f"MF_{embedding_dims}")
    trainer = L.Trainer(max_epochs=max_epochs, accelerator="auto", logger=logger)
    trainer.fit(model, data)

In [5]:
matirx_factorization()

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | MF               | 81.4 K | train
1 | rmse  | MeanSquaredError | 0      | train
---------------------------------------------------
81.4 K    Trainable params
0         Non-trainable params
81.4 K    Total par

Sanity Checking: |                                                                                            …

C:\ProgramData\miniconda3\envs\ai\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
C:\ProgramData\miniconda3\envs\ai\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=40` reached.


In [6]:
%load_ext tensorboard
%tensorboard --logdir log