In [38]:
import functools
import itertools
import logging
import math
import os
import pickle
import sys
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import seaborn as sns
import yaml

%load_ext autoreload
%autoreload 2

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

sns.set_context("poster")
sns.set(rc={"figure.figsize": (16, 12.0)})
sns.set_style("whitegrid")

import numpy as np
import pandas as pd
import torch.nn.functional as F

pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

logging.basicConfig(level=logging.INFO, stream=sys.stdout)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [39]:
from lda4rec.datasets import Interactions, DataLoader, random_train_test_split
from lda4rec.evaluations import summary
from lda4rec.estimators import MFEst, PopEst, LDA4RecEst, SNMFEst
from lda4rec.utils import process_ids, cmp_ranks

In [40]:
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import torch
from pyro.distributions import constraints
from pyro.infer import SVI, Predictive, Trace_ELBO, TraceEnum_ELBO, config_enumerate

In [41]:
import neptune.new as neptune
neptune.init(mode="offline");

offline/1c030a42-677b-40c1-a619-b2757fb675fb
Remember to stop your run once youâ€™ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


## Experimenting with Matrix Factorization as aLDA formulation

In [42]:
loader = DataLoader()
data = loader.load_movielens("100k")

In [43]:
data.max_user_interactions_(200)
data.implicit_(0.)
train, test = random_train_test_split(data)

In [44]:
mf_est = MFEst(embedding_dim=8, n_iter=20)
mf_est.fit(train)

INFO:lda4rec.estimators:Epoch     0: loss 0.37674845956467296
INFO:lda4rec.estimators:Epoch     1: loss 0.1569650957511889
INFO:lda4rec.estimators:Epoch     2: loss 0.11594437254039017
INFO:lda4rec.estimators:Epoch     3: loss 0.09309354179815667
INFO:lda4rec.estimators:Epoch     4: loss 0.08185665727265783
INFO:lda4rec.estimators:Epoch     5: loss 0.07677619832794408
INFO:lda4rec.estimators:Epoch     6: loss 0.07107280397133248
INFO:lda4rec.estimators:Epoch     7: loss 0.066495165037545
INFO:lda4rec.estimators:Epoch     8: loss 0.06351528119396518
INFO:lda4rec.estimators:Epoch     9: loss 0.062339871185454164
INFO:lda4rec.estimators:Epoch    10: loss 0.059464904424306506
INFO:lda4rec.estimators:Epoch    11: loss 0.057876920221826514
INFO:lda4rec.estimators:Epoch    12: loss 0.058256027198119745
INFO:lda4rec.estimators:Epoch    13: loss 0.055741424632032174
INFO:lda4rec.estimators:Epoch    14: loss 0.05600147931458982
INFO:lda4rec.estimators:Epoch    15: loss 0.054386950539727065
INFO:

0.05216813558410551

### Overall summaries showing equivalence of MF and adjoint LDA formulation

In [45]:
user_id = 42
mf_est.lda_trafo = False
mf_est.predict(user_id)

array([12.159285  , 14.964415  ,  6.598125  , ...,  0.17913675,
       -2.4951115 , -4.671628  ], dtype=float32)

In [46]:
summary(mf_est, train=train, test=test)

Unnamed: 0_level_0,train,test
metric,Unnamed: 1_level_1,Unnamed: 2_level_1
prec,0.307213,0.103612
recall,0.056619,0.074053
mrr,0.517084,0.26205


In [47]:
mf_est.lda_trafo = True
mf_est.predict(user_id) # the numbers differ as expected due to the transformation

array([1.4415156e-04, 1.4848042e-04, 1.2672733e-04, ..., 1.0437275e-04,
       9.6469092e-05, 9.0835732e-05], dtype=float32)

In [48]:
summary(mf_est, train=train, test=test)

Unnamed: 0_level_0,train,test
metric,Unnamed: 1_level_1,Unnamed: 2_level_1
prec,0.311148,0.110837
recall,0.057248,0.08019
mrr,0.517007,0.262178


### Compare equivalence of the ranking from MF and adjoint LDA formulation for a single user

In [49]:
user_id = 140
mf_est.lda_trafo = False
orig_scores = mf_est.predict(np.array([user_id], dtype=int))
item_probs = mf_est.get_item_probs(torch.tensor([user_id]))

In [50]:
cmp_ranks(orig_scores, item_probs, eps=1e-5)

False