In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from time import time
from tqdm import tqdm

In [17]:
device = torch.device("mps")

In [35]:
# read data

rating_df = pd.read_csv('rating.csv', parse_dates=['timestamp'])

In [36]:
rand_users = np.random.choice(rating_df['userId'].unique(),
                              size=int(len(rating_df['userId'].unique())*0.05),
                              replace=False)

In [37]:
rating_df = rating_df.loc[rating_df['userId'].isin(rand_users)]

print('Reduced dataframe: {} rows for {} different users'.format(len(rating_df), len(rand_users)))

Reduced dataframe: 1038616 rows for 6924 different users


In [38]:
# Get number of unique users and items
num_users = rating_df['userId'].nunique()
num_items = rating_df['movieId'].nunique()

# Mapping userId and movieId to indices
user2idx = {userId: idx for idx, userId in enumerate(rating_df['userId'].unique())}
item2idx = {movieId: idx for idx, movieId in enumerate(rating_df['movieId'].unique())}

# Apply the mapping to your dataframe
rating_df['userId'] = rating_df['userId'].map(user2idx)
rating_df['movieId'] = rating_df['movieId'].map(item2idx)

# Split the dataset into training and testing sets
train_data, test_data = train_test_split(rating_df, test_size=0.2, random_state=42)

In [39]:
class CFData(Dataset):
    def __init__(self, data):
        self.user_ids = torch.tensor(data['userId'].values, dtype=torch.long)
        self.item_ids = torch.tensor(data['movieId'].values, dtype=torch.long)
        self.ratings = torch.tensor(data['rating'].values, dtype=torch.float32)

    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, idx):
        return self.user_ids[idx], self.item_ids[idx], self.ratings[idx]

# Create DataLoader for training
train_dataset = CFData(train_data)
test_dataset = CFData(test_data)

train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True)

In [40]:
class MatrixFactorization(nn.Module):
    def __init__(self, num_users, num_items, latent_dim, regs=[0, 0]):
        super(MatrixFactorization, self).__init__()
        self.user_embedding = nn.Embedding(num_users, latent_dim)
        self.item_embedding = nn.Embedding(num_items, latent_dim)
        self.user_embedding.weight = nn.Parameter(self.user_embedding.weight * (1 - regs[0]))
        self.item_embedding.weight = nn.Parameter(self.item_embedding.weight * (1 - regs[1]))

    def forward(self, user, item):
        user_vecs = self.user_embedding(user)
        item_vecs = self.item_embedding(item)
        y_hat = torch.sum(user_vecs * item_vecs, dim=1)
        return y_hat

# Initialize the model
latent_dim = 15
model = MatrixFactorization(num_users, num_items, latent_dim).to(device)

In [51]:
def train_model(model, train_loader, device, epochs=100, learning_rate=0.001):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    hr_list = []
    loss_list = []
    best_hr, best_iter = -1, -1
    patience = 5
    patience_count = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        # Add tqdm progress bar
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}") as pbar:
            for user, item, rating in train_loader:
                # Move data to the MPS or CPU device
                user = user.to(device)
                item = item.to(device)
                rating = rating.to(device)
                
                optimizer.zero_grad()
                prediction = model(user, item)
                loss = criterion(prediction, rating)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

                # Update the progress bar
                pbar.set_postfix({'Loss': loss.item()})
                pbar.update(1)

        # Early stopping logic
        # hr = evaluate_model(model, test_data, device, topK=5)  
        # hr_list.append(hr)
        # loss_list.append(total_loss / len(train_loader))

        # if hr < np.max(hr_list):
        #     patience_count += 1
        # else:
        #     patience_count = 0
        #     best_hr, best_iter = hr, epoch

        # if patience_count == patience:
        #     print(f"Early stopping at epoch {epoch}")
        #     break

        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader)}')

    print("End.")

In [48]:
def evaluate_model(model, test_data, device, topK=10):
    hits = []

    # Generate movie_ids_unique and interacted_items if not already done
    movie_ids_unique = rating_df['movieId'].unique()
    interacted_items = rating_df.groupby('userId')['movieId'].apply(list).to_dict()
    
    with tqdm(total=len(test_data), desc="Evaluating") as pbar:
        for index, row in test_data.iterrows():
            user = torch.tensor([row['userId']]).long().to(device)
            true_item = row['movieId']
            negative_items = set(movie_ids_unique) - set(interacted_items[user.item()])  # Adjust according to your data

            test_items = list(np.random.choice(list(negative_items), size=99)) + [true_item]
            user_tensor = torch.tensor([user.item()] * 100).long().to(device)
            item_tensor = torch.tensor(test_items).long().to(device)

            with torch.no_grad():
                predictions = model(user_tensor, item_tensor)
            
            top_k_items = item_tensor[torch.argsort(predictions, descending=True)[:topK]].tolist()

            if true_item in top_k_items:
                hits.append(1)
            else:
                hits.append(0)

            # Update progress bar
            pbar.update(1)

    return np.mean(hits)

In [52]:
# Train the model
train_model(model, train_loader, device, epochs=100, learning_rate=0.001)

# Evaluate the model 
hr = evaluate_model(model, test_data, device, topK=10)
print(f'Final Hit Ratio@5: {hr:.4f}')

Epoch 1/100: 100%|██████████████████| 203/203 [00:05<00:00, 39.03it/s, Loss=5.6]


Epoch 1/100, Loss: 6.359740449877208


Epoch 2/100: 100%|█████████████████| 203/203 [00:05<00:00, 39.64it/s, Loss=4.63]


Epoch 2/100, Loss: 5.061100882262432


Epoch 3/100: 100%|█████████████████| 203/203 [00:05<00:00, 38.84it/s, Loss=3.77]


Epoch 3/100, Loss: 4.107828383375272


Epoch 4/100: 100%|█████████████████| 203/203 [00:05<00:00, 39.48it/s, Loss=3.13]


Epoch 4/100, Loss: 3.407434671383186


Epoch 5/100: 100%|██████████████████| 203/203 [00:05<00:00, 39.98it/s, Loss=2.6]


Epoch 5/100, Loss: 2.883514761337506


Epoch 6/100: 100%|█████████████████| 203/203 [00:05<00:00, 37.77it/s, Loss=2.47]


Epoch 6/100, Loss: 2.4842099955516495


Epoch 7/100: 100%|█████████████████| 203/203 [00:05<00:00, 39.89it/s, Loss=2.07]


Epoch 7/100, Loss: 2.173894387747854


Epoch 8/100: 100%|█████████████████| 203/203 [00:04<00:00, 40.66it/s, Loss=1.95]


Epoch 8/100, Loss: 1.9291204789589191


Epoch 9/100: 100%|█████████████████| 203/203 [00:06<00:00, 33.33it/s, Loss=1.51]


Epoch 9/100, Loss: 1.7327651190640303


Epoch 10/100: 100%|████████████████| 203/203 [00:06<00:00, 33.80it/s, Loss=1.44]


Epoch 10/100, Loss: 1.5735882726208916


Epoch 11/100: 100%|████████████████| 203/203 [00:05<00:00, 34.63it/s, Loss=1.41]


Epoch 11/100, Loss: 1.4429204963110938


Epoch 12/100: 100%|████████████████| 203/203 [00:06<00:00, 33.50it/s, Loss=1.25]


Epoch 12/100, Loss: 1.3345122096573778


Epoch 13/100: 100%|████████████████| 203/203 [00:06<00:00, 33.62it/s, Loss=1.22]


Epoch 13/100, Loss: 1.2438144114217147


Epoch 14/100: 100%|████████████████| 203/203 [00:05<00:00, 34.71it/s, Loss=1.16]


Epoch 14/100, Loss: 1.1673317748337544


Epoch 15/100: 100%|█████████████████| 203/203 [00:05<00:00, 34.64it/s, Loss=1.1]


Epoch 15/100, Loss: 1.102315365974539


Epoch 16/100: 100%|████████████████| 203/203 [00:05<00:00, 35.39it/s, Loss=1.06]


Epoch 16/100, Loss: 1.0467732026071972


Epoch 17/100: 100%|████████████████| 203/203 [00:05<00:00, 34.53it/s, Loss=1.01]


Epoch 17/100, Loss: 0.9989947476997751


Epoch 18/100: 100%|███████████████| 203/203 [00:05<00:00, 35.54it/s, Loss=0.943]


Epoch 18/100, Loss: 0.9576728948818639


Epoch 19/100: 100%|███████████████| 203/203 [00:05<00:00, 38.99it/s, Loss=0.885]


Epoch 19/100, Loss: 0.9217122126682639


Epoch 20/100: 100%|███████████████| 203/203 [00:05<00:00, 38.15it/s, Loss=0.902]


Epoch 20/100, Loss: 0.8903569056482737


Epoch 21/100: 100%|███████████████| 203/203 [00:05<00:00, 38.57it/s, Loss=0.845]


Epoch 21/100, Loss: 0.8627853681301249


Epoch 22/100: 100%|███████████████| 203/203 [00:05<00:00, 38.70it/s, Loss=0.838]


Epoch 22/100, Loss: 0.8385540828329002


Epoch 23/100: 100%|███████████████| 203/203 [00:05<00:00, 38.33it/s, Loss=0.833]


Epoch 23/100, Loss: 0.8169971575290699


Epoch 24/100: 100%|███████████████| 203/203 [00:05<00:00, 39.46it/s, Loss=0.803]


Epoch 24/100, Loss: 0.7979369686154897


Epoch 25/100: 100%|█████████████████| 203/203 [00:05<00:00, 39.07it/s, Loss=0.8]


Epoch 25/100, Loss: 0.7808078459918205


Epoch 26/100: 100%|███████████████| 203/203 [00:05<00:00, 38.77it/s, Loss=0.767]


Epoch 26/100, Loss: 0.7654492449877884


Epoch 27/100: 100%|███████████████| 203/203 [00:05<00:00, 39.52it/s, Loss=0.799]


Epoch 27/100, Loss: 0.7516865589348554


Epoch 28/100: 100%|███████████████| 203/203 [00:05<00:00, 39.20it/s, Loss=0.707]


Epoch 28/100, Loss: 0.739182876542284


Epoch 29/100: 100%|███████████████| 203/203 [00:05<00:00, 38.55it/s, Loss=0.715]


Epoch 29/100, Loss: 0.7278923210252095


Epoch 30/100: 100%|████████████████| 203/203 [00:05<00:00, 38.24it/s, Loss=0.72]


Epoch 30/100, Loss: 0.7175411634844512


Epoch 31/100: 100%|███████████████| 203/203 [00:05<00:00, 39.39it/s, Loss=0.743]


Epoch 31/100, Loss: 0.7080938076150829


Epoch 32/100: 100%|████████████████| 203/203 [00:05<00:00, 39.67it/s, Loss=0.75]


Epoch 32/100, Loss: 0.6994187723239654


Epoch 33/100: 100%|███████████████| 203/203 [00:05<00:00, 39.26it/s, Loss=0.694]


Epoch 33/100, Loss: 0.6913340852178377


Epoch 34/100: 100%|███████████████| 203/203 [00:05<00:00, 39.59it/s, Loss=0.697]


Epoch 34/100, Loss: 0.6838870588781798


Epoch 35/100: 100%|███████████████| 203/203 [00:05<00:00, 39.68it/s, Loss=0.659]


Epoch 35/100, Loss: 0.676938859993601


Epoch 36/100: 100%|████████████████| 203/203 [00:05<00:00, 37.98it/s, Loss=0.69]


Epoch 36/100, Loss: 0.6705058816031282


Epoch 37/100: 100%|████████████████| 203/203 [00:05<00:00, 39.07it/s, Loss=0.69]


Epoch 37/100, Loss: 0.664423147739448


Epoch 38/100: 100%|███████████████| 203/203 [00:05<00:00, 35.94it/s, Loss=0.647]


Epoch 38/100, Loss: 0.658674785069057


Epoch 39/100: 100%|███████████████| 203/203 [00:05<00:00, 38.30it/s, Loss=0.662]


Epoch 39/100, Loss: 0.653317104125845


Epoch 40/100: 100%|███████████████| 203/203 [00:05<00:00, 38.77it/s, Loss=0.651]


Epoch 40/100, Loss: 0.6481748393603733


Epoch 41/100: 100%|███████████████| 203/203 [00:05<00:00, 38.69it/s, Loss=0.615]


Epoch 41/100, Loss: 0.6432086399623326


Epoch 42/100: 100%|███████████████| 203/203 [00:05<00:00, 38.08it/s, Loss=0.615]


Epoch 42/100, Loss: 0.6385828688226897


Epoch 43/100: 100%|████████████████| 203/203 [00:05<00:00, 37.80it/s, Loss=0.64]


Epoch 43/100, Loss: 0.6341461727184615


Epoch 44/100: 100%|███████████████| 203/203 [00:05<00:00, 38.89it/s, Loss=0.663]


Epoch 44/100, Loss: 0.6298415070684086


Epoch 45/100: 100%|████████████████| 203/203 [00:05<00:00, 35.83it/s, Loss=0.64]


Epoch 45/100, Loss: 0.625755849730205


Epoch 46/100: 100%|███████████████| 203/203 [00:05<00:00, 34.30it/s, Loss=0.616]


Epoch 46/100, Loss: 0.6217258499173696


Epoch 47/100: 100%|███████████████| 203/203 [00:05<00:00, 37.25it/s, Loss=0.615]


Epoch 47/100, Loss: 0.6178857543198346


Epoch 48/100: 100%|███████████████| 203/203 [00:05<00:00, 37.45it/s, Loss=0.623]


Epoch 48/100, Loss: 0.6141759368586422


Epoch 49/100: 100%|███████████████| 203/203 [00:05<00:00, 38.32it/s, Loss=0.615]


Epoch 49/100, Loss: 0.6105895703062049


Epoch 50/100: 100%|███████████████| 203/203 [00:05<00:00, 36.64it/s, Loss=0.632]


Epoch 50/100, Loss: 0.60708851209415


Epoch 51/100: 100%|███████████████| 203/203 [00:05<00:00, 37.85it/s, Loss=0.619]


Epoch 51/100, Loss: 0.6036769474668456


Epoch 52/100: 100%|███████████████| 203/203 [00:05<00:00, 36.71it/s, Loss=0.616]


Epoch 52/100, Loss: 0.6004387134401669


Epoch 53/100: 100%|███████████████| 203/203 [00:05<00:00, 37.44it/s, Loss=0.587]


Epoch 53/100, Loss: 0.5972089635327532


Epoch 54/100: 100%|███████████████| 203/203 [00:05<00:00, 37.78it/s, Loss=0.594]


Epoch 54/100, Loss: 0.5940631854123083


Epoch 55/100: 100%|███████████████| 203/203 [00:05<00:00, 37.78it/s, Loss=0.578]


Epoch 55/100, Loss: 0.5910997610961275


Epoch 56/100: 100%|███████████████| 203/203 [00:05<00:00, 37.93it/s, Loss=0.611]


Epoch 56/100, Loss: 0.5881243920678576


Epoch 57/100: 100%|███████████████| 203/203 [00:05<00:00, 37.26it/s, Loss=0.614]


Epoch 57/100, Loss: 0.5852833331512113


Epoch 58/100: 100%|███████████████| 203/203 [00:05<00:00, 37.08it/s, Loss=0.571]


Epoch 58/100, Loss: 0.5825324695685814


Epoch 59/100: 100%|███████████████| 203/203 [00:05<00:00, 37.61it/s, Loss=0.588]


Epoch 59/100, Loss: 0.5797762618276286


Epoch 60/100: 100%|███████████████| 203/203 [00:05<00:00, 37.97it/s, Loss=0.621]


Epoch 60/100, Loss: 0.5772303736268594


Epoch 61/100: 100%|███████████████| 203/203 [00:05<00:00, 37.79it/s, Loss=0.598]


Epoch 61/100, Loss: 0.5746526823842467


Epoch 62/100: 100%|███████████████| 203/203 [00:05<00:00, 37.14it/s, Loss=0.575]


Epoch 62/100, Loss: 0.5721234544157394


Epoch 63/100: 100%|███████████████| 203/203 [00:05<00:00, 36.55it/s, Loss=0.611]


Epoch 63/100, Loss: 0.5696821685495048


Epoch 64/100: 100%|███████████████| 203/203 [00:05<00:00, 38.30it/s, Loss=0.604]


Epoch 64/100, Loss: 0.5672973256393019


Epoch 65/100: 100%|███████████████| 203/203 [00:05<00:00, 37.43it/s, Loss=0.574]


Epoch 65/100, Loss: 0.5649552870853781


Epoch 66/100: 100%|███████████████| 203/203 [00:05<00:00, 36.29it/s, Loss=0.566]


Epoch 66/100, Loss: 0.5627670687407695


Epoch 67/100: 100%|███████████████| 203/203 [00:05<00:00, 37.22it/s, Loss=0.572]


Epoch 67/100, Loss: 0.5605412833209108


Epoch 68/100: 100%|████████████████| 203/203 [00:05<00:00, 37.56it/s, Loss=0.57]


Epoch 68/100, Loss: 0.5584125407223631


Epoch 69/100: 100%|████████████████| 203/203 [00:05<00:00, 38.06it/s, Loss=0.59]


Epoch 69/100, Loss: 0.5563489762433057


Epoch 70/100: 100%|███████████████| 203/203 [00:05<00:00, 35.96it/s, Loss=0.568]


Epoch 70/100, Loss: 0.5543343477648467


Epoch 71/100: 100%|████████████████| 203/203 [00:05<00:00, 36.90it/s, Loss=0.54]


Epoch 71/100, Loss: 0.5522796303180638


Epoch 72/100: 100%|███████████████| 203/203 [00:05<00:00, 37.67it/s, Loss=0.548]


Epoch 72/100, Loss: 0.5503877677353732


Epoch 73/100: 100%|███████████████| 203/203 [00:05<00:00, 37.49it/s, Loss=0.574]


Epoch 73/100, Loss: 0.5484641609814367


Epoch 74/100: 100%|███████████████| 203/203 [00:05<00:00, 36.92it/s, Loss=0.557]


Epoch 74/100, Loss: 0.5465796895215077


Epoch 75/100: 100%|███████████████| 203/203 [00:05<00:00, 37.49it/s, Loss=0.538]


Epoch 75/100, Loss: 0.5447625177247184


Epoch 76/100: 100%|███████████████| 203/203 [00:05<00:00, 38.10it/s, Loss=0.564]


Epoch 76/100, Loss: 0.5430458293759765


Epoch 77/100: 100%|████████████████| 203/203 [00:05<00:00, 37.57it/s, Loss=0.55]


Epoch 77/100, Loss: 0.5413006030280014


Epoch 78/100: 100%|███████████████| 203/203 [00:05<00:00, 36.91it/s, Loss=0.548]


Epoch 78/100, Loss: 0.5396036445507275


Epoch 79/100: 100%|███████████████| 203/203 [00:05<00:00, 35.30it/s, Loss=0.564]


Epoch 79/100, Loss: 0.537943159711772


Epoch 80/100: 100%|███████████████| 203/203 [00:05<00:00, 37.50it/s, Loss=0.557]


Epoch 80/100, Loss: 0.5363065482947627


Epoch 81/100: 100%|███████████████| 203/203 [00:05<00:00, 36.61it/s, Loss=0.541]


Epoch 81/100, Loss: 0.5347075725130259


Epoch 82/100: 100%|███████████████| 203/203 [00:05<00:00, 38.08it/s, Loss=0.539]


Epoch 82/100, Loss: 0.5331640801406259


Epoch 83/100: 100%|███████████████| 203/203 [00:05<00:00, 36.74it/s, Loss=0.519]


Epoch 83/100, Loss: 0.5316092464430578


Epoch 84/100: 100%|████████████████| 203/203 [00:05<00:00, 37.88it/s, Loss=0.54]


Epoch 84/100, Loss: 0.5301121545249018


Epoch 85/100: 100%|███████████████| 203/203 [00:05<00:00, 37.82it/s, Loss=0.549]


Epoch 85/100, Loss: 0.528668934048103


Epoch 86/100: 100%|███████████████| 203/203 [00:05<00:00, 37.16it/s, Loss=0.556]


Epoch 86/100, Loss: 0.5272505364100921


Epoch 87/100: 100%|███████████████| 203/203 [00:05<00:00, 36.83it/s, Loss=0.532]


Epoch 87/100, Loss: 0.5257413537044243


Epoch 88/100: 100%|███████████████| 203/203 [00:05<00:00, 36.25it/s, Loss=0.517]


Epoch 88/100, Loss: 0.5243748649587772


Epoch 89/100: 100%|███████████████| 203/203 [00:05<00:00, 35.74it/s, Loss=0.537]


Epoch 89/100, Loss: 0.5229885184706138


Epoch 90/100: 100%|████████████████| 203/203 [00:05<00:00, 37.30it/s, Loss=0.54]


Epoch 90/100, Loss: 0.5216904690700211


Epoch 91/100: 100%|███████████████| 203/203 [00:05<00:00, 36.80it/s, Loss=0.528]


Epoch 91/100, Loss: 0.5203613690261183


Epoch 92/100: 100%|████████████████| 203/203 [00:05<00:00, 37.10it/s, Loss=0.52]


Epoch 92/100, Loss: 0.519030998493063


Epoch 93/100: 100%|███████████████| 203/203 [00:05<00:00, 37.69it/s, Loss=0.539]


Epoch 93/100, Loss: 0.5177916670080476


Epoch 94/100: 100%|███████████████| 203/203 [00:05<00:00, 36.96it/s, Loss=0.512]


Epoch 94/100, Loss: 0.5165017342626168


Epoch 95/100: 100%|███████████████| 203/203 [00:05<00:00, 36.40it/s, Loss=0.519]


Epoch 95/100, Loss: 0.5152838782136664


Epoch 96/100: 100%|███████████████| 203/203 [00:05<00:00, 37.15it/s, Loss=0.534]


Epoch 96/100, Loss: 0.5140500396049669


Epoch 97/100: 100%|███████████████| 203/203 [00:05<00:00, 37.60it/s, Loss=0.543]


Epoch 97/100, Loss: 0.5128502359824815


Epoch 98/100: 100%|███████████████| 203/203 [00:05<00:00, 36.96it/s, Loss=0.513]


Epoch 98/100, Loss: 0.5116390903007808


Epoch 99/100: 100%|███████████████| 203/203 [00:05<00:00, 36.76it/s, Loss=0.509]


Epoch 99/100, Loss: 0.5104795158496631


Epoch 100/100: 100%|████████████████| 203/203 [00:05<00:00, 37.27it/s, Loss=0.5]


Epoch 100/100, Loss: 0.5093180353711979
End.


Evaluating: 100%|██████████████████████| 207724/207724 [10:31<00:00, 329.04it/s]

Final Hit Ratio@5: 0.0542



