## Roadmap
- [x] load data into a tf dataset in memory 
- [x] create a dummy MLP and training loop
- [x] debug MLP on linear regression
- [x] Add dummy regularization and make it a baby resnet
    - Going for dropout and weight decay
- [x] Split code nicely into several files
- Apply Karpathy's guide on dummy data and then on normal data
- [ ] Unit tests
    - Understand how to specify neural net parameters init
    - Understand init rngs and nn.module
    - debug simple resnet training
- [ ] Checkpointing and evaluation
- [ ] Tensorboard
- [ ] Baby version submission
- [ ] GPU training on AWS
- [ ] Rayon for training on several instances at the same time
- [ ] Go all in on hyperparameters tuning

### 2. Init sanity

- init random seed
- no fanciness (think regularization), yet
- actual loss printing
- correct init please
- loss at init
- human baseline -> no need on my synthetic data, I know the upper and lower bounds of perf
- input indep baseline -> we should converge to the mean
- overfit one batch
- verify decreasing training loss
- visualize just before the net
- visualize prediction dynamics
- use backprop to sanity check interactions

### Notes
- Understand `nn.compact`
- Check out the source code for their train state

#### Imports

In [1]:
import matplotlib.pyplot as plt

import jax
import jax.random as random
import jax.numpy as jnp
import optax

from data_loader import get_dataset, linear_data
from models import MLP, Resnet
from train import train, init_params
from utils import poisson_loss, squared_loss

#### Load training data

In [2]:
X_train, Y_train = linear_data(seed=15, n=1000, p=10, bias=15.0)



### 1. make one with the data

In [4]:
Y_train.shape

(1000,)

In [11]:
plt.figure()
plt.plot(Y_train)
plt.show()
plt.close()

In [14]:
Y_train.mean(), Y_train.std()

(DeviceArray(15.045962, dtype=float32), DeviceArray(2.1840475, dtype=float32))

In [13]:
X_train.mean(0), X_train.std(0)

(DeviceArray([ 0.01137057,  0.01005887,  0.01693998, -0.05002026,
              -0.01076224, -0.0156854 , -0.02256254,  0.00553981,
              -0.0200922 , -0.00961938], dtype=float32),
 DeviceArray([0.99739987, 0.9881106 , 1.0309198 , 0.9692272 , 0.9714178 ,
              1.0113798 , 0.9474586 , 0.99501777, 1.0120884 , 1.0414252 ],            dtype=float32))

In [15]:
p, m = X_train.shape[1], Y_train.shape[0]
p, m

(10, 1000)

In [16]:
n_hidden_layers = 2
layer_size = p

lr = 5.0 * 1e-3  # learning rate
num_epochs = 30
batch_size = 32

jax_squared = jax.vmap(jax.jit(squared_loss), in_axes=0)

p, m

(10, 1000)

In [17]:
model = Resnet([layer_size for _ in range(n_hidden_layers)] + [1], dropout=False)
optim = optax.adam(lr)

In [15]:
trained_params = train(rng=random.PRNGKey(1515),
      model=model,
      optimizer=optim,
      dataset=get_dataset(X_train, Y_train, batch_size=batch_size, single_batch=True),
      loss_fn=jax_squared,
      metric_fn=jax_squared,
      num_epochs=num_epochs,
      inputs_shape=(p,),
      bias=bias,
      layer_name=str(n_hidden_layers),
      print_every=1,
      output_dir="logs",
      hist_every=1,
     )

Step 1, Metric: 2.560, Loss: 6.556
Step 2, Metric: 2.413, Loss: 5.822
Step 3, Metric: 2.285, Loss: 5.220
Step 4, Metric: 2.165, Loss: 4.686
Step 5, Metric: 2.055, Loss: 4.225
Step 6, Metric: 1.950, Loss: 3.801
Step 7, Metric: 1.840, Loss: 3.385
Step 8, Metric: 1.730, Loss: 2.993
Step 9, Metric: 1.628, Loss: 2.651
Step 10, Metric: 1.530, Loss: 2.341
Step 11, Metric: 1.431, Loss: 2.049
Step 12, Metric: 1.340, Loss: 1.796
Step 13, Metric: 1.259, Loss: 1.585
Step 14, Metric: 1.191, Loss: 1.419
Step 15, Metric: 1.137, Loss: 1.294
Step 16, Metric: 1.094, Loss: 1.197
Step 17, Metric: 1.062, Loss: 1.128
Step 18, Metric: 1.039, Loss: 1.079
Step 19, Metric: 1.023, Loss: 1.046
Step 20, Metric: 1.003, Loss: 1.007
Step 21, Metric: 0.975, Loss: 0.951
Step 22, Metric: 0.936, Loss: 0.876
Step 23, Metric: 0.890, Loss: 0.793
Step 24, Metric: 0.844, Loss: 0.713
Step 25, Metric: 0.802, Loss: 0.643
Step 26, Metric: 0.766, Loss: 0.587
Step 27, Metric: 0.738, Loss: 0.545
Step 28, Metric: 0.719, Loss: 0.517
S

In [None]:
beta

In [None]:
test_key = random.PRNGKey(12097341324)
test_input = random.normal(test_key, shape=(10,))
test_input.dot(beta)

In [None]:
void_resnet = Resnet([10, 1])

void_resnet.init(test_key, test_input)

In [None]:
void_resnet.apply(trained_params, test_input, train=False)