In [1]:
'''
Author: Felix
Date: 2023-04-11 14:54:25
LastEditors: Felix
LastEditTime: 2023-04-11 19:39:55
Description: Please enter description
'''
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader
import random
from model import BERTModel,Trainer
import torch.nn as nn
import torch.optim as optim

In [2]:
from datasets import ML20MDataset

In [3]:
mldataset = ML20MDataset("../bert4rec/data/ml-20m/",4,0,5)

In [4]:
data_csv_path = "../bert4rec/data/ml-20m/ratings.csv"
movies_path = "../bert4rec/data/ml-20m/movies.csv"

In [8]:
data = pd.read_csv(data_csv_path)

In [12]:
data[data['userId']==34]

Unnamed: 0,userId,movieId,rating,timestamp
3903,34,1,5.0,846509445
3904,34,2,3.0,846509384
3905,34,7,3.0,846510487
3906,34,10,5.0,839249881
3907,34,15,3.0,846510556
...,...,...,...,...
3991,34,733,5.0,853510848
3992,34,736,4.0,844965523
3993,34,761,3.0,852036542
3994,34,780,4.0,847124405


In [5]:
data.sort_values(by="timestamp",inplace=True)

In [6]:
# build map and inversed map from movieId to tokenId
movies = sorted(data["movieId"].unique().tolist())

In [8]:
# 0 : PAD
# 1 : MASK
movie_to_id = {k:i+2 for i,k in enumerate(movies)}
id_to_movie = {movie_to_id[k]:k for k in movie_to_id}

In [9]:
group_by_data = data.groupby(by='userId').agg(list)["movieId"]

In [10]:
groups_data = group_by_data.to_list()

In [11]:
train_set,test_set,validate_set = random_split(groups_data,[0.8,0.1,0.1])

In [12]:
class BERTDataset(Dataset):
    
    def __init__(self, data, mapping, padding_id, mask_id, max_len = 128, train = False  ):
        """ Dataset class object for ml-20m dataset

        Args:
            data (list): data
            mapping (dict): the dictionary that map moive id to token id
            padding_id(int): the token id of [PAD]
            mask_id (int): the token id of [MASK]
            max_len (int, optional): the maximum length of a sequence. Defaults to 128.
            train (bool, optional): if this dataset is a training set. Defaults to False.
        """
        self.data = data
        self.mapping = mapping
        self.padding_id = padding_id
        self.masked_id = mask_id
        self.max_len = max_len
        self.train = train
        self.num_items = len(mapping)
    
    def __getitem__(self, index):
        seq = self.data[index]
        # depricate parts over max_len
        if len(seq) > self.max_len:
            seq = seq[:self.max_len]
        # tokenize the sequence
        seq = [self.mapping[x] for x in seq]
        mask = [0 for _ in range(len(seq))]
        # if it is training set, mask it
        if self.train:
            seq, mask = self.random_mask(seq)
        # padding 
        padding_len = self.max_len - len(seq)
        seq = seq + [self.padding_id] * padding_len
        mask = mask + [self.padding_id] * padding_len
        return torch.LongTensor(seq), torch.LongTensor(mask)

    def __len__(self):
        return len(self.data)

    def random_mask(self, sequence):
        """randomly mask sequence use following strategy:
           85% chance not to mask
           15% chance to mask
           when masking, 80% chance to use [MASK] to replace the token,
           10% chance to replace it with an random token and 10% chance to make no change

        Args:
            sequence(iteratble) sequence to be masked

        Return:
            sequence(list) sequence after masking
            mask(list) mask matrix
        """
        tokens = []
        mask = []
        for s in sequence:
            prob = random.random()
            # not mask
            if prob < 0.85:
                tokens.append(s)
                mask.append(0)
            # mask
            else:
                prob = random.random()
                if prob < 0.8:
                    tokens.append(self.masked_id)
                elif prob < 0.9:
                    tokens.append(random.randint(2,self.num_items+1))
                else:
                    tokens.append(s)
                mask.append(s) 
        return tokens, mask

In [13]:
device = 'mps'
lr = 1e-3
epochs = 100
decay_steps=25
gamma = 0.01
weight_decay = 0.01
model = BERTModel(128,len(movies),2,4,256,0.1)
optimizer = optim.Adam(model.parameters(),lr = lr, weight_decay= weight_decay)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=decay_steps, gamma=gamma)
train_dataset = BERTDataset(train_set, movie_to_id, 0, 1, 128, True)
val_dataset = BERTDataset(train_set, movie_to_id, 0, 1, 128, False)
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset,batch_size=64,shuffle=False, drop_last=True,pin_memory=True)

In [7]:
model = BERTModel(128,len(movies),2,4,256,0.1)

In [14]:
bert_trainer = Trainer(model,train_loader,val_loader,'./checkpoint/',device,optimizer,loss_fn,lr_scheduler,epochs)

In [15]:
bert_trainer.train()

Epoch 1, loss 7.680707 : 100%|██████████| 1731/1731 [10:10<00:00,  2.84it/s]
Epoch 2, loss 7.685313 : 100%|██████████| 1731/1731 [10:04<00:00,  2.86it/s]
Epoch 3, loss 7.680290 :   3%|▎         | 50/1731 [00:17<09:43,  2.88it/s]


KeyboardInterrupt: 

SyntaxError: invalid syntax (1164068198.py, line 1)