In [1]:
import os
import sys
import json
from typing import Dict, List, Iterable

parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import pandas as pd
import numpy as np
import seaborn as sns
import torch
import joblib
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

from models.BERT4Rec import Tokenizer

## Ideas
1. Modify KeBERT4Rec by adding genre prediction `BCEWithLogitsLoss`
2. Use first 80% of train.csv as training set, and last 20% as validation set
3. 


## Preprocess

In [2]:
# origin datasets
train_df = pd.read_csv('../data/train.csv')
test_df = pd.read_csv('../data/test.csv')
songs_df = pd.read_csv('../data/songs.csv')
songs_info_df = pd.read_csv('../data/song_extra_info.csv')

In [3]:
tr_df, val_df = train_test_split(train_df, test_size=0.25, shuffle=False)

In [4]:
tr_song_df = tr_df.merge(songs_df, how='inner', on='song_id')
split_genre = tr_song_df['genre_ids'].astype(str).str.split('|')

In [6]:
tknr = Tokenizer.load(load_kw_enc=True)

In [5]:
Tokenizer.construct(tr_song_df, 'song_id', split_genre).save()

In [13]:
members = pd.read_csv('../data/members.csv')

In [14]:
members

Unnamed: 0,msno,city,bd,gender,registered_via,registration_init_time,expiration_date
0,XQxgAYj3klVKjR3oxPPXYYFp4soD4TuBghkhMTD4oTw=,1,0,,7,20110820,20170920
1,UizsfmJb9mV54qE9hCYyU07Va97c0lCRLEQX3ae+ztM=,1,0,,7,20150628,20170622
2,D8nEhsIOBSoE6VthTaqDX8U6lqjJ7dLdr72mOyLya2A=,1,0,,4,20160411,20170712
3,mCuD+tZ1hERA/o5GPqk38e041J8ZsBaLcu7nGoIIvhI=,1,0,,9,20150906,20150907
4,q4HRBfVSssAFS9iRfxWrohxuk9kCYMKjHOEagUMV6rQ=,1,0,,4,20170126,20170613
...,...,...,...,...,...,...,...
34398,Wwd/cudKVuLJ3txRVxlg2Zaeliu+LRUfiBmfrnxhRCY=,1,0,,7,20131111,20170910
34399,g3JGnJX6Hg50lFbrNWfsHwCUmApIkiv2M8sXOaeXoIQ=,4,18,male,3,20141024,20170518
34400,IMaPMJuyN+ip9Vqi+z2XuXbFAP2kbHr+EvvCNkFfj+o=,1,0,,7,20130802,20170908
34401,WAnCAJjUty9Stv8yKtV7ZC7PN+ilOy5FX3aIJgGPANM=,1,0,,7,20151020,20170920


In [15]:
test_df

Unnamed: 0,id,msno,song_id,source_system_tab,source_screen_name,source_type
0,0,V8ruy7SGk7tDm3zA51DPpn6qutt+vmKMBKa21dp54uM=,WmHKgKMlp1lQMecNdNvDMkvIycZYHnFwDT72I5sIssc=,my library,Local playlist more,local-library
1,1,V8ruy7SGk7tDm3zA51DPpn6qutt+vmKMBKa21dp54uM=,y/rsZ9DC7FwK5F2PK2D5mj+aOBUJAjuu3dZ14NgE0vM=,my library,Local playlist more,local-library
2,2,/uQAlrAkaczV+nWCd2sPF2ekvXPRipV7q0l+gbLuxjw=,8eZLFOdGVdXBSqoAv5nsLigeH2BvKXzTQYtUM53I0k4=,discover,,song-based-playlist
3,3,1a6oo/iXKatxQx4eS9zTVD+KlSVaAFbTIqVvwLC1Y0k=,ztCf8thYsS4YN3GcIL/bvoxLm/T5mYBVKOO4C9NiVfQ=,radio,Radio,radio
4,4,1a6oo/iXKatxQx4eS9zTVD+KlSVaAFbTIqVvwLC1Y0k=,MKVMpslKcQhMaFEgcEQhEfi5+RZhMYlU3eRDpySrH8Y=,radio,Radio,radio
...,...,...,...,...,...,...
2556785,2556785,XmA/cIkXJ8jZUfsUc4bBfJVWMMqmylnPW0WVkg/iz0s=,wJoWzZd7AL+qX9xZWZwRTzGRYg0Lxcl5Pe+9n5hZgAQ=,discover,Online playlist more,online-playlist
2556786,2556786,jvbujcxCExG0CrPShsEmZ6pePcHuRqru2OLEUw85iGk=,KCJ8BlSfRQRgB7EVuzFvg52AhR8m2fT032MzN5ewbEI=,discover,,online-playlist
2556787,2556787,jvbujcxCExG0CrPShsEmZ6pePcHuRqru2OLEUw85iGk=,ySDsKJSnhFMESzC9mBhY0hnFEEBWDYx0FOMNcYDLb/A=,discover,,online-playlist
2556788,2556788,jvbujcxCExG0CrPShsEmZ6pePcHuRqru2OLEUw85iGk=,WXH8kL8e+0H2jY+s2Y6FpHyyHdaMYV3b6yfDa3FAt9U=,discover,,online-playlist


In [3]:
# modified datasets
used_songs_df = pd.read_csv('../data/used_songs.csv')

In [8]:
test_df[~test_df['song_id'].isin(train_df['song_id'])]['song_id'].unique(
).shape

(59873,)

In [11]:
# new songs in test.csv
round(59873 / test_df['song_id'].unique().shape[0], 2)

0.27

In [10]:
test_df[~test_df['msno'].isin(train_df['msno'])]['msno'].unique().shape

(3648,)

In [12]:
# new users in test.csv
round(3648 / test_df['msno'].unique().shape[0], 2)

0.15

In [7]:
tr_df, val_df = train_test_split(train_df, test_size=0.25, shuffle=False)

In [8]:
# new songs in the validation set
round(
    val_df[~val_df['song_id'].isin(tr_df['song_id'])]
    ['song_id'].unique().shape[0] / val_df['song_id'].unique().shape[0], 2)

0.28

In [9]:
# new users in the validation set
round(
    val_df[~val_df['msno'].isin(tr_df['msno'])]['msno'].unique().shape[0] /
    val_df['msno'].unique().shape[0], 2)

0.14

In [13]:
temp_df = val_df.groupby('msno').get_group(
    'FnNP1yrSvV9bSxxmccXu3PSarO2wFqhOWByD89kwmvQ=')

In [15]:
round(temp_df[temp_df['target'] == 1].shape[0] / temp_df.shape[0], 2)

0.2

In [17]:
tr_song_df = tr_df.merge(songs_df, how='inner', on='song_id')
tr_song_df.columns

Index(['msno', 'song_id', 'source_system_tab', 'source_screen_name',
       'source_type', 'target', 'song_length', 'genre_ids', 'artist_name',
       'composer', 'lyricist', 'language'],
      dtype='object')

In [20]:
round(tr_song_df.isna().sum() / tr_song_df.shape[0], 2)

msno                  0.00
song_id               0.00
source_system_tab     0.00
source_screen_name    0.05
source_type           0.00
target                0.00
song_length           0.00
genre_ids             0.02
artist_name           0.00
composer              0.23
lyricist              0.43
language              0.00
dtype: float64

In [None]:
tr_song_df['genre_ids_multi_hot'] = 

In [21]:
genre_set = set('|'.join(
    tr_song_df['genre_ids'].astype(str).tolist()).split('|'))

In [26]:
split_genre = tr_song_df['genre_ids'].astype(str).str.split('|')

In [36]:
mlb = MultiLabelBinarizer()

In [41]:
multi_hot = mlb.fit_transform(split_genre)

In [50]:
joblib.dump(mlb, '../models/BERT4Rec/tokenizer/kw_enc.pkl')

['../models/BERT4Rec/tokenizer/kw_enc.pkl']

In [34]:
split_genre[split_genre.apply(lambda x: len(x) > 1)]

13318           [465, 458]
13319           [465, 458]
13320           [465, 458]
13321           [465, 458]
13322           [465, 458]
                ...       
5532898    [139, 125, 109]
5532914       [1180, 1152]
5532945    [139, 125, 109]
5532955        [1572, 275]
5532968       [1616, 2058]
Name: genre_ids, Length: 257973, dtype: object

In [35]:
split_genre.isna().sum()

0

In [8]:
train_df.columns

Index(['msno', 'song_id', 'source_system_tab', 'source_screen_name',
       'source_type', 'target'],
      dtype='object')

In [11]:
train_df['is_train'] = 1
test_df['is_train'] = 0
test_df.drop(columns='id', inplace=True)

train_test_df = pd.concat([train_df, test_df], ignore_index=True)

In [15]:
smaple_submission = pd.read_csv('../data/sample_submission.csv')

In [14]:
(test_df.isna().sum() / len(test_df)).round(2)

msno                  0.00
song_id               0.00
source_system_tab     0.00
source_screen_name    0.06
source_type           0.00
is_train              0.00
dtype: float64

In [3]:
(songs_df.isna().sum() / len(songs_df)).round(2)

song_id        0.00
song_length    0.00
genre_ids      0.04
artist_name    0.00
composer       0.47
lyricist       0.85
language       0.00
dtype: float64

In [4]:
(songs_info_df.isna().sum() / len(songs_info_df)).round(2)

song_id    0.00
name       0.00
isrc       0.06
dtype: float64

In [5]:
full_songs_df = songs_df.merge(songs_info_df, how='inner', on='song_id')

In [6]:
(full_songs_df.isna().sum() / len(full_songs_df)).round(2)

song_id        0.00
song_length    0.00
genre_ids      0.04
artist_name    0.00
composer       0.47
lyricist       0.85
language       0.00
name           0.00
isrc           0.06
dtype: float64

In [7]:
full_songs_df.to_csv('../data/full_songs.csv', index=False, encoding='utf-8')

In [8]:
train_usr_song_df = train_df.merge(full_songs_df, how='inner', on='song_id')

In [11]:
test_usr_song_df = test_df.merge(full_songs_df, how='inner', on='song_id')

In [12]:
train_usr_song_df.columns

Index(['msno', 'song_id', 'source_system_tab', 'source_screen_name',
       'source_type', 'target', 'song_length', 'genre_ids', 'artist_name',
       'composer', 'lyricist', 'language', 'name', 'isrc'],
      dtype='object')

In [18]:
full_usr_song_df = pd.concat(
    [train_usr_song_df, test_usr_song_df.drop(columns='id')],
    ignore_index=True)

In [20]:
full_usr_song_df.to_csv('../data/full_usr_song.csv',
                        index=False,
                        encoding='utf-8')

In [27]:
# songs in train and test
used_songs_df = full_songs_df[full_songs_df['song_id'].isin(
    full_usr_song_df['song_id'].unique())].drop_duplicates(ignore_index=True)

In [29]:
used_songs_df.to_csv('../data/used_songs.csv', index=False, encoding='utf-8')

In [31]:
(used_songs_df.isna().sum() / len(used_songs_df)).round(2)

song_id        0.00
song_length    0.00
genre_ids      0.02
artist_name    0.00
composer       0.43
lyricist       0.75
language       0.00
name           0.00
isrc           0.13
dtype: float64

In [32]:
genre_set = set('|'.join(
    used_songs_df['genre_ids'].astype(str).tolist()).split('|'))

In [35]:
train_usr_song_df.groupby('msno').get_group(
    'FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=')

Unnamed: 0,msno,song_id,source_system_tab,source_screen_name,source_type,target,song_length,genre_ids,artist_name,composer,lyricist,language,name,isrc
0,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,BBzumQNXUHKdEBOB7mAJuzok+IJA1c2Ryg/yzTF6tik=,explore,Explore,online-playlist,1,206471,359,Bastille,Dan Smith| Mark Crew,,52.0,Good Grief,GBUM71602854
221,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,3qm6XTZ6MOCU11x8FIVbAGH5l5uMkT3/ZalWG1oo2Gc=,explore,Explore,online-playlist,1,187802,1011,Brett Young,Brett Young| Kelly Archer| Justin Ebach,,52.0,Sleep Without You,QM3E21606003
633,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,3Hg5kugV1S0wzEVLAEfqjIV5UHzb7bCrdBRQlGygLvU=,explore,Explore,online-playlist,1,247803,1259,Desiigner,Sidney Selby| Adnan Khan,,52.0,Panda,USUM71601094
5610,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,bPIvRTzfHxH5LgHrStll+tYwSQNVV8PySgA3M1PfTgc=,explore,Explore,online-playlist,1,181115,1011,Thomas Rhett,Thomas Rhett| Rhett Akins| Ben Hayslip,,52.0,Star Of The Show,USLXJ1607334
6621,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,EbI7xoNxI+3QSsiHxL13zBdgHIJOwa3srHd7cDcnJ0g=,explore,Explore,online-playlist,0,257369,465,OneRepublic,Ryan Tedder,,52.0,Counting Stars,USUM71301306
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7370732,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,J06isMbryq9+xddfV+bUQhEj9DKfrL3cOWN80Z87tPA=,radio,Radio,radio,0,520240,465,Tetsuya komuro (小室哲哉),Tetsuya Komuro,Tetsuya Komuro/Rap words: MARC,17.0,Judgement 2014,JPB601401528
7370733,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,YR8jzRXzET5IvYAjXBhTIFyD6JbWXkGsVPr5z79arlI=,radio,Radio,radio,0,445257,1609,Morsy| Noone Costelo,,,-1.0,Get Mad,USNRS1433298
7370767,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,LO88CiYqRzttnEVXjJvKTNi1odC2ZoHJ5cQcBXQzFzs=,radio,Radio,radio,0,435095,1609|2107,Sailor & I,,,52.0,Leave The Light On,USUS11500014
7370768,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,QXn/3rYpoGRcbsd45/lHcRvUsxqRGDHjKVVKT2CmytU=,radio,Radio,radio,0,399940,1609,Various Artists,,,52.0,Da Funk,DEH741507362


### build vocab

In [39]:
PAD, MASK = 0, 1
songs_ids = {song: i + 2 for i, song in enumerate(used_songs_df['song_id'])}
songs_ids['[PAD]'] = 0
songs_ids['[MASK]'] = 1

In [41]:
ids_songs = {idx: song for song, idx in songs_ids.items()}

In [48]:
vocab = {'item2id': songs_ids, 'id2item': ids_songs}
with open('../models/BERT4Rec/vocab.json', 'w', encoding='utf-8') as f:
    json.dump(vocab, f)

In [4]:
with open('../models/BERT4Rec/vocab.json', 'r', encoding='utf-8') as f:
    vocab = json.load(f)

In [44]:
class Tokenizer:

    def __init__(self,
                 vocab: Dict[str, Dict],
                 keyword_encoder: MultiLabelBinarizer = None) -> None:
        """Constructor
        
        :param vocab: a dict contains 'item2id' and 'id2item'
        :param keyword_encoder: a multi-hot encoder
        """

        self.vocab = vocab
        self.keyword_encoder = keyword_encoder

    @classmethod
    def construct(cls,
                  item_df: pd.DataFrame,
                  item_column: str,
                  keyword_ls: Iterable[Iterable[str]] = None):
        """Construct a Tokenizer

        :param item_df: a dataframe contains items
        :param item_column: a column contains items' ids
        :param keyword_ls: a set of keywords for each item

        :return: a Tokenizer
        """

        item2id = {
            item: idx + 2
            for idx, item in enumerate(item_df[item_column].unique())
        }
        item2id['[PAD]'] = 0
        item2id['[MASK]'] = 1
        id2item = {idx: item for item, idx in item2id.items()}

        if keyword_ls is not None:
            mlb = MultiLabelBinarizer()
            mlb.fit(keyword_ls)

        return cls({'item2id': item2id, 'id2item': id2item}, mlb)

    @classmethod
    def load(cls,
             vocab_fp: str = None,
             load_kw_enc: bool = False,
             keyword_enc_fp: str = None):
        """Load a Tokenizer

        :param vocab_fp: the vocab's file path; if None, use the default vocab
        :param load_kw_enc: whether to load the keyword multi-hot encoder; default False
        :param keyword_enc_fp: the multi-hot encoder's file path; if None, use the default multi-hot encoder

        :return: a Tokenizer
        """

        if not vocab_fp:
            vocab_fp = os.path.join(os.path.dirname(__file__), 'vocab.json')
            if not os.path.exists(vocab_fp):
                raise FileNotFoundError('No default vocab!')

        with open(vocab_fp, 'r', encoding='utf-8') as vocab_f:
            vocab = json.load(vocab_f)

        if load_kw_enc:
            if not keyword_enc_fp:
                keyword_enc_fp = os.path.join(os.path.dirname(__file__),
                                              'kw_enc.pkl')
                if not os.path.exists(vocab_fp):
                    raise FileNotFoundError('No default keyword encoder!')

            kw_enc = joblib.load(keyword_enc_fp)

        return cls(vocab, kw_enc)


In [45]:
tknr = Tokenizer.construct(tr_df, 'song_id', split_genre)

In [49]:
raise FileNotFoundError('No default vocab')

FileNotFoundError: Not default vocab

In [46]:
tknr.keyword_encoder.classes_

array(['1000', '1007', '1011', '1019', '102', '1026', '1033', '1040',
       '1047', '1054', '1068', '1082', '109', '1096', '1103', '1110',
       '1117', '1124', '1131', '1138', '1145', '1152', '1155', '1162',
       '1169', '118', '1180', '1187', '1194', '1201', '1208', '125',
       '1259', '1266', '1273', '1280', '1287', '139', '152', '1568',
       '1572', '1579', '1598', '1605', '1609', '1616', '1630', '1633',
       '177', '184', '191', '1944', '1955', '1965', '1969', '1977', '198',
       '1981', '1988', '1995', '2008', '2015', '2022', '2029', '2032',
       '205', '2052', '2058', '2065', '2072', '2079', '2086', '2093',
       '2100', '2107', '2109', '2116', '212', '2122', '2127', '2130',
       '2144', '2150', '2157', '2172', '2176', '2183', '2189', '2192',
       '2194', '2206', '2213', '2215', '2219', '2245', '2248', '242',
       '252', '275', '282', '296', '310', '331', '338', '352', '359',
       '367', '374', '381', '388', '402', '409', '416', '423', '430',
       '437',

In [7]:
tknr.vocab

{'item2id': {'CXoTN1eb7AI+DntdU1vbcwGRV4SCIDxZu+YD8JP8r4E=': 2,
  'o0kFgae9QtnYgRkVPqLJwa05zIhRlUjfF7O1tDw0ZDU=': 3,
  'DwVvVurfpuz+XPuFvucclVQEyPqcpUkHR0ne1RQzPs0=': 4,
  'dKMBWoZyScdxSkihKG+Vf47nc18N9q4m58+b4e7dSSE=': 5,
  'W3bqWd3T+VeHFzHAUfARgW9AvVRaF4N5Yzm4Mr6Eo/o=': 6,
  'kKJ2JNU5h8rphyW21ovC+RZU+yEHPM+3w85J37p7vEQ=': 7,
  'N9vbanw7BSMoUgdfJlgX1aZPE1XZg8OS1wf88AQEcMc=': 8,
  'GsCpr618xfveHYJdo+E5SybrpR906tsjLMeKyrCNw8s=': 9,
  'oTi7oINPX+rxoGp+3O6llSltQTl80jDqHoULfRoLcG4=': 10,
  'btcG03OHY3GNKWccPP0auvtSbhxog/kllIIOx5grE/k=': 11,
  'HulM/OaHgD5kUyjNQjDUf8VZdsy7h4EJUIff79Cifwo=': 12,
  'wypPzqFNdUJAqyBVxmFGaK4z7krUNWr5YqA0q0wi9eE=': 13,
  'fAZLdfQaLG76a6Ei4alt1eSjBM9rshQkiQEC6+n+y08=': 14,
  'tqBlH4r/q1Tf6C5+C6ucjGlLjMbfu5yjqB6ifRzy5dc=': 15,
  'an6EdIr+Z+KbqIVQiXn5PKkcXncefQ7hhWONseRuub4=': 16,
  'J2MFmy8iF94mExWfRWE3KxsMZB+ZIedV5liqZoSrERQ=': 17,
  'MrRilXQwoUAcoAf0N3RT82qX2/us/wEhYDXE+ZTIW5o=': 18,
  'OcG4Ya7iXmVMCMy24C5wxDMtr9w6WQZiFaN0uq6zdTk=': 19,
  'JcHIgDP5ivyqYIn7RxfXM1