In [102]:
import sys
import logging

import numpy as np
import scipy as sp
import sklearn
import statsmodels.api as sm
from statsmodels.formula.api import ols

%load_ext autoreload
%autoreload 2

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import seaborn as sns
sns.set_context("poster")
sns.set(rc={'figure.figsize': (16, 9.)})
sns.set_style("whitegrid")

import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

logging.basicConfig(level=logging.INFO, stream=sys.stdout)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [103]:
import torch

from lrann.datasets import DataLoader, random_train_test_split, Interactions
from lrann.estimators import ImplicitEst, ExplicitEst
from lrann.models import BilinearNet, DeepNet
from lrann.evaluations import mrr_score, precision_recall_score, rmse_score
from lrann.utils import is_cuda_available

In [104]:
def get_latent(n_users, n_items):
    users = np.random.uniform(-1, 1, size=n_users)
    items = np.random.uniform(-1, 1, size=n_items)
    return users, items

def get_interactions(users, items, size):
    """
    Multiply user and item latent variable and select as positive depending on the outcome
    """
    runs = 0
    while size > 0:
        user_idx = np.random.randint(users.shape[0])
        item_idx = np.random.randint(items.shape[0])
        user, item = users[user_idx], items[item_idx]
        prod = user * item
        signum = np.sign(prod)
        p = abs(prod)
        r = np.random.binomial(2, p)  # flip coin two times and only accept if two times head
        if r == 2:
            size -= 1
            yield (user_idx, item_idx, int(signum))
        runs += 1
    print(f"Number of runs {runs}")

In [105]:
# generate random latent features for users and items
users, items = get_latent(100, 2000)

In [106]:
# get the interactions using the latent features of users and items
raw = np.array(list(get_interactions(users, items, 20000)))

Number of runs 197320


In [107]:
user_ids = raw[:, 0]
item_ids = raw[:, 1]
ratings = raw[:, 2]
interactions = Interactions(user_ids, item_ids, ratings, n_users=users.shape[0], n_items=items.shape[0])

In [108]:
train, test = random_train_test_split(interactions)

In [111]:
# Switch here between MF model and deep neural network
nn_model = DeepNet(interactions.n_users, interactions.n_items, embedding_dim=1, sparse=False, activation=torch.sigmoid)
lra_model = BilinearNet(interactions.n_users, interactions.n_items, embedding_dim=1, sparse=False, biases=False)

In [112]:
lra_est = ExplicitEst(model=lra_model, n_iter=50, batch_size=128, learning_rate=1e-2)
nn_est = ExplicitEst(model=nn_model, n_iter=50, batch_size=128, learning_rate=1e-2)

In [113]:
nn_est.fit(train, verbose=True)

Epoch 0: loss 0.6947774755981324
Epoch 1: loss 0.6922463919538033
Epoch 2: loss 0.6855434861785604
Epoch 3: loss 0.6722886386613851
Epoch 4: loss 0.6599923098967442
Epoch 5: loss 0.624903425195586
Epoch 6: loss 0.560733764457881
Epoch 7: loss 0.5208981120447681
Epoch 8: loss 0.5006947540801504
Epoch 9: loss 0.4914328445417202
Epoch 10: loss 0.4860767373586982
Epoch 11: loss 0.48355621634625945
Epoch 12: loss 0.48181525095657496
Epoch 13: loss 0.4810477784462435
Epoch 14: loss 0.4801130995691393
Epoch 15: loss 0.4789100445048245
Epoch 16: loss 0.47805104399726267
Epoch 17: loss 0.4773739681688852
Epoch 18: loss 0.4769845797971727
Epoch 19: loss 0.47686072881990355
Epoch 20: loss 0.476827058758232
Epoch 21: loss 0.47657457849846896
Epoch 22: loss 0.4765679371958176
Epoch 23: loss 0.4766288414578053
Epoch 24: loss 0.47647129085266254
Epoch 25: loss 0.47662735692016095
Epoch 26: loss 0.47641312599900015
Epoch 27: loss 0.47643526609647546
Epoch 28: loss 0.4762992487455724
Epoch 29: loss 0.4

In [85]:
lra_est.fit(train, verbose=True)

Epoch 0: loss 0.7608586030714338
Epoch 1: loss 0.7091694725261544
Epoch 2: loss 0.6784868082392363
Epoch 3: loss 0.6203073064435938
Epoch 4: loss 0.48330450981322176
Epoch 5: loss 0.29723157076525075
Epoch 6: loss 0.15999313538801427
Epoch 7: loss 0.08811275183832747
Epoch 8: loss 0.05248033852791808
Epoch 9: loss 0.033841191302825036
Epoch 10: loss 0.0233135456945501
Epoch 11: loss 0.01685581814844717
Epoch 12: loss 0.012649136715430826
Epoch 13: loss 0.0097808331673591
Epoch 14: loss 0.007733471781275093
Epoch 15: loss 0.006226185972047261
Epoch 16: loss 0.005088296016271556
Epoch 17: loss 0.0042114117318469615
Epoch 18: loss 0.0035241378490211968
Epoch 19: loss 0.0029784079931282105
Epoch 20: loss 0.002540380587982205
Epoch 21: loss 0.002184245476079342
Epoch 22: loss 0.0018904150384573546
Epoch 23: loss 0.00164746427344192
Epoch 24: loss 0.0014441051354210303
Epoch 25: loss 0.001272951538578634
Epoch 26: loss 0.0011266847544082945
Epoch 27: loss 0.0010018706035706314
Epoch 28: loss

In [86]:
np.min(nn_est.predict(3))

2.3900706e-05

In [87]:
rmse_score(nn_est, train), rmse_score(nn_est, test)

(0.008392262378638134, 0.08265019373064733)

In [88]:
rmse_score(lra_est, train), rmse_score(lra_est, test)

(0.0003582965959585957, 0.07778868741917662)

In [101]:
np.sign(-users) - np.sign(nn_model.user_embeddings.weight.detach().numpy().flatten())

array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  2., -2.,  0.,  0., -2.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  2.,  2.,  0.,  2.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  2.,  0.,  0.,
        0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,
        0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

Parameter containing:
tensor([[-3.3687],
        [-4.3450],
        [ 3.6985],
        [-3.3264],
        [-4.1404],
        [-3.8634],
        [-2.0612],
        [-4.1487],
        [ 4.0532],
        [-4.2517],
        [-3.2820],
        [ 4.1630],
        [ 3.6715],
        [-4.2080],
        [-4.3776],
        [ 3.2657],
        [-3.2119],
        [ 4.1774],
        [ 0.6590],
        [-3.9009],
        [-4.0437],
        [-2.7811],
        [ 3.9576],
        [ 4.0030],
        [-3.3528],
        [ 4.0974],
        [ 2.0056],
        [-4.0730],
        [-3.7225],
        [ 3.9802],
        [-4.0895],
        [ 3.3490],
        [ 2.8927],
        [-4.2397],
        [ 2.3696],
        [ 4.0655],
        [ 4.0230],
        [-4.2375],
        [-3.2643],
        [ 4.3706],
        [ 4.0434],
        [ 3.0078],
        [ 2.9739],
        [ 3.1208],
        [ 1.3883],
        [ 3.9635],
        [-4.0353],
        [ 4.0761],
        [-3.8637],
        [ 3.6994],
        [ 3.3272],
        [