# Image Classification: MNIST

Models:
+ Model 1: Vanilla NODE
+ Model 2: Aug.NODE
+ Model num_c: Input-Layer
+ Model 4: 2nd-Ordered

In [1]:
from torchdyn.models import *; from torchdyn import *
from torchdyn.nn import Augmenter

import torch
import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
# from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from utils import Learner, get_MNIST_dloaders, MetricTracker

## Initialization

In [2]:
# quick run for automated notebook validation
dry_run = True

In [3]:
# GPU
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

GPU State: cuda:0


## Data Loading

In [4]:
epoch = 10
path_to_data='./data/mnist_data'

trainloader, testloader = get_MNIST_dloaders(batch_size=64, size=28, path=path_to_data, download= False, num_workers=8)

<a id = 'NODE'></a>
## Vanilla Neural ODE

+ **vector field $f_\theta$**:  3–layer depth–invariant CNNs,with each layer followed by an instance normalization layer

In [5]:
dim = 42
vfunc = nn.Sequential(
    # input data = (64,1,28,28)
        nn.GroupNorm(1, 1),
        nn.Conv2d(1, dim, 3, padding=1, bias=False),
        nn.Softplus(),
        nn.Conv2d(dim, dim, 3, padding=1, bias=False),
        nn.Softplus(), 
        nn.GroupNorm(dim, dim),
        nn.Conv2d(dim, 1, 1),
    ).to(device)

In [6]:
nde = NeuralODE(vfunc, 
               solver='dopri5',
               sensitivity='adjoint',
               atol=1e-4,
               rtol=1e-4
               ).to(device)
# NOTE: the first noop `Augmenters` is used only to keep the `nde` at index `2`. Used to extract NFEs in Learner.
model = nn.Sequential(
        Augmenter(1, 0), # does nothing
        Augmenter(1, 0), # does nothing
        nde,
        nn.Conv2d(1, 3, 1),
        nn.AdaptiveAvgPool2d(4),
        nn.Flatten(),                     
        nn.Linear(3*16, 10)
    ).to(device)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [None]:
learn = Learner(model, trainloader, testloader, device)
cb = MetricTracker()
logger = TensorBoardLogger(save_dir='lightning_logs/MNIST/', name = 'model1')
trainer = pl.Trainer(max_epochs=1,
                     accelerator=accelerator,
                     fast_dev_run=dry_run,
                     gpus=torch.cuda.device_count(),
                     logger = logger,
                     callbacks = [cb]
                     )
trainer.fit(learn)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | model    | Sequential       | 16.9 K
1 | loss     | CrossEntropyLoss | 0     
2 | accuarcy | Accuracy         | 0     
----------------------------------------------
16.9 K    Trainable params
0         Non-tra

Training: 0it [00:00, ?it/s]

In [None]:
trainer.test(learn)

In [None]:
file = './lightning_logs/MNIST/model1/logs.pt'
torch.save(cb.collection, file)
torch.load(file)