In [1]:
class Args():
    cuda_visible_devices= 0
    data_path= 'Beauty.txt'
    max_length= 200
    full_negative_sampling= False
    num_negatives = None

    batch_size= 128
    test_batch_size= 256
    num_workers= 8
    validation_size= 10000
    model= "SASRec"

    maxlen= 200
    hidden_units= 64
    num_blocks= 2
    num_heads= 1
    dropout_rate= 0.1

    lr= 0.001
    predict_top_k= 10
    filter_seen= True

    max_epochs= 100
    patience= 10
    sampled_metrics= False
    top_k_metrics=10

    add_head = False

In [2]:
import time
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
import torch
from torch.utils.data import DataLoader
#from metrics import compute_metrics
from datasets import (CausalLMDataset, CausalLMPredictionDataset, MaskedLMDataset,
                          MaskedLMPredictionDataset, PaddingCollateFn)
from models import RNN, BERT4Rec, SASRec
from modules import SeqRec, SeqRecWithSampling
from postprocess import preds2recs
from preprocess import add_time_idx



In [3]:
config = Args()

In [4]:
data = pd.read_csv(config.data_path, sep=' ', header=None, names=['user_id', 'item_id'])
data = add_time_idx(data, sort=False)

# index 1 is used for masking value
if config.model == 'BERT4Rec':
    data.item_id += 1

print(data.shape)
data.head()

(394908, 4)


Unnamed: 0,user_id,item_id,time_idx,time_idx_reversed
0,1,12888,0,41
1,1,49583,1,40
2,1,1,2,39
3,1,4733,3,38
4,1,5761,4,37


In [5]:
train = data[data.time_idx_reversed >= 2]
validation = data[data.time_idx_reversed == 1]
validation_full = data[data.time_idx_reversed >= 1]
test = data[data.time_idx_reversed == 0]

In [6]:
validation_size = config.validation_size
validation_users = validation_full.user_id.unique()
if validation_size and (validation_size < len(validation_users)):
    validation_users = np.random.choice(validation_users, size=validation_size, replace=False)

if config.model in ['SASRec', 'RNN']:
    train_dataset = CausalLMDataset(train, config.max_length,config.full_negative_sampling)
    eval_dataset = CausalLMPredictionDataset(
        validation_full[validation_full.user_id.isin(validation_users)],
        max_length=config.max_length, validation_mode=True)
elif config.model == 'BERT4Rec':
    train_dataset = MaskedLMDataset(train, config.max_length,config.full_negative_sampling)
    eval_dataset = MaskedLMPredictionDataset(
        validation_full[validation_full.user_id.isin(validation_users)],
        max_length=config.max_length, validation_mode=True)

train_loader = DataLoader(
    train_dataset, shuffle=True,
    collate_fn=PaddingCollateFn(),
    batch_size=config.batch_size)
eval_loader = DataLoader(
    eval_dataset, shuffle=False,
    collate_fn=PaddingCollateFn(),
    batch_size=config.test_batch_size)

In [7]:
item_count = data.item_id.max()

if hasattr(config, 'num_negatives') and config.num_negatives:
    config.add_head = False
else:
    config.add_head = True

if config.model == 'SASRec':
    model = SASRec(item_num=item_count, maxlen=config.max_length, add_head=config.add_head)
if config.model == 'BERT4Rec':
    model = BERT4Rec(vocab_size=item_count + 1, maxlen=config.max_length, add_head=add_head,
                     bert_config=config.model_params)
elif config.model == 'RNN':
    model = RNN(vocab_size=item_count + 1, add_head=add_head,
                rnn_config=config.model_params)

In [8]:
if (config.num_negatives != None) :
    seqrec_module = SeqRecWithSampling(model, config.lr,config.predict_top_k,config.filter_seen)
else:
    seqrec_module = SeqRec(model, config.lr,config.predict_top_k,config.filter_seen)
    


In [9]:
for step, batch in enumerate(train_loader):
    continue

In [None]:
train_list = []
val_list = []
model_path ='recommendation.pth'
best_ndcg = -float('inf')

for e in range(config.max_epochs):
    e = e + 1
    print(
        f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'ndcg':^12} | {'hit_rate':^10} | {'mrr':^9} | {'Elapsed':^9}")
    print("-" * 86)
    t0_epoch, t0_batch = time.time(), time.time()
    total_loss = 0
    batch_counts = 0
    for step, batch in enumerate(train_loader):
        batch_counts += 1
        
        #print(batch['input_ids'][0][0].item())       
        #print(type(batch['input_ids'][0][0].item()))
        #if(batch['input_ids'][0][0].item() == 0.0):
        #    print(batch['input_ids'][0])
        loss = seqrec_module.training_step(batch,step)
        loss.backward()
        total_loss += loss
        print(
                f"{e:^7} | {step:^7} | {total_loss / batch_counts:^12.6f} | {'-':^14.6} | {'-':^10} | {'-':^9} | {'-':^9.2}")

    train_list.append(total_loss/len(train_loader))
    if eval_loader:
        # After the completion of each training epoch, measure the model's performance
        # on our validation set.        
        ndcg,hit_rate,mrr = 0, 0, 0
        matrix = []
        for step, batch in enumerate(eval_loader):
            x,y,z = seqrec_module.validation_step(batch,step)
            print(x,y,z)
            ndcg += x
            hit_rate += y
            mrr += z

        time_elapsed = time.time() - t0_epoch

        print("-" * 86)
        print(
            f"{'end':^7} | {'-':^7} | {total_loss/len(train_loader):^12.6f} | {ndcg:^14.6} | {hit_rate:^10.6f} | {mrr:^9.2f} | {time_elapsed:^9.2f}")
        print("-" * 86)
        print("\n")
        
        matrix.append(ndcg)
        matrix.append(hit_rate)
        matrix.append(mrr)
            
        val_list.append(matrix)

    if (ndcg > 0):
        best_ndcg = ndcg
        torch.save(model.state_dict(), model_path)

 Epoch  |  Batch  |  Train Loss  |     ndcg     |  hit_rate  |    mrr    |  Elapsed 
--------------------------------------------------------------------------------------
   1    |    0    |  10.968785   |       -        |     -      |     -     |     -    
   1    |    1    |  10.964926   |       -        |     -      |     -     |     -    
   1    |    2    |  10.967954   |       -        |     -      |     -     |     -    
   1    |    3    |  10.966573   |       -        |     -      |     -     |     -    
   1    |    4    |  10.965791   |       -        |     -      |     -     |     -    
   1    |    5    |  10.966888   |       -        |     -      |     -     |     -    
   1    |    6    |  10.968060   |       -        |     -      |     -     |     -    
   1    |    7    |  10.966969   |       -        |     -      |     -     |     -    
   1    |    8    |  10.967381   |       -        |     -      |     -     |     -    
   1    |    9    |  10.967466   |       -   

/Users/mac/Library/Python/3.9/lib/python/site-packages/pytorch_lightning/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
--------------------------------------------------------------------------------------
  end   |    -    |  10.968657   |      0.0       |  0.000000  |   0.00    |  2626.74 
--------------------------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |     ndcg     |  hit_rate  |    mrr    |  Elapsed 
--------------------------------------------------------------------------------------
   2    |    0    |  10.976839   |       -        |     -      |     -     |     -   

In [None]:
if eval_loader:
        # After the completion of each training epoch, measure the model's performance
        # on our validation set.        
        val_loss = 0
        for step, batch in enumerate(eval_loader):
            print(batch)
            val_loss += seqrec_module.validation_step(batch,step)
            
        val_list.append(val_loss)