In [1]:
import os
import sys
import json
import random
import time
from typing import Dict, List, Tuple, Iterable

parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import joblib
import pandas as pd
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from tqdm import tqdm
from pandas.core.groupby.generic import DataFrameGroupBy

from src.BERT4Rec import Tokenizer, SequenceDataset, KeBERT4Rec

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Ideas
1. Modify KeBERT4Rec by adding genre prediction `BCEWithLogitsLoss`
2. Use first 75% of train.csv as training set, and last 25% as validation set


## Preprocess

In [2]:
# origin datasets
train_df = pd.read_csv('../data/train.csv')
# test_df = pd.read_csv('../data/test.csv')
songs_df = pd.read_csv('../data/songs.csv')
# songs_info_df = pd.read_csv('../data/song_extra_info.csv')

In [4]:
tr_df, val_df = train_test_split(train_df, test_size=0.25, shuffle=False)

In [5]:
tr_song_df = tr_df.merge(songs_df, how='inner', on='song_id')
split_genre = tr_song_df['genre_ids'].astype(str).str.split('|')

In [6]:
tknr = Tokenizer.load(load_kw_enc=True)

### build dataset

In [6]:
train_ds = SequenceDataset(tr_song_df, 'train', tknr, 10)

In [7]:
(source_items, target_items, mask, source_keywords,
 target_keywords) = train_ds[0]

In [10]:
print(target_items)
print(source_items)
print(mask)

tensor([8841, 8888, 8900, 8964, 9060, 9072, 9139, 9158, 9486, 9695],
       device='cuda:0')
tensor([   1, 8888, 8900,    1,    1, 9072,    1, 9158, 9486, 9695],
       device='cuda:0')
tensor([1, 0, 0, 1, 1, 0, 1, 0, 0, 0], device='cuda:0')


In [14]:
print(source_keywords[5:])

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 

In [7]:
target_items, target_keywords = train_ds.user_seq[0]

In [10]:
seq_len = 10
print(target_items.shape, target_keywords.shape)
end_idx = random.randint(10, max(target_items.size(0), 10))
start_idx = max(0, end_idx - seq_len)
target_items = target_items[start_idx:end_idx]
target_keywords = target_keywords[start_idx:end_idx]

torch.Size([407]) torch.Size([407, 167])


In [14]:
mask = torch.rand(target_items.shape, device=DEVICE) <= 0.2
source_items = target_items.masked_fill(mask, tknr.MASK)

print(target_items)
print(source_items)

tensor([11954, 12070, 12128, 12183, 12249, 12524, 12663, 12833, 12970, 13362],
       device='cuda:0')
tensor([11954, 12070, 12128, 12183, 12249,     1, 12663, 12833, 12970,     1],
       device='cuda:0')


In [15]:
source_keywords = target_keywords.masked_fill(mask.unsqueeze(1), tknr.MASK)
print(target_keywords)
print(source_keywords)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')


In [24]:
pad_len, pad_side = 0, 'left'
F.pad(target_items, (pad_len, 0) if pad_side == 'left' else (0, pad_len))

tensor([11954, 12070, 12128, 12183, 12249, 12524, 12663, 12833, 12970, 13362],
       device='cuda:0')

In [22]:
target_items

tensor([11954, 12070, 12128, 12183, 12249, 12524, 12663, 12833, 12970, 13362],
       device='cuda:0')

In [7]:
def preprocess(group_df: pd.DataFrame, tokenizer: Tokenizer):

    group_df = group_df.drop_duplicates(subset='song_id', ignore_index=True)
    target_items = group_df['song_id'].to_list()
    keywords_ls = group_df['genre_ids'].astype(str).str.split('|')
    target_keywords_tensor = tokenizer.encode_keywords(keywords_ls)
    target_item_tensor = tokenizer.convert_tokens_to_ids(target_items)

    return target_item_tensor, target_keywords_tensor

In [7]:
gpb = tr_song_df.groupby('msno').apply(preprocess, tokenizer=tknr).to_list()


In [10]:
user_seq = [(torch.tensor(items, dtype=torch.long, device=DEVICE),
             torch.tensor(keywords, dtype=torch.long, device=DEVICE))
            for items, keywords in gpb]

In [12]:
target_item, target_keyword = user_seq[0]

In [18]:
random.randint(10, max(target_item.size(0), 10))

78

In [19]:
target_keyword[:78].shape

torch.Size([78, 167])

In [16]:
torch.rand(target_item.shape) > 0.2

tensor([ True,  True,  True, False, False,  True,  True,  True,  True,  True,
        False,  True,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False,  True, False,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True,  True,  True, False, False,  True,  True,  True, False,  True,
         True,  True,  True, False,  True,  True,  True, False,  True,  True,
        False,  True,  True, False,  True,  True,  True,  True,  True, False,
         True,  True,  True,  True, False,  True,  True, False, False,  True,
         True,  True, False,  True,  True, False,  True,  True,  True,  True,
         True,  True,  True, False,  True,  True,  True,  True, 

In [33]:
sample_df = gpb.loc['zzqc2ja7z10FtSpagYVcAZXg/gPRq7wcDZuNFj+zJSU=']

In [34]:
kw_enc = tknr.encode_keywords(
    sample_df['genre_ids'].astype(str).str.split('|'))

In [35]:
kw_enc

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 1]])

In [28]:
gpb['msno'].duplicated()

msno                                             
++5wYjoMgQHoRuD3GbbvmphZbBBwymzv5Q4l8sywtuU=  0      False
                                              1       True
                                              2       True
                                              3       True
                                              4       True
                                                     ...  
zzqc2ja7z10FtSpagYVcAZXg/gPRq7wcDZuNFj+zJSU=  180     True
                                              181     True
zzzRi5ek1YCKTGns8C77xwAutE05PAPmz8T/pIIQhzE=  0      False
                                              1       True
                                              2       True
Name: msno, Length: 5532974, dtype: bool

In [7]:
dataset = SequenceDataset(gpb, mode='val', tokenizer=tknr, seq_len=100)

In [8]:
si, ti, mk, sk, tk = dataset[2]

In [10]:
mk

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

In [7]:
temp_df = gpb.get_group(list(gpb.groups.keys())[0])

In [8]:
temp_df = temp_df.drop_duplicates(subset='song_id', ignore_index=True)

In [9]:
tknr.padding(temp_df['song_id'].iloc[:5], max_len=4)

['lBYsvASSajQQPfE/o65Qy3FYsMHhbCmynFANAP29nks=',
 'DLBDZhOoW7zd7GBV99bi92ZXYUS26lzV+jJKbHshP5c=',
 'skehue/d/R59G71dXYpntDwdjRRPlweN3JE8g40TgZU=',
 'PgRtmmESVNtWjoZHO5a1r21vIz9sVZmcJJpFCbRa1LI=']

In [11]:
tknr.mask(temp_df['song_id'].iloc[:10], ratio=0.2, only_last=True)

['u6/Pb7X4u7KU4gXrBgGqt8RlRrNNFLn03tLAHyxRxwA=',
 'lBYsvASSajQQPfE/o65Qy3FYsMHhbCmynFANAP29nks=',
 'DLBDZhOoW7zd7GBV99bi92ZXYUS26lzV+jJKbHshP5c=',
 'skehue/d/R59G71dXYpntDwdjRRPlweN3JE8g40TgZU=',
 'PgRtmmESVNtWjoZHO5a1r21vIz9sVZmcJJpFCbRa1LI=',
 'GFWjz4apE8zMiXe6wc0qCjr+DbK9BFkwo/rr+FsNm+g=',
 'zHqZ07gn+YvF36FWzv9+y8KiCMhYhdAUS+vSIKY3UZY=',
 'i356q8t0P9emMJq8PsFkY6CGoi34mP3cgXuDpfEDyhY=',
 'pCf4HI+z4rQsY9FF79m+Sojnl207qtHqtOWU2VFW1ZE=',
 '[MASK]']

In [11]:
tknr.convert_ids_to_tokens(
    tknr.convert_tokens_to_ids(temp_df['song_id']).numpy())

['u6/Pb7X4u7KU4gXrBgGqt8RlRrNNFLn03tLAHyxRxwA=',
 'lBYsvASSajQQPfE/o65Qy3FYsMHhbCmynFANAP29nks=',
 'DLBDZhOoW7zd7GBV99bi92ZXYUS26lzV+jJKbHshP5c=',
 'skehue/d/R59G71dXYpntDwdjRRPlweN3JE8g40TgZU=',
 'PgRtmmESVNtWjoZHO5a1r21vIz9sVZmcJJpFCbRa1LI=',
 'GFWjz4apE8zMiXe6wc0qCjr+DbK9BFkwo/rr+FsNm+g=',
 'zHqZ07gn+YvF36FWzv9+y8KiCMhYhdAUS+vSIKY3UZY=',
 'i356q8t0P9emMJq8PsFkY6CGoi34mP3cgXuDpfEDyhY=',
 'pCf4HI+z4rQsY9FF79m+Sojnl207qtHqtOWU2VFW1ZE=',
 'wBTWuHbjdjxnG1lQcbqnK4FddV24rUhuyrYLd9c/hmk=',
 '43Qm2YzsP99P5wm37B1JIhezUcQ/1CDjYlQx6rBbz2U=',
 'gZphe9aRvr0vVO/oEt23amqyDc+YabbIJ9WIzZTVjG0=',
 'B4zTMZ/an9RmxBJHxlP07ByYW45ycQGtqu6G89GNFM0=',
 'd5ayexvXscdzmuGxENyY8Uwb4AQaxt0dEcnkAgES7xw=',
 'doSc+a6NWyq83MvQRihhH3RIQt617F7d4X9Q34ZapgQ=',
 '35dx60z4m4+Lg+qIS0l2A8vspbthqnpTylWUu51jW+4=',
 'JA6C0GEK1sSCVbHyqtruH/ARD1NKolYrw7HXy6EVNAc=',
 'DEQhJpQFad8IXRU71BVOmbSkx9vmG+2pvC8wpQamniI=',
 'pkKADcjPLFdP9kX40L8dlmbLWkgUTqUB0LhurVLSTME=',
 'A7Z2nDVASy04EpmfzOL9PRMN3hFVbWQ7ah4J+o9sLnE=',
 'OaEbZ6TJ1NePtNUeEg

In [14]:
enc_kw = tknr.encode_keywords(split_genre)

In [17]:
tknr.decode_keywords(enc_kw[-10:])

[('2022',),
 ('2022',),
 ('465',),
 ('423',),
 ('1616', '2058'),
 ('1609',),
 ('1259',),
 ('465',),
 ('1152',),
 ('1152',)]

In [46]:
s, d = 5, 3  # Example original dimensions
original_tensor = torch.randn(s, d)

# Desired padding
a = 2  # Number of rows to add


In [57]:
original_tensor

tensor([[-0.7740, -0.7554, -0.9208],
        [ 2.6779,  1.2348,  0.0938],
        [-1.0274,  0.7983,  0.0112],
        [-0.1177, -0.3713, -1.8274],
        [ 1.1774, -0.6543,  0.7098]])

In [54]:
F.pad(original_tensor, (0, 0, a, 0))

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [-0.7740, -0.7554, -0.9208],
        [ 2.6779,  1.2348,  0.0938],
        [-1.0274,  0.7983,  0.0112],
        [-0.1177, -0.3713, -1.8274],
        [ 1.1774, -0.6543,  0.7098]])

In [58]:
original_tensor.masked_fill(torch.tensor([0, 1, 0, 0, 1]).unsqueeze(1) == 1, 0)

tensor([[-0.7740, -0.7554, -0.9208],
        [ 0.0000,  0.0000,  0.0000],
        [-1.0274,  0.7983,  0.0112],
        [-0.1177, -0.3713, -1.8274],
        [ 0.0000,  0.0000,  0.0000]])

In [11]:
torch.tensor(
    list(
        map(lambda item: tknr.vocab['item2id'].get(item, 1),
            temp_df['song_id'])))

tensor([    14,     45,     46,     49,     56,     73,    165,    177,    212,
           243,    249,    290,    309,    311,    324,    385,    407,    446,
           452,    460,    462,    521,    546,    565,    580,    612,    616,
           629,    630,    657,    700,    717,    743,    844,    924,    974,
           988,   1059,   1084,   1091,   1118,   1130,   1182,   1193,   1231,
          1264,   1361,   1370,   1451,   1496,   1551,   1572,   1598,   1599,
          1796,   2080,   2125,   2150,   2151,   2202,   2275,   2292,   2301,
          2316,   2351,   2449,   2508,   2569,   2693,   2707,   2726,   2730,
          2854,   2869,   2883,   2897,   2898,   2899,   2900,   2936,   2977,
          3054,   3101,   3111,   3172,   3173,   3174,   3309,   3323,   3352,
          3785,   4052,   4089,   4107,   4154,   4281,   4402,   4431,   4441,
          4478,   4509,   4510,   4557,   4662,   4666,   4693,   4745,   4796,
          4798,   4841,   5212,   5226, 

In [62]:
class SequenceDataset(Dataset):

    def __init__(self,
                 group_by: DataFrameGroupBy,
                 mode: str,
                 tokenizer: Tokenizer,
                 seq_len: int = 100) -> None:
        """Constructor

        :param group_by: a pandas groupby
        :param mode: 'train' or 'val'
        :param tokenizer: the Tokenizer
        :param seq_len: the sequence length
        """

        self.groups = list(group_by.groups.keys())
        self.group_by = group_by
        self.mode = mode
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, index):

        group = self.groups[index]
        group_df: pd.DataFrame = self.group_by.get_group(group)
        group_df = group_df.drop_duplicates(subset='song_id',
                                            ignore_index=True)

        # sample start and end
        end_idx = random.randint(
            10,
            group_df.shape[0]) if self.mode == 'train' else group_df.shape[0]
        start_idx = max(0, end_idx - self.seq_len)
        group_df = group_df.iloc[start_idx:end_idx]

        # extract items and keywords
        target_items = group_df['song_id'].to_list()
        keywords_ls = group_df['genre_ids'].astype(str).str.split('|')
        target_keywords_tensor = self.tokenizer.encode_keywords(keywords_ls)

        # mask
        if self.mode == 'train':
            # only mask the last item, akin finetune
            if random.random() <= 0.05:
                source_items = target_items[:-1] + ['[MASK]']
            else:
                source_items = self.tokenizer.mask(target_items, ratio=0.2)
        else:
            source_items = target_items[:] + ['[MASK]']

        # padding
        pad_len = self.seq_len - len(target_items)
        pad_side = 'left' if random.random() <= 0.5 else 'right'
        target_items = self.tokenizer.padding(target_items,
                                              side=pad_side,
                                              max_len=self.seq_len)
        source_items = self.tokenizer.padding(source_items,
                                              side=pad_side,
                                              max_len=self.seq_len)
        target_keywords_tensor = F.pad(target_keywords_tensor,
                                       (0, 0, pad_len, 0) \
                                            if pad_side == 'left' else
                                       (0, 0, 0, pad_len))

        # convert to tensors
        target_item_tensor = self.tokenizer.convert_tokens_to_ids(target_items)
        source_item_tensor = self.tokenizer.convert_tokens_to_ids(source_items)
        mask = (target_item_tensor != source_item_tensor).to(torch.long)
        source_keywords_tensor = target_keywords_tensor.masked_fill(
            mask.unsqueeze(1) == 1, 1)  # mask keywords

        return (source_item_tensor, target_item_tensor, mask,
                source_keywords_tensor, target_keywords_tensor)

In [71]:
dataset = SequenceDataset(gpb, mode='train', tokenizer=tknr, seq_len=100)

In [72]:
si, ti, mk, sk, tk = dataset[2]

In [73]:
mk

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 1, 0])

In [74]:
print(si)
print(ti)

tensor([    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,    97,   110,   186,     1,   243,   287,
          381,   452,     1,   599,   677,   906,     1,  1214,     1,     1,
            1,  1806,  2570,  2582,  2883,  3426,  3701,  4693,  5124,  5308,
         5526,     1,  6312,  6956,  7464,  8036,  8619,  9322,  9743, 10612,
            1, 13527, 14454, 15953, 16042, 16440, 18510, 20086,     1, 21777,
        22167, 23864, 27045, 27994, 28668, 29010, 29049,     1,     1, 37993])
tensor([    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,

In [75]:
print(sk)
print(tk)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 0, 0, 0]])
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])


In [50]:
enc = tknr.encode_keywords(split_genre)

In [52]:
(enc.sum(axis=1) / enc.shape[1] == 1).sum()

tensor(0)

In [43]:
mk

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

In [9]:
Tokenizer.construct(tr_song_df, 'song_id').save()

In [13]:
members = pd.read_csv('../data/members.csv')

In [3]:
# modified datasets
used_songs_df = pd.read_csv('../data/used_songs.csv')

In [8]:
test_df[~test_df['song_id'].isin(train_df['song_id'])]['song_id'].unique(
).shape

(59873,)

In [11]:
# new songs in test.csv
round(59873 / test_df['song_id'].unique().shape[0], 2)

0.27

In [10]:
test_df[~test_df['msno'].isin(train_df['msno'])]['msno'].unique().shape

(3648,)

In [12]:
# new users in test.csv
round(3648 / test_df['msno'].unique().shape[0], 2)

0.15

In [7]:
tr_df, val_df = train_test_split(train_df, test_size=0.25, shuffle=False)

In [8]:
# new songs in the validation set
round(
    val_df[~val_df['song_id'].isin(tr_df['song_id'])]
    ['song_id'].unique().shape[0] / val_df['song_id'].unique().shape[0], 2)

0.28

In [9]:
# new users in the validation set
round(
    val_df[~val_df['msno'].isin(tr_df['msno'])]['msno'].unique().shape[0] /
    val_df['msno'].unique().shape[0], 2)

0.14

In [13]:
temp_df = val_df.groupby('msno').get_group(
    'FnNP1yrSvV9bSxxmccXu3PSarO2wFqhOWByD89kwmvQ=')

In [15]:
round(temp_df[temp_df['target'] == 1].shape[0] / temp_df.shape[0], 2)

0.2

In [17]:
tr_song_df = tr_df.merge(songs_df, how='inner', on='song_id')
tr_song_df.columns

Index(['msno', 'song_id', 'source_system_tab', 'source_screen_name',
       'source_type', 'target', 'song_length', 'genre_ids', 'artist_name',
       'composer', 'lyricist', 'language'],
      dtype='object')

In [20]:
round(tr_song_df.isna().sum() / tr_song_df.shape[0], 2)

msno                  0.00
song_id               0.00
source_system_tab     0.00
source_screen_name    0.05
source_type           0.00
target                0.00
song_length           0.00
genre_ids             0.02
artist_name           0.00
composer              0.23
lyricist              0.43
language              0.00
dtype: float64

In [None]:
tr_song_df['genre_ids_multi_hot'] = 

In [21]:
genre_set = set('|'.join(
    tr_song_df['genre_ids'].astype(str).tolist()).split('|'))

In [26]:
split_genre = tr_song_df['genre_ids'].astype(str).str.split('|')

In [36]:
mlb = MultiLabelBinarizer()

In [41]:
multi_hot = mlb.fit_transform(split_genre)

In [50]:
joblib.dump(mlb, '../models/BERT4Rec/tokenizer/kw_enc.pkl')

['../models/BERT4Rec/tokenizer/kw_enc.pkl']

In [34]:
split_genre[split_genre.apply(lambda x: len(x) > 1)]

13318           [465, 458]
13319           [465, 458]
13320           [465, 458]
13321           [465, 458]
13322           [465, 458]
                ...       
5532898    [139, 125, 109]
5532914       [1180, 1152]
5532945    [139, 125, 109]
5532955        [1572, 275]
5532968       [1616, 2058]
Name: genre_ids, Length: 257973, dtype: object

In [35]:
split_genre.isna().sum()

0

In [8]:
train_df.columns

Index(['msno', 'song_id', 'source_system_tab', 'source_screen_name',
       'source_type', 'target'],
      dtype='object')

In [11]:
train_df['is_train'] = 1
test_df['is_train'] = 0
test_df.drop(columns='id', inplace=True)

train_test_df = pd.concat([train_df, test_df], ignore_index=True)

In [15]:
smaple_submission = pd.read_csv('../data/sample_submission.csv')

In [14]:
(test_df.isna().sum() / len(test_df)).round(2)

msno                  0.00
song_id               0.00
source_system_tab     0.00
source_screen_name    0.06
source_type           0.00
is_train              0.00
dtype: float64

In [3]:
(songs_df.isna().sum() / len(songs_df)).round(2)

song_id        0.00
song_length    0.00
genre_ids      0.04
artist_name    0.00
composer       0.47
lyricist       0.85
language       0.00
dtype: float64

In [4]:
(songs_info_df.isna().sum() / len(songs_info_df)).round(2)

song_id    0.00
name       0.00
isrc       0.06
dtype: float64

In [5]:
full_songs_df = songs_df.merge(songs_info_df, how='inner', on='song_id')

In [6]:
(full_songs_df.isna().sum() / len(full_songs_df)).round(2)

song_id        0.00
song_length    0.00
genre_ids      0.04
artist_name    0.00
composer       0.47
lyricist       0.85
language       0.00
name           0.00
isrc           0.06
dtype: float64

In [7]:
full_songs_df.to_csv('../data/full_songs.csv', index=False, encoding='utf-8')

In [8]:
train_usr_song_df = train_df.merge(full_songs_df, how='inner', on='song_id')

In [11]:
test_usr_song_df = test_df.merge(full_songs_df, how='inner', on='song_id')

In [12]:
train_usr_song_df.columns

Index(['msno', 'song_id', 'source_system_tab', 'source_screen_name',
       'source_type', 'target', 'song_length', 'genre_ids', 'artist_name',
       'composer', 'lyricist', 'language', 'name', 'isrc'],
      dtype='object')

In [18]:
full_usr_song_df = pd.concat(
    [train_usr_song_df, test_usr_song_df.drop(columns='id')],
    ignore_index=True)

In [20]:
full_usr_song_df.to_csv('../data/full_usr_song.csv',
                        index=False,
                        encoding='utf-8')

In [27]:
# songs in train and test
used_songs_df = full_songs_df[full_songs_df['song_id'].isin(
    full_usr_song_df['song_id'].unique())].drop_duplicates(ignore_index=True)

In [29]:
used_songs_df.to_csv('../data/used_songs.csv', index=False, encoding='utf-8')

In [31]:
(used_songs_df.isna().sum() / len(used_songs_df)).round(2)

song_id        0.00
song_length    0.00
genre_ids      0.02
artist_name    0.00
composer       0.43
lyricist       0.75
language       0.00
name           0.00
isrc           0.13
dtype: float64

In [32]:
genre_set = set('|'.join(
    used_songs_df['genre_ids'].astype(str).tolist()).split('|'))

In [35]:
train_usr_song_df.groupby('msno').get_group(
    'FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=')

Unnamed: 0,msno,song_id,source_system_tab,source_screen_name,source_type,target,song_length,genre_ids,artist_name,composer,lyricist,language,name,isrc
0,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,BBzumQNXUHKdEBOB7mAJuzok+IJA1c2Ryg/yzTF6tik=,explore,Explore,online-playlist,1,206471,359,Bastille,Dan Smith| Mark Crew,,52.0,Good Grief,GBUM71602854
221,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,3qm6XTZ6MOCU11x8FIVbAGH5l5uMkT3/ZalWG1oo2Gc=,explore,Explore,online-playlist,1,187802,1011,Brett Young,Brett Young| Kelly Archer| Justin Ebach,,52.0,Sleep Without You,QM3E21606003
633,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,3Hg5kugV1S0wzEVLAEfqjIV5UHzb7bCrdBRQlGygLvU=,explore,Explore,online-playlist,1,247803,1259,Desiigner,Sidney Selby| Adnan Khan,,52.0,Panda,USUM71601094
5610,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,bPIvRTzfHxH5LgHrStll+tYwSQNVV8PySgA3M1PfTgc=,explore,Explore,online-playlist,1,181115,1011,Thomas Rhett,Thomas Rhett| Rhett Akins| Ben Hayslip,,52.0,Star Of The Show,USLXJ1607334
6621,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,EbI7xoNxI+3QSsiHxL13zBdgHIJOwa3srHd7cDcnJ0g=,explore,Explore,online-playlist,0,257369,465,OneRepublic,Ryan Tedder,,52.0,Counting Stars,USUM71301306
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7370732,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,J06isMbryq9+xddfV+bUQhEj9DKfrL3cOWN80Z87tPA=,radio,Radio,radio,0,520240,465,Tetsuya komuro (小室哲哉),Tetsuya Komuro,Tetsuya Komuro/Rap words: MARC,17.0,Judgement 2014,JPB601401528
7370733,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,YR8jzRXzET5IvYAjXBhTIFyD6JbWXkGsVPr5z79arlI=,radio,Radio,radio,0,445257,1609,Morsy| Noone Costelo,,,-1.0,Get Mad,USNRS1433298
7370767,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,LO88CiYqRzttnEVXjJvKTNi1odC2ZoHJ5cQcBXQzFzs=,radio,Radio,radio,0,435095,1609|2107,Sailor & I,,,52.0,Leave The Light On,USUS11500014
7370768,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,QXn/3rYpoGRcbsd45/lHcRvUsxqRGDHjKVVKT2CmytU=,radio,Radio,radio,0,399940,1609,Various Artists,,,52.0,Da Funk,DEH741507362


### build tokenizer

In [39]:
PAD, MASK = 0, 1
songs_ids = {song: i + 2 for i, song in enumerate(used_songs_df['song_id'])}
songs_ids['[PAD]'] = 0
songs_ids['[MASK]'] = 1

In [41]:
ids_songs = {idx: song for song, idx in songs_ids.items()}

In [48]:
vocab = {'item2id': songs_ids, 'id2item': ids_songs}
with open('../models/BERT4Rec/vocab.json', 'w', encoding='utf-8') as f:
    json.dump(vocab, f)

In [4]:
with open('../models/BERT4Rec/vocab.json', 'r', encoding='utf-8') as f:
    vocab = json.load(f)

In [44]:
class Tokenizer:

    def __init__(self,
                 vocab: Dict[str, Dict],
                 keyword_encoder: MultiLabelBinarizer = None) -> None:
        """Constructor
        
        :param vocab: a dict contains 'item2id' and 'id2item'
        :param keyword_encoder: a multi-hot encoder
        """

        self.vocab = vocab
        self.keyword_encoder = keyword_encoder

    @classmethod
    def construct(cls,
                  item_df: pd.DataFrame,
                  item_column: str,
                  keyword_ls: Iterable[Iterable[str]] = None):
        """Construct a Tokenizer

        :param item_df: a dataframe contains items
        :param item_column: a column contains items' ids
        :param keyword_ls: a set of keywords for each item

        :return: a Tokenizer
        """

        item2id = {
            item: idx + 2
            for idx, item in enumerate(item_df[item_column].unique())
        }
        item2id['[PAD]'] = 0
        item2id['[MASK]'] = 1
        id2item = {idx: item for item, idx in item2id.items()}

        if keyword_ls is not None:
            mlb = MultiLabelBinarizer()
            mlb.fit(keyword_ls)

        return cls({'item2id': item2id, 'id2item': id2item}, mlb)

    @classmethod
    def load(cls,
             vocab_fp: str = None,
             load_kw_enc: bool = False,
             keyword_enc_fp: str = None):
        """Load a Tokenizer

        :param vocab_fp: the vocab's file path; if None, use the default vocab
        :param load_kw_enc: whether to load the keyword multi-hot encoder; default False
        :param keyword_enc_fp: the multi-hot encoder's file path; if None, use the default multi-hot encoder

        :return: a Tokenizer
        """

        if not vocab_fp:
            vocab_fp = os.path.join(os.path.dirname(__file__), 'vocab.json')
            if not os.path.exists(vocab_fp):
                raise FileNotFoundError('No default vocab!')

        with open(vocab_fp, 'r', encoding='utf-8') as vocab_f:
            vocab = json.load(vocab_f)

        if load_kw_enc:
            if not keyword_enc_fp:
                keyword_enc_fp = os.path.join(os.path.dirname(__file__),
                                              'kw_enc.pkl')
                if not os.path.exists(vocab_fp):
                    raise FileNotFoundError('No default keyword encoder!')

            kw_enc = joblib.load(keyword_enc_fp)

        return cls(vocab, kw_enc)


In [45]:
tknr = Tokenizer.construct(tr_df, 'song_id', split_genre)

In [49]:
raise FileNotFoundError('No default vocab')

FileNotFoundError: Not default vocab

In [46]:
tknr.keyword_encoder.classes_

array(['1000', '1007', '1011', '1019', '102', '1026', '1033', '1040',
       '1047', '1054', '1068', '1082', '109', '1096', '1103', '1110',
       '1117', '1124', '1131', '1138', '1145', '1152', '1155', '1162',
       '1169', '118', '1180', '1187', '1194', '1201', '1208', '125',
       '1259', '1266', '1273', '1280', '1287', '139', '152', '1568',
       '1572', '1579', '1598', '1605', '1609', '1616', '1630', '1633',
       '177', '184', '191', '1944', '1955', '1965', '1969', '1977', '198',
       '1981', '1988', '1995', '2008', '2015', '2022', '2029', '2032',
       '205', '2052', '2058', '2065', '2072', '2079', '2086', '2093',
       '2100', '2107', '2109', '2116', '212', '2122', '2127', '2130',
       '2144', '2150', '2157', '2172', '2176', '2183', '2189', '2192',
       '2194', '2206', '2213', '2215', '2219', '2245', '2248', '242',
       '252', '275', '282', '296', '310', '331', '338', '352', '359',
       '367', '374', '381', '388', '402', '409', '416', '423', '430',
       '437',

In [7]:
tknr.vocab

{'item2id': {'CXoTN1eb7AI+DntdU1vbcwGRV4SCIDxZu+YD8JP8r4E=': 2,
  'o0kFgae9QtnYgRkVPqLJwa05zIhRlUjfF7O1tDw0ZDU=': 3,
  'DwVvVurfpuz+XPuFvucclVQEyPqcpUkHR0ne1RQzPs0=': 4,
  'dKMBWoZyScdxSkihKG+Vf47nc18N9q4m58+b4e7dSSE=': 5,
  'W3bqWd3T+VeHFzHAUfARgW9AvVRaF4N5Yzm4Mr6Eo/o=': 6,
  'kKJ2JNU5h8rphyW21ovC+RZU+yEHPM+3w85J37p7vEQ=': 7,
  'N9vbanw7BSMoUgdfJlgX1aZPE1XZg8OS1wf88AQEcMc=': 8,
  'GsCpr618xfveHYJdo+E5SybrpR906tsjLMeKyrCNw8s=': 9,
  'oTi7oINPX+rxoGp+3O6llSltQTl80jDqHoULfRoLcG4=': 10,
  'btcG03OHY3GNKWccPP0auvtSbhxog/kllIIOx5grE/k=': 11,
  'HulM/OaHgD5kUyjNQjDUf8VZdsy7h4EJUIff79Cifwo=': 12,
  'wypPzqFNdUJAqyBVxmFGaK4z7krUNWr5YqA0q0wi9eE=': 13,
  'fAZLdfQaLG76a6Ei4alt1eSjBM9rshQkiQEC6+n+y08=': 14,
  'tqBlH4r/q1Tf6C5+C6ucjGlLjMbfu5yjqB6ifRzy5dc=': 15,
  'an6EdIr+Z+KbqIVQiXn5PKkcXncefQ7hhWONseRuub4=': 16,
  'J2MFmy8iF94mExWfRWE3KxsMZB+ZIedV5liqZoSrERQ=': 17,
  'MrRilXQwoUAcoAf0N3RT82qX2/us/wEhYDXE+ZTIW5o=': 18,
  'OcG4Ya7iXmVMCMy24C5wxDMtr9w6WQZiFaN0uq6zdTk=': 19,
  'JcHIgDP5ivyqYIn7RxfXM1

## Model

In [79]:
class KeBERT4Rec(nn.Module):

    def __init__(self,
                 item_size: int,
                 keyword_size: int,
                 dim: int = 64,
                 dropout: float = 0.4,
                 num_layers: int = 3) -> None:
        """Constructor
        
        :param item_size: total number of items
        :param keyword_size: total number of keywords
        :param dim: the dim of embeddings
        :param dropout: dropout rate
        :param num_layers: number of Encoder layers
        """

        super().__init__()

        # embeddings
        self.item_embeddings = nn.Embedding(item_size, dim)
        self.pos_embdddings = nn.Embedding(512, dim)
        self.kw_embdddings = nn.Linear(keyword_size, dim)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(dim,
                                                   nhead=4,
                                                   dropout=dropout,
                                                   batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer,
                                             num_layers=num_layers)

        # output layer
        self.item_out = nn.Linear(dim, item_size)
        self.kw_out = nn.Linear(dim, keyword_size)

    def forward(
            self, source_items: torch.Tensor, source_keywords: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass
        
        :param source_items: items
        :param source_keywords: keywords
        :return: predicted item logits and keyword logits
        """

        batch_size, seq_len = source_items.shape

        # encode input
        item_embed = self.item_embeddings(source_items)
        kw_embed = self.kw_embdddings(source_keywords.to(torch.float32))
        pos_encode = torch.arange(
            0, seq_len,
            device=source_items.device).unsqueeze(0).repeat(batch_size, 1)
        pos_embed = self.pos_embdddings(pos_encode)

        # forward
        x = item_embed + kw_embed + pos_embed
        x = self.encoder(x)
        item_out = self.item_out(x)
        kw_out = self.kw_out(x)

        return item_out, kw_out

    @staticmethod
    def item_loss_acc(predict_items_logits: torch.Tensor,
                      target_items: torch.Tensor, mask: torch.Tensor):
        """Calculate item loss and accuracy
        
        :param predict_items_logits: the prediction logits
        :param target_items: the target items
        :param mask: the mask
        :return: the loss and the accuracy
        """
        predict_items_logits = predict_items_logits.view(
            -1, predict_items_logits.shape[-1])
        target_items = target_items.view(-1)
        mask = mask.flatten() == 1

        # loss
        loss = F.cross_entropy(predict_items_logits,
                               target_items,
                               reduction='none')
        loss *= mask
        loss = loss.sum() / (mask.sum() + 1e-8)

        # accuracy
        predict = predict_items_logits.argmax(dim=-1)
        y_true = target_items.masked_select(mask)
        y_predict = predict.masked_select(mask)
        acc = (y_true == y_predict).double().mean()

        return loss, acc

    @staticmethod
    def kw_loss_sim(predict_kw_logits: torch.Tensor, target_kw: torch.Tensor,
                    mask: torch.Tensor):
        """Calculate keywords loss and cosine similarity
        
        :param predict_kw_logits: the keywords logits
        :param target_kw: the target keywords
        :param mask: the mask
        :return: the loss and the cosine similarity
        """

        mask = mask == 1

        # loss
        loss = F.binary_cross_entropy_with_logits(predict_kw_logits,
                                                  target_kw.to(torch.float32),
                                                  reduction='none').sum(-1)
        loss *= mask
        loss = loss.sum() / (mask.sum() + 1e-8)

        # accuracy
        sim = torch.cosine_similarity(
            predict_kw_logits, target_kw,
            dim=-1).masked_select(mask == 1).sum() / (mask.sum() + 1e-8)

        return loss, sim


In [7]:
train_ds = SequenceDataset(tr_song_df, 'train', tknr, 100)

In [8]:
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=10)

In [26]:
len(train_dl)

1714

In [9]:
def train(model: KeBERT4Rec, data_loader: DataLoader, num_epochs: int):

    for param in model.parameters():
        nn.init.trunc_normal_(param, mean=0, std=0.02, a=-0.02, b=0.02)

    optimizer = optim.Adam(model.parameters(),
                           lr=1e-4,
                           betas=(0.9, 0.999),
                           weight_decay=0.01)
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda epoch: 1 - epoch / num_epochs)

    model.to(DEVICE)

    for epoch in range(num_epochs):

        model.train()
        total_loss = 0
        for i, batch in enumerate(data_loader):
            start = time.time()
            source_item, target_item, mask, source_keyword, target_keyword = batch

            item_out, keyword_out = model(source_item, source_keyword)

            item_loss, item_acc = model.item_loss_acc(item_out, target_item,
                                                      mask)
            keyword_loss, keyword_sim = model.kw_loss_sim(
                keyword_out, target_keyword, mask)

            loss = item_loss + keyword_loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()
            print(f"Batch {i}, use {time.time()-start:.2f}")

        print("Epoch %d avg loss %.2f" %
              (epoch, total_loss / len(data_loader)))

        scheduler.step()

In [10]:
model = KeBERT4Rec(len(tknr.vocab['id2item']),
                   len(tknr.keyword_encoder.classes_))

In [11]:
train(model, train_dl, 1)

Batch 0, use 3.62
Batch 1, use 1.70
Batch 2, use 1.21
Batch 3, use 1.54
Batch 4, use 1.04
Batch 5, use 1.54
Batch 6, use 1.20
Batch 7, use 1.53
Batch 8, use 1.33
Batch 9, use 1.47
Batch 10, use 1.30
Batch 11, use 1.62
Batch 12, use 1.27
Batch 13, use 1.63
Batch 14, use 1.00
Batch 15, use 1.46
Batch 16, use 1.23
Batch 17, use 1.46
Batch 18, use 1.24
Batch 19, use 1.27
Batch 20, use 1.20
Batch 21, use 1.57
Batch 22, use 1.23
Batch 23, use 1.46
Batch 24, use 0.96
Batch 25, use 1.51
Batch 26, use 1.19
Batch 27, use 1.55
Batch 28, use 1.20
Batch 29, use 1.46
Batch 30, use 1.09
Batch 31, use 1.59
Batch 32, use 1.18
Batch 33, use 1.71
Batch 34, use 0.94
Batch 35, use 1.51
Batch 36, use 0.94
Batch 37, use 1.53
Batch 38, use 1.20
Batch 39, use 1.66
Batch 40, use 1.08
Batch 41, use 1.51
Batch 42, use 1.18
Batch 43, use 1.61
Batch 44, use 0.96
Batch 45, use 1.57
Batch 46, use 1.24
Batch 47, use 1.56
Batch 48, use 1.00
Batch 49, use 1.37
Batch 50, use 1.18
Batch 51, use 1.53
Batch 52, use 1.17
Bat

In [8]:
si, ti, mk, sk, tk = train_ds[4]

In [12]:
batch = next(iter(train_dl))
source_item, target_item, mask, source_kw, target_kw = batch

In [17]:
source_item, target_item, mask, source_kw, target_kw = batch

In [23]:
for i in range(len(batch)):
    batch[i] = batch[i].to(DEVICE)
source_item, target_item, mask, source_kw, target_kw = batch
source_item.device

device(type='cuda', index=0)

In [80]:
model = KeBERT4Rec(len(tknr.vocab['id2item']),
                   len(tknr.keyword_encoder.classes_))

In [81]:
item_out, kw_out = model(source_item, source_kw)

In [83]:
model.item_loss_acc(item_out, target_item, mask)

(tensor(12.9208, grad_fn=<DivBackward0>), tensor(0., dtype=torch.float64))

In [84]:
model.kw_loss_sim(kw_out, target_kw, mask)

(tensor(121.9991, grad_fn=<DivBackward0>),
 tensor(0.0261, grad_fn=<DivBackward0>))

In [16]:
a_item_out = item_out.view(-1, item_out.shape[-1])
a_kw_out = kw_out.view(-1, kw_out.shape[-1])
a_item_trg = target_item.view(-1)
a_kw_trg = target_kw.view(-1)

In [56]:
loss = F.binary_cross_entropy_with_logits(kw_out,
                                          target_kw.to(torch.float32),
                                          reduction='none')

In [66]:
loss.sum(-1) * (mask == 1)

tensor([[125.8160,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 122.5573,
         125.8246,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000, 125.5507, 121.9584,   0.0000,   0.0000, 125.5337,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0

In [78]:
torch.cosine_similarity(kw_out, target_kw,
                        dim=-1).masked_select(mask == 1).sum() / mask.sum()

tensor(0.0339, grad_fn=<DivBackward0>)

In [75]:
mask.sum()

tensor(43)

In [54]:
kw_out.dtype

torch.float32

In [23]:
torch.allclose(mask.flatten(), (source_item.view(-1)
                                != a_item_trg).to(torch.long))

True

In [27]:
loss = F.cross_entropy(a_item_out, a_item_trg, reduction='none')
loss

tensor([13.9642, 13.6723, 13.4483, 12.6086, 12.3457, 11.4078, 12.7360, 12.6496,
        12.7356, 11.9945, 13.8190, 12.9021, 13.8431, 12.8897, 12.2497, 13.3047,
        13.1563, 12.7346, 13.8828, 13.3377, 13.1270, 13.5089, 13.0576, 13.0882,
        13.6664, 12.7604, 12.9936, 13.0940, 12.7521, 12.4667, 12.8097, 12.5713,
        12.6958, 12.7983, 12.8831, 12.8244, 12.7116, 12.8163, 13.1992, 12.9368,
        12.5753, 13.0136, 12.9261, 12.3828, 12.3319, 12.7954, 12.6944, 12.8703,
        12.5307, 13.2218, 12.4872, 12.4915, 12.5212, 13.4166, 12.4752, 13.6178,
        12.8816, 12.5286, 12.7947, 12.7063, 12.9730, 12.5026, 13.2123, 12.8360,
        12.2891, 12.9247, 11.9535, 12.7041, 13.1867, 12.9042, 12.7599, 12.4386,
        12.2086, 12.3103, 13.5146, 12.9431, 12.4321, 12.0057, 12.4941, 13.2652,
        12.8150, 12.7250, 12.8384, 11.9849, 13.4179, 12.0336, 13.1036, 12.3194,
        12.6713, 12.6999, 12.2669, 12.6512, 13.4600, 12.9374, 12.8595, 12.7819,
        13.2597, 12.8863, 13.0576, 12.73

In [33]:
(loss * mask.flatten()).sum() / (mask.sum() + 1e-8)

tensor(12.7632, grad_fn=<DivBackward0>)

In [35]:
it_prd = a_item_out.argmax(-1)
y_true = a_item_trg.masked_select(mask.flatten() == 1)
y_predict = it_prd.masked_select(mask.flatten() == 1)
acc = (y_true == y_predict).double().mean()

In [36]:
acc

tensor(0., dtype=torch.float64)

In [48]:
model.item_loss_acc(item_out, target_item, mask)

(tensor(12.6250, grad_fn=<DivBackward0>), tensor(0., dtype=torch.float64))

In [49]:
model.kw_loss_acc(kw_out, target_kw, mask)

RuntimeError: result type Float can't be cast to the desired output type Long

In [40]:
model.kw_embdddings.weight.dtype

torch.float32