### Change the data_name and model_name at tmpargs class, then Run All. The result will save at final_result/*.csv

In [1]:
import torch
import os
import numpy as np

class tmpargs():
    def __init__(self):
        super().__init__()
        self.data_dir = 'data/'
        self.output_dir = 'final_result/'
        self.data_name = 'LastFM'   # Change data_name Here.   
        self.do_eval = None
        self.ckp = 50
        self.model_name = 'SASRec'  # Change model_name Here.  
        self.hidden_size = 64
        self.num_hidden_layers = 2
        self.num_attention_heads = 2
        self.hidden_act = 'gelu'
        self.attention_probs_dropout_prob = 0.5
        self.hidden_dropout_prob = 0.5
        self.initializer_range = 0.02
        self.max_seq_length = 50
        self.lr = 0.001
        self.batch_size = 256
        self.epochs = 50
        self.no_cuda = None
        self.log_freq = 1
        self.seed = 42
        self.weight_decay = 0.0
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.999
        self.gpu_id = 0
        self.cuda_condition = torch.cuda.is_available() and not self.no_cuda
        self.data_file = self.data_dir + self.data_name + '.txt'
        self.sample_file = self.data_dir + self.data_name + '_sample.txt'
        self.item2attribute = None
        self.item_size = 104546+2
        self.mask_id = 104546+1
        self.attribute_size = 1 #没用到预训练模型
        args_str = f'{self.model_name}-{self.data_name}-{self.ckp}'
        self.log_file = os.path.join(self.output_dir, args_str + '.txt')
        checkpoint = args_str + '.pt'
        self.checkpoint_path = os.path.join(self.output_dir, checkpoint)
        self.isfull = 0
        self.sample_num =99
        self.loss_type = None
        self.writer = None
        self.RQ3 = 1
args = tmpargs()

from utils import EarlyStopping, get_user_seqs
user_seq, max_item, valid_rating_matrix, test_rating_matrix = \
        get_user_seqs(args.data_file)
args.train_matrix = valid_rating_matrix
args.item_size = max_item + 2
args.mask_id = max_item + 1

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load user seq and sample seq
user_seq,sample_seq = [], []
with open(args.data_file,'r') as f:
    for l in f:
        user, items = l.strip().split(' ', 1)
        items = items.split(' ')
        items = [int(item) for item in items]
        user_seq.append(items)

with open(args.sample_file,'r') as f:
    for l in f:
        user, items = l.strip().split(' ', 1)
        items = items.split(' ')
        items = [int(item) for item in items]
        sample_seq.append(items)


In [3]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from datasets import SASRecDataset
from trainers import FinetuneTrainer
from models import S3RecModel,GRU4Rec
train_dataset = SASRecDataset(args, user_seq, data_type='train')
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size)

eval_dataset = SASRecDataset(args, user_seq, data_type='valid')
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.batch_size)

test_dataset = SASRecDataset(args, user_seq, test_neg_items=None, data_type='test')
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.batch_size)

if args.model_name == "GRU4Rec":
    model = GRU4Rec(args=args)
else:
    model = S3RecModel(args=args)

trainer = FinetuneTrainer(model, train_dataloader, eval_dataloader,
                            test_dataloader, args)

Total Parameters: 315456


In [4]:
import tqdm
from utils import recall_at_k, ndcg_k, get_metric, neg_sample

import math

def ndcg_at_k(rank, topk = 10):
    metric = 0
    tmp_rank = rank[rank<topk]
    dcg_k = 0
    for i in tmp_rank:
        dcg_k += 1 / math.log(i+2,2)
    return dcg_k / len(rank)

def get_metric_at(time):
    train = False
    str_code = "train" if train else "test"
    epoch = 0
    dataloader = trainer.test_dataloader

    rank = []
    for i, batch in enumerate(dataloader):
        # 0. batch_data will be sent into the device(GPU or cpu)
        batch = tuple(t.to(trainer.device) for t in batch)
        user_ids, input_ids, target_pos, target_neg, answers = batch
        answers = target_pos[:,time]
        input_ids = input_ids[answers!=0,:]
        answers = answers[answers!=0]
        
        recommend_output = trainer.model.finetune(input_ids)
        recommend_output = recommend_output[:, time, :]
        rating_pred = trainer.predict_full(recommend_output)
        rating_pred = rating_pred.cpu().data.numpy().copy()

        tag = rating_pred[ np.arange(answers.shape[0]), answers.cpu().numpy() ]
        
        rating_pred.sort(axis=-1)
        rating_pred = rating_pred[:,-1::-1]
        
        rank += list(np.where(rating_pred == tag.reshape(-1,1))[1])
    rank = np.array(rank)
    
    metric = ndcg_at_k(rank)

    print( sum(rank<100), round(sum(rank<100)/rank.shape[0],4), rank.shape[0] ,metric)
    return sum(rank<100), round(sum(rank<100)/rank.shape[0],4), rank.shape[0] ,metric

In [5]:
file_name = 'final_result/GRU4Rec/' + args.data_name + '/GRU4Rec_CE-' + args.data_name + '-0.pt'
model.load_state_dict(torch.load(file_name))

res = []
for time in range(50):
    rank, rate ,size, metric = get_metric_at(time)
    res.append([rank, rate ,size, metric])

17 0.0625 272 0.0011067279252352249
20 0.0733 273 0.008350651881650253
20 0.0719 278 0.006123444906547735
29 0.1025 283 0.01276153248449223
18 0.0621 290 0.006664139706151445
26 0.0884 294 0.006963168131030963
19 0.0644 295 0.0029198410716840205
29 0.0963 301 0.0054758699711900385
22 0.0717 307 0.00997747723836737
25 0.0781 320 0.009379622501668539
24 0.0734 327 0.011446666649365637
30 0.0893 336 0.008255076233538938
30 0.0872 344 0.014327110970376691
33 0.0943 350 0.006791631966944512
27 0.0763 354 0.01201505060232395
49 0.1357 361 0.009380393782039729
38 0.1027 370 0.010172767555851678
33 0.0878 376 0.002475203611897322
34 0.0876 388 0.009142194408564699
36 0.0909 396 0.006934483684640193
43 0.1059 406 0.01346723315844987
42 0.1014 414 0.00920027023102921
54 0.1286 420 0.012757585190211005
43 0.1002 429 0.006154267413653194
42 0.0966 435 0.0022988505747126436
31 0.0697 445 0.008123878169766424
45 0.0989 455 0.003287288890426595
37 0.0786 471 0.005382603526915095
37 0.076 487 0.005101

In [6]:
file_name = 'final_result/GRU4Rec/' + args.data_name + '/GRU4Rec_PointwiseCE-' + args.data_name + '-0.pt'
model.load_state_dict(torch.load(file_name))

res1 = []
for time in range(50):
    rank, rate ,size, metric = get_metric_at(time)
    res1.append([rank, rate ,size, metric])

36 0.1324 272 0.010734845541452495
28 0.1026 273 0.003892932017534693
39 0.1403 278 0.008398297259444819
44 0.1555 283 0.01615143914675128
39 0.1349 289 0.007640697288002788
48 0.1633 294 0.013393799798283484
39 0.1322 295 0.002674796021416569
43 0.1433 300 0.005593670613648534
46 0.1498 307 0.007514160870110381
47 0.1469 320 0.014634150200977325
56 0.1718 326 0.022788833285378843
61 0.1815 336 0.015877403913536797
57 0.1662 343 0.0182914947541215
51 0.1461 349 0.009486287032325923
65 0.1836 354 0.02311013215138795
66 0.1828 361 0.017613903023879853
75 0.2027 370 0.021885780351259744
60 0.1596 376 0.008053015657105977
67 0.1727 388 0.012984817486300688
69 0.1742 396 0.004809752042406997
71 0.1753 405 0.0209110453691573
81 0.1957 414 0.016965281142556847
69 0.1643 420 0.01572858115409202
69 0.1612 428 0.009571991992574438
74 0.1701 435 0.014333798464940522
71 0.1599 444 0.008418907977762397
65 0.1429 455 0.005827658331895501
72 0.1529 471 0.007754025243743767
77 0.1584 486 0.01223571338

In [7]:
file_name = 'final_result/GRU4Rec/' + args.data_name + '/GRU4Rec_CE_ALL-' + args.data_name + '-0.pt'
model.load_state_dict(torch.load(file_name))

res2 = []
for time in range(50):
    rank, rate ,size, metric = get_metric_at(time)
    res2.append([rank, rate ,size, metric])

37 0.136 272 0.007274284363085453
30 0.1099 273 0.004830815225164762
36 0.1295 278 0.0037616824012567545
41 0.1449 283 0.016526799406714518
35 0.1211 289 0.006165371806964448
45 0.1531 294 0.008186053933886894
33 0.1119 295 0.0023807379119331197
45 0.15 300 0.011414941391229151
53 0.1726 307 0.006980752222441165
44 0.1375 320 0.016327923445830287
54 0.1656 326 0.019460282074565755
59 0.1756 336 0.01894072692908105
56 0.1633 343 0.016651283189204467
51 0.1461 349 0.014988659299522018
63 0.178 354 0.02234622908473074
74 0.205 361 0.019953625884519235
71 0.1914 371 0.022221686645063108
63 0.1676 376 0.007577994740441646
72 0.1856 388 0.01487625887097236
80 0.202 396 0.006921159100035949
77 0.1901 405 0.020252424168589297
76 0.1836 414 0.013074761790904322
77 0.1833 420 0.017046689177051096
70 0.1636 428 0.00850998049002817
76 0.1747 435 0.010412711908551829
74 0.1667 444 0.008384226120167914
70 0.1538 455 0.008829898997356261
72 0.1529 471 0.007919905610274076
85 0.1749 486 0.013208732522

In [8]:
import pandas as pd

col = ['count','rank','num', 'metric']
CE_result = pd.DataFrame(res,columns=col)
CE_result.to_csv(f"final_result/{args.data_name}_CE_result.csv")

BCEFT_result = pd.DataFrame(res1,columns=col)
BCEFT_result.to_csv(f"final_result/{args.data_name}_BCE_result.csv")

CEFT_result = pd.DataFrame(res2,columns=col)
CEFT_result.to_csv(f"final_result/{args.data_name}_CCE_result.csv")