## Roadmap
- [x] load data into a tf dataset in memory 
- [x] create a dummy MLP and training loop
- [x] debug MLP on linear regression
- [x] Add dummy regularization and make it a baby resnet
    - Going for dropout and weight decay
- [ ] Split code nicely into several files
- [ ] Unit tests
    - Understand how to specify neural net parameters init
    - Understand init rngs and nn.module
    - debug simple resnet training
- [ ] Checkpointing and evaluation
- [ ] Tensorboard
- [ ] Baby version submission
- [ ] GPU training on AWS
- [ ] Rayon for training on several instances at the same time
- [ ] Go all in on hyperparameters tuning

### Notes
- Understand `nn.compact`
- Check out the source code for their train state

#### Imports

In [1]:
import numpy as np
import jax
import jax.random as random
import jax.numpy as jnp
import optax

from data_loader import get_dataset
from models import MLP, Resnet
from train import train
from utils import poisson_loss, squared_loss

#### Load training data

In [2]:
X_train = np.load('data/X_train.npy')
Y_train = np.load('data/Y_train.npy')

In [3]:
p, m = X_train.shape[1], Y_train.shape[1]
num_h_layers = 0
h = 16  # hidden layer size
lr = 1e-4  # learning rate
num_epochs = 2
batch_size = 32

jax_poisson = jax.vmap(jax.jit(poisson_loss), in_axes=0)
jax_squared = jax.vmap(jax.jit(squared_loss), in_axes=0)

p, m

(8, 3)

In [4]:
trained_params = train(rng=random.PRNGKey(15),
      model=MLP([h for _ in range(num_h_layers)] + [m]),
      optimizer=optax.adam(lr),
      dataset=get_dataset(X_train, Y_train, batch_size),
      loss_fn=jax_squared,
      metric_fn=lambda a, b: jnp.sqrt(jax_squared(a, b)),
      num_epochs=num_epochs,
      inputs_shape=(1, p)
     )



Step 25, Metric: 79.976, Loss: 30736.629
Step 50, Metric: 63.508, Loss: 23480.641
Step 75, Metric: 67.143, Loss: 23843.727
Step 100, Metric: 67.422, Loss: 22010.256
Step 125, Metric: 65.654, Loss: 23634.754
Step 150, Metric: 73.030, Loss: 22683.008
Step 175, Metric: 84.351, Loss: 37960.688
Step 25, Metric: 94.466, Loss: 44216.824
Step 50, Metric: 82.179, Loss: 29926.129
Step 75, Metric: 101.346, Loss: 49665.156
Step 100, Metric: 68.996, Loss: 23342.051
Step 125, Metric: 61.867, Loss: 17652.324
Step 150, Metric: 70.633, Loss: 19446.002
Step 175, Metric: 78.239, Loss: 32207.465


In [None]:
n = 10000
p = 10

key = random.PRNGKey(1515)
key, init_key = random.split(key)
X_linear = random.normal(init_key, shape=(n, p))
key, init_key = random.split(key)
beta = random.uniform(init_key, shape=(p,))
key, init_key = random.split(key)
noise = random.normal(init_key, shape=(n,))

y_linear = noise + X_linear.dot(beta)

In [None]:
trained_params = train(rng=random.PRNGKey(15),
      model=Resnet([10, 1]),
      optimizer=optax.sgd(5 * 1e-3),
      dataset=get_dataset(X_linear, y_linear, batch_size=32),
      loss_fn=jax_squared,
      metric_fn=jax_squared,
      num_epochs=5,
      inputs_shape=(1, p),
      l2_scale=1e-3,
     )

In [None]:
beta

In [None]:
test_key = random.PRNGKey(12097341324)
test_input = random.normal(test_key, shape=(10,))
test_input.dot(beta)

In [None]:
void_resnet = Resnet([10, 1])

void_resnet.init(test_key, test_input)

In [None]:
void_resnet.apply(trained_params, test_input, train=False)