In [1]:
import torch
import torch.optim as optim
import wandb
from Pytorch.environment_pytorch import WaterTank
from JAX.environment_JAX import WaterTank_Jax
from Pytorch.model_pytorch import MLP, log_weights_and_derivatives
from plotting import plot_history
from JAX.model_JAX import MLP_Jax, log_weights_and_derivatives_JAX
from JAX.environment_JAX import WaterTank_Jax
import optax
import jax
from jax import random
import jax.numpy as jnp
from flax.training import orbax_utils
import orbax
import wandb
from plotting import plot_history
import pandas as pd
%load_ext autoreload
%autoreload 2

# NN Speed Comparison

In [5]:
from params import influx_params, env_params, model_params, run_params, start_params, optimizer_params

# JAX
level = jnp.array([[start_params["level"]]])
time = jnp.array([[start_params["time"]]])
state_JAX = jnp.concatenate((level, time), axis = 1)
model_JAX = MLP_Jax(model_params["layer_sizes"][1:])
weight_params = model_JAX.init(random.PRNGKey(42), state_JAX)

# Pytorch
water_tank = WaterTank(start_params, env_params, influx_params)
state = water_tank.get_state()
model = MLP(model_params)

# test
JAX_apply = jax.jit(model_JAX.apply)
JAX_apply(weight_params, state_JAX)
%timeit JAX_apply(weight_params, state_JAX)
%timeit model_JAX.apply(weight_params, state_JAX)
%timeit model(state)


32.5 µs ± 1.57 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
5.6 ms ± 30.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
68.7 µs ± 3.4 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# Initialized Parameter Stats of the two NNs

In [16]:
from params import influx_params, env_params, model_params, run_params, start_params, optimizer_params

# JAX
level = jnp.array([[start_params["level"]]])
time = jnp.array([[start_params["time"]]])
state_JAX = jnp.concatenate((level, time), axis = 1)
model_JAX = MLP_Jax(model_params["layer_sizes"][1:])
weight_params = model_JAX.init(random.PRNGKey(42), state_JAX)

# Pytorch
water_tank = WaterTank(start_params, env_params, influx_params)
state = water_tank.get_state()
model = MLP(model_params)

pytorch_par = []
for param in model.parameters():
    pytorch_par = pytorch_par + param.view(-1).tolist()
df_pytorch = pd.DataFrame(pytorch_par)
display(df_pytorch.describe())

weights_flatten = jax.tree_util.tree_flatten(weight_params)[0]
weights_flatten = jax.tree_map(jnp.ravel, weights_flatten)
weights_flatten = jnp.concatenate(weights_flatten)
df_JAX = pd.DataFrame(weights_flatten.tolist())
display(df_JAX.describe())

Unnamed: 0,0
count,129.0
mean,0.042463
std,0.343735
min,-0.655224
25%,-0.143862
50%,0.015522
75%,0.294862
max,0.707056


Unnamed: 0,0
count,129.0
mean,0.010316
std,0.51059
min,-1.601605
25%,-0.101484
50%,0.0
75%,0.123139
max,1.593314


# Initialized Parameter Size of the two NNs

In [43]:
from params import influx_params, env_params, model_params, run_params, start_params, optimizer_params

seed = 42
torch.manual_seed(seed)

# JAX
level = jnp.array([[start_params["level"]]])
time = jnp.array([[start_params["time"]]])
state_JAX = jnp.concatenate((level, time), axis = 1)
model_JAX = MLP_Jax(model_params["layer_sizes"][1:])
weight_params = model_JAX.init(random.PRNGKey(42), state_JAX)

# Pytorch
water_tank = WaterTank(start_params, env_params, influx_params)
state = water_tank.get_state()
model = MLP(model_params)

for name, param in model.named_parameters():
    print(name, list(param.shape), end = " ")
print()
print()

print(jax.tree_map(lambda x: x.shape, weight_params))

layers.0.weight [32, 2] layers.0.bias [32] layers.2.weight [1, 32] layers.2.bias [1] 

{'params': {'layers_0': {'bias': (32,), 'kernel': (2, 32)}, 'layers_1': {'bias': (1,), 'kernel': (32, 1)}}}


In [44]:
print(weight_params["params"]["layers_0"]["kernel"])
print(list(model.parameters())[0].T.tolist())
print(weight_params["params"]["layers_0"]["kernel"].at[:].set(list(model.parameters())[0].T.tolist()))
#print(weight_params["params"]["layers_0"]["kernel"])

[[ 4.2911819e-01  1.2873628e+00  4.1615748e-01 -3.2046041e-01
  -1.0824715e+00 -3.4905264e-01 -4.0321961e-02  1.0576610e-03
  -8.2933229e-01  3.5555342e-01 -5.5605835e-01 -8.7483805e-01
  -1.6016047e+00  3.8891137e-01  1.0036556e-01 -3.3083498e-01
   8.1692898e-01  2.4937668e-01  2.0649074e-01  2.7163011e-01
   3.1569880e-01  1.0814291e+00 -2.6750933e-02 -1.3788413e+00
  -1.1066534e-01 -8.4871006e-01 -4.2546454e-01  7.7156889e-01
   1.9841006e-02  1.5933142e+00  6.0597646e-01  4.5173422e-01]
 [-7.2890383e-01  1.2313894e-01  1.5339019e+00 -1.4125415e-02
  -6.6341805e-01  8.5607606e-01  6.0746133e-01  1.1522740e-01
  -9.4534695e-01 -1.0287511e+00 -2.6974782e-01 -4.9252715e-02
  -3.5951115e-02  5.8555305e-01  7.5108573e-02  1.2834657e+00
   1.2580771e+00  7.0255592e-02  1.0237257e+00  5.2203673e-01
  -2.1699713e-01  5.0387986e-02 -9.9643892e-01  2.0291030e-01
  -4.1196761e-01 -1.0742004e+00  3.7874031e-01 -6.6206402e-01
   6.6892937e-02  5.0207633e-01 -6.2699866e-01 -8.7600297e-01]]
[[0.5

# Same Initialization

Code to be added to the Jax main to copy the parameters from the Pytorch model

In [1]:
from params import influx_params, env_params, model_params, run_params, start_params, optimizer_params
import torch

seed = 42
torch.manual_seed(seed)

# starting conditions
level = jnp.array([[start_params["level"]]])
curr_time = jnp.array([[start_params["time"]]])
state = jnp.concatenate((level, curr_time), axis = 1)

# ML
model = MLP_Jax(model_params["layer_sizes"][1:]) # have to remove the first element
weight_params = model.init(random.PRNGKey(42), state)
model = MLP(model_params)

weight_params["params"]["layers_0"]["kernel"] = weight_params["params"]["layers_0"]["kernel"].at[:].set(list(model.parameters())[0].T.tolist())

weight_params["params"]["layers_0"]["bias"] = weight_params["params"]["layers_0"]["bias"].at[:].set(list(model.parameters())[1].tolist())

weight_params["params"]["layers_1"]["kernel"] = weight_params["params"]["layers_1"]["kernel"].at[:].set(list(model.parameters())[2].T.tolist())

weight_params["params"]["layers_1"]["bias"] = weight_params["params"]["layers_1"]["bias"].at[:].set(list(model.parameters())[3].tolist())

NameError: name 'jnp' is not defined

# Parameter Precision, decimal point

In [None]:
from params import influx_params, env_params, model_params, run_params, start_params, optimizer_params
import torch

seed = 42
torch.manual_seed(seed)

# starting conditions
level = jnp.array([[start_params["level"]]])
curr_time = jnp.array([[start_params["time"]]])
state = jnp.concatenate((level, curr_time), axis = 1)

# ML
model_JAX = MLP_Jax(model_params["layer_sizes"][1:]) # have to remove the first element
weight_params = model_JAX.init(random.PRNGKey(42), state)
model = MLP(model_params)

print(weight_params["params"]["layers_0"]["kernel"].dtype)
print(list(model.parameters())[0].dtype)