In [29]:
import pandas as pd
import numpy as np

from funk_svd.dataset import fetch_ml_ratings
from funk_svd import SVD

from sklearn.metrics import mean_absolute_error


df = fetch_ml_ratings(variant='100k')
df = df.drop('timestamp', axis=1)

train = df.sample(frac=0.8, random_state=7)
val = df.drop(train.index.tolist()).sample(frac=0.5, random_state=8)
test = df.drop(train.index.tolist()).drop(val.index.tolist())

svd = SVD(learning_rate=0.001, regularization=0.005, n_epochs=100,
          n_factors=15, min_rating=1, max_rating=5)

svd.fit(X=train, X_val=val, early_stopping=True, shuffle=False)

pred = svd.predict(test)
mae = mean_absolute_error(test["rating"], pred)

print(f'Test MAE: {mae:.2f}')

Preprocessing data...

Epoch 1/100  | val_loss: 1.17 - val_rmse: 1.08 - val_mae: 0.91 - took 0.0 sec
Epoch 2/100  | val_loss: 1.11 - val_rmse: 1.05 - val_mae: 0.87 - took 0.0 sec
Epoch 3/100  | val_loss: 1.07 - val_rmse: 1.03 - val_mae: 0.85 - took 0.0 sec
Epoch 4/100  | val_loss: 1.04 - val_rmse: 1.02 - val_mae: 0.83 - took 0.0 sec
Epoch 5/100  | val_loss: 1.02 - val_rmse: 1.01 - val_mae: 0.82 - took 0.0 sec
Epoch 6/100  | val_loss: 1.00 - val_rmse: 1.00 - val_mae: 0.81 - took 0.0 sec
Epoch 7/100  | val_loss: 0.99 - val_rmse: 1.00 - val_mae: 0.80 - took 0.0 sec
Epoch 8/100  | val_loss: 0.98 - val_rmse: 0.99 - val_mae: 0.80 - took 0.0 sec
Epoch 9/100  | val_loss: 0.97 - val_rmse: 0.99 - val_mae: 0.79 - took 0.0 sec
Epoch 10/100 | val_loss: 0.96 - val_rmse: 0.98 - val_mae: 0.79 - took 0.0 sec
Epoch 11/100 | val_loss: 0.96 - val_rmse: 0.98 - val_mae: 0.78 - took 0.0 sec
Epoch 12/100 | val_loss: 0.95 - val_rmse: 0.98 - val_mae: 0.78 - took 0.0 sec
Epoch 13/100 | val_loss: 0.95 - val_rmse:

In [65]:
df

Unnamed: 0,u_id,i_id,rating
0,259,255,4.0
1,259,286,4.0
2,259,298,4.0
3,259,185,4.0
4,259,173,4.0
...,...,...,...
99994,729,328,3.0
99995,729,333,4.0
99996,729,313,3.0
99997,729,748,4.0


In [25]:
import sys

import tensorflow as tf

sys.path.append('/Users/tmo/Nodes/bandits/')

DATA_PATH = '/Users/tmo/Data/bandits/'

from src.sample_jester_data import sample_jester_data

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [26]:
dataset, opt_rewards, opt_actions, num_actions, context_dim = sample_jester_data(DATA_PATH+'jester/jester_data_40jokes_19181users.npy', 
                                                                                 num_contexts=2000,
                                                                                 pct_zero=None)

In [35]:
dataset

array([[ 8.79,  0.  ,  1.84, ...,  9.22, -0.73, -0.63],
       [ 5.87, -3.06, -1.02, ..., -1.75,  1.94,  0.97],
       [ 5.83, -6.5 , -9.27, ...,  6.99, -5.05,  2.18],
       ...,
       [ 4.76,  8.98, -1.31, ..., -9.13,  4.81, -9.71],
       [ 6.5 ,  8.88,  8.01, ...,  3.79, -4.17,  6.07],
       [ 7.82,  8.3 ,  8.25, ...,  7.72,  6.65,  3.16]])

In [37]:
_df = pd.DataFrame(dataset)

In [49]:
_df_unstacked = _df.unstack().reset_index()

In [66]:
_df_unstacked.columns = ['u_id', 'i_id', 'rating']

In [67]:
_df_unstacked

Unnamed: 0,u_id,i_id,rating
0,0,0,8.79
1,0,1,5.87
2,0,2,5.83
3,0,3,6.50
4,0,4,3.20
...,...,...,...
79995,39,1995,9.13
79996,39,1996,-6.99
79997,39,1997,-9.71
79998,39,1998,6.07


In [68]:
train = _df_unstacked.sample(frac=0.8, random_state=7)
val = _df_unstacked.drop(train.index.tolist()).sample(frac=0.5, random_state=8)
test = _df_unstacked.drop(train.index.tolist()).drop(val.index.tolist())

svd = SVD(learning_rate=0.001, regularization=0.1, n_epochs=1000,
          n_factors=100, min_rating=-10, max_rating=10)

svd.fit(X=train, X_val=val, early_stopping=True, shuffle=False)

pred = svd.predict(test)
mae = mean_absolute_error(test["rating"], pred)

print(f'Test MAE: {mae:.2f}')

Preprocessing data...

Epoch 1/1000  | val_loss: 23.12 - val_rmse: 4.81 - val_mae: 3.94 - took 0.0 sec
Epoch 2/1000  | val_loss: 22.63 - val_rmse: 4.76 - val_mae: 3.89 - took 0.0 sec
Epoch 3/1000  | val_loss: 22.22 - val_rmse: 4.71 - val_mae: 3.86 - took 0.0 sec
Epoch 4/1000  | val_loss: 21.74 - val_rmse: 4.66 - val_mae: 3.82 - took 0.0 sec
Epoch 5/1000  | val_loss: 21.11 - val_rmse: 4.59 - val_mae: 3.76 - took 0.0 sec
Epoch 6/1000  | val_loss: 20.24 - val_rmse: 4.50 - val_mae: 3.68 - took 0.0 sec
Epoch 7/1000  | val_loss: 19.20 - val_rmse: 4.38 - val_mae: 3.57 - took 0.0 sec
Epoch 8/1000  | val_loss: 18.28 - val_rmse: 4.28 - val_mae: 3.46 - took 0.0 sec
Epoch 9/1000  | val_loss: 17.69 - val_rmse: 4.21 - val_mae: 3.38 - took 0.0 sec
Epoch 10/1000 | val_loss: 17.39 - val_rmse: 4.17 - val_mae: 3.33 - took 0.0 sec
Epoch 11/1000 | val_loss: 17.24 - val_rmse: 4.15 - val_mae: 3.30 - took 0.0 sec
Epoch 12/1000 | val_loss: 17.13 - val_rmse: 4.14 - val_mae: 3.29 - took 0.0 sec
Epoch 13/1000 | v

In [69]:
test

Unnamed: 0,u_id,i_id,rating
16,0,16,6.50
18,0,18,9.37
20,0,20,3.98
22,0,22,8.93
25,0,25,-8.79
...,...,...,...
79959,39,1959,2.23
79964,39,1964,-4.08
79987,39,1987,1.46
79991,39,1991,8.83


In [70]:
pred

[5.484376722183473,
 5.911298447190121,
 2.790321089251032,
 5.651411343858635,
 -3.7212248610812613,
 2.4940173576640596,
 3.471709994231837,
 4.132546260208146,
 3.043700713151564,
 0.9493241117201419,
 -0.13400632566917547,
 5.621928205497216,
 1.6544916573794606,
 4.498445870316021,
 4.019980253281044,
 -3.5154545639319084,
 4.77792160997671,
 -0.6835558001102857,
 -3.8185506728608996,
 3.590353990984645,
 0.623579641955071,
 4.870570735383628,
 2.1251081365923588,
 2.2717118591725076,
 4.645559140002199,
 3.831110252928475,
 1.8086796626834039,
 3.457703901997771,
 0.7351716898644232,
 4.844616516030554,
 5.315206898345507,
 2.3458544929391585,
 4.9767772099662375,
 -0.1365460594403447,
 0.3366325746097636,
 2.267764632122058,
 4.734909360307653,
 -1.7242904817285996,
 4.090437908050443,
 2.442332952328657,
 4.225380519836837,
 4.355922361740967,
 3.9890320407297786,
 2.560453119363135,
 6.001544864915421,
 6.358177943197363,
 3.1502801478489095,
 7.410913693806693,
 4.23761330144