# Loss Sanity Check

In the previous notebook, we changed the forward pass so:

1. We initialized our Jax network with the Keras weights (no improvement).
2. We corrected the shapes of our network outputs (training loss now decreases).

However we're still not done. The validation loss and MAE get worse and worse as our model trains. (We don't even have an initial period in which they do better...).

In this notebook I want to:

1. Make sure test loss and test MAE are calculated correctly (compare to Keras if possible).
2. Make JAX training loss as close as possible to Keras training loss.




### Initialize Jax Network with Keras Weights

In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import random

import tensorflow as tf
import jax
import jax.numpy as jnp
import functools
import optax

from flax.core import frozen_dict
from flax import struct
from flax import linen as nn
from flax.training import train_state

from matplotlib import pyplot as plt
from tensorflow import keras

tf.__version__

'2.7.0'

In [2]:
ratings_df = pd.read_csv("data/movie-lens/rating.csv", usecols=['userId', 'movieId', 'rating', 'y'])
movies_df = pd.read_csv("data/movie-lens/movie.csv", usecols=['movieId', 'title', 'year'])
df = ratings_df.merge(movies_df, on='movieId').sort_values(by='userId')

In [39]:
# Shuffle
df = df.sample(frac=1, random_state=1) 
# First 50,000 for test
test_df = df.iloc[:50000]
# Rest for train
train_df = df.iloc[50000:]

n_movies = len(df.movieId.unique())
n_users = len(df.userId.unique())
print(
    "{1:,} distinct users rated {0:,} different movies (total ratings = {2:,})".format(
        n_movies, n_users, len(df),
    )
)

138,493 distinct users rated 26,744 different movies (total ratings = 20,000,263)


In [40]:
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users + 1   # 138,494 (+1 because we are 1-indexed)
  movies_size: int = n_movies + 0 # 26,744  (+0 because we are 0-indexed)
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 10
  batch_size: int = 5000
  lr: float = 0.01
  momentum: float = 0.1

In [41]:
class MovieLensModel(nn.Module):
  """A simple embedding model."""

  config: MovieLensConfig

  @nn.compact
  def __call__(self, user_id, movie_id):
    
    cfg = self.config
    
    user_id = user_id.astype('int32')
    user_emb = nn.Embed(num_embeddings=cfg.users_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='user')(user_id)
    
    movie_id = movie_id.astype('int32')
    movie_emb = nn.Embed(num_embeddings=cfg.movies_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='movie')(movie_id)
    
    x = jnp.concatenate((user_emb, movie_emb), axis=-1)
    x = jnp.squeeze(x)
        
    x = nn.Dense(cfg.dense_size_0, kernel_init=nn.initializers.kaiming_uniform())(x)
    x = nn.relu(x)
    x = nn.Dense(cfg.dense_size_1, kernel_init=nn.initializers.kaiming_uniform())(x)
    x = nn.relu(x)
    x = nn.Dense(cfg.out_size, kernel_init=nn.initializers.kaiming_uniform())(x)

    return x

In [42]:
def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  model = MovieLensModel(config)
  user_id_fake = jnp.zeros((5,1), jnp.int32)
  movie_id_fake = jnp.zeros((5,1), jnp.int32)
  # Pass fake values through our model to initialize the parameters
  params = model.init(rng, user_id_fake, movie_id_fake)['params']
  
  # TODO(joshvarty): Consider other optimizers.
  #tx = optax.sgd(config.lr)
  tx = optax.adam(0.005)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

In [43]:
# Check if we can create trainstate
config = MovieLensConfig()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, config)
print("Success.")

Success.


In [44]:
def get_keras_model(): 
  tf.random.set_seed(1)
  np.random.seed(1)
  random.seed(1)

  hidden_units = (32,4)
  movie_embedding_size = 8
  user_embedding_size = 8

  # Each instance will consist of two inputs: a single user id, and a single movie id
  user_id_input = keras.Input(shape=(1,), name='user_id')
  movie_id_input = keras.Input(shape=(1,), name='movie_id')
  user_embedded = keras.layers.Embedding(df.userId.max()+1, user_embedding_size, 
                                         input_length=1, name='user_embedding')(user_id_input)
  movie_embedded = keras.layers.Embedding(df.movieId.max()+1, movie_embedding_size, 
                                          input_length=1, name='movie_embedding')(movie_id_input)
  # Concatenate the embeddings (and remove the useless extra dimension)
  concatenated = keras.layers.Concatenate()([user_embedded, movie_embedded])
  out = keras.layers.Flatten()(concatenated)

  # Add one or more hidden layers
  for n_hidden in hidden_units:
      out = keras.layers.Dense(n_hidden, activation='relu')(out)

  # A single output: our predicted rating
  out = keras.layers.Dense(1, activation='linear', name='prediction')(out)

  model = keras.Model(
      inputs = [user_id_input, movie_id_input],
      outputs = out,
  )
  return model

model = get_keras_model()
keras_weights = model.get_weights()
model.summary(line_length=88)

Model: "model_2"
________________________________________________________________________________________
 Layer (type)                Output Shape       Param #   Connected to                  
 user_id (InputLayer)        [(None, 1)]        0         []                            
                                                                                        
 movie_id (InputLayer)       [(None, 1)]        0         []                            
                                                                                        
 user_embedding (Embedding)  (None, 1, 8)       1107952   ['user_id[0][0]']             
                                                                                        
 movie_embedding (Embedding)  (None, 1, 8)      213952    ['movie_id[0][0]']            
                                                                                        
 concatenate_2 (Concatenate)  (None, 1, 16)     0         ['user_embedding[0][0]',      
    

In [45]:
# Unfreeze params so we can update it directly.
unfrozen_state_params = state.params.unfreeze()
# Update with Keras weights
unfrozen_state_params['user']['embedding'] = state.params['user']['embedding'].at[:].set(keras_weights[0])
unfrozen_state_params['movie']['embedding'] = state.params['movie']['embedding'].at[:].set(keras_weights[1])
unfrozen_state_params['Dense_0']['kernel'] = state.params['Dense_0']['kernel'].at[:].set(keras_weights[2])
unfrozen_state_params['Dense_0']['bias'] = state.params['Dense_0']['bias'].at[:].set(keras_weights[3])
unfrozen_state_params['Dense_1']['kernel'] = state.params['Dense_1']['kernel'].at[:].set(keras_weights[4])
unfrozen_state_params['Dense_1']['bias'] = state.params['Dense_1']['bias'].at[:].set(keras_weights[5])
unfrozen_state_params['Dense_2']['kernel'] = state.params['Dense_2']['kernel'].at[:].set(keras_weights[6])
unfrozen_state_params['Dense_2']['bias'] = state.params['Dense_2']['bias'].at[:].set(keras_weights[7])
# Freeze new params
new_params = frozen_dict.freeze(unfrozen_state_params)

In [46]:
# Update state
new_state = state.replace(step=0, params=new_params, opt_state=state.opt_state)

In [47]:
# Use the first few training examples
train_df.iloc[0:5]
user_id = train_df['userId'].iloc[0:5].values
movie_id = train_df['movieId'].iloc[0:5].values

In [48]:
# One Step of Jax Forward Pass
MovieLensModel(config).apply({'params': new_state.params}, user_id, movie_id)

DeviceArray([[ 0.00625672],
             [ 0.00206297],
             [-0.01768314],
             [-0.04741435],
             [-0.02822906]], dtype=float32)

In [49]:
# One step of Keras Forward Pass
model.predict([user_id, movie_id])

array([[ 0.00625672],
       [ 0.00206297],
       [-0.01768313],
       [-0.04741435],
       [-0.02822906]], dtype=float32)

### Train JAX with new initial weights

In [50]:
@functools.partial(jax.jit, static_argnums=(0))
def apply_model(cfg, state, user_id, movie_id, rating):
  
  def loss_fn(params):
    logits = MovieLensModel(cfg).apply({'params': params}, user_id, movie_id)
    #loss = jnp.mean(optax.l2_loss(predictions=logits, targets=rating))
    loss = jnp.mean(jnp.square(logits - rating))
    return loss, logits
  
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  mae = jnp.mean(jnp.abs(logits - rating))
  return grads, loss, mae

In [51]:
@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [52]:
def train_epoch(cfg, state, train_df, rng):
  """Train for a single epoch."""
  train_df_size = len(train_df)
  steps_per_epoch = train_df_size // cfg.batch_size
  
  perms = jax.random.permutation(rng, train_df_size)
  perms = perms[:steps_per_epoch * cfg.batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, cfg.batch_size))
  
  epoch_loss = []
  epoch_mae = []  

  for perm in perms:
    # x
    batch_user_id = np.expand_dims(train_df.iloc[perm]['userId'].values, 1)
    batch_movie_id = np.expand_dims(train_df.iloc[perm]['movieId'].values, 1)
    # y
    batch_rating = np.expand_dims(train_df.iloc[perm]['y'].values, 1)
    

    grads, loss, mae = apply_model(cfg, state, batch_user_id, batch_movie_id, batch_rating)
    state = update_model(state, grads)
    
    user_emb_new = state.params['user']['embedding']
    movie_emb_new = state.params['movie']['embedding']
    
    epoch_loss.append(loss)
    epoch_mae.append(mae)
    
  train_loss = np.mean(epoch_loss)
  train_mae = np.mean(epoch_mae)
  return state, train_loss, train_mae

In [53]:
def train_and_evaluate(config, train_df, test_df):
  rng = jax.random.PRNGKey(0)
  rng, init_rng = jax.random.split(rng)
  # NOTE: This is where we're using the new Keras weights.
  state = new_state
  
  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_mae = train_epoch(config, state, train_df, rng)
    
    user_id = np.expand_dims(test_df['userId'].values, 1)
    movie_id = np.expand_dims(test_df['movieId'].values, 1)
    rating = np.expand_dims(test_df['y'].values, 1)
    _, test_loss, test_mae = apply_model(config, state, user_id, movie_id, rating)
    
    print(
        'epoch:% 3d, train_loss: %.4f, train_mae: %.4f, test_loss: %.4f, test_mae: %.4f'
        % (epoch, train_loss, train_mae, test_loss, test_mae))
    
  return state

In [54]:
# Train Jax model with Keras weights.
config = MovieLensConfig()
state = train_and_evaluate(config, train_df, test_df)

epoch:  1, train_loss: 0.7204, train_mae: 0.6510, test_loss: 0.6883, test_mae: 0.6345
epoch:  2, train_loss: 0.6619, train_mae: 0.6228, test_loss: 0.6665, test_mae: 0.6245
epoch:  3, train_loss: 0.6351, train_mae: 0.6096, test_loss: 0.6537, test_mae: 0.6172
epoch:  4, train_loss: 0.6194, train_mae: 0.6017, test_loss: 0.6465, test_mae: 0.6147
epoch:  5, train_loss: 0.6091, train_mae: 0.5965, test_loss: 0.6447, test_mae: 0.6129
epoch:  6, train_loss: 0.6014, train_mae: 0.5926, test_loss: 0.6427, test_mae: 0.6134
epoch:  7, train_loss: 0.5954, train_mae: 0.5895, test_loss: 0.6424, test_mae: 0.6116
epoch:  8, train_loss: 0.5907, train_mae: 0.5872, test_loss: 0.6405, test_mae: 0.6111
epoch:  9, train_loss: 0.5868, train_mae: 0.5851, test_loss: 0.6406, test_mae: 0.6115
epoch: 10, train_loss: 0.5836, train_mae: 0.5835, test_loss: 0.6384, test_mae: 0.6096


After correcting the inputs to be 2-dimensional instead of 1-Dimensions (eg. `[5,1]` vs `[5,]`) the training loss and test loss both seem to decrease.

It works! We get a better test MAE than the Keras tutorial (likely because we use a larger train set and smaller test set).

Things I had to fix:

1. Make sure to shuffle dataset so test is filled exclusively with users we've never seen before.

2. Make sure that the train **and** test inputs are of the expected shape. Double check the loss values.

3. Adam seems to work much better than SGD.