In [1]:
import wandb
from pprint import pprint

import neural_stack.training as nsnn
from neural_stack.utils import get_project_root

In [2]:
# Get project root directory
ROOT_DIR = get_project_root()
WANDB_DIR = f"{ROOT_DIR}/data/.wandb"

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mandreifurdui[0m ([33mandreifurdui-team[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
config = nsnn.load_config('../configs/vit_cifar10.yaml')

pprint(nsnn.config_to_dict(config))

{'checkpoint_dir': '../data/checkpoints',
 'data/batch_size': 128,
 'data/crop_ratio': (0.9, 1.1),
 'data/crop_scale': (0.8, 1.0),
 'data/data_dir': '../data',
 'data/dataset': 'cifar10',
 'data/num_workers': 4,
 'data/pin_memory': True,
 'data/random_crop': True,
 'data/random_horizontal_flip': True,
 'device': 'cuda:0',
 'experiment_name': 'vit_cifar10_baseline',
 'log_every': 50,
 'model/dropout': 0.1,
 'model/embed_dim': 256,
 'model/img_size': (32, 32),
 'model/in_channels': 3,
 'model/mlp_ratio': 2.0,
 'model/name': 'vit',
 'model/num_classes': 10,
 'model/num_heads': 8,
 'model/num_layers': 4,
 'model/patch_size': 8,
 'model/positional_embedding': 'learned-1d',
 'model/use_cls_token': True,
 'num_epochs': 25,
 'optimizer/betas': (0.9, 0.999),
 'optimizer/lr': 0.0003,
 'optimizer/momentum': 0.9,
 'optimizer/name': 'adamw',
 'optimizer/weight_decay': 0.01,
 'save_best': True,
 'save_every': None,
 'scheduler/T_max': None,
 'scheduler/eta_min': 0.0,
 'scheduler/gamma': 0.1,
 'sched

In [4]:
train_loader, val_loader = nsnn.build_dataloaders(config.data)

In [5]:
components = nsnn.build_from_config(config)

In [6]:
callback_list = [
    nsnn.PrintCallback(),
    nsnn.WandbCallback(
        project='neural-stack',
        name='trainer_setup_testing',
        group='vit-cifar10-debug',
        config=nsnn.config_to_dict(config),
        dir=WANDB_DIR
    ),
    nsnn.LRSchedulerCallback(),
    nsnn.ProgressCallback()
]

In [7]:
trainer = nsnn.Trainer(**components,
                       train_loader=train_loader,
                       val_loader=val_loader,
                       callbacks=callback_list,
                       device=config.device,
                       num_epochs=config.num_epochs
                       )

In [8]:
results = trainer.fit()

Training:   4%|▍         | 1/25 [00:26<10:29, 26.24s/epoch]

Epoch 1: train/accuracy=0.4090, train/loss=1.6306, val/accuracy=0.4843, val/loss=1.4146


Training:   8%|▊         | 2/25 [00:52<10:01, 26.13s/epoch]

Epoch 2: train/accuracy=0.5062, train/loss=1.3747, val/accuracy=0.5423, val/loss=1.2818


Training:  12%|█▏        | 3/25 [01:18<09:32, 26.05s/epoch]

Epoch 3: train/accuracy=0.5421, train/loss=1.2752, val/accuracy=0.5545, val/loss=1.2511


Training:  16%|█▌        | 4/25 [01:44<09:07, 26.05s/epoch]

Epoch 4: train/accuracy=0.5669, train/loss=1.2074, val/accuracy=0.5769, val/loss=1.1853


Training:  20%|██        | 5/25 [02:10<08:40, 26.00s/epoch]

Epoch 5: train/accuracy=0.5866, train/loss=1.1538, val/accuracy=0.5853, val/loss=1.1611


Training:  24%|██▍       | 6/25 [02:36<08:13, 25.97s/epoch]

Epoch 6: train/accuracy=0.6045, train/loss=1.1044, val/accuracy=0.5970, val/loss=1.1323


Training:  28%|██▊       | 7/25 [03:01<07:45, 25.87s/epoch]

Epoch 7: train/accuracy=0.6179, train/loss=1.0648, val/accuracy=0.6027, val/loss=1.1161


Training:  32%|███▏      | 8/25 [03:27<07:20, 25.88s/epoch]

Epoch 8: train/accuracy=0.6332, train/loss=1.0310, val/accuracy=0.6325, val/loss=1.0359


Training:  36%|███▌      | 9/25 [03:53<06:54, 25.88s/epoch]

Epoch 9: train/accuracy=0.6464, train/loss=0.9931, val/accuracy=0.6278, val/loss=1.0544


Training:  40%|████      | 10/25 [04:19<06:29, 25.97s/epoch]

Epoch 10: train/accuracy=0.6577, train/loss=0.9590, val/accuracy=0.6332, val/loss=1.0205


Training:  44%|████▍     | 11/25 [04:45<06:03, 25.98s/epoch]

Epoch 11: train/accuracy=0.6678, train/loss=0.9309, val/accuracy=0.6465, val/loss=1.0003


Training:  48%|████▊     | 12/25 [05:11<05:37, 25.97s/epoch]

Epoch 12: train/accuracy=0.6780, train/loss=0.9006, val/accuracy=0.6488, val/loss=0.9926


Training:  52%|█████▏    | 13/25 [05:37<05:10, 25.90s/epoch]

Epoch 13: train/accuracy=0.7180, train/loss=0.7933, val/accuracy=0.6739, val/loss=0.9183


Training:  56%|█████▌    | 14/25 [06:03<04:44, 25.90s/epoch]

Epoch 14: train/accuracy=0.7330, train/loss=0.7581, val/accuracy=0.6822, val/loss=0.9093


Training:  60%|██████    | 15/25 [06:29<04:20, 26.02s/epoch]

Epoch 15: train/accuracy=0.7341, train/loss=0.7461, val/accuracy=0.6834, val/loss=0.9064


Training:  64%|██████▍   | 16/25 [06:55<03:53, 25.96s/epoch]

Epoch 16: train/accuracy=0.7411, train/loss=0.7286, val/accuracy=0.6837, val/loss=0.9003


Training:  68%|██████▊   | 17/25 [07:20<03:26, 25.82s/epoch]

Epoch 17: train/accuracy=0.7459, train/loss=0.7180, val/accuracy=0.6868, val/loss=0.8964


Training:  72%|███████▏  | 18/25 [07:47<03:01, 25.91s/epoch]

Epoch 18: train/accuracy=0.7494, train/loss=0.7102, val/accuracy=0.6854, val/loss=0.8996


Training:  76%|███████▌  | 19/25 [08:12<02:35, 25.86s/epoch]

Epoch 19: train/accuracy=0.7523, train/loss=0.7011, val/accuracy=0.6888, val/loss=0.8956


Training:  80%|████████  | 20/25 [08:38<02:09, 25.86s/epoch]

Epoch 20: train/accuracy=0.7560, train/loss=0.6914, val/accuracy=0.6878, val/loss=0.8978


Training:  84%|████████▍ | 21/25 [09:04<01:43, 25.86s/epoch]

Epoch 21: train/accuracy=0.7585, train/loss=0.6842, val/accuracy=0.6898, val/loss=0.8938


Training:  88%|████████▊ | 22/25 [09:30<01:17, 25.91s/epoch]

Epoch 22: train/accuracy=0.7609, train/loss=0.6768, val/accuracy=0.6906, val/loss=0.8938


Training:  92%|█████████▏| 23/25 [09:56<00:51, 25.86s/epoch]

Epoch 23: train/accuracy=0.7620, train/loss=0.6694, val/accuracy=0.6930, val/loss=0.8912


Training:  96%|█████████▌| 24/25 [10:22<00:25, 25.84s/epoch]

Epoch 24: train/accuracy=0.7630, train/loss=0.6635, val/accuracy=0.6941, val/loss=0.8889


Training: 100%|██████████| 25/25 [10:48<00:00, 25.88s/epoch]

Epoch 25: train/accuracy=0.7716, train/loss=0.6432, val/accuracy=0.6949, val/loss=0.8860


0,1
train/accuracy,▁▃▄▄▄▅▅▅▆▆▆▆▇▇▇▇█████████
train/loss,▆▆█▆▇▆█▅▄▅▅▄▅▂▅▃▃▄▃▄▃▃▂▃▃▂▃▁▂▁▂▂▃▁▂▁▂▂▂▃
train/lr,███████████████▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁
val/accuracy,▁▃▃▄▄▅▅▆▆▆▆▆▇████████████
val/loss,█▆▆▅▅▄▄▃▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/accuracy,0.77163
train/loss,0.64325
train/lr,0.0
val/accuracy,0.6949
val/loss,0.88596


Training: 100%|██████████| 25/25 [10:50<00:00, 26.01s/epoch]
