# [Foward Pass Sanity Check](https://github.com/JoshVarty/LearningJax/issues/1)

My Jax network is completely broken and I don't know why. When I run (what appears to be) the same network in Keras, everything looks correct and I see that the network is learning.

Let's make the weights identical and then run the forward pass. Do we get the same result?


### Print the default JAX weights

In [84]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import random

import tensorflow as tf
import jax
import jax.numpy as jnp
import functools
import optax

from flax.core import frozen_dict
from flax import struct
from flax import linen as nn
from flax.training import train_state

In [2]:
ratings_df = pd.read_csv("data/movie-lens/rating.csv", usecols=['userId', 'movieId', 'rating', 'y'])
movies_df = pd.read_csv("data/movie-lens/movie.csv", usecols=['movieId', 'title', 'year'])
df = ratings_df.merge(movies_df, on='movieId').sort_values(by='userId')

In [3]:
# Shuffle
# df = df.sample(frac=1, random_state=1) 
# First 50,000 for test
test_df = df.iloc[:50000]
# Rest for train
train_df = df.iloc[50000:]

In [4]:
n_movies = len(df.movieId.unique())
n_users = len(df.userId.unique())
print(
    "{1:,} distinct users rated {0:,} different movies (total ratings = {2:,})".format(
        n_movies, n_users, len(df),
    )
)

138,493 distinct users rated 26,744 different movies (total ratings = 20,000,263)


In [5]:
@struct.dataclass
class MovieLensConfig:
  """Global hyperparameters for our MovieLens Model"""
  users_size: int = n_users + 1   # 138,494 (+1 because we are 1-indexed)
  movies_size: int = n_movies     # 26,744  (+0 because we are 0-indexed)
  emb_dim: int = 8
  dense_size_0: int = 32
  dense_size_1: int = 4
  out_size: int = 1
  num_epochs: int = 10
  batch_size: int = 5000
  lr: float = 0.01
  momentum: float = 0.1

In [6]:
class MovieLensModel(nn.Module):
  """A simple embedding model."""

  config: MovieLensConfig

  @nn.compact
  def __call__(self, user_id, movie_id):
    
    cfg = self.config
    
    user_id = user_id.astype('int32')
    user_emb = nn.Embed(num_embeddings=cfg.users_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='user')(user_id)
    
    movie_id = movie_id.astype('int32')
    movie_emb = nn.Embed(num_embeddings=cfg.movies_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='movie')(movie_id)
    
    x = jnp.concatenate((user_emb, movie_emb), axis=-1)
    
    # TODO(joshvarty): Consider initializing with bias of zero.
    x = nn.Dense(cfg.dense_size_0, kernel_init=nn.initializers.kaiming_uniform())(x)
    x = nn.relu(x)
    x = nn.Dense(cfg.dense_size_1, kernel_init=nn.initializers.kaiming_uniform())(x)
    x = nn.relu(x)
    x = nn.Dense(cfg.out_size, kernel_init=nn.initializers.kaiming_uniform())(x)

    return x

In [7]:
def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  model = MovieLensModel(config)
  user_id_fake = jnp.zeros((1), jnp.int32)
  movie_id_fake = jnp.zeros((1), jnp.int32)
  # Pass fake values through our model to initialize the parameters
  params = model.init(rng, user_id_fake, movie_id_fake)['params']
  
  # TODO(joshvarty): Consider other optimizers.
  #tx = optax.sgd(config.lr)
  tx = optax.sgd(config.lr)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

In [8]:
# Check if we can create trainstate
config = MovieLensConfig()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, config)
print("Success.")

Success.


In [9]:
state.params

FrozenDict({
    user: {
        embedding: DeviceArray([[-1.8625552e-04,  5.9650191e-03,  5.4601341e-04, ...,
                      -2.2080580e-03,  4.1289997e-04, -3.9349264e-03],
                     [-5.2755969e-03, -4.5083277e-03,  3.6032971e-03, ...,
                      -3.0476563e-03,  8.4897282e-04, -6.0522733e-03],
                     [ 2.1315611e-03,  5.0859689e-03,  4.6556517e-03, ...,
                      -3.5043349e-04, -1.9411424e-03,  4.7205477e-03],
                     ...,
                     [ 5.1226970e-03, -9.9316821e-06, -6.0135541e-03, ...,
                      -4.5624427e-03,  1.9483782e-03,  2.9091886e-03],
                     [-5.6149010e-03, -6.0111815e-03,  1.4060760e-03, ...,
                       3.7494712e-03,  4.8383907e-03, -8.6835132e-04],
                     [ 4.6219588e-03, -5.3815292e-03, -1.6332446e-03, ...,
                      -4.1636750e-03,  5.6607621e-03, -4.9878950e-03]],            dtype=float32),
    },
    movie: {
        embedd

In [10]:
state.params.keys()

frozen_dict_keys(['user', 'movie', 'Dense_0', 'Dense_1', 'Dense_2'])

In [11]:
print("Shape:\t", state.params['user']['embedding'].shape)
print("Mean:\t", state.params['user']['embedding'].mean())
print("Std:\t", state.params['user']['embedding'].std())

Shape:	 (138494, 8)
Mean:	 -4.4421763e-06
Std:	 0.0038013132


In [12]:
print("Shape:\t", state.params['movie']['embedding'].shape)
print("Mean:\t", state.params['movie']['embedding'].mean())
print("Std:\t", state.params['movie']['embedding'].std())

Shape:	 (26744, 8)
Mean:	 2.797547e-05
Std:	 0.008661809


In [13]:
print("Shape:\t", state.params['Dense_0']['kernel'].shape)
print("Mean:\t", state.params['Dense_0']['kernel'].mean())
print("Std:\t", state.params['Dense_0']['kernel'].std())

Shape:	 (16, 32)
Mean:	 0.0025646356
Std:	 0.35773057


In [14]:
print("Shape:\t", state.params['Dense_0']['bias'].shape)
print("Mean:\t", state.params['Dense_0']['bias'].mean())
print("Std:\t", state.params['Dense_0']['bias'].std())

Shape:	 (32,)
Mean:	 0.0
Std:	 0.0


In [15]:
print("Shape:\t", state.params['Dense_1']['kernel'].shape)
print("Mean:\t", state.params['Dense_1']['kernel'].mean())
print("Std:\t", state.params['Dense_1']['kernel'].std())

Shape:	 (32, 4)
Mean:	 0.029176384
Std:	 0.25124845


In [16]:
print("Shape:\t", state.params['Dense_1']['bias'].shape)
print("Mean:\t", state.params['Dense_1']['bias'].mean())
print("Std:\t", state.params['Dense_1']['bias'].std())

Shape:	 (4,)
Mean:	 0.0
Std:	 0.0


In [17]:
print("Shape:\t", state.params['Dense_2']['kernel'].shape)
print("Mean:\t", state.params['Dense_2']['kernel'].mean())
print("Std:\t", state.params['Dense_2']['kernel'].std())

Shape:	 (4, 1)
Mean:	 0.013601229
Std:	 0.4561657


In [18]:
print("Shape:\t", state.params['Dense_2']['bias'].shape)
print("Mean:\t", state.params['Dense_2']['bias'].mean())
print("Std:\t", state.params['Dense_2']['bias'].std())

Shape:	 (1,)
Mean:	 0.0
Std:	 0.0


### Print the default Keras weights

In [19]:
# Setup. Import libraries and load dataframes for Movielens data.
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
import os
import random

tf.__version__

'2.7.0'

In [20]:
input_dir = 'data/movie-lens'
ratings_path = os.path.join(input_dir, 'rating.csv')

ratings_df = pd.read_csv(ratings_path, usecols=['userId', 'movieId', 'rating', 'y'])

movies_df = pd.read_csv(os.path.join(input_dir, 'movie.csv'), usecols=['movieId', 'title', 'year'])

df = ratings_df.merge(movies_df, on='movieId').sort_values(by='userId')
df = df.sample(frac=1, random_state=1) # Shuffle

df.sample(5, random_state=1)

Unnamed: 0,userId,movieId,rating,y,title,year
12904240,85731,1883,4.5,0.974498,Labyrinth,1986
6089380,45008,1221,4.5,0.974498,"Femme Nikita, La (Nikita)",1990
17901393,125144,3948,4.0,0.474498,The Alamo,1960
9024816,122230,3027,3.5,-0.025502,Toy Story 2,1999
11655659,21156,5202,3.0,-0.525502,My Big Fat Greek Wedding,2002


In [21]:
tf.random.set_seed(1)
np.random.seed(1)
random.seed(1)

hidden_units = (32,4)
movie_embedding_size = 8
user_embedding_size = 8

# Each instance will consist of two inputs: a single user id, and a single movie id
user_id_input = keras.Input(shape=(1,), name='user_id')
movie_id_input = keras.Input(shape=(1,), name='movie_id')
user_embedded = keras.layers.Embedding(df.userId.max()+1, user_embedding_size, 
                                       input_length=1, name='user_embedding')(user_id_input)
movie_embedded = keras.layers.Embedding(df.movieId.max()+1, movie_embedding_size, 
                                        input_length=1, name='movie_embedding')(movie_id_input)
# Concatenate the embeddings (and remove the useless extra dimension)
concatenated = keras.layers.Concatenate()([user_embedded, movie_embedded])
out = keras.layers.Flatten()(concatenated)

# Add one or more hidden layers
for n_hidden in hidden_units:
    out = keras.layers.Dense(n_hidden, activation='relu')(out)

# A single output: our predicted rating
out = keras.layers.Dense(1, activation='linear', name='prediction')(out)

model = keras.Model(
    inputs = [user_id_input, movie_id_input],
    outputs = out,
)
model.summary(line_length=88)

Model: "model"
________________________________________________________________________________________
 Layer (type)                Output Shape       Param #   Connected to                  
 user_id (InputLayer)        [(None, 1)]        0         []                            
                                                                                        
 movie_id (InputLayer)       [(None, 1)]        0         []                            
                                                                                        
 user_embedding (Embedding)  (None, 1, 8)       1107952   ['user_id[0][0]']             
                                                                                        
 movie_embedding (Embedding)  (None, 1, 8)      213952    ['movie_id[0][0]']            
                                                                                        
 concatenate (Concatenate)   (None, 1, 16)      0         ['user_embedding[0][0]',      
      

In [22]:
keras_weights = model.get_weights()
keras_weights

[array([[-0.03348692,  0.04014813,  0.01309742, ...,  0.01425021,
          0.04757855, -0.00649005],
        [ 0.01601019,  0.01048958,  0.01366315, ...,  0.01277617,
          0.0031975 , -0.04740218],
        [-0.00591249, -0.02473292,  0.03862232, ..., -0.0440448 ,
         -0.04289062, -0.01915853],
        ...,
        [-0.02868495,  0.005866  , -0.01473292, ..., -0.03378378,
         -0.04109376,  0.02150441],
        [ 0.04850432, -0.04216467,  0.02519933, ..., -0.01390805,
          0.04621546,  0.02585406],
        [-0.03591843,  0.02629702, -0.02125278, ...,  0.03549831,
         -0.00832587, -0.04367541]], dtype=float32),
 array([[ 0.0010107 , -0.00564682, -0.00914669, ..., -0.01541504,
         -0.0063933 ,  0.0101061 ],
        [-0.00433757,  0.0252698 , -0.03120028, ...,  0.00495271,
         -0.02201053, -0.03910435],
        [ 0.01309114,  0.0373818 ,  0.0032263 , ...,  0.04284838,
         -0.03902638,  0.0463542 ],
        ...,
        [-0.02446281, -0.01340784,  0.0

In [23]:
# User Embedding
print("Shape:\t", keras_weights[0].shape)
print("Mean:\t", keras_weights[0].mean())
print("Std:\t", keras_weights[0].std())

Shape:	 (138494, 8)
Mean:	 9.316845e-06
Std:	 0.028877884


In [24]:
# Movie Embedding 
print("Shape:\t", keras_weights[1].shape)
print("Mean:\t", keras_weights[1].mean())
print("Std:\t", keras_weights[1].std())

Shape:	 (26744, 8)
Mean:	 2.4546707e-05
Std:	 0.02890124


In [27]:
# Dense 0 kernel
print("Shape:\t", keras_weights[2].shape)
print("Mean:\t", keras_weights[2].mean())
print("Std:\t", keras_weights[2].std())

Shape:	 (16, 32)
Mean:	 0.002208138
Std:	 0.20556653


In [28]:
# Dense 0 bias
print("Shape:\t", keras_weights[3].shape)
print("Mean:\t", keras_weights[3].mean())
print("Std:\t", keras_weights[3].std())

Shape:	 (32,)
Mean:	 0.0
Std:	 0.0


In [31]:
# Dense 1 kernel
print("Shape:\t", keras_weights[4].shape)
print("Mean:\t", keras_weights[4].mean())
print("Std:\t", keras_weights[4].std())

Shape:	 (32, 4)
Mean:	 -0.02374632
Std:	 0.23336196


In [32]:
# Dense 1 bias
print("Shape:\t", keras_weights[5].shape)
print("Mean:\t", keras_weights[5].mean())
print("Std:\t", keras_weights[5].std())

Shape:	 (4,)
Mean:	 0.0
Std:	 0.0


In [33]:
# Dense 2 kernel
print("Shape:\t", keras_weights[6].shape)
print("Mean:\t", keras_weights[6].mean())
print("Std:\t", keras_weights[6].std())

Shape:	 (4, 1)
Mean:	 -0.47674978
Std:	 0.58064425


In [34]:
# Dense 2 bias
print("Shape:\t", keras_weights[7].shape)
print("Mean:\t", keras_weights[7].mean())
print("Std:\t", keras_weights[7].std())

Shape:	 (1,)
Mean:	 0.0
Std:	 0.0


### Comparison

The shapes and biases are identical so I've omitted them.

|               | Jax        | Keras    |
|---------------|------------|----------|
|User Mean      | -4.44e-06  | 9.32e-06 |
|User Std       |  0.00380   | 0.0289   |
|Movie Mean     |  2.80e-05  | 2.45e-05 |
|Movie Std      |  0.00866   | 0.0289   |
|Dense_0 Mean   |  0.00256   | 0.00220  |
|Dense_0 Std    |  0.358     | 0.2055   |
|Dense_1 Mean   |  0.0292    |-0.0237   |
|Dense_1 Std    |  0.251     | 0.2333   |
|Dense_2 Mean   |  0.0136    |-0.477    |
|Dense_2 Std    |  0.456     | 0.581    |

- Keras Standard Deviations seem to match up better. Not sure what's wrong in Jax (Uniform vs Normal distribution?)


## Load Keras Weights into Jax Model

Now that we've compared the weights and they seem somewhat different, let's try loading the Keras weights into our Jax model.

In [67]:
# Unfreeze params so we can update it directly.
unfrozen_state_params = state.params.unfreeze()

In [72]:
# Update with Keras weights
unfrozen_state_params['user']['embedding'] = state.params['user']['embedding'].at[:].set(keras_weights[0])
unfrozen_state_params['movie']['embedding'] = state.params['movie']['embedding'].at[:].set(keras_weights[1])
unfrozen_state_params['Dense_0']['kernel'] = state.params['Dense_0']['kernel'].at[:].set(keras_weights[2])
unfrozen_state_params['Dense_0']['bias'] = state.params['Dense_0']['bias'].at[:].set(keras_weights[3])
unfrozen_state_params['Dense_1']['kernel'] = state.params['Dense_1']['kernel'].at[:].set(keras_weights[4])
unfrozen_state_params['Dense_1']['bias'] = state.params['Dense_1']['bias'].at[:].set(keras_weights[5])
unfrozen_state_params['Dense_2']['kernel'] = state.params['Dense_2']['kernel'].at[:].set(keras_weights[6])
unfrozen_state_params['Dense_2']['bias'] = state.params['Dense_2']['bias'].at[:].set(keras_weights[7])

In [85]:
# Freeze new params
new_params = frozen_dict.freeze(unfrozen_state_params)

In [91]:
# Update state
new_state = state.replace(step=0, params=new_params, opt_state=state.opt_state)

In [92]:
# TODO(joshvarty): Run this through the model.
# Might not have to create an entire train_state like we've done. 
# Might be able to integrate with `create_train_state`

## Run One Step of Forward Pass on Each Model with same weights