# MovieLens Dataset


Based on the Keras code at: https://www.kaggle.com/colinmorris/embedding-layers

In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import random

import tensorflow as tf
import jax
import jax.numpy as jnp
import functools
import optax
from flax import struct
from flax import linen as nn
from flax.training import train_state

In [2]:
ratings_df = pd.read_csv("data/movie-lens/rating.csv", usecols=['userId', 'movieId', 'rating'])
movies_df = pd.read_csv("data/movie-lens/movie.csv", usecols=['movieId', 'title'])
df = ratings_df.merge(movies_df, on='movieId').sort_values(by='userId')

In [3]:
# Normalize the ratings
df['y'] = (df['rating'] - df['rating'].mean())/df['rating'].std()
df

Unnamed: 0,userId,movieId,rating,title,y
0,1,2,3.5,Jumanji (1995),-0.024267
505014,1,541,4.0,Blade Runner (1982),0.451023
2380423,1,6807,3.5,Monty Python's The Meaning of Life (1983),-0.024267
2378699,1,6774,4.0,Videodrome (1983),0.451023
2376750,1,6755,3.5,Bubba Ho-tep (2002),-0.024267
...,...,...,...,...,...
8765330,138493,4343,4.0,Evolution (2001),0.451023
2266440,138493,5679,5.0,"Ring, The (2002)",1.401604
17434898,138493,4673,4.0,Tango & Cash (1989),0.451023
14982572,138493,836,3.5,Chain Reaction (1996),-0.024267


In [4]:
# Shuffle
df = df.sample(frac=1, random_state=1) 
# First 50,000 for test
test_df = df.iloc[:50000]
# Rest for train
train_df = df.iloc[50000:]

In [5]:
n_movies = len(df.movieId.unique())
n_users = len(df.userId.unique())
print(
    "{1:,} distinct users rated {0:,} different movies (total ratings = {2:,})".format(
        n_movies, n_users, len(df),
    )
)

138,493 distinct users rated 26,744 different movies (total ratings = 20,000,263)


In [6]:
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users + 1   # 138,493 (+1 because we are 1-indexed)
  movies_size: int = n_movies + 1 # 26,744  (+1 because we are 1-indexed)
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 10
  batch_size: int = 5000
  lr: float = 0.005
  momentum: float = 0.1

In [7]:
class MovieLensModel(nn.Module):
  """A simple embedding model."""

  config: MovieLensConfig

  @nn.compact
  def __call__(self, user_id, movie_id):
    
    cfg = self.config
    
    user_id = user_id.astype('int32')
    user_emb = nn.Embed(num_embeddings=cfg.users_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='user')(user_id)
    
    movie_id = movie_id.astype('int32')
    movie_emb = nn.Embed(num_embeddings=cfg.movies_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='movie')(movie_id)
    
    x = jnp.concatenate((user_emb, movie_emb))
    
    x = nn.Dense(cfg.dense_size_0, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6))(x)
    x = nn.relu(x)
    x = nn.Dense(cfg.dense_size_1, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6))(x)
    x = nn.relu(x)
    x = nn.Dense(cfg.out_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6))(x)

    return x

In [8]:
@functools.partial(jax.jit, static_argnums=(0))
def apply_model(cfg, state, user_id, movie_id, rating):
  
  def loss_fn(params):
    logits = MovieLensModel(cfg).apply({'params': params}, user_id, movie_id)
    loss = jnp.mean(optax.l2_loss(predictions=logits, targets=rating))
    return loss, logits
  
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  mae = jnp.mean(jnp.abs(logits - rating))
  return grads, loss, mae

In [9]:
@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [10]:
def train_epoch(cfg, state, train_df, rng):
  """Train for a single epoch."""
  train_df_size = len(train_df)
  steps_per_epoch = train_df_size // cfg.batch_size
  
  perms = jax.random.permutation(rng, train_df_size)
  perms = perms[:steps_per_epoch * cfg.batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, cfg.batch_size))
  
  epoch_loss = []
  epoch_mae = []
  
  for perm in perms:
    # x
    batch_user_id = train_df.iloc[perm]['userId'].values
    batch_movie_id = train_df.iloc[perm]['movieId'].values
    # y
    batch_rating = train_df.iloc[perm]['y'].values
    
    grads, loss, mae = apply_model(cfg, state, batch_user_id, batch_movie_id, batch_rating)
    state = update_model(state, grads)
    epoch_loss.append(loss)
    epoch_mae.append(mae)
    
  train_loss = np.mean(epoch_loss)
  train_mae = np.mean(epoch_mae)
  return state, train_loss, train_mae

In [11]:
def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  model = MovieLensModel(config)
  user_id_fake = jnp.zeros((1), jnp.int32)
  movie_id_fake = jnp.zeros((1), jnp.int32)
  # Pass fake values through our model to initialize the parameters
  params = model.init(rng, user_id_fake, movie_id_fake)['params']
  
  # TODO(joshvarty): Consider other optimizers.
  #tx = optax.sgd(config.lr)
  tx = optax.adam(config.lr)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

In [12]:
# Check if we can create trainstate
config = MovieLensConfig()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, config)
print("Success.")

Success.


In [13]:
def train_and_evaluate(config, train_df, test_df):
  rng = jax.random.PRNGKey(0)
  rng, init_rng = jax.random.split(rng)
  state = create_train_state(init_rng, config)
  
  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_mae = train_epoch(config, state, train_df, rng)
    
    user_id = test_df['userId'].values
    movie_id = test_df['movieId'].values
    rating = test_df['y'].values
    _, test_loss, test_mae = apply_model(config, state, user_id, movie_id, rating)
    
    print(
        'epoch:% 3d, train_loss: %.4f, train_mae: %.4f, test_loss: %.4f, test_mae: %.4f'
        % (epoch, train_loss, train_mae, test_loss, test_mae))
    
  return state

In [14]:
state.params['movie']

FrozenDict({
    embedding: DeviceArray([[-0.01368771, -0.00685614,  0.00945372, ...,  0.00268572,
                  -0.00731529, -0.01161271],
                 [-0.01298709,  0.00838352, -0.00408528, ...,  0.00954333,
                   0.00334894,  0.01404892],
                 [-0.01307638,  0.01251677,  0.01329136, ..., -0.004022  ,
                  -0.0090436 ,  0.0064981 ],
                 ...,
                 [-0.01188561,  0.00276067, -0.00438368, ...,  0.00160737,
                  -0.00121605,  0.01141576],
                 [ 0.01489048,  0.00284079, -0.00232035, ..., -0.00124785,
                   0.00152713, -0.01334447],
                 [-0.01218265,  0.00942387,  0.00938285, ..., -0.00825126,
                   0.01407073, -0.00186161]], dtype=float32),
})

In [15]:
state.params['user']

FrozenDict({
    embedding: DeviceArray([[-1.8625552e-04,  5.9650191e-03,  5.4601341e-04, ...,
                  -2.2080580e-03,  4.1289997e-04, -3.9349264e-03],
                 [-5.2755969e-03, -4.5083277e-03,  3.6032971e-03, ...,
                  -3.0476563e-03,  8.4897282e-04, -6.0522733e-03],
                 [ 2.1315611e-03,  5.0859689e-03,  4.6556517e-03, ...,
                  -3.5043349e-04, -1.9411424e-03,  4.7205477e-03],
                 ...,
                 [ 5.1226970e-03, -9.9316821e-06, -6.0135541e-03, ...,
                  -4.5624427e-03,  1.9483782e-03,  2.9091886e-03],
                 [-5.6149010e-03, -6.0111815e-03,  1.4060760e-03, ...,
                   3.7494712e-03,  4.8383907e-03, -8.6835132e-04],
                 [ 4.6219588e-03, -5.3815292e-03, -1.6332446e-03, ...,
                  -4.1636750e-03,  5.6607621e-03, -4.9878950e-03]],            dtype=float32),
})

In [16]:
state.params['Dense_0']

FrozenDict({
    kernel: DeviceArray([[ 0.34066415, -0.05986798, -0.16192755, -0.36246014,
                  -0.20655015, -0.06213647,  0.32437578,  0.1391137 ,
                  -0.30785975, -0.24873693,  0.30159897,  0.3001693 ,
                   0.0446817 ,  0.330935  , -0.3423905 , -0.04727726,
                   0.1855937 ,  0.15858072,  0.30912018, -0.09327415,
                  -0.11039475, -0.18706132, -0.21662776,  0.18575686,
                   0.03425615,  0.2844513 , -0.26061755,  0.12673508,
                  -0.06168669, -0.38502145,  0.08017751, -0.325711  ],
                 [ 0.21493879, -0.01516975, -0.30024102,  0.37731612,
                   0.07651681, -0.36540842, -0.32441273, -0.38400608,
                   0.19698861, -0.22489294,  0.27606323, -0.04067021,
                   0.2688869 , -0.1419014 , -0.00400826, -0.03927173,
                  -0.2551907 , -0.19325572,  0.14063184,  0.1954472 ,
                   0.28022262,  0.00699654, -0.27243087,  0.24493839

In [17]:
# Sample output
user_id = test_df['userId'].values
movie_id = test_df['movieId'].values
rating = test_df['y'].values
logits = MovieLensModel(config).apply({'params': state.params}, user_id, movie_id)
logits

DeviceArray([[-0.00035244],
             [-0.00162394],
             [-0.00116191],
             ...,
             [ 0.00347138],
             [ 0.00028476],
             [ 0.00134978]], dtype=float32)

In [None]:
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')
config = MovieLensConfig()

state = train_and_evaluate(config, train_df, test_df)

epoch:  1, train_loss: 0.5000, train_mae: 0.7993, test_loss: 0.5018, test_mae: 0.8024
epoch:  2, train_loss: 0.5000, train_mae: 0.7993, test_loss: 0.5016, test_mae: 0.8023
epoch:  3, train_loss: 0.5000, train_mae: 0.7993, test_loss: 0.5017, test_mae: 0.8023
epoch:  4, train_loss: 0.5000, train_mae: 0.7993, test_loss: 0.5016, test_mae: 0.8023
epoch:  5, train_loss: 0.5000, train_mae: 0.7993, test_loss: 0.5017, test_mae: 0.8023
epoch:  6, train_loss: 0.5000, train_mae: 0.7993, test_loss: 0.5017, test_mae: 0.8023
epoch:  7, train_loss: 0.5000, train_mae: 0.7993, test_loss: 0.5017, test_mae: 0.8023


In [None]:
state.params['movie']

In [None]:
state.params['user']

In [None]:
state.params['Dense_0']

In [22]:
# Sample output
user_id = test_df['userId'].values
movie_id = test_df['movieId'].values
rating = test_df['y'].values
logits = MovieLensModel(config).apply({'params': state.params}, user_id, movie_id)
logits

DeviceArray([[-0.00026448],
             [-0.00026448],
             [-0.00026448],
             ...,
             [-0.00026448],
             [-0.00026448],
             [-0.00026448]], dtype=float32)

## Thoughts

It's not working:

 - Are the ratings balanced? Can we rebalance?
 - Are the User/Movie embedding updating?
 - Do we want different LRs for the embeddings vs weights?


###  What is the best train/test loss we've seen? 

In [18]:
# train loss: 0.5051
# train mae:  0.7994
# test loss:  0.5174
# test mae:   0.8102
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users   # 138,493
  movies_size: int = n_movies # 26,744
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 40
  batch_size: int = 8192
  lr: float = 0.000015
  momentum: float = 0.1
    
    
# train loss: 0.5010
# test loss:  0.5134
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users   # 138,493
  movies_size: int = n_movies # 26,744
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 40
  batch_size: int = 8192
  lr: float = 0.00015
  momentum: float = 0.1
    
    
# train loss: 0.5000
# test_loss:  0.5125
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users   # 138,493
  movies_size: int = n_movies # 26,744
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 40
  batch_size: int = 8192
  lr: float = 0.0015
  momentum: float = 0.1


## TODO: Is this working?

- Add a metric. Maybe Mean Absolute Error?
- Compare MAE from predicting mean to MAE from our model.
- Look at a user's ratings compared to the predicted ratings? Are they reasonable?
- Find movies that are similar to one another.