In [1]:
import paddle 
from paddle import nn 
import numpy as np 
from utils import *
import os

In [2]:

class SASRec(nn.Layer):
    def __init__(self, user_num, item_num, batch_size, lr, maxlen, hidden_units, num_blocks, num_epochs, num_heads, dropout_rate, l2_emb):
        super(SASRec, self).__init__()

        
        self.user_num = user_num
        self.item_num = item_num
        
        self.dev = paddle.get_device()

        self.item_emb = nn.Embedding(self.item_num+1, hidden_units, padding_idx=1)
        self.pos_emb = nn.Embedding(maxlen, hidden_units) 
        self.emb_dropout = nn.Dropout(p=dropout_rate)
        
        transencoderlayer = nn.TransformerEncoderLayer(d_model= hidden_units, nhead= num_heads, dropout= dropout_rate, dim_feedforward= hidden_units, normalize_before=True)
        self.encoder = nn.TransformerEncoder(encoder_layer = transencoderlayer, num_layers= num_blocks )


    def log2feats(self, log_seqs):

        #pos_emb
        seqs = self.item_emb(paddle.to_tensor(log_seqs,dtype='int64'))
        #seqs *= log_seqs.shape[0] ** 0.5
        positions = np.tile(np.array(range(seqs.shape[1])), [seqs.shape[0], 1])
        positions = self.pos_emb(paddle.to_tensor(positions))
        seqs_embed = self.emb_dropout(seqs + positions)
        
        attention_mask = paddle.triu(paddle.ones((maxlen, maxlen)))==0
        log_feats = self.encoder(seqs_embed,attention_mask)

        # log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C)

        return log_feats

    def forward(self, user_ids, log_seqs, pos_seqs, neg_seqs): # for training        
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet

        pos_embs = self.item_emb(paddle.to_tensor(pos_seqs))
        neg_embs = self.item_emb(paddle.to_tensor(neg_seqs))

        pos_logits = (log_feats * pos_embs).sum(axis=-1)
        neg_logits = (log_feats * neg_embs).sum(axis=-1)

        return pos_logits, neg_logits # pos_pred, neg_pred

    def predict(self, user_ids, log_seqs, item_indices): # for inference
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet

        final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste

        item_embs = self.item_emb(paddle.to_tensor(np.array(item_indices).astype(np.int64))) # (U, I, C)

        logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1)

        return logits # preds # (U, I)

In [3]:
dataset = "ml-1m"
batch_size = 128
lr = 0.001
maxlen = 200
hidden_units = 50
num_blocks = 2
num_epochs = 600
num_heads = 1
dropout_rate = 0.2
l2_emb = 0


In [4]:
user_train, user_valid, user_test, usernum, itemnum = data_partition(dataset)
num_batch = len(user_train) // batch_size

In [5]:
cc = 0.0
for u in user_train:
    cc += len(user_train[u])
print('average sequence length: %.2f' % (cc / len(user_train)))

average sequence length: 163.50


In [6]:
f = open(dataset + '_' + 'log.txt', 'w')

In [8]:
sampler = WarpSampler(user_train, usernum, itemnum, batch_size=batch_size, maxlen=maxlen, n_workers=3)
model = SASRec(usernum, itemnum, batch_size, lr, maxlen, hidden_units, num_blocks, num_epochs, num_heads, dropout_rate, l2_emb)

bce_criterion = nn.BCEWithLogitsLoss() 
adam_optimizer = paddle.optimizer.Adam(parameters= model.parameters(), beta1= 0.9, beta2= 0.98)

In [9]:
import time
T = 0.0
t0 = time.time()

for epoch in range(num_epochs+1 ):
    loss_e = 0
    # for i, u, seq, pos, neg in  tqdm(range(num_batch), total=num_batch, ncols=70, leave=False, unit='b'):
    for i in range(num_batch):
        u, seq, pos, neg = sampler.next_batch() # tuples to ndarray
        
        u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)
        pos_logits, neg_logits = model(u, seq, pos, neg)
        
        pos_labels, neg_labels = paddle.ones(pos_logits.shape), paddle.zeros(neg_logits.shape)
        # print("\neye ball check raw_logits:"); print(pos_logits); print(neg_logits) # check pos_logits > 0, neg_logits < 0
        adam_optimizer.clear_grad()
        indices = paddle.to_tensor(np.where(pos != 0,1,0)).astype(paddle.float32)

        pos_logits = paddle.multiply(pos_logits, indices)
        pos_labels = paddle.multiply(pos_labels, indices)
        neg_logits = paddle.multiply(neg_logits, indices)
        neg_labels = paddle.multiply(neg_labels, indices)

        loss = bce_criterion(pos_logits, pos_labels)
        loss += bce_criterion(neg_logits, neg_labels)

        for param in model.item_emb.parameters(): loss += l2_emb * paddle.norm(param)
        loss.backward()
        adam_optimizer.step()
        loss_e += loss.item()
    print("loss in epoch {} : {}".format(epoch, loss_e/num_batch)) # expected 0.4~0.6 after init few epochs

    if epoch % 40 == 0 and epoch!=0:
        model.eval()
        t1 = time.time() - t0
        T += t1
        print('Evaluating', end='')
        #t_test = evaluate(model, user_train, user_valid, user_test, usernum, itemnum, maxlen)
        t_valid = evaluate_valid(model, user_train, user_valid, user_test, usernum, itemnum, maxlen)
        print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f)'
                % (epoch, T, t_valid[0], t_valid[1]))

        f.write(str(t_valid) + '\n')
        f.flush()
        t0 = time.time()
        model.train()

    if epoch == num_epochs:
        folder = dataset
        fname = 'SASRec.epoch={}.lr={}.layer={}.head={}.hidden={}.maxlen={}.pdparams'
        fname = fname.format(num_epochs, lr, num_blocks, num_heads, hidden_units, maxlen)
        paddle.save(model.state_dict(), os.path.join(folder, fname))

f.close()
sampler.close()

loss in epoch 0 : 1.2351962327957153
loss in epoch 1 : 1.1841678238929587
loss in epoch 2 : 1.1688527908731015
loss in epoch 3 : 1.1472041124993182
loss in epoch 4 : 1.113295499314653
loss in epoch 5 : 1.0984528901729178
loss in epoch 6 : 1.0824620495451258
loss in epoch 7 : 1.0680202220348602
loss in epoch 8 : 1.054815723540935
loss in epoch 9 : 1.0438662412318778
loss in epoch 10 : 1.0376609234099692
loss in epoch 11 : 1.0305345197941393
loss in epoch 12 : 1.024889738001722
loss in epoch 13 : 1.0161835728807653
loss in epoch 14 : 1.0114926454868722
loss in epoch 15 : 1.0071433769895675
loss in epoch 16 : 1.0049965990350602
loss in epoch 17 : 0.9983519543992713
loss in epoch 18 : 0.9984757012509285
loss in epoch 19 : 1.0021409011901694
loss in epoch 20 : 0.9881873663435591
loss in epoch 21 : 0.9828925335660894
loss in epoch 22 : 0.9761358575618013
loss in epoch 23 : 0.9751568474668137
loss in epoch 24 : 0.9750050940412156
loss in epoch 25 : 0.9703084141650098
loss in epoch 26 : 0.9666

In [16]:
# load model and eval

state_dict = paddle.load('ml-1m/SASRec.epoch=600.lr=0.001.layer=2.head=1.hidden=50.maxlen=200.pdparams')
model_eval = SASRec(usernum, itemnum, batch_size, lr, maxlen, hidden_units, num_blocks, num_epochs, num_heads, dropout_rate, l2_emb)
model_eval.set_state_dict(state_dict)

valid = evaluate_valid(model_eval, user_train, user_valid, user_test, usernum, itemnum, maxlen)
test = evaluate(model_eval, user_train, user_valid, user_test, usernum, itemnum, maxlen)
print('Valid: NDCG@10: %.4f, HR@10: %.4f /n Test: NDCG@10: %.4f, HR@10: %.4f ' %(valid[0], valid[1], test[0], test[1]))

........................................................................................................................Valid: NDCG@10: 0.5912, HR@10: 0.8315 /n Test: NDCG@10: 0.5713, HR@10: 0.8103 
