# MovieLens Dataset


Based on the Keras code at: https://www.kaggle.com/colinmorris/embedding-layers

In [5]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import random

import tensorflow as tf
import jax
import jax.numpy as jnp
import functools
import optax
from flax import struct
from flax import linen as nn
from flax.training import train_state

In [6]:
ratings_df = pd.read_csv("data/movie-lens/rating.csv", usecols=['userId', 'movieId', 'rating', 'y'])
movies_df = pd.read_csv("data/movie-lens/movie.csv", usecols=['movieId', 'title', 'year'])
df = ratings_df.merge(movies_df, on='movieId').sort_values(by='userId')

In [7]:
# Shuffle
df = df.sample(frac=1, random_state=1) 
# First 50,000 for test
test_df = df.iloc[:50000]
# Rest for train
train_df = df.iloc[50000:]

In [8]:
train_df[(train_df['userId'] == 85731) & (train_df['movieId'] == 1883)]

Unnamed: 0,userId,movieId,rating,y,title,year
12904240,85731,1883,4.5,0.974498,Labyrinth,1986


In [9]:
n_movies = len(df.movieId.unique())
n_users = len(df.userId.unique())
print(
    "{1:,} distinct users rated {0:,} different movies (total ratings = {2:,})".format(
        n_movies, n_users, len(df),
    )
)

138,493 distinct users rated 26,744 different movies (total ratings = 20,000,263)


In [10]:
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users + 1   # 138,493 (+1 because we are 1-indexed)
  movies_size: int = n_movies + 1 # 26,744  (+1 because we are 1-indexed)
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 10
  batch_size: int = 5000
  lr: float = 0.01
  momentum: float = 0.1

In [24]:
class MovieLensModel(nn.Module):
  """A simple embedding model."""

  config: MovieLensConfig

  @nn.compact
  def __call__(self, user_id, movie_id):
    
    cfg = self.config
    
    user_id = user_id.astype('int32')
    user_emb = nn.Embed(num_embeddings=cfg.users_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='user')(user_id)
    
    movie_id = movie_id.astype('int32')
    movie_emb = nn.Embed(num_embeddings=cfg.movies_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='movie')(movie_id)
    
    x = jnp.concatenate((user_emb, movie_emb), axis=-1)
    
    # TODO(joshvarty): Consider initializing with bias of zero.
    x = nn.Dense(cfg.dense_size_0, kernel_init=nn.initializers.kaiming_uniform())(x)
    x = nn.relu(x)
    x = nn.Dense(cfg.dense_size_1, kernel_init=nn.initializers.kaiming_uniform())(x)
    x = nn.relu(x)
    x = nn.Dense(cfg.out_size, kernel_init=nn.initializers.kaiming_uniform())(x)

    return x

In [25]:
@functools.partial(jax.jit, static_argnums=(0))
def apply_model(cfg, state, user_id, movie_id, rating):
  
  def loss_fn(params):
    logits = MovieLensModel(cfg).apply({'params': params}, user_id, movie_id)
    #loss = jnp.mean(optax.l2_loss(predictions=logits, targets=rating))
    loss = jnp.mean(jnp.square(logits - rating))
    return loss, logits
  
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  mae = jnp.mean(jnp.abs(logits - rating))
  return grads, loss, mae

In [26]:
@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [29]:
def train_epoch(cfg, state, train_df, rng):
  """Train for a single epoch."""
  train_df_size = len(train_df)
  steps_per_epoch = train_df_size // cfg.batch_size
  
  perms = jax.random.permutation(rng, train_df_size)
  perms = perms[:steps_per_epoch * cfg.batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, cfg.batch_size))
  
  epoch_loss = []
  epoch_mae = []
  print(len(perms))
  

  for perm in perms:
    # x
    batch_user_id = train_df.iloc[perm]['userId'].values
    batch_movie_id = train_df.iloc[perm]['movieId'].values
    # y
    batch_rating = train_df.iloc[perm]['y'].values
    

    grads, loss, mae = apply_model(cfg, state, batch_user_id, batch_movie_id, batch_rating)
    state = update_model(state, grads)
    
    user_emb_new = state.params['user']['embedding']
    movie_emb_new = state.params['movie']['embedding']
    # Percentage identical 
    #print("User: ", int((user_emb_old == user_emb_new).flatten().sum()) / len((user_emb_old == user_emb_new).flatten()))
    # Percentage identical 
    #print("Movie: ", int((movie_emb_old == movie_emb_new).flatten().sum()) / len((movie_emb_old == movie_emb_new).flatten()))
    
    epoch_loss.append(loss)
    epoch_mae.append(mae)
    
  print("done epoch: ", len(epoch_mae))
  train_loss = np.mean(epoch_loss)
  train_mae = np.mean(epoch_mae)
  return state, train_loss, train_mae

In [30]:
def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  model = MovieLensModel(config)
  user_id_fake = jnp.zeros((1), jnp.int32)
  movie_id_fake = jnp.zeros((1), jnp.int32)
  # Pass fake values through our model to initialize the parameters
  params = model.init(rng, user_id_fake, movie_id_fake)['params']
  
  # TODO(joshvarty): Consider other optimizers.
  #tx = optax.sgd(config.lr)
  tx = optax.sgd(config.lr)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

In [31]:
# # Check if we can create trainstate
# config = MovieLensConfig()
# rng = jax.random.PRNGKey(0)
# rng, init_rng = jax.random.split(rng)
# state = create_train_state(init_rng, config)
# print("Success.")

In [32]:
def train_and_evaluate(config, train_df, test_df):
  rng = jax.random.PRNGKey(0)
  rng, init_rng = jax.random.split(rng)
  state = create_train_state(init_rng, config)
  
  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_mae = train_epoch(config, state, train_df, rng)
    
    user_id = test_df['userId'].values
    movie_id = test_df['movieId'].values
    rating = test_df['y'].values
    _, test_loss, test_mae = apply_model(config, state, user_id, movie_id, rating)
    
    print(
        'epoch:% 3d, train_loss: %.4f, train_mae: %.4f, test_loss: %.4f, test_mae: %.4f'
        % (epoch, train_loss, train_mae, test_loss, test_mae))
    
  return state

In [33]:
# # Sample output
# user_id = test_df['userId'].values
# movie_id = test_df['movieId'].values
# rating = test_df['y'].values
# logits = MovieLensModel(config).apply({'params': state.params}, user_id, movie_id)
# logits

In [34]:
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')
config = MovieLensConfig()

#user_emb_old = state.params['user']['embedding']
#movie_emb_old = state.params['movie']['embedding']

state = train_and_evaluate(config, train_df, test_df)

3990
done epoch:  3990
epoch:  1, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3990
epoch:  2, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3990
epoch:  3, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3990
epoch:  4, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3990
epoch:  5, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3990
epoch:  6, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3990
epoch:  7, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3990
epoch:  8, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3990
epoch:  9, train_loss: 1.1067, train_mae: 0.8408, test_loss: 1.1131, test_mae: 0.8441
3990
done epoch:  3

In [38]:
state.params

FrozenDict({
    Dense_0: {
        bias: DeviceArray([-0.00013747,  0.00110754, -0.00449737, -0.00134757,
                     -0.00165991, -0.00498055, -0.0063904 , -0.00415923,
                     -0.00318376,  0.00032267,  0.00685282, -0.00381637,
                     -0.00149815, -0.0004154 , -0.00461463,  0.00388077,
                      0.00117954, -0.0023359 ,  0.00042156, -0.00568703,
                      0.00087653, -0.0043879 ,  0.0016294 , -0.00485105,
                      0.00231849, -0.00219046, -0.0037174 , -0.00364537,
                     -0.00550478,  0.00726458, -0.00369152, -0.00152601],            dtype=float32),
        kernel: DeviceArray([[-0.23310493,  0.41453663,  0.32547474, -0.00497667,
                      -0.44627872,  0.5162732 ,  0.5208649 , -0.01651442,
                      -0.01088391,  0.4541001 ,  0.05976991,  0.16411215,
                      -0.17918831,  0.27834204,  0.01817691, -0.40108287,
                       0.4852821 , -0.5079375 , -0

In [None]:
user_emb_new = state.params['user']['embedding']
movie_emb_new = state.params['movie']['embedding']

In [None]:
n_users, len(user_emb_new), len(user_emb_new)*8

In [None]:
# Total number of embedding values
len((user_emb_old == user_emb_new).flatten())

In [None]:
# Number of identical embeddings
int((user_emb_old == user_emb_new).flatten().sum())

In [None]:
# Percentage identical 
int((user_emb_old == user_emb_new).flatten().sum()) / len((user_emb_old == user_emb_new).flatten())

In [None]:
n_movies, len(movie_emb_new), len(movie_emb_new)*8

In [None]:
# Percentage identical 
int((movie_emb_old == movie_emb_new).flatten().sum()) / len((movie_emb_old == movie_emb_new).flatten())

In [45]:
state.params

FrozenDict({
    Dense_0: {
        bias: DeviceArray([-0.00013747,  0.00110754, -0.00449737, -0.00134757,
                     -0.00165991, -0.00498055, -0.0063904 , -0.00415923,
                     -0.00318376,  0.00032267,  0.00685282, -0.00381637,
                     -0.00149815, -0.0004154 , -0.00461463,  0.00388077,
                      0.00117954, -0.0023359 ,  0.00042156, -0.00568703,
                      0.00087653, -0.0043879 ,  0.0016294 , -0.00485105,
                      0.00231849, -0.00219046, -0.0037174 , -0.00364537,
                     -0.00550478,  0.00726458, -0.00369152, -0.00152601],            dtype=float32),
        kernel: DeviceArray([[-0.23310493,  0.41453663,  0.32547474, -0.00497667,
                      -0.44627872,  0.5162732 ,  0.5208649 , -0.01651442,
                      -0.01088391,  0.4541001 ,  0.05976991,  0.16411215,
                      -0.17918831,  0.27834204,  0.01817691, -0.40108287,
                       0.4852821 , -0.5079375 , -0

In [37]:
# Sample output
user_id = test_df['userId'].values
movie_id = test_df['movieId'].values
rating = test_df['y'].values
logits = MovieLensModel(config).apply({'params': state.params}, user_id, movie_id)
logits

DeviceArray([[0.00030525],
             [0.00030525],
             [0.00030525],
             ...,
             [0.00030525],
             [0.00030525],
             [0.00030525]], dtype=float32)

In [43]:
test_df

Unnamed: 0,userId,movieId,rating,y,title,year
17081086,123425,167,2.0,-1.525502,Free Willy 2: The Adventure Home,1995
11578986,108287,1195,4.0,0.474498,The Godfather: Part II,1974
15951635,109381,3255,5.0,1.474498,Birdy,1984
12481916,64404,2957,5.0,1.474498,Dead Again,1991
15295156,137309,3018,2.5,-1.025502,Awakenings,1990
...,...,...,...,...,...,...
2278349,103718,2873,4.5,0.974498,Fight Club,1999
10185612,121483,7041,3.0,-0.525502,The Lord of the Rings: The Return of the King,2003
11804158,79531,286,3.0,-0.525502,Only You,1994
7720568,67054,1125,3.5,-0.025502,Wallace & Gromit: The Wrong Trousers,1993


## Thoughts

It's not working:

 - Are the ratings balanced? Can we rebalance?
 - Are the User/Movie embedding updating?
 - Do we want different LRs for the embeddings vs weights?


###  What is the best train/test loss we've seen? 

In [None]:
# train loss: 0.5051
# train mae:  0.7994
# test loss:  0.5174
# test mae:   0.8102
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users   # 138,493
  movies_size: int = n_movies # 26,744
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 40
  batch_size: int = 8192
  lr: float = 0.000015
  momentum: float = 0.1
    
    
# train loss: 0.5010
# test loss:  0.5134
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users   # 138,493
  movies_size: int = n_movies # 26,744
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 40
  batch_size: int = 8192
  lr: float = 0.00015
  momentum: float = 0.1
    
    
# train loss: 0.5000
# test_loss:  0.5125
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users   # 138,493
  movies_size: int = n_movies # 26,744
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 40
  batch_size: int = 8192
  lr: float = 0.0015
  momentum: float = 0.1


## TODO: Is this working?

- Add a metric. Maybe Mean Absolute Error?
- Compare MAE from predicting mean to MAE from our model.
- Look at a user's ratings compared to the predicted ratings? Are they reasonable?
- Find movies that are similar to one another.