In [1]:
import pandas as pd
import numpy as np
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td

import pytorch_lightning as pl

from tqdm.autonotebook import tqdm
import json
import sklearn.metrics as sm
from sklearn.preprocessing import LabelEncoder
from scipy import spatial

import tensorboardX as tb
import tensorflow as tf
import datetime, os

import matplotlib.pyplot as plt
import seaborn as sns

import typing as tp
import faiss
import glob
from sklearn.metrics.pairwise import euclidean_distances
from functools import partial
import shutil

np.random.seed(31337)



In [2]:
DATA_DIR = "./data"

In [3]:
data_prepared = True

In [4]:
Pair = namedtuple("Session", ["user", "start", "track", "time"])

def get_pairs(user_data):
    pairs = []
    first = None
    for _, row in user_data.sort_values("timestamp").iterrows():
        if first is None:
            first = row["track"]
        else:
            pairs.append(Pair(row["user"], first, row["track"], row["time"]))

        if row["message"] == "last":
            first = None
    return pairs

In [5]:
if not data_prepared:
    data = pd.concat([
        pd.read_json(data_path, lines=True) 
        for data_path 
        in glob.glob(DATA_DIR + "/*/data.json")
    ] + [pd.read_csv(DATA_DIR + "/contextual_data.csv")])

    data = data[["message", "timestamp", "user", "track", "time"]]
    data["timestamp"] = pd.to_datetime(data["timestamp"])

    pairs = pd.DataFrame(
        data
        .groupby("user")
        .apply(get_pairs)
        .explode()
        .values
        .tolist(),
        columns=["user", "start", "track", "time"]
    )
    pairs.to_csv(f"{DATA_DIR}/pairs_prepared.csv", index=False)
else:
    pairs = pd.read_csv(f"{DATA_DIR}/pairs_prepared.csv")
    
pairs.head()

Unnamed: 0,user,start,track,time
0,0,544,491,0.8
1,0,544,567,0.64
2,0,544,1616,0.51
3,0,544,59,0.35
4,0,544,5768,0.13


In [6]:
positives = pairs[pairs["time"] > 0.7].copy()

In [7]:
track_counts = positives.groupby("track").size()
tracks = set(track_counts[track_counts >= 10].index.values)

positives = positives[positives["track"].isin(tracks)]

len(positives), len(tracks)

(219806, 13693)

In [8]:
track_metadata = pd.read_json("data/tracks.json", lines=True).drop_duplicates(subset=["track"])
track_metadata = track_metadata.fillna(value={'genre': 'Unk'})
track_metadata["genre"] = LabelEncoder().fit_transform(track_metadata["genre"])
track_metadata["artist"] = LabelEncoder().fit_transform(track_metadata["artist"])

In [9]:
item_features = track_metadata[["track", "genre", "artist"]].set_index("track", drop=False)

In [10]:
triplets = positives[["user", "start", "track"]].rename(columns={"track": "track_pos"})

In [11]:
NUM_NEGATIVE_SAMPLES = 15
triplets =  pd.concat([triplets] * NUM_NEGATIVE_SAMPLES).sort_index().reset_index(drop=True)
triplets["track_neg"] = np.random.choice(range(50000), len(triplets))

In [12]:
rdm = np.random.random(len(triplets))
train_data = triplets[rdm < 0.8]
val_data = triplets[(rdm >= 0.8) & (rdm < 0.9)]
test_data = triplets[rdm >= 0.9]

len(train_data), len(val_data), len(test_data)

(2637540, 329905, 329645)

In [13]:
class DSSMData(pl.LightningDataModule):
    def __init__(self, train_triplets, val_triplets, test_triplets, item_features):
        super().__init__()
        self.train_triplets = train_triplets
        self.val_triplets = val_triplets
        self.item_features = item_features
        self.test_triplets = test_triplets
        
    def _track_features(self, tracks):
        return torch.from_numpy(item_features.loc[tracks].values).long()

    def _collect_data(self, triplets):
        users = torch.from_numpy(triplets["user"].values).long()
        starts = triplets["start"].values
        positives = triplets["track_pos"].values
        negatives = triplets["track_neg"].values

        return td.TensorDataset(
            users,
            self._track_features(starts),
            self._track_features(positives),
            self._track_features(negatives),
        )

    def prepare_data(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = self._collect_data(self.train_triplets)
            self.val_dataset = self._collect_data(self.val_triplets)
        elif stage == "test" or stage is None:
            self.test_dataset = self._collect_data(self.test_triplets)

    def train_dataloader(self):
        return td.DataLoader(self.train_dataset, batch_size=4096, shuffle=True, num_workers=0)

    def val_dataloader(self):
        return td.DataLoader(self.val_dataset, batch_size=4096, num_workers=0)
    
    def test_dataloader(self):
        return td.DataLoader(self.test_dataset, batch_size=4096, shuffle=False, num_workers=0)

In [14]:
class ItemNet(nn.Module):
    def __init__(self,
        genre_number: int,
        artist_number: int,
        item_number: int,
        emb_dim: int = 32,) -> None:
        super().__init__()
        self.genre_embeddings = nn.Embedding(genre_number + 1, emb_dim)
        self.artist_embeddings = nn.Embedding(artist_number + 1, emb_dim)
        self.track_embedddings = nn.Embedding(item_number + 1, emb_dim)

    def forward(self, features) -> torch.Tensor:
        emb = self.track_embedddings(features[:,0])
        emb += self.artist_embeddings(features[:,2])
        emb += self.genre_embeddings(features[:,1])
        return emb

class SessionedDSSM(pl.LightningModule):
    def __init__(
        self,
        genre_number: int,
        artist_number: int,
        item_number: int,
        user_number: int,
        embedding_dim: int = 32,
        lr: float = 1e-3,
        triplet_loss_margin: float = 0.4,
        weight_decay: float = 1e-6,
        log_to_prog_bar: bool = True,
    ) -> None:
        super().__init__()
        self.lr = lr
        self.triplet_loss_margin = triplet_loss_margin
        self.weight_decay = weight_decay
        self.log_to_prog_bar = log_to_prog_bar
        self.item_embeds = nn.EmbeddingBag(item_number+1, embedding_dim, padding_idx=item_number)
        self.item_net = ItemNet(genre_number, artist_number, item_number, embedding_dim)
        self.user_embeddings = nn.Embedding(user_number+1, embedding_dim*3)
        self.embedding_dim = embedding_dim
        
    def user_embedding_by_start(self, users, start_embeddings):
        user_embeddings = self.user_embeddings(users).reshape(-1, 3, self.embedding_dim)
        with torch.no_grad():
            ind = torch.argmin(F.pairwise_distance(start_embeddings[:,None,:], user_embeddings), dim=-1)
        user_embeddings = user_embeddings[torch.arange(users.shape[0]),ind,:]
        return user_embeddings

    def forward(
        self,
        users: torch.Tensor,
        starts: torch.Tensor,
        positives: torch.Tensor,
        negatives: torch.Tensor,
    ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        start = self.item_net(starts)
        anchor = self.user_embedding_by_start(users, start)
        pos = self.item_net(positives)
        neg = self.item_net(negatives)
        return anchor, start, pos, neg

    def _step(self, batch, batch_idx, metric, prog_bar=False):
        users, starts, positive, negative = batch
        anchor, start, pos, neg = self(users, starts, positive, negative)
        loss = F.triplet_margin_loss(anchor, pos, neg, margin=self.triplet_loss_margin)
        loss += F.triplet_margin_loss(anchor, start, neg, margin=self.triplet_loss_margin)
        self.log(metric, loss, prog_bar=prog_bar)
        return loss

    def training_step(self, batch: tp.Sequence[torch.Tensor], batch_idx: int) -> torch.Tensor:
        return self._step(batch, batch_idx, "train_loss")

    def validation_step(self, batch: tp.Sequence[torch.Tensor], batch_idx: int) -> torch.Tensor:
        return self._step(batch, batch_idx, "val_loss", self.log_to_prog_bar)
    
    def test_step(self, batch, batch_idx, prog_bar=False):
        return self._step(batch, batch_idx, "test_loss", self.log_to_prog_bar)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
        scheduler = {
            'scheduler': lr_scheduler,
            'monitor': 'val_loss'
        }
        return [optimizer], [scheduler]

In [15]:
genre_number = len(np.unique(item_features["genre"]))
artist_number = len(np.unique(item_features["artist"]))
track_number = 50000
user_number = 10000

In [51]:
data_module = DSSMData(train_data, val_data, test_data, item_features)
net = SessionedDSSM(genre_number, artist_number, track_number, user_number, embedding_dim=64)

In [52]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=f"{DATA_DIR}/checkpoints", monitor="val_loss")

trainer = pl.Trainer(
    max_epochs=50,
    accelerator='gpu',
    devices=1,
    callbacks=[
        pl.callbacks.early_stopping.EarlyStopping(monitor="val_loss", patience=10),
        pl.callbacks.LearningRateMonitor(logging_interval="step"),
        checkpoint_callback,
])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [53]:
trainer.fit(
    net,
    data_module
)

D:\c23\envs\lucky\lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:634: Checkpoint directory ./data/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type         | Params
-------------------------------------------------
0 | item_embeds     | EmbeddingBag | 3.2 M 
1 | item_net        | ItemNet      | 4.0 M 
2 | user_embeddings | Embedding    | 1.9 M 
-------------------------------------------------
9.1 M     Trainable params
0         Non-trainable params
9.1 M     Total params
36.438    Total estimated model params size (MB)


Sanity Checking: |                                                                               | 0/? [00:00<…

D:\c23\envs\lucky\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
D:\c23\envs\lucky\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=50` reached.


In [None]:
data_module.prepare_data("test")
trainer.test(ckpt_path="best", dataloaders=data_module.test_dataloader())

In [None]:
checkpoint_callback.best_model_path

'./data/checkpoints\\epoch=26-step=21978.ckpt'

In [16]:
best = SessionedDSSM.load_from_checkpoint('./data/checkpoints/epoch=49-step=32200.ckpt', genre_number=genre_number, artist_number=artist_number, item_number=track_number, user_number=user_number, embedding_dim=64)

In [17]:
device = torch.device("cuda")
best = best.to(device)

In [19]:
with torch.no_grad():
    item_embeds = best.item_net(torch.from_numpy(item_features.values).to(device)).cpu()

In [20]:
item_embeds

tensor([[-2.3135e-02,  2.8014e-01,  2.0678e-01,  ..., -1.7322e-01,
         -1.9742e-02, -2.8156e-01],
        [ 9.6068e-02, -1.1737e-01, -2.1915e-01,  ...,  9.9605e-02,
          8.0314e-02,  2.2179e-02],
        [ 2.8117e-03,  6.8514e-03,  3.3315e-02,  ...,  3.8978e-04,
         -7.5968e-02, -4.1691e-02],
        ...,
        [ 6.5128e-02,  7.3491e-01, -2.6803e-01,  ...,  1.1308e-01,
          2.3676e-02,  1.4862e-02],
        [-9.3441e-02,  5.1142e-01,  2.2801e-01,  ...,  2.1516e-01,
         -9.2805e-02,  6.0926e-02],
        [ 8.9297e-02,  6.1818e-01,  3.3492e-01,  ...,  4.3432e-01,
          8.0773e-02, -1.7692e-02]])

In [21]:
gpu_res = faiss.StandardGpuResources()
index = faiss.index_factory(64, "Flat", faiss.METRIC_L2)
# index = faiss.IndexFlatIP(64)
index = faiss.index_cpu_to_gpu(gpu_res, 0, index)

index.add(item_embeds)

In [26]:
with open(f"../../botify/data/recommendations_sessioned.json", "w") as rf:
    for user in tqdm(range(user_number)):
        with torch.no_grad():
            users = torch.zeros((1, 1), dtype=int) + user
            embeds = best.user_embeddings(users.to(device)).reshape(3, 64).cpu()
            _, tracks = index.search(embeds, k=50)
            recommendation = {
                "user": int(user),
                "tracks": tracks.tolist()
            }
            rf.write(json.dumps(recommendation) + "\n")

  0%|          | 0/10000 [00:00<?, ?it/s]