In [1]:
import torch
import os
import time

import numpy as np
import pandas as pd
from SASRec.utils import *

In [2]:
from SASRec.model import SASRec, SASRecWithDiffusion

In [3]:
os.chdir('SASRec')

In [12]:
args = pd.Series(
    dict(
        dataset='ml-1m',
        train_dir='default',
        maxlen=200,
        dropout_rate=0.2,
        batch_size=128,
        device='cpu',
        hidden_units=50,
        lr=0.001,
        num_blocks=2,
        num_epochs=1000,
        num_heads=1,
        l2_emb=0.0,
        state_dict_path=None,
        n_recs=20,
    )
)

In [5]:
u2i_index, i2u_index = build_index(args.dataset)
dataset = data_partition(args.dataset)

In [6]:
[user_train, user_valid, user_test, usernum, itemnum] = dataset
num_batch = (len(user_train) - 1) // args.batch_size + 1

cc = 0.0
for u in user_train:
    cc += len(user_train[u])

print('average sequence length: %.2f' % (cc / len(user_train)))

average sequence length: 163.50


In [7]:
f = open(os.path.join(args.dataset + '_' + args.train_dir, 'log.txt'), 'w')
f.write('epoch (val_ndcg, val_hr) (test_ndcg, test_hr)\n')

46

In [8]:
sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)
model = SASRec(usernum, itemnum, args)

# Training Vanilla SASRec

In [9]:
for name, param in model.named_parameters():
    try:
        torch.nn.init.xavier_normal_(param.data)
    except:
        pass # just ignore those failed init layers

In [10]:
model.pos_emb.weight.data[0, :] = 0
model.item_emb.weight.data[0, :] = 0

model.train() # enable model training

epoch_start_idx = 1
if args.state_dict_path is not None:
    try:
        model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
        tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
        epoch_start_idx = int(tail[:tail.find('.')]) + 1
    except: # in case your pytorch version is not 1.6 etc., pls debug by pdb if load weights failed
        print('failed loading state_dicts, pls check file path: ', end="")
        print(args.state_dict_path)
        print('pdb enabled for your quick check, pls type exit() if you do not need it')
        import pdb; pdb.set_trace()

In [11]:
bce_criterion = torch.nn.BCEWithLogitsLoss() # torch.nn.BCELoss()
adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))

best_val_ndcg, best_val_hr = 0.0, 0.0
best_test_ndcg, best_test_hr = 0.0, 0.0
T = 0.0
t0 = time.time()

In [12]:
for epoch in range(epoch_start_idx, args.num_epochs + 1):
    for step in range(num_batch):
        u, seq, pos, neg = sampler.next_batch() # tuples to ndarray
        u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)
        
        pos_logits, neg_logits = model(u, seq, pos, neg)
        pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)

        adam_optimizer.zero_grad()
        indices = np.where(pos != 0)
        loss = bce_criterion(pos_logits[indices], pos_labels[indices])
        loss += bce_criterion(neg_logits[indices], neg_labels[indices])
        for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param)
        loss.backward()
        adam_optimizer.step()
        print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item())) # expected 0.4~0.6 after init few epochs

    if epoch % 5 == 0:
        model.eval()
        t1 = time.time() - t0
        T += t1
        print('Evaluating', end='')
        t_test = evaluate(model, dataset, args)
        t_valid = evaluate_valid(model, dataset, args)
        print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)'
                % (epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1]))

        if t_valid[0] > best_val_ndcg or t_valid[1] > best_val_hr or t_test[0] > best_test_ndcg or t_test[1] > best_test_hr:
            best_val_ndcg = max(t_valid[0], best_val_ndcg)
            best_val_hr = max(t_valid[1], best_val_hr)
            best_test_ndcg = max(t_test[0], best_test_ndcg)
            best_test_hr = max(t_test[1], best_test_hr)
            folder = args.dataset + '_' + args.train_dir
            fname = 'SASRec.epoch={}.lr={}.layer={}.head={}.hidden={}.maxlen={}.pth'
            fname = fname.format(epoch, args.lr, args.num_blocks, args.num_heads, args.hidden_units, args.maxlen)
            torch.save(model.state_dict(), os.path.join(folder, fname))

        f.write(str(epoch) + ' ' + str(t_valid) + ' ' + str(t_test) + '\n')
        f.flush()
        t0 = time.time()
        model.train()

    if epoch == args.num_epochs:
        folder = args.dataset + '_' + args.train_dir
        fname = 'SASRec.epoch={}.lr={}.layer={}.head={}.hidden={}.maxlen={}.pth'
        fname = fname.format(args.num_epochs, args.lr, args.num_blocks, args.num_heads, args.hidden_units, args.maxlen)
        torch.save(model.state_dict(), os.path.join(folder, fname))

f.close()
sampler.close()
print("Done")

loss in epoch 1 iteration 0: 1.3935099840164185
loss in epoch 1 iteration 1: 1.3839380741119385
loss in epoch 1 iteration 2: 1.3789432048797607
loss in epoch 1 iteration 3: 1.3697421550750732
loss in epoch 1 iteration 4: 1.3589098453521729
loss in epoch 1 iteration 5: 1.3480799198150635


KeyboardInterrupt: 

# Prediction

In [29]:
proposed_items = np.arange(itemnum)

In [30]:
model.predict(u, seq, proposed_items)

tensor([[  1.8008,  -6.6977,  -7.7460,  ...,  -4.7461,  -6.7682,  -8.1332],
        [  1.4729,   1.3699,  -5.8549,  ...,  -7.9728,   1.2932,   5.8851],
        [  1.6071,   0.1935,  -9.1936,  ...,  -7.9932,  -9.3019, -10.7897],
        ...,
        [  0.7931,   2.8766,  -1.7839,  ...,  -4.6640,  -5.1508,   1.6283],
        [  1.1645,  -4.0154,  -2.8742,  ...,  -3.4040,  -6.1271,  -8.6607],
        [  1.9209,   3.1261,  -4.4585,  ..., -14.5340,  -0.1543,   0.5881]],
       grad_fn=<SqueezeBackward1>)

# Diffusin SASRec

In [24]:
model = SASRecWithDiffusion(usernum, itemnum, args)

In [25]:
sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen + args.n_recs, n_workers=3)

In [26]:
adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))

In [28]:
for epoch in range(epoch_start_idx, args.num_epochs + 1):
    for step in range(num_batch):
        u, seq, pos, neg = sampler.next_batch() # tuples to ndarray
        u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)

        adam_optimizer.zero_grad()
        indices = np.where(pos != 0)
        loss = model.get_loss(seq)
        
        for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param)
        loss.backward()
        adam_optimizer.step()
        
        print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item())) # expected 0.4~0.6 after init few epochs

loss in epoch 1 iteration 0: 18.688432693481445
loss in epoch 1 iteration 1: 18.77422523498535
loss in epoch 1 iteration 2: 18.20121192932129
loss in epoch 1 iteration 3: 17.8505916595459
loss in epoch 1 iteration 4: 18.061403274536133
loss in epoch 1 iteration 5: 18.004785537719727
loss in epoch 1 iteration 6: 17.839412689208984
loss in epoch 1 iteration 7: 17.895626068115234
loss in epoch 1 iteration 8: 17.905954360961914
loss in epoch 1 iteration 9: 16.981290817260742
loss in epoch 1 iteration 10: 17.117555618286133
loss in epoch 1 iteration 11: 17.530784606933594
loss in epoch 1 iteration 12: 17.30600357055664
loss in epoch 1 iteration 13: 17.533531188964844
loss in epoch 1 iteration 14: 17.224214553833008
loss in epoch 1 iteration 15: 16.459470748901367
loss in epoch 1 iteration 16: 16.650836944580078
loss in epoch 1 iteration 17: 16.868009567260742
loss in epoch 1 iteration 18: 16.31648826599121
loss in epoch 1 iteration 19: 16.015810012817383
loss in epoch 1 iteration 20: 16.509

KeyboardInterrupt: 