In [1]:
import torch
import numpy as np
import torch.nn as nn

In [2]:
class Args:
    dropout_rate = 0.3
    hidden_units = 768
    num_heads = 8
    device = "cpu"
    maxlen = 512
    num_blocks = 8

In [3]:
class PointWiseFeedForward(torch.nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(PointWiseFeedForward, self).__init__()

        self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)
        outputs += inputs
        return outputs

In [4]:
class SASRec(torch.nn.Module):
    def __init__(self, user_num, item_num, args):
        super(SASRec, self).__init__()

        self.user_num = user_num
        self.item_num = item_num
        self.dev = args.device

        # TODO: loss += args.l2_emb for regularizing embedding vectors during training
        # https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch
        self.item_emb = torch.nn.Embedding(self.item_num+1, args.hidden_units, padding_idx=0)
        self.pos_emb = torch.nn.Embedding(args.maxlen, args.hidden_units) # TO IMPROVE
        self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate)

        self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention
        self.attention_layers = torch.nn.ModuleList()
        self.forward_layernorms = torch.nn.ModuleList()
        self.forward_layers = torch.nn.ModuleList()

        self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)

        for _ in range(args.num_blocks):
            new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
            self.attention_layernorms.append(new_attn_layernorm)

            new_attn_layer =  torch.nn.MultiheadAttention(args.hidden_units,
                                                            args.num_heads,
                                                            args.dropout_rate)
            self.attention_layers.append(new_attn_layer)

            new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
            self.forward_layernorms.append(new_fwd_layernorm)

            new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)
            self.forward_layers.append(new_fwd_layer)

            # self.pos_sigmoid = torch.nn.Sigmoid()
            # self.neg_sigmoid = torch.nn.Sigmoid()

    def log2feats(self, log_seqs):
        seqs = self.item_emb(torch.LongTensor(log_seqs).to(self.dev))
        seqs *= self.item_emb.embedding_dim ** 0.5
        #positions = np.tile(np.array(range(log_seqs.shape[1])), [log_seqs.shape[0], 1])
        #seqs += self.pos_emb(torch.LongTensor(positions).to(self.dev))
        seqs = self.emb_dropout(seqs)

        timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev)
        seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim

        tl = seqs.shape[1] # time dim len for enforce causality
        attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev))

        for i in range(len(self.attention_layers)):
            seqs = torch.transpose(seqs, 0, 1)
            Q = self.attention_layernorms[i](seqs)
            mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, 
                                            attn_mask=attention_mask)
                                            # key_padding_mask=timeline_mask
                                            # need_weights=False) this arg do not work?
            seqs = Q + mha_outputs
            seqs = torch.transpose(seqs, 0, 1)

            seqs = self.forward_layernorms[i](seqs)
            seqs = self.forward_layers[i](seqs)
            seqs *=  ~timeline_mask.unsqueeze(-1)

        log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C)

        return log_feats

    def forward(self, log_seqs, pos_seqs, neg_seqs): # for training        
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet
        print("hi")
        print("log feats shape:{}".format(log_feats.shape))
        pos_embs = self.item_emb(torch.LongTensor(pos_seqs).to(self.dev))
        neg_embs = self.item_emb(torch.LongTensor(neg_seqs).to(self.dev))

        pos_logits = (log_feats * pos_embs).sum(dim=-1)
        neg_logits = (log_feats * neg_embs).sum(dim=-1)

        # pos_pred = self.pos_sigmoid(pos_logits)
        # neg_pred = self.neg_sigmoid(neg_logits)

        return pos_logits, neg_logits # pos_pred, neg_pred

    def predict(self, log_seqs, item_indices): # for inference
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet

        final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste

        item_embs = self.item_emb(torch.LongTensor(item_indices).to(self.dev)) # (U, I, C)

        logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1)

        # preds = self.pos_sigmoid(logits) # rank same item list for different users

        return logits # preds # (U, I)


In [5]:
log_seq_ids = torch.tensor([[1,12,0,43,45,66]])
user_ids = torch.tensor([[0,3,4,5,6,7]])
pos_seq = torch.tensor([[12,23,55,67,84,98]])
neg_seq = torch.tensor([[23,43,25,17,54,78]])

In [6]:
args = Args()
model = SASRec(1000,1000,args)

In [7]:
model(log_seq_ids,pos_seq,neg_seq)

hi
log feats shape:torch.Size([1, 6, 768])


(tensor([[ 2.1265, -4.0618,  0.0000,  8.8231, 26.9071, 21.7874]],
        grad_fn=<SumBackward1>),
 tensor([[ 0.4874,  1.5811,  0.0000,  6.9174, 19.3078, 19.2773]],
        grad_fn=<SumBackward1>))

In [8]:
from models.RWKV.model import *

In [9]:
class RWKV_Config:
    vocab_size = 768
    n_embd = 768
    n_layer = 12
    ctx_len = 768
    model_type = "RWKV"
    n_head = 8
    n_attn = 8
    n_ffn = 4
    rwkv_emb_scale = 4

In [10]:
rwkv = GPT(config = RWKV_Config())

768   768   4    tok_emb.weight
8     768   0    blocks.0.attn.key.weight
8     768   1    blocks.0.attn.value.weight
8     768   0    blocks.0.attn.receptance.weight
768   8     0    blocks.0.attn.output.weight
10    768   1    blocks.0.mlp.key.weight
10    768   1    blocks.0.mlp.value.weight
768   10    0    blocks.0.mlp.weight.weight
768   768   0    blocks.0.mlp.receptance.weight
8     768   0    blocks.1.attn.key.weight
8     768   1    blocks.1.attn.value.weight
8     768   0    blocks.1.attn.receptance.weight
768   8     0    blocks.1.attn.output.weight
10    768   1    blocks.1.mlp.key.weight
10    768   1    blocks.1.mlp.value.weight
768   10    0    blocks.1.mlp.weight.weight
768   768   0    blocks.1.mlp.receptance.weight
8     768   0    blocks.2.attn.key.weight
8     768   1    blocks.2.attn.value.weight
8     768   0    blocks.2.attn.receptance.weight
768   8     0    blocks.2.attn.output.weight
10    768   1    blocks.2.mlp.key.weight
10    768   1    blocks.2.mlp.value

In [11]:
rwkv(log_seq_ids,pos_seq,neg_seq)[0].shape

C shape :torch.Size([1, 6, 6]).
x shape:torch.Size([1, 6, 768])
pos embedding shape:torch.Size([1, 6, 768])
neg embedding shape:torch.Size([1, 6, 768])


torch.Size([6, 768])