In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import logging

from jax import random
import wandb

from src.models import make_Reg_Ens_loss as make_loss
import src.data
from src.data import NumpyLoader
from src.utils.training import setup_training, train_loop
from experiments.configs.toy_reg_ens import get_config

In [3]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'train_reg_ens.ipynb'
# ^ W&B doesn't know how to handle VS Code notebooks.

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mjamesallingham[0m ([33minvariance-learners[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
rng = random.PRNGKey(0)

In [5]:
config = get_config()

In [6]:
data_gen_fn = getattr(src.data, config.dataset_name)
train_dataset, test_dataset, val_dataset = data_gen_fn(**config.dataset.to_dict())
train_loader = NumpyLoader(train_dataset, config.batch_size)
val_loader = NumpyLoader(val_dataset, config.batch_size)
test_loader = NumpyLoader(test_dataset, config.batch_size)

In [7]:
setup_rng, rng = random.split(rng)
init_x = train_dataset[0][0]
init_y = train_dataset[0][1]

model, state = setup_training(config, setup_rng, init_x, init_y)

+---------------------------------------------+------------+--------+-----------+--------+
| Name                                        | Shape      | Size   | Mean      | Std    |
+---------------------------------------------+------------+--------+-----------+--------+
| batch_stats/nets_0/layer_0/BatchNorm_0/mean | (100,)     | 100    | 0.0       | 0.0    |
| batch_stats/nets_0/layer_0/BatchNorm_0/var  | (100,)     | 100    | 1.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/mean | (100,)     | 100    | 0.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/var  | (100,)     | 100    | 1.0       | 0.0    |
| batch_stats/nets_1/layer_0/BatchNorm_0/mean | (100,)     | 100    | 0.0       | 0.0    |
| batch_stats/nets_1/layer_0/BatchNorm_0/var  | (100,)     | 100    | 1.0       | 0.0    |
| batch_stats/nets_1/layer_1/BatchNorm_0/mean | (100,)     | 100    | 0.0       | 0.0    |
| batch_stats/nets_1/layer_1/BatchNorm_0/var  | (100,)     | 100    | 1.0       | 0.0    |

In [8]:
state = train_loop(
    model, state, config, rng, make_loss, make_loss, train_loader, val_loader,
    # test_loader,
    wandb_kwargs={
        'mode': 'offline',
        # 'notes': '',
    },
)

  0%|          | 0/50 [00:00<?, ?it/s]

epoch:   1 - train loss: 2.10950, val_loss: 1.62705, lr: 0.00010
Best val_loss
epoch:   2 - train loss: 1.47555, val_loss: 1.52929, lr: 0.00010
Best val_loss
epoch:   3 - train loss: 1.32613, val_loss: 1.46609, lr: 0.00010
Best val_loss
epoch:   4 - train loss: 1.34467, val_loss: 1.42987, lr: 0.00010
Best val_loss
epoch:   5 - train loss: 1.29503, val_loss: 1.41331, lr: 0.00010
Best val_loss
epoch:   6 - train loss: 1.21987, val_loss: 1.41003, lr: 0.00010
Best val_loss
epoch:   7 - train loss: 1.12627, val_loss: 1.42549, lr: 0.00010
epoch:   8 - train loss: 1.11018, val_loss: 1.45219, lr: 0.00010
epoch:   9 - train loss: 1.13372, val_loss: 1.47350, lr: 0.00010
epoch:  10 - train loss: 1.14183, val_loss: 1.47784, lr: 0.00010
epoch:  11 - train loss: 1.12735, val_loss: 1.46176, lr: 0.00010
epoch:  12 - train loss: 1.09606, val_loss: 1.43070, lr: 0.00010
epoch:  13 - train loss: 1.06787, val_loss: 1.39198, lr: 0.00010
Best val_loss
epoch:  14 - train loss: 1.03326, val_loss: 1.35084, lr: 

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▅▅▅▄▄▄▄▄▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▂▃▂▇▃▃▃▂▂▂▂▂▂▂▂
val/loss,▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▃▃▃▃▃▃▃▃▅▆▅▄▅▄█▆▅▅▅▃▂▂▂▁

0,1
best_epoch,50.0
best_val_loss,0.94763
epoch,50.0
learning_rate,0.0001
train/loss,0.59774
val/loss,0.94763
