In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from collections import OrderedDict
import pandas as pd
import numpy as np
from recs.jaxmodels.cf import bprmf
from recs.evaluator import evaluate
from recs import utils
from recs.dataset import negative_sample_dataset

import tensorflow as tf
from tqdm.notebook import tqdm

import jax
from jax import numpy as jnp

In [4]:
train_data, train_len, num_items, num_users = negative_sample_dataset(
    path="~/work/dataset/RC15/derived/train.df",
    batch_size=64,
    #sessionkey="userId"
)

In [5]:
num_items, num_users, train_len

(37961, 9977433, 495478)

In [6]:
params = bprmf.init_params(
    num_items,
    num_users,
    10
)

In [7]:
losses = []
stop_count = 0
before_loss = np.inf
for epoch in range(10):
    batch_loss = 0.
    with tqdm(train_data.as_numpy_iterator(), desc=f"[Epoch {epoch+1}]", total=train_len) as tm:
        for i, batch in enumerate(tm):
            params, error = bprmf.update(params, batch, alpha=0.005, lam=0.05)
            batch_loss += error
            tm.set_postfix(OrderedDict(loss=batch_loss / (i + 1)))
    batch_loss /= train_len
    if batch_loss > before_loss:
        stop_count += 1
    else:
        before_loss = batch_loss
    if stop_count == 3:
        break
    losses += np.asarray(batch_loss)

[Epoch 1]:   0%|          | 0/495478 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [11]:
def predict_items(params, data, k):
    users = np.unique(data[:, 0])
    pred_dic = {}
    for u in users:
        pred_dic[u] = np.asarray(bprmf.predict(params, u, k))
    return pd.DataFrame(pred_dic.items(), columns=["userId", "itemIds"])

In [8]:
test_data = pd.read_pickle("~/work/dataset/ml-100k/test.df")

In [12]:
test_data

Unnamed: 0,userId,itemId,rating,timestamp
3758,3,323,2,889237269
38670,3,322,3,889237269
1257,3,335,1,889237269
18385,3,264,2,889237297
53950,3,325,1,889237297
...,...,...,...,...
46773,729,689,4,893286638
73008,729,313,3,893286638
46574,729,328,3,893286638
64312,729,748,4,893286638


In [13]:
preds = predict_items(params, test_data.values, 20)

In [14]:
preds

Unnamed: 0,userId,itemIds
0,1,"[50, 100, 174, 222, 181, 143, 121, 96, 288, 21..."
1,3,"[50, 172, 181, 121, 100, 79, 1, 56, 117, 286, ..."
2,4,"[50, 127, 222, 100, 121, 1, 181, 133, 202, 168..."
3,7,"[50, 1, 56, 28, 181, 98, 204, 118, 174, 546, 4..."
4,11,"[181, 50, 100, 288, 172, 1, 191, 405, 25, 173,..."
...,...,...
296,932,"[50, 181, 258, 294, 100, 127, 732, 174, 121, 7..."
297,934,"[50, 258, 181, 127, 100, 69, 1, 174, 121, 294,..."
298,938,"[50, 181, 100, 121, 117, 1, 210, 258, 286, 172..."
299,940,"[50, 121, 64, 181, 222, 294, 210, 168, 56, 268..."


In [15]:
true_y = utils.get_y_trues(test_data)

In [18]:
evaluate.evaluate(
    true_y,
    preds,
    k=20
)

{'precision': 0.22076411960132897,
 'map': 0.25506126916519134,
 'recall': 0.06466161466089679,
 'ndcg': 0.003050455485954647,
 'mrr': 0.012696029514966392}