# Image Classification: CIFAR

Models:
+ [Vanilla NODE](#NODE)
+ [Aug.NODE](#ANODE)
+ [Input-Layer](#ILNODE)
+ [2nd-Ordered](#2NODE)

In [1]:
from torchdyn.models import *; from torchdyn import *
from torchdyn.nn import DataControl, DepthCat, Augmenter, GalConv2d, Fourier

import torch
import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
# from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.utilities.model_summary import ModelSummary

from utils import CIFARLearner, get_cifar_dloaders, MetricTracker


## Initialization

In [2]:
# quick run for automated notebook validation
dry_run = False

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

GPU State: cpu


## Data Loading

In [4]:
epoch = 20
path_to_data='./data/cifar10_data'

trainloader, testloader = get_cifar_dloaders(batch_size=64, size=32, path=path_to_data, download=False, num_workers=8)

<a id = 'NODE'></a>
## Vanilla Neural ODE

+ **vector field $f_\theta$**:  3–layer depth–invariant CNNs,with each layer followed by an instance normalization layer

In [5]:
dim = 62
vfunc = nn.Sequential(
    nn.GroupNorm(3, 3),
    nn.Conv2d(3, dim, 3, padding=1, bias=False),
    nn.Softplus(),                   
    nn.Conv2d(dim, dim, 3, padding=1, bias=False),
    nn.Softplus(), 
    nn.GroupNorm(dim, dim),
    nn.Conv2d(dim, 3, 1)
    ).to(device)

In [6]:
nde = NeuralODE(vfunc, 
               solver='dopri5',
               sensitivity='adjoint',
               atol=1e-4,
               rtol=1e-4
               ).to(device)
# NOTE: the first noop `Augmenters` is used only to keep the `nde` at index `2`. Used to extract NFEs in CIFARLearner.
model = nn.Sequential(Augmenter(1, 0), # does nothing
                      Augmenter(1, 0), # does nothing
                      nde,
                      nn.Conv2d(3, 3, 1),
                      nn.AdaptiveAvgPool2d(4),
                      nn.Flatten(),                     
                      nn.Linear(3*16, 10)).to(device)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [7]:
learn = CIFARLearner(model, trainloader, testloader)
cb1 = MetricTracker()
logger = TensorBoardLogger(save_dir='lightning_logs/CIFAR/', name = 'model1')
trainer = pl.Trainer(max_epochs=epoch,
                     accelerator=accelerator,
                     fast_dev_run=dry_run,
                     gpus=torch.cuda.device_count(),
                     logger = logger,
                     callbacks = [cb1]
                     )
trainer.fit(learn)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 37.1 K
-------------------------------------
37.1 K    Trainable params
0         Non-trainable params
37.1 K    Total params
0.148     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [17]:
trainer.test(learn)

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.43220001459121704
        test_loss            1.590834140777588
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.590834140777588, 'test_accuracy': 0.43220001459121704}]

In [18]:
file = 'logs.pt'
torch.save(cb1.collection, file)
torch.load(file)

[{'epoch': tensor(0.),
  'train_loss': tensor(1.8231),
  'accuracy': tensor(0.3533),
  'NFE': tensor(71.2717)},
 {'test_loss': tensor(1.5908), 'test_accuracy': tensor(0.4322)}]

<a id = 'ANODE'></a>
## Aug. Neural ODE

In [None]:
dim = 42
dim_half = 21
func = nn.Sequential(nn.GroupNorm(dim, dim),
                     nn.Conv2d(dim, dim, 3, padding=1, bias=False),
                     nn.Softplus(),                   
                     nn.Conv2d(dim, dim, 3, padding=1, bias=False),
                     nn.Softplus(), 
                     nn.GroupNorm(dim, dim),
                     nn.Conv2d(dim, dim, 1)
                     ).to(device)

In [None]:
nde = NeuralODE(func, 
               solver='dopri5',
               sensitivity='adjoint',
               atol=1e-4,
               rtol=1e-4,
               ).to(device)

# NOTE: the first noop `Augmenter` is used only to keep the `nde` at index `2`. Used to extract NFEs in CIFARLearner.
model = nn.Sequential(Augmenter(1, 0), # does nothing
                      Augmenter(1, 39),
                      nde,
                      nn.Conv2d(dim, 6, 1),
                      nn.AdaptiveAvgPool2d(4),
                      nn.Flatten(),                     
                      nn.Linear(6*16, 10)).to(device)

In [None]:
learn = CIFARLearner(model, trainloader, testloader)
cb2 = MetricTracker()
logger = TensorBoardLogger(save_dir='lightning_logs/CIFAR/', name = 'model2')
trainer = pl.Trainer(max_epochs=epoch,
                     accelerator=accelerator,
                     fast_dev_run=dry_run,
                     gpus=torch.cuda.device_count(),
                     logger = logger,
                     callbacks = [cb2]
                     )
trainer.fit(learn)

In [None]:
trainer.test(learn)

In [None]:
print(cb2.collection[-2:])
print(ModelSummary(learn).param_nums)

<a id = 'ILNODE'></a>
## Input-Layer Neural ODE

In [None]:
nde = NeuralODE(func, 
               solver='dopri5',
               sensitivity='adjoint',
               atol=1e-4,
               rtol=1e-4,
               ).to(device)

# NOTE: the first noop `Augmenters` is used only to keep the `nde` at index `2`. Used to extract NFEs in CIFARLearner.
model = nn.Sequential(Augmenter(1, 0),
                      nn.Conv2d(3, dim, 3, padding=1, bias=False),
                      nde,
                      nn.Conv2d(dim, 6, 1),
                      nn.AdaptiveAvgPool2d(4),
                      nn.Flatten(),                     
                      nn.Linear(6*16, 10)).to(device)

In [None]:
learn = CIFARLearner(model, trainloader, testloader)
cb3 = MetricTracker()
logger = TensorBoardLogger(save_dir='lightning_logs/CIFAR/', name = 'model3')
trainer = pl.Trainer(max_epochs=epoch,
                     accelerator=accelerator,
                     fast_dev_run=dry_run,
                     gpus=torch.cuda.device_count(),
                     logger = logger,
                     callbacks = [cb3]
                     )
trainer.fit(learn)

In [None]:
trainer.test(learn)

In [None]:
print(cb3.collection[-2:])
print(ModelSummary(learn).param_nums)

<a id = '2NODE'></a>
## 2nd-Ordered Neural ODE

In [None]:

nde = NeuralODE(func, 
               solver='dopri5',
               sensitivity='adjoint',
               atol=1e-4,
               rtol=1e-4,
               order=2,
               ).to(device)

model = nn.Sequential(nn.Conv2d(3, dim_half, 3, padding=1, bias=False),
                      Augmenter(1, dim_half),
                      nde,
                      nn.Conv2d(dim, 6, 1),
                      nn.AdaptiveAvgPool2d(4),
                      nn.Flatten(),                     
                      nn.Linear(6*16, 10)).to(device)

In [None]:
learn = CIFARLearner(model, trainloader, testloader)
cb4 = MetricTracker()
logger = TensorBoardLogger(save_dir='lightning_logs/CIFAR/', name = 'model4')
trainer = pl.Trainer(max_epochs=epoch,
                     accelerator=accelerator,
                     fast_dev_run=dry_run,
                     gpus=torch.cuda.device_count(),
                     logger = logger,
                     callbacks = [cb4]
                     )
trainer.fit(learn)

In [None]:
trainer.test(learn)

In [None]:
print(cb4.collection[-2:])
print(ModelSummary(learn).param_nums)