# STAMP: Short-Term Attention/Memory Priority Model for Session-based Recommendation(STAMP)



### class 관계

* [AbstractRecommender](https://github.com/RUCAIBox/RecBole/blob/master/recbole/model/abstract_recommender.py#L25)  
    * [SequentialRecommender](https://github.com/RUCAIBox/RecBole/blob/master/recbole/model/abstract_recommender.py#L146)
        * [STAMP](https://github.com/RUCAIBox/RecBole/blob/master/recbole/model/sequential_recommender/stamp.py)
         
                    


### Recbole STAMP 코드실행 예시([참고](https://github.com/RUCAIBox/RecBole/blob/master/run_example/sequential-model-fixed-missing-last-item.ipynb))

    

In [1]:
# Recbole SASRec 코드실행 예시 

# 1. config 
from recbole.config import Config
parameter_dict = {
    'data_path': './data',
    'USER_ID_FIELD': 'user_id',
    'ITEM_ID_FIELD': 'item_id',
    'TIME_FIELD': 'timestamp',
    'user_inter_num_interval': "[30,inf)",
    'item_inter_num_interval': "[40,inf)",
    'load_col': {'inter': ['user_id', 'item_id', 'timestamp']},
    'train_neg_sample_args': None,
    'epochs': 1,
    'eval_args': {
        'split': {'RS': [10, 0, 0]},
        'group_by': 'user',
        'order': 'TO',
        'mode': 'full'}
}
config = Config(model='STAMP', dataset='recbox_data', config_dict=parameter_dict) 

# 2. dataset 
from recbole.data import create_dataset, data_preparation
dataset = create_dataset(config)
train_data, valid_data, test_data = data_preparation(config, dataset)

# 3. model
from recbole.model.sequential_recommender import STAMP
model = STAMP(config, train_data.dataset).to(config['device']) 

# # 4. training 
# from recbole.trainer import Trainer
# trainer = Trainer(config, model)
# best_valid_score, best_valid_result = trainer.fit(train_data)

In [3]:
model

STAMP(
  (item_embedding): Embedding(10962, 64, padding_idx=0)
  (w1): Linear(in_features=64, out_features=64, bias=False)
  (w2): Linear(in_features=64, out_features=64, bias=False)
  (w3): Linear(in_features=64, out_features=64, bias=False)
  (w0): Linear(in_features=64, out_features=1, bias=False)
  (mlp_a): Linear(in_features=64, out_features=64, bias=True)
  (mlp_b): Linear(in_features=64, out_features=64, bias=True)
  (sigmoid): Sigmoid()
  (tanh): Tanh()
  (loss_fct): CrossEntropyLoss()
)

### train data 예시([참고](https://github.com/RUCAIBox/RecBole/blob/master/recbole/trainer/trainer.py#L234))

In [2]:
for batch_idx, batch_data in enumerate(train_data):
    batch_idx = batch_idx
    interaction = batch_data
    break

USER_ID = 'user_id'
POS_ITEM_ID = 'item_id'
ITEM_SEQ = 'item_id_list'
ITEM_SEQ_LEN = 'item_length'

user_seq = interaction[USER_ID] 
item_seq = interaction[ITEM_SEQ]         
item_seq_len = interaction[ITEM_SEQ_LEN] 

In [3]:
user_seq # torch.Size([2048])

tensor([ 8681,  4622, 22968,  ..., 30080, 19206,   759])

In [4]:
item_seq # torch.Size([2048, 50])

tensor([[ 1868,   266,  2206,  ...,     0,     0,     0],
        [ 2549,  2549,    60,  ...,     0,     0,     0],
        [ 1017,   522,   265,  ...,     0,     0,     0],
        ...,
        [ 4208,   818,  6991,  ...,     0,     0,     0],
        [ 3813,  4103,  4103,  ...,  9211, 10242,   265],
        [  181,  1299,  1076,  ...,     0,     0,     0]])

In [5]:
item_seq_len # torch.Size([2048])

tensor([ 7, 26, 13,  ..., 39, 50,  7])

* 유저 8681 학습데이터 예시

In [8]:
import numpy as np
import pandas as pd

uid = train_data.dataset.id2token(train_data.dataset.uid_field, [8681])[0]
index = np.isin(train_data.dataset[train_data.dataset.uid_field].numpy(), 8681) 

user_interaction = train_data.dataset[index]
user_interaction

# df = pd.read_csv('./data/recbox_data/recbox_data.inter', sep='\t')
# ex = df[df['user_id:token'] == uid] # 유저 8681의 로그는 총 42개 (이중 39개가 train data로 사용)

The batch_size of interaction: 39
    user_id, torch.Size([39]), cpu, torch.int64
    item_id, torch.Size([39]), cpu, torch.int64
    timestamp, torch.Size([39]), cpu, torch.float32
    item_length, torch.Size([39]), cpu, torch.int64
    item_id_list, torch.Size([39, 50]), cpu, torch.int64
    timestamp_list, torch.Size([39, 50]), cpu, torch.float32


In [11]:
user_interaction[USER_ID] # torch.Size([39])

tensor([8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681,
        8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681,
        8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681, 8681,
        8681, 8681, 8681])

In [12]:
# x
user_interaction[ITEM_SEQ] # torch.Size([39, 50]) max sequence length = 50

tensor([[1868,    0,    0,  ...,    0,    0,    0],
        [1868,  266,    0,  ...,    0,    0,    0],
        [1868,  266, 2206,  ...,    0,    0,    0],
        ...,
        [1868,  266, 2206,  ...,    0,    0,    0],
        [1868,  266, 2206,  ...,    0,    0,    0],
        [1868,  266, 2206,  ...,    0,    0,    0]])

In [13]:
user_interaction[ITEM_SEQ][-1] 

tensor([ 1868,   266,  2206,  2488,   439,  6038,  5231,  5693,    39,  6429,
         1234,  5944,  3629,   173,  8296,  7010,  6897,  6897,   106,  3858,
         3858,  1196,   643,  5705,  4854,  7869,  8703,  3170,  9258,  9774,
         9362, 10102, 10440, 10334,  8651, 10242,  6997, 10610, 10465,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])

In [14]:
# y(label)
user_interaction[POS_ITEM_ID] # torch.Size([39])

tensor([  266,  2206,  2488,   439,  6038,  5231,  5693,    39,  6429,  1234,
         5944,  3629,   173,  8296,  7010,  6897,  6897,   106,  3858,  3858,
         1196,   643,  5705,  4854,  7869,  8703,  3170,  9258,  9774,  9362,
        10102, 10440, 10334,  8651, 10242,  6997, 10610, 10465,  2939])

In [15]:
user_interaction[ITEM_SEQ_LEN] # torch.Size([39])

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
        37, 38, 39])

### SequentialRecommender class
* [code](https://github.com/RUCAIBox/RecBole/blob/master/recbole/model/abstract_recommender.py#L146)

In [6]:
import torch
import torch.nn as nn

In [None]:
class SequentialRecommender(AbstractRecommender):
    """
    This is a abstract sequential recommender. All the sequential model should implement This class.
    """
    type = ModelType.SEQUENTIAL


```python
def __init__(self, config, dataset):
        super(SequentialRecommender, self).__init__()

        # load dataset info
        self.USER_ID = config["USER_ID_FIELD"]
        self.ITEM_ID = config["ITEM_ID_FIELD"]
        self.ITEM_SEQ = self.ITEM_ID + config["LIST_SUFFIX"]
        self.ITEM_SEQ_LEN = config["ITEM_LIST_LENGTH_FIELD"]
        self.POS_ITEM_ID = self.ITEM_ID
        self.NEG_ITEM_ID = config["NEG_PREFIX"] + self.ITEM_ID
        self.max_seq_length = config["MAX_ITEM_LIST_LENGTH"]
        self.n_items = dataset.num(self.ITEM_ID)

        # load parameters info
        self.device = config["device"]
```

In [6]:
# load dataset info
USER_ID = config["USER_ID_FIELD"] 
ITEM_ID = config["ITEM_ID_FIELD"]  
ITEM_SEQ = ITEM_ID + config["LIST_SUFFIX"]
ITEM_SEQ_LEN = config["ITEM_LIST_LENGTH_FIELD"]

POS_ITEM_ID = ITEM_ID
NEG_ITEM_ID = config["NEG_PREFIX"] + ITEM_ID

max_seq_length = config["MAX_ITEM_LIST_LENGTH"]
n_items = dataset.num(ITEM_ID)

# load parameters info
device = config["device"]

print('USER_ID:', USER_ID)
print('ITEM_ID:', ITEM_ID)
print('ITEM_SEQ:', ITEM_SEQ)
print('ITEM_SEQ_LEN:', ITEM_SEQ_LEN)
print('POS_ITEM_ID:', POS_ITEM_ID)
print('NEG_ITEM_ID:', NEG_ITEM_ID)
print('max_seq_length:', max_seq_length)
print('n_items:', n_items)
print('device:', device)

USER_ID: user_id
ITEM_ID: item_id
ITEM_SEQ: item_id_list
ITEM_SEQ_LEN: item_length
POS_ITEM_ID: item_id
NEG_ITEM_ID: neg_item_id
max_seq_length: 50
n_items: 10962
device: cpu


```python
 def get_attention_mask(self, item_seq, bidirectional=False):
        """Generate left-to-right uni-directional or bidirectional attention mask for multi-head attention."""
        attention_mask = item_seq != 0
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # torch.bool
        if not bidirectional:
            extended_attention_mask = torch.tril(
                extended_attention_mask.expand((-1, -1, item_seq.size(-1), -1))
            )
        extended_attention_mask = torch.where(extended_attention_mask, 0.0, -10000.0)
        return extended_attention_mask
    
```

In [26]:
# input 
item_seq  # torch.Size([2048, 50])

tensor([[ 1868,   266,  2206,  ...,     0,     0,     0],
        [ 2549,  2549,    60,  ...,     0,     0,     0],
        [ 1017,   522,   265,  ...,     0,     0,     0],
        ...,
        [ 4208,   818,  6991,  ...,     0,     0,     0],
        [ 3813,  4103,  4103,  ...,  9211, 10242,   265],
        [  181,  1299,  1076,  ...,     0,     0,     0]])

In [27]:
item_seq[0]

tensor([1868,  266, 2206, 2488,  439, 6038, 5231,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0])

In [28]:
attention_mask = item_seq != 0  # torch.Size([2048, 50])
attention_mask 

tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ..., False, False, False]])

In [29]:
attention_mask[0]

tensor([ True,  True,  True,  True,  True,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [31]:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.Size([2048, 1, 50]) -> torch.Size([2048, 1, 1, 50])
extended_attention_mask

tensor([[[[ True,  True,  True,  ..., False, False, False]]],


        [[[ True,  True,  True,  ..., False, False, False]]],


        [[[ True,  True,  True,  ..., False, False, False]]],


        ...,


        [[[ True,  True,  True,  ..., False, False, False]]],


        [[[ True,  True,  True,  ...,  True,  True,  True]]],


        [[[ True,  True,  True,  ..., False, False, False]]]])

In [None]:
bidirectional = False
if not bidirectional:
    tmp = extended_attention_mask.expand((-1, -1, item_seq.size(-1), -1)) # torch.Size([2048, 1, 50, 50])
    extended_attention_mask = torch.tril(tmp)                             # torch.Size([2048, 1, 50, 50])

In [32]:
tmp = extended_attention_mask.expand((-1, -1, item_seq.size(-1), -1)) # torch.Size([2048, 1, 50, 50])
tmp

tensor([[[[ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False]]],


        [[[ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False]]],


        [[[ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ..., False, Fa

In [33]:
tmp[0][0] # # torch.Size([50, 50])

tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]])

In [38]:
extended_attention_mask = torch.tril(tmp) # torch.Size([2048, 1, 50, 50])
extended_attention_mask

tensor([[[[ True, False, False,  ..., False, False, False],
          [ True,  True, False,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False]]],


        [[[ True, False, False,  ..., False, False, False],
          [ True,  True, False,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False]]],


        [[[ True, False, False,  ..., False, False, False],
          [ True,  True, False,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ..., False, Fa

In [39]:
extended_attention_mask[0][0][0]

tensor([ True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [40]:
extended_attention_mask[0][0][1]

tensor([ True,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [45]:
extended_attention_mask[0][0][6]

tensor([ True,  True,  True,  True,  True,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [46]:
extended_attention_mask[0][0][7] # 7 ~ 49까지 동일 

tensor([ True,  True,  True,  True,  True,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [47]:
extended_attention_mask[0][0][49]

tensor([ True,  True,  True,  True,  True,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [48]:
extended_attention_mask = torch.where(extended_attention_mask, 0.0, -10000.0) # torch.Size([2048, 1, 50, 50])

In [49]:
extended_attention_mask

tensor([[[[     0., -10000., -10000.,  ..., -10000., -10000., -10000.],
          [     0.,      0., -10000.,  ..., -10000., -10000., -10000.],
          [     0.,      0.,      0.,  ..., -10000., -10000., -10000.],
          ...,
          [     0.,      0.,      0.,  ..., -10000., -10000., -10000.],
          [     0.,      0.,      0.,  ..., -10000., -10000., -10000.],
          [     0.,      0.,      0.,  ..., -10000., -10000., -10000.]]],


        [[[     0., -10000., -10000.,  ..., -10000., -10000., -10000.],
          [     0.,      0., -10000.,  ..., -10000., -10000., -10000.],
          [     0.,      0.,      0.,  ..., -10000., -10000., -10000.],
          ...,
          [     0.,      0.,      0.,  ..., -10000., -10000., -10000.],
          [     0.,      0.,      0.,  ..., -10000., -10000., -10000.],
          [     0.,      0.,      0.,  ..., -10000., -10000., -10000.]]],


        [[[     0., -10000., -10000.,  ..., -10000., -10000., -10000.],
          [     0.,      0

```python
 def gather_indexes(self, output, gather_index):
        """Gathers the vectors at the specific positions over a minibatch"""
        gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1])
        output_tensor = output.gather(dim=1, index=gather_index)
        return output_tensor.squeeze(1)
```

In [52]:
# input 
output = torch.rand([2048, 50, 64]) # output 예시
gather_index = item_seq_len - 1     # torch.Size([2048])

In [53]:
gather_index

tensor([ 6, 25, 12,  ..., 38, 49,  6])

In [54]:
gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1]) # torch.Size([2048, 1, 1]) -> [2048, 1, 64]
gather_index

tensor([[[ 6,  6,  6,  ...,  6,  6,  6]],

        [[25, 25, 25,  ..., 25, 25, 25]],

        [[12, 12, 12,  ..., 12, 12, 12]],

        ...,

        [[38, 38, 38,  ..., 38, 38, 38]],

        [[49, 49, 49,  ..., 49, 49, 49]],

        [[ 6,  6,  6,  ...,  6,  6,  6]]])

In [55]:
output_tensor = output.gather(dim=1, index=gather_index) # torch.Size([2048, 1, 64])

# output[0][6]
# output[1][25]
# ...
# output[2047][6]

In [56]:
output_tensor

tensor([[[0.0320, 0.4300, 0.1542,  ..., 0.1900, 0.8480, 0.5424]],

        [[0.3069, 0.4541, 0.0528,  ..., 0.8476, 0.6264, 0.6698]],

        [[0.0131, 0.2552, 0.2702,  ..., 0.0176, 0.9185, 0.1752]],

        ...,

        [[0.7754, 0.5159, 0.2408,  ..., 0.2827, 0.5453, 0.8278]],

        [[0.8666, 0.7152, 0.2644,  ..., 0.1230, 0.5616, 0.8798]],

        [[0.6902, 0.2934, 0.7709,  ..., 0.3917, 0.8086, 0.0504]]])

In [57]:
output[0][6]

tensor([0.0320, 0.4300, 0.1542, 0.0019, 0.3915, 0.0605, 0.1935, 0.6811, 0.5048,
        0.3541, 0.6980, 0.4920, 0.4206, 0.1087, 0.1328, 0.2339, 0.0416, 0.4659,
        0.7262, 0.5235, 0.0957, 0.5303, 0.6733, 0.7423, 0.9876, 0.5838, 0.1462,
        0.7638, 0.6565, 0.3153, 0.6075, 0.4204, 0.0161, 0.0368, 0.7844, 0.0072,
        0.5428, 0.8369, 0.2896, 0.3736, 0.8334, 0.6844, 0.2207, 0.3696, 0.0388,
        0.9414, 0.8661, 0.6743, 0.4789, 0.2714, 0.4282, 0.6675, 0.1873, 0.0504,
        0.5365, 0.9650, 0.1889, 0.7426, 0.0550, 0.3698, 0.6055, 0.1900, 0.8480,
        0.5424])

In [60]:
output[1][25]

tensor([0.3069, 0.4541, 0.0528, 0.7933, 0.7326, 0.9650, 0.4293, 0.9217, 0.0565,
        0.1969, 0.7149, 0.6095, 0.5099, 0.5188, 0.8574, 0.8394, 0.0267, 0.6900,
        0.3555, 0.2192, 0.2323, 0.1542, 0.0817, 0.1590, 0.0681, 0.8346, 0.7143,
        0.3378, 0.2859, 0.7785, 0.9668, 0.8525, 0.0838, 0.4283, 0.4188, 0.2971,
        0.5262, 0.4591, 0.9092, 0.6489, 0.6712, 0.7706, 0.7793, 0.5896, 0.7011,
        0.9960, 0.3039, 0.4623, 0.9865, 0.3380, 0.6413, 0.6226, 0.4300, 0.7187,
        0.6727, 0.4492, 0.6484, 0.8844, 0.3546, 0.3703, 0.8179, 0.8476, 0.6264,
        0.6698])

In [58]:
output_tensor = output_tensor.squeeze(1) # torch.Size([2048, 64])

In [59]:
output_tensor

tensor([[0.0320, 0.4300, 0.1542,  ..., 0.1900, 0.8480, 0.5424],
        [0.3069, 0.4541, 0.0528,  ..., 0.8476, 0.6264, 0.6698],
        [0.0131, 0.2552, 0.2702,  ..., 0.0176, 0.9185, 0.1752],
        ...,
        [0.7754, 0.5159, 0.2408,  ..., 0.2827, 0.5453, 0.8278],
        [0.8666, 0.7152, 0.2644,  ..., 0.1230, 0.5616, 0.8798],
        [0.6902, 0.2934, 0.7709,  ..., 0.3917, 0.8086, 0.0504]])

### STAMP class

* AbstractRecommender
    * SequentialRecommender
        * STAMP
            

In [3]:
import torch
from torch import nn
from torch.nn.init import normal_

from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss

In [None]:
class STAMP(SequentialRecommender):
    r"""STAMP is capable of capturing users’ general interests from the long-term memory of a session context,
    whilst taking into account users’ current interests from the short-term memory of the last-clicks.


    Note:
        According to the test results, we made a little modification to the score function mentioned in the paper,
        and did not use the final sigmoid activation function.

    """

```python

def __init__(self, config, dataset):
    super(STAMP, self).__init__(config, dataset)

    # load parameters info
    self.embedding_size = config["embedding_size"]

    # define layers and loss
    self.item_embedding = nn.Embedding(
        self.n_items, self.embedding_size, padding_idx=0
    )
    self.w1 = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
    self.w2 = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
    self.w3 = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
    self.w0 = nn.Linear(self.embedding_size, 1, bias=False)
    self.b_a = nn.Parameter(torch.zeros(self.embedding_size), requires_grad=True)
    self.mlp_a = nn.Linear(self.embedding_size, self.embedding_size, bias=True)
    self.mlp_b = nn.Linear(self.embedding_size, self.embedding_size, bias=True)
    self.sigmoid = nn.Sigmoid()
    self.tanh = nn.Tanh()
    self.loss_type = config["loss_type"]
    if self.loss_type == "BPR":
        self.loss_fct = BPRLoss()
    elif self.loss_type == "CE":
        self.loss_fct = nn.CrossEntropyLoss()
    else:
        raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

    # # parameters initialization
    self.apply(self._init_weights)

```

In [4]:
# load parameters info
embedding_size = config["embedding_size"]
embedding_size

64

In [7]:
# define layers
item_embedding = nn.Embedding(n_items, embedding_size, padding_idx=0)

## attention machanism weight
w0 = nn.Linear(embedding_size, 1, bias=False)
w1 = nn.Linear(embedding_size, embedding_size, bias=False)
w2 = nn.Linear(embedding_size, embedding_size, bias=False)
w3 = nn.Linear(embedding_size, embedding_size, bias=False)
b_a = nn.Parameter(torch.zeros(embedding_size), requires_grad=True)

## mlp
mlp_a = nn.Linear(embedding_size, embedding_size, bias=True)
mlp_b = nn.Linear(embedding_size, embedding_size, bias=True)

sigmoid = nn.Sigmoid()
tanh = nn.Tanh()

In [8]:
# define loss 
loss_type = config["loss_type"]
if loss_type == "BPR":
    loss_fct = BPRLoss()
elif loss_type == "CE":
    loss_fct = nn.CrossEntropyLoss()
else:
    raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

In [None]:
# parameters initialization
apply(self._init_weights)

# 해당 코드와 동일하게 작동
# for submodule in model.children():
#     _init_weights(submodule)

```python

def _init_weights(self, module):
    if isinstance(module, nn.Embedding):
        normal_(module.weight.data, 0, 0.002)
    elif isinstance(module, nn.Linear):
        normal_(module.weight.data, 0, 0.05)
        if module.bias is not None:
            module.bias.data.fill_(0.0)
```

```python

def forward(self, item_seq, item_seq_len):
    item_seq_emb = self.item_embedding(item_seq)
    last_inputs = self.gather_indexes(item_seq_emb, item_seq_len - 1)
    org_memory = item_seq_emb
    ms = torch.div(torch.sum(org_memory, dim=1), item_seq_len.unsqueeze(1).float())
    alpha = self.count_alpha(org_memory, last_inputs, ms)
    vec = torch.matmul(alpha.unsqueeze(1), org_memory)
    ma = vec.squeeze(1) + ms
    hs = self.tanh(self.mlp_a(ma))
    ht = self.tanh(self.mlp_b(last_inputs))
    seq_output = hs * ht
    return seq_output
```


In [9]:
item_seq # torch.Size([2048, 50]) batch_size, max_seq_len

tensor([[ 1868,   266,  2206,  ...,     0,     0,     0],
        [ 2549,  2549,    60,  ...,     0,     0,     0],
        [ 1017,   522,   265,  ...,     0,     0,     0],
        ...,
        [ 4208,   818,  6991,  ...,     0,     0,     0],
        [ 3813,  4103,  4103,  ...,  9211, 10242,   265],
        [  181,  1299,  1076,  ...,     0,     0,     0]])

In [10]:
item_seq_emb = item_embedding(item_seq) # torch.Size([2048, 50, 64])
org_memory = item_seq_emb

item_seq_emb

tensor([[[-0.3531,  0.1079, -0.0826,  ...,  2.1116,  1.5943,  0.9022],
         [-0.0926, -1.5311, -1.1752,  ...,  1.8774,  0.1764, -0.6408],
         [ 2.7649, -0.3843,  0.6599,  ...,  0.9337,  0.6107,  0.6867],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0785,  0.2606, -1.2229,  ...,  0.9491, -0.1844,  1.5269],
         [ 0.0785,  0.2606, -1.2229,  ...,  0.9491, -0.1844,  1.5269],
         [-0.6542,  1.4045, -0.9057,  ...,  0.9867,  0.2270,  0.9651],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.5249, -0.0755, -1.8381,  ...,  0.6181, -3.1122, -0.2412],
         [ 0.7290, -1.2760, -0.4319,  ...,  0

In [11]:
def gather_indexes( output, gather_index):
        """Gathers the vectors at the specific positions over a minibatch"""
        gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1])
        output_tensor = output.gather(dim=1, index=gather_index)
        return output_tensor.squeeze(1)

last_inputs = gather_indexes(item_seq_emb, item_seq_len - 1) # item_seq_len - 1 인덱스 행 가져오기 
last_inputs # torch.Size([2048, 64])

tensor([[-1.4458, -1.3469,  0.0455,  ..., -0.4465, -0.3475,  2.3354],
        [ 0.0329,  2.5744,  0.6203,  ..., -0.8717, -1.5073, -0.3776],
        [-0.8985,  0.3676,  1.8222,  ..., -0.2688,  0.1008,  2.1990],
        ...,
        [ 1.6410, -0.2803,  0.6968,  ..., -1.0407, -0.1854, -0.1268],
        [ 1.9577, -1.6954, -0.2121,  ...,  0.4650,  1.4187,  0.6217],
        [ 1.3527,  1.5362,  0.1002,  ...,  0.7500, -1.9779,  0.6683]],
       grad_fn=<SqueezeBackward1>)

In [12]:
# longterm memory 

ms = torch.div(torch.sum(org_memory, dim=1), item_seq_len.unsqueeze(1).float())
ms # torch.Size([2048, 64])

tensor([[-0.0602, -0.5067,  0.2281,  ...,  0.4464,  0.2637,  0.9053],
        [-0.0492,  0.4845, -0.4089,  ..., -0.2129, -0.2832,  0.4292],
        [-0.0764, -0.6103,  0.0675,  ..., -0.1303, -0.2001,  0.3318],
        ...,
        [ 0.0443,  0.1398, -0.2464,  ...,  0.0652, -0.2164,  0.1868],
        [ 0.2123, -0.2032, -0.0776,  ..., -0.0422,  0.1022,  0.2967],
        [-0.0197,  0.4879,  0.2212,  ..., -0.1135, -0.8879, -0.1351]],
       grad_fn=<DivBackward0>)

In [13]:
# attention mechanism 

# alpha = count_alpha(org_memory, last_inputs, ms) # x_i, x_t, m_s 
alpha = torch.rand([2048, 50])                     # 아래 참고 
ma = torch.matmul(alpha.unsqueeze(1), org_memory)  # torch.Size([2048, 1, 50]) x torch.Size([2048, 50, 64])

In [32]:
# mlp cell
hs = tanh(mlp_a(ma)).squeeze(1)  # torch.Size([2048, 64])
ht = tanh(mlp_b(last_inputs))    # torch.Size([2048, 64])
seq_output = hs * ht             # torch.Size([2048, 64])

```python
def count_alpha(self, context, aspect, output):
        r"""This is a function that count the attention weights

        Args:
            context(torch.FloatTensor): Item list embedding matrix, shape of [batch_size, time_steps, emb]
            aspect(torch.FloatTensor): The embedding matrix of the last click item, shape of [batch_size, emb]
            output(torch.FloatTensor): The average of the context, shape of [batch_size, emb]

        Returns:
            torch.Tensor:attention weights, shape of [batch_size, time_steps]
        """
        timesteps = context.size(1)
        aspect_3dim = aspect.repeat(1, timesteps).view(
            -1, timesteps, self.embedding_size
        )
        output_3dim = output.repeat(1, timesteps).view(
            -1, timesteps, self.embedding_size
        )
        res_ctx = self.w1(context)
        res_asp = self.w2(aspect_3dim)
        res_output = self.w3(output_3dim)
        res_sum = res_ctx + res_asp + res_output + self.b_a
        res_act = self.w0(self.sigmoid(res_sum))
        alpha = res_act.squeeze(2)
        return alpha
```

In [15]:
# input 
context = org_memory # x_i
aspect = last_inputs # x_t torch.Size([2048, 64])
output = ms          # m_s torch.Size([2048, 64])

In [16]:
timesteps = context.size(1) # 50
aspect_3dim = aspect.repeat(1, timesteps).view(-1, timesteps, embedding_size) # torch.Size([2048, 3200])-> torch.Size([2048, 50, 64])
output_3dim = output.repeat(1, timesteps).view(-1, timesteps, embedding_size)

In [17]:
res_ctx = w1(context)         # W1 * x_i  torch.Size([2048, 50, 64])
res_asp = w2(aspect_3dim)     # W2 * x_t  torch.Size([2048, 50, 64])
res_output = w3(output_3dim)  # W3 * m_s  torch.Size([2048, 50, 64])

In [18]:
res_sum = res_ctx + res_asp + res_output + b_a #  torch.Size([2048, 50, 64])
res_sum

tensor([[[ 1.3875e+00, -4.6703e-01,  6.2192e-01,  ..., -4.0820e-01,
           5.6535e-01, -5.2281e-01],
         [ 5.5756e-01, -1.4719e-01, -4.1336e-01,  ...,  7.6941e-01,
          -7.7056e-01, -7.6560e-01],
         [-7.1844e-01,  4.8012e-01, -1.1105e+00,  ..., -7.3048e-01,
          -5.8200e-01, -1.0103e+00],
         ...,
         [ 1.2692e-01, -6.6955e-01, -5.0031e-01,  ...,  8.1640e-01,
          -7.1951e-01, -3.6567e-01],
         [ 1.2692e-01, -6.6955e-01, -5.0031e-01,  ...,  8.1640e-01,
          -7.1951e-01, -3.6567e-01],
         [ 1.2692e-01, -6.6955e-01, -5.0031e-01,  ...,  8.1640e-01,
          -7.1951e-01, -3.6567e-01]],

        [[-3.8064e-01,  5.0792e-01,  3.7935e-02,  ...,  7.9792e-01,
          -4.0008e-01,  2.3094e-01],
         [-3.8064e-01,  5.0792e-01,  3.7935e-02,  ...,  7.9792e-01,
          -4.0008e-01,  2.3094e-01],
         [ 7.6625e-02, -6.7121e-01, -9.1939e-01,  ..., -6.3593e-01,
          -2.5423e-01,  1.2693e+00],
         ...,
         [-3.1710e-01,  1

In [19]:
res_act = w0(sigmoid(res_sum)) # torch.Size([2048, 50, 1])
res_act

tensor([[[-0.0125],
         [-0.2338],
         [ 0.0077],
         ...,
         [-0.1970],
         [-0.1970],
         [-0.1970]],

        [[-0.1639],
         [-0.1639],
         [-0.1499],
         ...,
         [-0.1744],
         [-0.1744],
         [-0.1744]],

        [[-0.3336],
         [-0.1058],
         [-0.3249],
         ...,
         [-0.2879],
         [-0.2879],
         [-0.2879]],

        ...,

        [[-0.3936],
         [-0.4166],
         [-0.3106],
         ...,
         [-0.2892],
         [-0.2892],
         [-0.2892]],

        [[-0.2170],
         [-0.2404],
         [-0.2404],
         ...,
         [-0.2515],
         [-0.2047],
         [-0.2671]],

        [[-0.0733],
         [-0.3106],
         [-0.1488],
         ...,
         [-0.2250],
         [-0.2250],
         [-0.2250]]], grad_fn=<UnsafeViewBackward0>)

In [20]:
alpha = res_act.squeeze(2) # torch.Size([2048, 50])
alpha

tensor([[-0.0125, -0.2338,  0.0077,  ..., -0.1970, -0.1970, -0.1970],
        [-0.1639, -0.1639, -0.1499,  ..., -0.1744, -0.1744, -0.1744],
        [-0.3336, -0.1058, -0.3249,  ..., -0.2879, -0.2879, -0.2879],
        ...,
        [-0.3936, -0.4166, -0.3106,  ..., -0.2892, -0.2892, -0.2892],
        [-0.2170, -0.2404, -0.2404,  ..., -0.2515, -0.2047, -0.2671],
        [-0.0733, -0.3106, -0.1488,  ..., -0.2250, -0.2250, -0.2250]],
       grad_fn=<SqueezeBackward1>)

```python

def calculate_loss(self, interaction):
    item_seq = interaction[self.ITEM_SEQ]
    item_seq_len = interaction[self.ITEM_SEQ_LEN]
    seq_output = self.forward(item_seq, item_seq_len)
    pos_items = interaction[self.POS_ITEM_ID]
    if self.loss_type == "BPR":
        neg_items = interaction[self.NEG_ITEM_ID]
        pos_items_emb = self.item_embedding(pos_items)
        neg_items_emb = self.item_embedding(neg_items)
        pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  # [B]
        neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)  # [B]
        loss = self.loss_fct(pos_score, neg_score)
        return loss
    else:  # self.loss_type = 'CE'
        test_item_emb = self.item_embedding.weight
        logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
        loss = self.loss_fct(logits, pos_items)
        return loss
```

In [21]:
item_seq = interaction[ITEM_SEQ]
item_seq_len = interaction[ITEM_SEQ_LEN]
# seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[POS_ITEM_ID]

In [35]:
# cross entorpy loss 
test_item_emb = item_embedding.weight # torch.Size([10962, 64])

In [38]:
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) # torch.Size([2048, 64]) x torch.Size([64, 10962])
logits # torch.Size([2048, 10962])

tensor([[ 0.0000, -3.8363, -3.7135,  ...,  1.2520, -1.9060,  1.2944],
        [ 0.0000, -2.4076, -0.7629,  ..., -1.0999, -5.1167,  0.9964],
        [ 0.0000, -0.1450,  0.3902,  ...,  0.8573,  0.5835, -0.0851],
        ...,
        [ 0.0000, -3.3026, -2.7665,  ..., -1.2535,  4.1188,  0.1657],
        [ 0.0000, -0.8994,  0.3988,  ...,  5.6830, -1.6246, -3.0343],
        [ 0.0000, -4.0702, -0.0350,  ..., -2.4813,  2.2666, -0.0387]],
       grad_fn=<MmBackward0>)

In [41]:
loss = loss_fct(logits, pos_items)
loss

tensor(13.0119, grad_fn=<NllLossBackward0>)

```python

def predict(self, interaction):
    item_seq = interaction[self.ITEM_SEQ]
    item_seq_len = interaction[self.ITEM_SEQ_LEN]
    test_item = interaction[self.ITEM_ID]
    seq_output = self.forward(item_seq, item_seq_len)
    test_item_emb = self.item_embedding(test_item)
    scores = torch.mul(seq_output, test_item_emb).sum(dim=1)  # [B]
    return scores

```

In [45]:
item_seq = interaction[ITEM_SEQ]
item_seq


tensor([[ 1868,   266,  2206,  ...,     0,     0,     0],
        [ 2549,  2549,    60,  ...,     0,     0,     0],
        [ 1017,   522,   265,  ...,     0,     0,     0],
        ...,
        [ 4208,   818,  6991,  ...,     0,     0,     0],
        [ 3813,  4103,  4103,  ...,  9211, 10242,   265],
        [  181,  1299,  1076,  ...,     0,     0,     0]])

In [46]:
item_seq_len = interaction[ITEM_SEQ_LEN]
item_seq_len


tensor([ 7, 26, 13,  ..., 39, 50,  7])

In [47]:
test_item = interaction[ITEM_ID] # prediction item_id_lst
test_item 

tensor([5693, 4209, 7263,  ...,  231, 1751, 7948])

In [None]:

def full_sort_predict(self, interaction):
    item_seq = interaction[self.ITEM_SEQ]
    item_seq_len = interaction[self.ITEM_SEQ_LEN]
    seq_output = self.forward(item_seq, item_seq_len)
    test_items_emb = self.item_embedding.weight
    scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))
    return scores