In [1]:
import sys
sys.path.append('../')

In [2]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

import torch.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.metrics.functional import accuracy
import torchdiffeq

from torchdyn.models import *; from torchdyn import *

# Image Classification

In this notebook we explore standard image classification on MNIST and CIFAR10 with convolutional neural ODE variants.
* Depth-invariant neural ODE
* Galerkin neural ODE (GalNODE)
* Galerkin neural ODE with adjoint loss

In the following notebooks we'll explore `augmentation` strategies that can be easily applied to the models below with the flexible `torchdyn` API. Here, we use simple `0-augmentation` (the ANODE model).

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
batch_size=64
size=28
path_to_data='../data/mnist_data'

all_transforms = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

train_data = datasets.MNIST(path_to_data, train=True, download=True,
                            transform=all_transforms)
test_data = datasets.MNIST(path_to_data, train=False,
                           transform=all_transforms)

trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

The **Learner** is then defined as:

In [28]:
class Learner(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.lr = 1e-3
        self.model = model
        self.iters = 0.
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        self.iters += 1.
        x, y = batch   
        x, y = x.to(device), y.to(device)
        y_hat = self.model(x)   
        loss = nn.CrossEntropyLoss()(y_hat, y)
        
        epoch_progress = self.iters / self.loader_len
        acc = accuracy(y_hat, y)
        nfe = model[2].nfe ; model[2].nfe = 0
        tqdm_dict = {'train_loss': loss, 'accuracy': acc, 'NFE': nfe}
        logs = {'train_loss': loss, 'epoch': epoch_progress}
        return {'loss': loss, 'progress_bar': tqdm_dict, 'log': logs}   

    def test_step(self, batch, batch_nb):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self(x)
        acc = accuracy(y_hat, y)
        return {'test_loss': nn.CrossEntropyLoss()(y_hat, y), 'test_accuracy': acc}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_accuracy'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'avg_test_accuracy': avg_acc,
                'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)

    def train_dataloader(self):
        self.loader_len = len(trainloader)
        return trainloader

    def test_dataloader(self):
        self.test_loader_len = len(trainloader)
        return testloader

## Depth-Invariant Neural ODE 

In [29]:
func = nn.Sequential(nn.Conv2d(6, 6, 3, padding=1),
                    nn.Softplus(),
                    nn.Conv2d(6, 6, 3, padding=1),
                    nn.Softplus(),
                    nn.Conv2d(6, 6, 3, padding=1) 
                    ).to(device)

neuralDE = NeuralDE(func, 
                   solver='rk4',
                   sensitivity='adjoint',
                   s_span=torch.linspace(0, 1, 30)).to(device)

model = nn.Sequential(Augmenter(augment_dims=5),
                      nn.BatchNorm2d(6),
                      neuralDE,
                      nn.Conv2d(6, 1, 3, padding=1),
                      nn.Flatten(),
                      nn.Dropout(p=0.8),
                      nn.Linear(28*28, 10)).to(device)


logger = WandbLogger() # feel free to comment out or use a different logging scheme :)

In [30]:
learn = Learner(model)
trainer = pl.Trainer(max_epochs=20,
                     logger=logger,
                     benchmark=True,
                     limit_test_batches=0.25,
                     gpus=1,
                     progress_bar_refresh_rate=1
                     )

trainer.fit(learn)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]



  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 8 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

2 epochs are not enough. Feel free to keep training and using all kinds of scheduling and optimization tricks :)

In [31]:
trainer.test(learn)



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'avg_test_accuracy': tensor(0.9804, device='cuda:0'),
 'avg_test_loss': tensor(0.0550, device='cuda:0'),
 'test_loss': tensor(0.0550, device='cuda:0')}
--------------------------------------------------------------------------------



## MNIST (GalNODE)

In [22]:
settings = {'type':'classic', 'controlled':False, 'solver':'dopri5', 'return_traj':False}

In [23]:
func = DEFunc(nn.Sequential(DepthCat(1),
                            GalConv2d(6, 6, 3, padding=1, expfunc=FourierExpansion, n_harmonics=4, n_eig=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1) 
                           )                       
             ).to(device)

In [24]:
neuralDE = NeuralDE(func, settings).to(device)

model = nn.Sequential(Augmenter(augment_dims=5),
                      nn.BatchNorm2d(6),
                      neuralDE,
                      nn.Conv2d(6, 1, 3, padding=1),
                      nn.Flatten(),
                      nn.Linear(28*28, 10)).to(device)

In [25]:
logger = WandbLogger()

In [26]:
learn = Learner(model, lr=1e-3)
trainer = pl.Trainer(min_nb_epochs=1, max_nb_epochs=2)
trainer.fit(learn)

INFO:lightning:
   | Name               | Type        | Params
-----------------------------------------------
0  | model              | Sequential  | 11 K  
1  | model.0            | Augmenter   | 0     
2  | model.1            | BatchNorm2d | 12    
3  | model.2            | NeuralDE    | 3 K   
4  | model.2.defunc     | DEFunc      | 3 K   
5  | model.2.defunc.m   | Sequential  | 3 K   
6  | model.2.defunc.m.0 | DepthCat    | 0     
7  | model.2.defunc.m.1 | GalConv2d   | 2 K   
8  | model.2.defunc.m.2 | Tanh        | 0     
9  | model.2.defunc.m.3 | Conv2d      | 330   
10 | model.2.defunc.m.4 | Tanh        | 0     
11 | model.2.defunc.m.5 | Conv2d      | 330   
12 | model.2.adjoint    | Adjoint     | 0     
13 | model.3            | Conv2d      | 55    
14 | model.4            | Flatten     | 0     
15 | model.5            | Linear      | 7 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …




1

In [71]:
trainer.test()



HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=157.0, style=Progr…

--------------------------------------------------------------------------------
TEST RESULTS
{'avg_test_accuracy': 91.18232727050781,
 'avg_test_loss': 0.3063780665397644,
 'test_loss': 0.3063780665397644}
--------------------------------------------------------------------------------

