In [1]:
import datetime
from collections import Counter, defaultdict
from itertools import chain

import pandas as pd
import numpy as np
import scipy.sparse as sp
import torch
from torch import nn
import torch.nn.functional as F

Используем [датасет прослушиваний с last.fm](http://ocelma.net/MusicRecommendationDataset/lastfm-1K.html)

In [2]:
header = ['user_name', 'time', 'artist_id', 'artist_name', 'track_id', 'track_name']
listens = pd.read_csv('lastfm-dataset-1K/userid-timestamp-artid-artname-traid-traname.tsv', 
                      delimiter='\t', error_bad_lines=False, names=header)

In [3]:
def to_unix_time(dt):
    return datetime.datetime.strptime(dt, '%Y-%m-%dT%H:%M:%SZ').timestamp()

listens['ts'] = listens['time'].apply(to_unix_time)

ts_threshold = sorted(listens['ts'])[int(len(listens) * .8)]

train_listens = listens[listens['ts'] < ts_threshold]
val_listens = listens[listens['ts'] >= ts_threshold]
# Разбиение на тренировочную и валидационную выборки по времени

In [4]:
train_users = set(train_listens['user_name'])
val_users = set(val_listens['user_name'])
common_users = list(train_users & val_users)
user_ids = {u: i for i, u in enumerate(common_users)}

val_artists = {v for v, c in Counter(val_listens['artist_id']).most_common() if c >= 250}
train_artists = {v for v, c in Counter(train_listens['artist_id']).most_common() if c >= 250}
common_artists = list(val_artists & train_artists - {np.NaN})
artist_ids = {a: i for i, a in enumerate(common_artists)}
# Возьмём слушателей и исполнителей, присутствующих в обеих выборках

In [5]:
def filter_listens(listens):
    listens = listens[listens['user_name'].apply(lambda u: u in user_ids) & 
                      listens['artist_id'].apply(lambda a: a in artist_ids)]
    listens = listens.assign(uid=listens['user_name'].apply(user_ids.get))
    listens = listens.assign(aid=listens['artist_id'].apply(artist_ids.get))
    return listens

train_listens = filter_listens(train_listens)
val_listens = filter_listens(val_listens)

In [6]:
def make_user_item_matrix(listens):
    user_item = sp.coo_matrix((len(common_users), len(common_artists)))
    for u, a in listens[['uid', 'aid']].iterrows():
        user_item[u, a] = 1
    return user_item

N_USERS = len(common_users)
N_ITEMS = len(common_artists)
train_user_item = sp.coo_matrix((np.ones(len(train_listens)), (train_listens['uid'], train_listens['aid']))).tocsr().tocoo()
val_user_item = sp.coo_matrix((np.ones(len(val_listens)), (val_listens['uid'], val_listens['aid']))).tocsr().tocoo()

In [7]:
art_id_name = {
    artist_ids[aid]: art_name 
    for _, aid, art_name in listens[['artist_id', 'artist_name']].itertuples() 
    if aid in artist_ids
}
art_name_id = {art_name: aid for aid, art_name in art_id_name.items()}

artists = pd.DataFrame({'name': art_id_name.values()}, index=art_id_name.keys())

In [8]:
def get_artist_names(art_ids):
    sim_df = pd.DataFrame({'aid': art_ids})
    return sim_df.join(artists, on='aid')['name']


def get_similars(item_id, model):
    similars = model.similar_items(item_id)
    sim_ids = [m_id for m_id, _ in similars]
    return get_artist_names(sim_ids)


def get_recommendations(user_id, model):
    recs = model.recommend(user_id)
    rec_ids = [m_id for m_id, _ in recs]
    return get_artist_names(rec_ids)

In [9]:
def view_hist(user_id):
    return train_listens[train_listens['uid'] == user_id]

In [10]:
view_hist(26)
# Будем использовать этого пользователя для вывода рекомендаций

Unnamed: 0,user_name,time,artist_id,artist_name,track_id,track_name,ts,uid,aid
14227973,user_000758,2008-10-15T19:20:43Z,0da0f48c-3689-4c38-bf4a-c5b50d516689,Mucc,,Saishuu Ressha ~Album Edit~,1.224088e+09,26,1559
14227974,user_000758,2008-10-15T19:16:26Z,0da0f48c-3689-4c38-bf4a-c5b50d516689,ムック,b40954f4-a082-4e89-bff4-30d9c6a694fa,Zetsubou,1.224087e+09,26,1559
14227975,user_000758,2008-10-15T19:12:42Z,0da0f48c-3689-4c38-bf4a-c5b50d516689,ムック,72d4a391-7a1e-4b4f-8a17-ca3fed4873b0,Gerbera,1.224087e+09,26,1559
14227976,user_000758,2008-10-15T19:10:38Z,0da0f48c-3689-4c38-bf4a-c5b50d516689,ムック,49951db3-80c9-41f3-8945-5a6535a379f2,Kuchiki No Tou,1.224087e+09,26,1559
14227977,user_000758,2008-10-15T19:06:51Z,0da0f48c-3689-4c38-bf4a-c5b50d516689,ムック,384ff389-4332-4b9f-ae25-f2faa4fe7a0e,Daremo Inai Ie,1.224087e+09,26,1559
...,...,...,...,...,...,...,...,...,...
14234703,user_000758,2006-02-19T01:42:14Z,20fb6261-2468-4eb7-acfa-48bdd67929a6,D'Espairsray,bd7ed34f-410d-4ade-b2b9-5bf610b6d927,Garnet,1.140303e+09,26,598
14234704,user_000758,2006-02-19T01:37:19Z,20fb6261-2468-4eb7-acfa-48bdd67929a6,D'Espairsray,a18d5b8a-e077-4a8b-b600-73cfab1bc215,Tsuki No Kioku -Fallen-,1.140302e+09,26,598
14234705,user_000758,2006-02-19T01:28:39Z,20fb6261-2468-4eb7-acfa-48bdd67929a6,D'Espairsray,2251efd3-d1ff-4541-a4f0-9e9c76eed545,In Vain,1.140302e+09,26,598
14234706,user_000758,2006-02-19T01:24:34Z,20fb6261-2468-4eb7-acfa-48bdd67929a6,D'Espairsray,2747e93d-def8-4e70-b8d8-fae8a87f2cc3,Dears,1.140301e+09,26,598


In [11]:
val_user_artists = defaultdict(set)
for _, uid, aid in val_listens[['uid', 'aid']].itertuples():
    val_user_artists[uid].add(aid)

In [12]:
# Метрики: precision@k, NDCG
def prec(model, k):
    rels = []
    for u, listened in val_user_artists.items():
        recs = model.recommend(u, n_recs=k)
        rels.append(sum(r in listened for r, _ in recs))
    return np.mean(rels) / k


gain_discounts = 1 / np.log2(np.arange(2, 52))


def ndcg(model, k):
    normalizer = 1 / gain_discounts[:k].sum()
    ndcgs = []
    for u, listened in val_user_artists.items():
        recs = model.recommend(u, n_recs=k)
        gains = np.array([r in listened for r, _ in recs])
        dcg = np.sum(gains * gain_discounts[:k])
        ndcg = dcg * normalizer
        ndcgs.append(ndcg)
    return np.mean(ndcgs)

In [13]:
class NegativeSampler:
    def __init__(self, interactions, n_items, pop_dist=False):
        self.positives = sp.csr_matrix(interactions)
        self.n_items = n_items
        self.items = interactions.col if pop_dist else np.unique(interactions.col)
        
    def get_positive_mask(self, samples, users):
        return np.array(self.positives[users, samples], np.bool).ravel()
        
    def sample(self, users):
        samples = np.random.choice(self.items, users.shape)
        positive_mask = self.get_positive_mask(samples, users)
        while np.any(positive_mask):
            samples[positive_mask] = np.random.choice(self.items, positive_mask.sum())
            positive_mask = self.get_positive_mask(samples, users)
        return samples

## WARP matrix factorization model

In [14]:
def scalar_prods(vecs1, vecs2):
    return np.sum(vecs1 * vecs2, axis=1).flatten()


class MatrixFactorizationBase:
    def __init__(self, dim, reg_param, n_users, n_items):
        self.dim = dim
        self.n_users = n_users
        self.n_items = n_items
        init_std = 1 / dim ** .5
        self.users_embeddings = np.random.normal(0, init_std, (n_users, dim))
        self.items_embeddings = np.random.normal(0, init_std, (n_items, dim))
        self.users_biases = np.random.uniform(0, .5, n_users)
        self.items_biases = np.random.uniform(0, .5, n_items)
        self.reg_param = reg_param
    
    def fit(self, interactions, n_epochs, lr):
        pass
    
    def similarities(self, users_ids, items_ids):
        return self.users_biases[users_ids] + self.items_biases[items_ids] + \
                scalar_prods(self.users_embeddings[users_ids], self.items_embeddings[items_ids])
    
    def recommend(self, user_id, n_recs = 20):
        similarities = self.items_embeddings @ self.users_embeddings[user_id]
        closest_item_ids = similarities.argsort()[::-1][:n_recs]
        return list(zip(closest_item_ids, similarities[closest_item_ids]))
    
    def similar_items(self, item_id, n_items = 20):
        similarities = self.items_embeddings @ self.items_embeddings[item_id]
        items_by_similariry = similarities.argsort()[::-1]
        items_by_similariry = items_by_similariry[items_by_similariry != item_id]
        most_similar_items = items_by_similariry[:n_items]
        return list(zip(most_similar_items, similarities[most_similar_items]))

In [15]:
WARP_BATCH_SIZE = 4
WARP_MAX_SAMPLE_TRIALS = 100
WARP_MARGIN = 1


def project_vectors(vectors, indexes, max_norm):
    vector_norms = np.linalg.norm(vectors[indexes], axis=1)
    vectors[indexes] *= np.maximum(max_norm / vector_norms, 1).reshape((-1, 1))


class WARPMF(MatrixFactorizationBase):
    def __init__(self, dim, reg_param, n_users, n_items):
        super().__init__(dim, reg_param, n_users, n_items)
        self.items_biases.fill(0.)
        
    def fit(self, interactions, n_epochs, lr):
        users = interactions.row
        positives = interactions.col
        neg_sampler = NegativeSampler(interactions, self.n_items)
            
        for epoch in range(1, n_epochs + 1):
            loss = 0.
            indexes = np.arange(interactions.nnz)
            for batch_start in range(0, interactions.nnz, WARP_BATCH_SIZE):
                batch_indexes = indexes[batch_start:batch_start + WARP_BATCH_SIZE]
                batch_users = users[batch_indexes]
                batch_positives = positives[batch_indexes]
                positives_similarities = self.similarities(batch_users, batch_positives)
                
                batch_negatives = neg_sampler.sample(batch_users)
                negatives_similarities = self.similarities(batch_users, batch_negatives)
                good_mask = positives_similarities - negatives_similarities > WARP_MARGIN
                sampling_counters = np.ones(len(batch_users))
                for _ in range(WARP_MAX_SAMPLE_TRIALS):
                    n_good = good_mask.sum()
                    if n_good == 0:
                        break
                    batch_negatives[good_mask] = neg_sampler.sample(batch_users[good_mask])
                    sampling_counters[good_mask] += 1
                    negatives_similarities[good_mask] = self.similarities(
                        batch_users[good_mask], batch_negatives[good_mask])
                    good_mask = positives_similarities - negatives_similarities > WARP_MARGIN
                to_opt_mask = ~good_mask
                n_to_opt = to_opt_mask.sum()
                
                batch_users = batch_users[to_opt_mask]
                batch_positives = batch_positives[to_opt_mask]
                batch_negatives = batch_negatives[to_opt_mask]
                positives_similarities = positives_similarities[to_opt_mask]
                negatives_similarities = negatives_similarities[to_opt_mask]
                samples_weights = np.log((WARP_MAX_SAMPLE_TRIALS - 1) / sampling_counters[to_opt_mask])
                
                
                loss += np.sum((WARP_MARGIN + negatives_similarities - positives_similarities) * samples_weights)
                positive_biases_grads = -samples_weights
                negative_biases_grads = samples_weights
                samples_weights = np.expand_dims(samples_weights, 1)
                user_grads = samples_weights * \
                        (self.items_embeddings[batch_negatives] - self.items_embeddings[batch_positives])
                positive_grads = samples_weights * (-self.users_embeddings[batch_users])
                negative_grads = samples_weights * self.users_embeddings[batch_users]
                
                np.add.at(self.users_embeddings, batch_users, -lr * user_grads)
                np.add.at(self.items_embeddings, batch_positives, -lr * positive_grads)
                np.add.at(self.items_embeddings, batch_negatives, -lr * negative_grads)
                project_vectors(self.users_embeddings, batch_users, self.reg_param)
                project_vectors(self.items_embeddings, batch_positives, self.reg_param)
                project_vectors(self.items_embeddings, batch_negatives, self.reg_param)
                np.add.at(self.items_biases, batch_positives, -lr * positive_biases_grads)
                np.add.at(self.items_biases, batch_negatives, -lr * negative_biases_grads)
            print(f'Epoch {epoch} loss {loss:.3f}')

In [16]:
warp_model = WARPMF(64, 4, N_USERS, N_ITEMS)
warp_model.fit(train_user_item, 5 , .01)

Epoch 1 loss 1374152.617
Epoch 2 loss 940351.201
Epoch 3 loss 826802.291
Epoch 4 loss 751453.331
Epoch 5 loss 704487.107


In [17]:
get_similars(art_name_id['Metallica'], warp_model)

0                Judas Priest
1                     Pantera
2           Dark Tranquillity
3                Machine Head
4                  Papa Roach
5                  Iced Earth
6                        Tool
7                     Trivium
8     Queens Of The Stone Age
9                   Aerosmith
10                  Motörhead
11                        Afi
12                   Megadeth
13                       Korn
14             Twisted Sister
15                   Static-X
16                     Danzig
17                  Rammstein
18                 Rob Zombie
19                Limp Bizkit
Name: name, dtype: object

In [18]:
get_recommendations(26, warp_model)

0                    雅-Miyavi-
1                         椎名林檎
2                      Pierrot
3                          ムック
4                         Korn
5                 Malice Mizer
6             System Of A Down
7                          Afi
8                 Orange Range
9                          Boa
10                    The Used
11                  Rob Zombie
12    Asian Kung-Fu Generation
13           Within Temptation
14                        坂本真綾
15              T.M.Revolution
16           Avenged Sevenfold
17                  Stone Sour
18               After Forever
19                   Buck-Tick
Name: name, dtype: object

In [19]:
prec(warp_model, 1), prec(warp_model, 10), ndcg(warp_model, 20)

(0.4654320987654321, 0.4253086419753086, 0.41674201497068053)

## Neural matrix factorization Model

In [20]:
class NMFModule(nn.Module):
    def __init__(self, n_users, n_items, dim):
        super().__init__()
        self.users_emb = nn.Embedding(n_users, dim)
        self.items_emb = nn.Embedding(n_items, dim)
        self.mlp = nn.Sequential(
            nn.Linear(2 * dim, 2 * dim),
            nn.PReLU(),
            nn.Linear(2 * dim, dim),
            nn.PReLU()
        )
        self.final_clf = nn.Sequential(
            nn.Linear(2 * dim, 1),
            nn.Sigmoid()
        )
        
    def forward(self, users, items):
        users_emb = self.users_emb(users)
        items_emb = self.items_emb(items)
        ewp = users_emb * items_emb
        mlp_out = self.mlp(torch.cat((users_emb, items_emb), dim=1))
        return self.final_clf(torch.cat((ewp, mlp_out), dim=1))

    
class NMFDataLoader:
    def __init__(self, users, items, neg_sampler, batch_size):
        super().__init__()
        self.users = users
        self.items = items
        self.neg_sampler = neg_sampler
        self.batch_size = batch_size
    
    def __len__(self):
        return (len(self.users) + self.batch_size - 1) // self.batch_size
    
    def __getitem__(self, index):
        batch_start = index * self.batch_size
        batch_end = batch_start + self.batch_size
        users = torch.tensor(self.users[batch_start:batch_end], dtype=torch.long)
        pos_items = self.items[batch_start:batch_end]
        neg_items = self.neg_sampler.sample(self.users[batch_start:batch_end])
        pos_items, neg_items = torch.tensor(pos_items, dtype=torch.long), torch.tensor(neg_items, dtype=torch.long)
        return users, pos_items, neg_items
    
    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

            
def train_nmf(model, data, n_epochs, opt, lr_scheduler, device):
    loss_function = nn.BCELoss()
    model = model.to(device)
    for epoch in range(n_epochs):
        losses = []
        for users, pos_items, neg_items in data:
            users, pos_items, neg_items = users.to(device), pos_items.to(device), neg_items.to(device)
            opt.zero_grad()
            pos_pred = model(users, pos_items)
            neg_pred = model(users, neg_items)
            loss = loss_function(torch.cat((pos_pred, neg_pred)), 
                        torch.cat((torch.ones((len(pos_items), 1)), torch.zeros((len(neg_items), 1)))).to(device))
            loss.backward()
            opt.step()
            losses.append(loss.detach())
        lr_scheduler.step()
        print(f'Epoch {epoch} loss: {torch.tensor(losses).mean():.6f}')

In [21]:
class NMFModel:
    def __init__(self, n_users, n_items, dim):
        self.nmf_module = NMFModule(n_users, n_items, dim)
        self.n_items = n_items
    
    def fit(self, interactions, device='cuda'):
        neg_sampler = NegativeSampler(interactions, self.n_items, True)
        users = interactions.row
        items = interactions.col
        nmf_dl = NMFDataLoader(users, items, neg_sampler, batch_size=16384)
        opt = torch.optim.Adam(self.nmf_module.parameters(), lr=.02)
        sched = torch.optim.lr_scheduler.StepLR(opt, 40, .20)
        train_nmf(self.nmf_module, nmf_dl, 200, opt, sched, device)
        self.nmf_module.cpu()
        
    def recommend(self, user_id, n_recs = 20):
        user_tensor = torch.tensor([user_id] * self.n_items, dtype=torch.long)
        items = torch.arange(self.n_items)
        with torch.no_grad():
            similarities = self.nmf_module(user_tensor, items).numpy().ravel()
        closest_item_ids = similarities.argsort()[::-1][:n_recs]
        return list(zip(closest_item_ids, similarities[closest_item_ids]))
    
    def similar_items(self, item_id, n_items = 20):
        items_emb = next(self.nmf_module.items_emb.parameters()).detach().numpy()
        similarities = items_emb @ items_emb[item_id]
        items_by_similariry = similarities.argsort()[::-1]
        items_by_similariry = items_by_similariry[items_by_similariry != item_id]
        most_similar_items = items_by_similariry[:n_items]
        return list(zip(most_similar_items, similarities[most_similar_items]))

In [22]:
nmf_model = NMFModel(N_USERS, N_ITEMS, 64)
nmf_model.fit(train_user_item)

Epoch 0 loss: 0.701347
Epoch 1 loss: 0.689099
Epoch 2 loss: 0.684184
Epoch 3 loss: 0.679517
Epoch 4 loss: 0.675791
Epoch 5 loss: 0.672159
Epoch 6 loss: 0.664645
Epoch 7 loss: 0.652486
Epoch 8 loss: 0.642903
Epoch 9 loss: 0.635540
Epoch 10 loss: 0.615018
Epoch 11 loss: 0.613042
Epoch 12 loss: 0.590295
Epoch 13 loss: 0.572129
Epoch 14 loss: 0.554333
Epoch 15 loss: 0.543234
Epoch 16 loss: 0.514438
Epoch 17 loss: 0.490559
Epoch 18 loss: 0.471722
Epoch 19 loss: 0.455592
Epoch 20 loss: 0.449577
Epoch 21 loss: 0.445043
Epoch 22 loss: 0.426775
Epoch 23 loss: 0.411436
Epoch 24 loss: 0.398060
Epoch 25 loss: 0.391628
Epoch 26 loss: 0.382789
Epoch 27 loss: 0.372112
Epoch 28 loss: 0.367241
Epoch 29 loss: 0.365319
Epoch 30 loss: 0.361790
Epoch 31 loss: 0.354124
Epoch 32 loss: 0.346698
Epoch 33 loss: 0.337026
Epoch 34 loss: 0.332369
Epoch 35 loss: 0.330591
Epoch 36 loss: 0.324201
Epoch 37 loss: 0.323091
Epoch 38 loss: 0.325803
Epoch 39 loss: 0.329352
Epoch 40 loss: 0.321898
Epoch 41 loss: 0.317315
Ep

In [24]:
get_similars(art_name_id['Eminem'], nmf_model)

0                    2Pac
1      Christina Aguilera
2                   Jay-Z
3                  Brandy
4              Snoop Dogg
5                 Dr. Dre
6                 50 Cent
7            Busta Rhymes
8                    T.I.
9                 Rihanna
10               The Game
11    The Black Eyed Peas
12                Cascada
13         Britney Spears
14             Jurassic 5
15      Justin Timberlake
16              Sean Paul
17                    Nas
18                  Ciara
19           Wu-Tang Clan
Name: name, dtype: object

In [25]:
get_recommendations(26, nmf_model)

0                     Boa
1          T.M.Revolution
2     My Chemical Romance
3            Rise Against
4             Iron Maiden
5                Thursday
6       The Coffinshakers
7                    椎名林檎
8               雅-Miyavi-
9                The Used
10                   植松伸夫
11                 宇多田ヒカル
12              [Unknown]
13           Danny Elfman
14           3 Doors Down
15            Linkin Park
16              Metallica
17             Audioslave
18            Dir En Grey
19                Pierrot
Name: name, dtype: object

In [26]:
prec(nmf_model, 1), prec(nmf_model, 10), ndcg(nmf_model, 20)

(0.4876543209876543, 0.4424691358024691, 0.4309850081425156)

## Attention model

In [27]:
class ComiRecDataloader:
    def __init__(self, n_users, listens, min_len=3):
        histories = [[] for u in range(n_users)]
        for _, uid, aid in listens[['uid', 'aid']].itertuples():
            histories[uid].append(aid)
        self.data = [(u, h) for u, h in enumerate(histories) if len(h) >= min_len]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        user, history = self.data[index]
        n_pred_points = max(64, min(1, len(history) // 16))
        pred_points = torch.randint(2, len(history), (n_pred_points,))
        return user, torch.tensor(history, dtype=torch.long), pred_points

In [28]:
class AttentionHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attention_layers = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Tanh(),
            nn.Linear(dim, 1)
        )
        
    def forward(self, history):
        w = self.attention_layers(history)
        w = F.softmax(w, dim=0)
        return history.T @ w
    
    def calc_points(self, history, pred_points):
        w = self.attention_layers(history)
        w = torch.exp(w)
        return torch.stack([(history[:p].T @ (w[:p] / w[:p].sum())).flatten()
                          for p in pred_points])
        

def calc_dists(user_embeddings, item_embeddings):
    return (user_embeddings * item_embeddings).sum(axis=-1).max(axis=0)[0]
        
        
class ComiRecModel:
    def __init__(self, n_users, n_items, dim, n_heads=4):
        self.item_emb = nn.Embedding(n_items, dim)
        self.attention_heads = [AttentionHead(dim) for _ in range(n_heads)]
        self.n_users = n_users
        self.n_items = n_items
        self.user_histories = []
        self.item_vectors = None
        
    def fit(self, listens, user_item, device='cuda'):
        self.item_emb.to(device)
        for head in self.attention_heads:
            head.to(device)
        
        data_loader = ComiRecDataloader(self.n_users, listens)
        neg_sampler = NegativeSampler(user_item, self.n_items)
        opt = torch.optim.Adam(
            chain(self.item_emb.parameters(), 
            *(head.parameters() for head in self.attention_heads)), 
            lr=.01)
        sched = torch.optim.lr_scheduler.StepLR(opt, 2, .4)
        for epoch in range(10):
            losses = []
            for user, history, pred_points in data_loader:
                opt.zero_grad()
                neg_samples = neg_sampler.sample(user + np.zeros(len(pred_points), np.long))
                neg_samples = torch.tensor(neg_samples, dtype=torch.long)
                history, neg_samples = history.to(device), neg_samples.to(device)
                hist_emb = self.item_emb(history)
                neg_emb = self.item_emb(neg_samples)
                u_embs = [head.calc_points(hist_emb, pred_points) for head in self.attention_heads]
                u_embs = torch.stack(u_embs)
                hist_dist = calc_dists(u_embs, hist_emb[pred_points])
                neg_dist = calc_dists(u_embs, neg_emb)
                loss = -torch.sigmoid(hist_dist).mean() + torch.sigmoid(neg_dist).mean()
                loss.backward()
                opt.step()
                losses.append(loss.detach().cpu())
            print(f'Epoch {epoch} loss {torch.tensor(losses).mean():.4f}')
            sched.step()
        self.item_emb.cpu()
        for head in self.attention_heads:
            head.cpu()
        for u in range(self.n_users):
            self.user_histories.append(torch.tensor(listens[listens['uid'] == u]['aid'].values))
        self.item_vectors = next(self.item_emb.parameters()).detach()
    
    def similar_items(self, item_id, n_items = 20):
        similarities = (self.item_vectors @ self.item_vectors[item_id]).numpy()
        items_by_similariry = similarities.argsort()[::-1]
        items_by_similariry = items_by_similariry[items_by_similariry != item_id]
        most_similar_items = items_by_similariry[:n_items]
        return list(zip(most_similar_items, similarities[most_similar_items]))
    
    def recommend(self, user_id, n_recs = 20):
        with torch.no_grad():
            hist = self.user_histories[user_id]
            hist_emb = self.item_emb(hist)
            usr_embs = torch.stack([head(hist_emb).flatten() for head in self.attention_heads])
            prods = torch.sum(usr_embs.unsqueeze(1) * self.item_vectors.unsqueeze(0), dim=2)
            similarities = prods.max(dim=0)[0].numpy()
        closest_item_ids = similarities.argsort()[::-1][:n_recs]
        return list(zip(closest_item_ids, similarities[closest_item_ids]))

In [29]:
cr_model = ComiRecModel(N_USERS, N_ITEMS, 64)
cr_model.fit(train_listens, train_user_item)

Epoch 0 loss -0.3345
Epoch 1 loss -0.5009
Epoch 2 loss -0.5644
Epoch 3 loss -0.5851
Epoch 4 loss -0.6055
Epoch 5 loss -0.6203
Epoch 6 loss -0.6277
Epoch 7 loss -0.6273
Epoch 8 loss -0.6334
Epoch 9 loss -0.6351


In [30]:
get_recommendations(26, cr_model)

0                          Afi
1                    Metallica
2               Marilyn Manson
3             System Of A Down
4                Guns N' Roses
5                  Iron Maiden
6                    Green Day
7                  Linkin Park
8                    In Flames
9     The All-American Rejects
10             Jimmy Eat World
11                   Nightwish
12                       Queen
13         My Chemical Romance
14                         Him
15                    Megadeth
16          30 Seconds To Mars
17                    The Used
18                  Pink Floyd
19                 Lacuna Coil
Name: name, dtype: object

In [31]:
get_similars(art_name_id['Rihanna'], cr_model)

0      Christina Aguilera
1           Avril Lavigne
2                Paramore
3     Natasha Bedingfield
4      The Pussycat Dolls
5            Phil Collins
6              Yellowcard
7                    P!Nk
8         Jimmy Eat World
9                 Rhianna
10             Fort Minor
11             John Mayer
12              Anastacia
13         Jennifer Lopez
14                  Robyn
15           Fall Out Boy
16            Conjure One
17                 Fragma
18               Maroon 5
19            Dir En Grey
Name: name, dtype: object

In [32]:
prec(cr_model, 1), prec(cr_model, 10), ndcg(cr_model, 20)

(0.6370370370370371, 0.47358024691358025, 0.4512121739218625)