In [None]:
import pickle
import pandas as pd
from recent_neighbor import RecentNeighbor
import os
import time
import torch
import pickle
from model import CLTSBR
from utils import *
from torch.utils.data import DataLoader, Dataset
from collections import Counter
# import torch.nn as nn   

In [None]:
train_data = pickle.load(open('data/diginetica/train_session_3.txt', 'rb'))
train_id = train_data[0]
train_session = train_data[1]
train_timestamp = train_data[2]
train_predict = train_data[3]


for i, s in enumerate(train_session):
    train_session[i] += [train_predict[i]]


test_data = pickle.load(open('data/diginetica/test_session_3.txt', 'rb'))
test_id = test_data[0]
test_session = test_data[1]
test_timestamp = test_data[2]
test_predict = test_data[3]

pick_neighbor = RecentNeighbor(session_id=train_id, session=train_session, session_timestamp=train_timestamp, sample_size=0, k=500,
             factor1=True, l1=1.25, factor2=True, l2=80 * 24 * 3600, factor3=True, l3=22.5)


In [None]:
class Args:
    dataset_name = "diginetica"
    train_dir = "default"
    batch_size = 2
    lr = 0.001
    maxlen = 50
    hidden_units = 50
    num_blocks = 2
    num_epochs = 2
    num_heads = 1
    dropout_rate = 0.5
    l2_emb = 0.0
    device = "cpu"
    inference_only = False
    state_dict_path = None
    num_layers=6
    
args = Args()

In [None]:
train_data = pickle.load(open('data/diginetica/train_session.txt', 'rb'))
user_train = {train_data[0][k]: train_data[1][k] for k in range(len(train_data[0]))}
usernum, itemnum = 95425, 37522

In [None]:
flat_list = [item for sublist in train_data[1] for item in sublist]
cnt = Counter()
for num in flat_list:
    cnt[num] += 1
      
most_common = cnt.most_common(7504)
head_items = set([num for num, cnt in most_common])
total_items = set(flat_list)
tail_items = total_items-head_items

In [None]:
class TrainDataset(Dataset):
    def __init__(self, user_train, usernum, itemnum, maxlen):
        self.user_train = user_train
        self.usernum = usernum
        self.itemnum = itemnum
        self.maxlen = maxlen

    def __len__(self):
        return len(self.user_train)
    
    def get_neighbour(self, user, seq):
        neighboring_sessions = pick_neighbor.predict(session_id=user, session_items=seq)
        return neighboring_sessions
        

    def __getitem__(self, user_idx):
        user = user_idx
        seq = np.zeros(self.maxlen, dtype=np.int64)
        pos = np.zeros(self.maxlen, dtype=np.int64)
        neg = np.zeros(self.maxlen, dtype=np.int64)
        nxt = user_train[user][-1]
        idx = self.maxlen - 1

        ts = set(user_train[user])
        for i in reversed(user_train[user][:-1]):
            seq[idx] = i
            pos[idx] = nxt
            if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
            nxt = i
            idx -= 1
            if idx == -1: break

        n_sess = self.get_neighbour(user, seq)
        n_sess = np.array(n_sess).squeeze()
        return (user, seq, pos, neg, n_sess)

train_dataset = TrainDataset(user_train, usernum, itemnum, args.maxlen)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
batch = next(iter(train_dataloader))
print(batch[0].shape)

In [None]:
model = CLTSBR(usernum, itemnum, args).to(args.device) 
model.train()

epoch_start_idx = 1
if args.state_dict_path is not None:
    try:
        model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
        tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
        epoch_start_idx = int(tail[:tail.find('.')]) + 1
    except: 
        print('failed loading state_dicts, pls check file path: ', end="")
        print(args.state_dict_path)
        print('pdb enabled for your quick check, pls type exit() if you do not need it')
        

if args.inference_only:
    model.eval()
    t_test = evaluate(model, args.dataset, args, head_items, tail_items)
    print('test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1]))


criterion = torch.nn.BCEWithLogitsLoss() # torch.nn.BCELoss()
adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))


In [None]:
T = 0.0
t0 = time.time()

for epoch in range(epoch_start_idx, args.num_epochs + 1):
    if args.inference_only: break 
    for batch_idx, (u, seq, pos, neg, n_sess) in enumerate(train_dataloader):
        u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)
        pos_logits, neg_logits = model(u, seq, pos, neg, n_sess.squeeze().type(torch.LongTensor))
        pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
        adam_optimizer.zero_grad()
        indices = np.where(pos != 0)
        loss = criterion(pos_logits[indices], pos_labels[indices])
        loss += criterion(neg_logits[indices], neg_labels[indices])
        for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param)
        loss.backward()
        adam_optimizer.step()
        print("loss in epoch {} iteration {}: {}".format(epoch, batch_idx, loss.item())) 
    if epoch % 10 == 0:
        model.eval()
        t1 = time.time() - t0
        T += t1
        print('Evaluating', end='')
        t_test = evaluate(model, args.dataset, args)
        model.train()

    if epoch == args.num_epochs:
        folder = args.dataset_name + '_' + args.train_dir
        fname = 'CLTSBR.epoch={}.lr={}.layer={}.head={}.hidden={}.maxlen={}.pth'
        fname = fname.format(args.num_epochs, args.lr, args.num_blocks, args.num_heads, args.hidden_units, args.maxlen)
        torch.save(model.state_dict(), os.path.join(folder, fname))

print("Done")
