In [1]:
#!git clone https://github.com/AlterVX22/Filtering_via_RetNet.git

In [2]:
import sys
sys.path.append("Filtering_via_RetNet")

In [3]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim


from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder

In [4]:
import pickle

# location of datasets: https://www.kaggle.com/datasets/vxmindset22/okko-data
with open('datasets.pkl', 'rb') as f:
    datasets = pickle.load(f)

In [5]:
datasets

{'train': (         userId  movieId
  1926171  185543     2016
  1926172   50933     3284
  1926173  230764       68
  1926174  186260     1473
  1926175   94969      493
  ...         ...      ...
  9630848  479013     3766
  9630849   29856      876
  9630850  277832     4771
  9630851   77545      491
  9630852  260632     6147
  
  [7704682 rows x 2 columns],
  1926171    0.189311
  1926172    0.390769
  1926173    0.010635
  1926174    0.102065
  1926175    0.200385
               ...   
  9630848    0.089422
  9630849    0.200385
  9630850    0.176240
  9630851    0.190627
  9630852    0.200385
  Name: rating_scaled, Length: 7704682, dtype: float64),
 'test': (         userId  movieId
  0             0        0
  1             1        1
  2             2        2
  3             3        3
  4             4        4
  ...         ...      ...
  1926166  236143      587
  1926167  120603      130
  1926168   56511      245
  1926169  149668      493
  1926170  215417      302
  


In [6]:
X_train = datasets['train'][0]
X_test = datasets['test'][0]

In [7]:
combined_users = pd.concat([X_train["userId"], X_test["userId"]])
user_count = combined_users.nunique()

combined_movie = pd.concat([X_train["movieId"], X_test["movieId"]])
movie_count = combined_movie.nunique()

In [8]:
retnet_config = RetNetConfig(vocab_size = 200,
                             decoder_layers=8,
                             decoder_embed_dim=200,
                             decoder_value_embed_dim=200,
                             decoder_retention_heads=4,
                             decoder_ffn_embed_dim=200,
                             chunkwise_recurrent = False
                                 )

retnet_model = RetNetDecoder(retnet_config)
batch_size = 20000

In [9]:
class RMSELoss(nn.Module):
    def __init__(self, reduction='sum'):
        super(RMSELoss, self).__init__()
        self.mse = nn.MSELoss(reduction=reduction)
        
    def forward(self, y_pred, y_true):
        loss = torch.sqrt(self.mse(y_pred, y_true))
        return loss

In [10]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [11]:
from filteringRetNet import DatasetBatchIterator
from filteringRetNet import NeuralColabFilteringRetNet

In [12]:
# Training loop control parameters
max_epochs = 1
early_stop_epoch_threshold = 3
no_loss_reduction_epoch_counter = 0
min_loss = np.inf
min_loss_model_weights = None
history = []

ncf_retnet = NeuralColabFilteringRetNet(user_count, 
                                        movie_count, 
                                        retnet_model,
                                        hidden_size = retnet_config.decoder_embed_dim,
                                        device = device).to(device)



loss_criterion = RMSELoss(reduction='sum').to(device)
#loss_criterion_2 = nn.L1Loss(reduction='sum').to(device)
optimizer = optim.Adam(ncf_retnet.parameters(), lr=1e-3, weight_decay=1e-4)

In [13]:
import math
import time
from tqdm import tqdm

In [14]:
training_start_time = time.perf_counter()
for epoch in range(max_epochs):
    stats = {'epoch': epoch + 1, 'total': max_epochs}
    epoch_start_time = time.perf_counter()

    # Every epoch runs training on train set, followed by eval on test set
    for phase in ('train', 'test'):
        is_training = phase == 'train'
        ncf_retnet.train(is_training)
        running_loss = 0.0
        #running_loss_2 = 0.0
        n_batches = 0
        total_batches = len(datasets[phase][0]) // batch_size
        # Iterate on train/test datasets in batches
        for x_batch, y_batch in  tqdm(DatasetBatchIterator(datasets[phase][0], datasets[phase][1], batch_size=batch_size, shuffle=is_training), desc=f'{phase} phase', total=total_batches):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            
            # We zero out the loss gradient, since PyTorch by default accumulates gradients  
            optimizer.zero_grad()

            # We need to compute gradients only during training
            with torch.set_grad_enabled(is_training):
                
                outputs = ncf_retnet(x_batch[:, 0], x_batch[:, 1], )
                loss = loss_criterion(outputs, y_batch)
                #loss_2 = loss_criterion_2(outputs, y_batch)
                if is_training:
                    loss.backward()
                    optimizer.step()
            running_loss += loss.item()
            #running_loss_2 += loss_2.item()
        
        # Compute overall epoch loss and update history tracker
        epoch_loss = running_loss / len(datasets[phase][0])
        stats[phase] = epoch_loss
        #epoch_loss_2 = running_loss_2 / len(datasets[phase][0])
                
        history.append(stats)
        

        # Handle early stopping
        if phase == 'test':
            stats['time'] = time.perf_counter() - epoch_start_time
            print('Epoch [{epoch:03d}/{total:03d}][Time:{time:.2f} sec] Train Loss: {train:.4f} / Validation Loss: {test:.4f}'.format(**stats))
            if epoch_loss < min_loss:
                min_loss = epoch_loss
                min_loss_model_weights = copy.deepcopy(ncf_retnet.state_dict())
                no_loss_reduction_epoch_counter = 0
                min_epoch_number = epoch + 1
            else:
                no_loss_reduction_epoch_counter += 1
    if no_loss_reduction_epoch_counter >= early_stop_epoch_threshold:
        print(f'Early stopping applied. Minimal epoch: {min_epoch_number}')
        break

print(f'Training completion duration: {(time.perf_counter() - training_start_time):.2f} sec. Validation Loss: {min_loss}')

train phase:   1%|▎                                                                  | 2/385 [01:13<3:55:17, 36.86s/it]


KeyboardInterrupt: 