**About :** Computes matrix factorization embeddings

In [None]:
cd ../src

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import gc
import sys
import cudf
import json
import glob
import pickle
import warnings
import itertools
import numpy as np
import pandas as pd

from tqdm import tqdm
from datetime import datetime
from collections import Counter
from numerize.numerize import numerize

from merlin.io import Dataset
from torch.optim import SparseAdam
from merlin.loader.torch import Loader

warnings.simplefilter(action="ignore", category=FutureWarning)

In [None]:
from params import *

from utils.load import load_sessions
from utils.metrics import get_coverage

### Params

In [None]:
MODE = "val"
NO_CLICKS = False

In [None]:
if MODE == "val":
    files = glob.glob("../output/full_train_parquet/*") + glob.glob(
        "../output/val_parquet/*"
    )
elif MODE == "test":
    files = glob.glob("../output/full_train_val_parquet/*") + glob.glob(
        "../output/test_parquet/*"
    )
else:
    raise NotImplementedError

In [None]:
train_pairs = cudf.concat([cudf.read_parquet(f) for f in files], ignore_index=True)

In [None]:
if NO_CLICKS:
    train_pairs = train_pairs[train_pairs['type'] != "clicks"].reset_index(drop=True)

In [None]:
# Single shift

SHIFT = 1  # this can be modified
SHIFTS = None

train_pairs['aid_next'] = train_pairs.groupby('session').aid.shift(-1 * SHIFT)
train_pairs = train_pairs[['aid', 'aid_next']].dropna().reset_index(drop=True)

In [None]:
# Several Shifts 

# SHIFTS =  [1, 2, 3, 4, 5]  # this can be modified
# SHIFT = "1-5"  # this can be modified

# train_pairs_ = []

# for shift in tqdm(SHIFTS):
#     train_pairs['aid_next'] = train_pairs.groupby('session').aid.shift(-1 * shift)
#     train_pairs_.append(train_pairs[['aid', 'aid_next']].dropna().reset_index(drop=True).to_pandas())

# train_pairs = cudf.from_pandas(pd.concat(train_pairs_, ignore_index=True).drop_duplicates(keep="first"))

In [None]:
print('Number of pairs', numerize(len(train_pairs)))

In [None]:
train_pairs.to_pandas().to_parquet(
    f"../output/matrix_factorization/{MODE}_pairs.parquet"
)

In [None]:
train_pairs.tail(10_000_000).to_parquet(
    f"../output/matrix_factorization/{MODE}_pairs_val.parquet"
)

### Utils

In [None]:
import torch
from torch import nn


class MatrixFactorization(nn.Module):
    def __init__(self, n_aids, n_factors):
        super().__init__()
        self.aid_factors = nn.Embedding(n_aids, n_factors, sparse=True)

    def forward(self, aid1, aid2):
        aid1 = self.aid_factors(aid1)
        aid2 = self.aid_factors(aid2)

        return (aid1 * aid2).sum(dim=1)


In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)

In [None]:
train_ds = Dataset(f"../output/matrix_factorization/{MODE}_pairs.parquet")
train_dl_merlin = Loader(train_ds, 65536, True)

valid_ds = Dataset(f"../output/matrix_factorization/{MODE}_pairs_val.parquet")
valid_dl_merlin = Loader(valid_ds, 65536, True)

In [None]:
DIM = 64

N_AIDS = 1855602
EPOCHS = 20
LR = 0.1

In [None]:
model = MatrixFactorization(N_AIDS + 1, DIM)
model.to("cuda")

optimizer = SparseAdam(model.parameters(), lr=LR)
criterion = nn.BCEWithLogitsLoss()

In [None]:
for epoch in range(1, EPOCHS + 1):
    for batch, _ in train_dl_merlin:
        model.train()
        losses = AverageMeter("Loss", ":.4e")

        aid1, aid2 = batch["aid"], batch["aid_next"]
        aid1 = aid1.to("cuda")
        aid2 = aid2.to("cuda")
        output_pos = model(aid1, aid2)
        output_neg = model(aid1, aid2[torch.randperm(aid2.shape[0])])

        output = torch.cat([output_pos, output_neg])
        targets = torch.cat([torch.ones_like(output_pos), torch.zeros_like(output_pos)])
        loss = criterion(output, targets)
        losses.update(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()

    with torch.no_grad():
        accuracy = AverageMeter("accuracy")
        for batch, _ in valid_dl_merlin:
            aid1, aid2 = batch["aid"], batch["aid_next"]
            output_pos = model(aid1, aid2)
            output_neg = model(aid1, aid2[torch.randperm(aid2.shape[0])])
            accuracy_batch = (
                torch.cat([output_pos.sigmoid() > 0.5, output_neg.sigmoid() < 0.5])
                .float()
                .mean()
            )
            accuracy.update(accuracy_batch, aid1.shape[0])

    print(
        f"Epoch {epoch:02d}/{EPOCHS} \t loss={losses.avg:.3f} \t val_acc={accuracy.avg:.3f}"
    )

In [None]:
embeddings = model.aid_factors.weight.detach().cpu().numpy().astype("float32")

name = f"embed_{SHIFT}_{DIM}{'_cartbuy' if NO_CLICKS else ''}_{MODE}.npy"
np.save(f"../output/matrix_factorization/{name}", embeddings)

print(
    f"Saved matrix of shape {embeddings.shape} to",
    f"../output/matrix_factorization/{name}",
)

Done !