In [1]:
# allows imports from parent folders
from prep import prep_nbook
prep_nbook()

import sklearn.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

import torch.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torchdiffeq

from torchdyn.models import *; from torchdyn.data_utils import *
from torchdyn import *

# Image classification with Neural ODEs and variants

In this notebook we explore standard image classification on MNIST and CIFAR10 with convolutional neural ODE variants.
* Depth-invariant neural ODE
* Galerkin neural ODE (GalNODE)
* Galerkin neural ODE with adjoint loss
In the following notebooks we'll explore `augmentation` strategies that can be easily applied to the models below with the flexible `torchdyn` API. Here, we use simple `0-augmentation` (the ANODE model).

### Data

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
batch_size=64
size=28
path_to_data='../data/mnist_data'

all_transforms = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

train_data = datasets.MNIST(path_to_data, train=True, download=True,
                            transform=all_transforms)
test_data = datasets.MNIST(path_to_data, train=False,
                           transform=all_transforms)

trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

### Learner

In [4]:
def accuracy(preds, targets):
    """Accuracy metric"""
    _, preds = torch.max(preds, dim=1)
    acc = 100*(preds == targets).float().mean()
    return acc

In [5]:
class Learner(pl.LightningModule):
    def __init__(self, model:nn.Module, lr=1e-3):
        super().__init__()
        defaults.update(settings)
        self.lr = lr
        self.model = model
        self.c = 0
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch   
        x, y = x.to(device), y.to(device)
        y_hat = self.model(x)   
        loss = nn.CrossEntropyLoss()(y_hat, y)
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}   

    def test_step(self, batch, batch_nb):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self(x)
        acc = accuracy(y_hat, y)
        return {'test_loss': nn.CrossEntropyLoss()(y_hat, y), 'test_accuracy': acc}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_accuracy'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'avg_test_accuracy': avg_acc,
                'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)

    def train_dataloader(self):
        return trainloader

    def test_dataloader(self):
        return testloader

## MNIST (Depth-Invariant Neural ODE)

### Model

In [6]:
settings = {'type':'classic', 'controlled':False, 'solver':'dopri5'}

In [7]:
func = DEFunc(nn.Sequential(nn.Conv2d(6, 6, 3, padding=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1) 
                             )).to(device)

In [8]:
neuralDE = NeuralDE(func, settings).to(device)

model = nn.Sequential(Augmenter(augment_dims=5),
                      nn.BatchNorm2d(6),
                      neuralDE,
                      nn.Conv2d(6, 1, 3, padding=1),
                      nn.Flatten(),
                      nn.Linear(28*28, 10)).to(device)

In [9]:
logger = WandbLogger()

In [10]:
learn = Learner(model)
trainer = pl.Trainer(min_nb_epochs=1, max_nb_epochs=2)
trainer.fit(learn)

INFO:lightning:
   | Name               | Type        | Params
-----------------------------------------------
0  | model              | Sequential  | 8 K   
1  | model.0            | Augmenter   | 0     
2  | model.1            | BatchNorm2d | 12    
3  | model.2            | NeuralDE    | 990   
4  | model.2.defunc     | DEFunc      | 990   
5  | model.2.defunc.m   | Sequential  | 990   
6  | model.2.defunc.m.0 | Conv2d      | 330   
7  | model.2.defunc.m.1 | Tanh        | 0     
8  | model.2.defunc.m.2 | Conv2d      | 330   
9  | model.2.defunc.m.3 | Tanh        | 0     
10 | model.2.defunc.m.4 | Conv2d      | 330   
11 | model.2.adjoint    | Adjoint     | 0     
12 | model.3            | Conv2d      | 55    
13 | model.4            | Flatten     | 0     
14 | model.5            | Linear      | 7 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …






1

2 epochs are not enough. Feel free to keep training and using all kinds of scheduling and optimization tricks :)

In [11]:
trainer.test()



HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=157.0, style=Progr…

--------------------------------------------------------------------------------
TEST RESULTS
{'avg_test_accuracy': 95.87977600097656,
 'avg_test_loss': 0.1398904025554657,
 'test_loss': 0.1398904025554657}
--------------------------------------------------------------------------------



## MNIST (GalNODE)

In [22]:
settings = {'type':'classic', 'controlled':False, 'solver':'dopri5', 'return_traj':False}

In [23]:
func = DEFunc(nn.Sequential(DepthCat(1),
                            GalConv2d(6, 6, 3, padding=1, expfunc=FourierExpansion, n_harmonics=4, n_eig=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1) 
                           )                       
             ).to(device)

In [24]:
neuralDE = NeuralDE(func, settings).to(device)

model = nn.Sequential(Augmenter(augment_dims=5),
                      nn.BatchNorm2d(6),
                      neuralDE,
                      nn.Conv2d(6, 1, 3, padding=1),
                      nn.Flatten(),
                      nn.Linear(28*28, 10)).to(device)

In [25]:
logger = WandbLogger()

In [26]:
learn = Learner(model, lr=1e-3)
trainer = pl.Trainer(min_nb_epochs=1, max_nb_epochs=2)
trainer.fit(learn)

INFO:lightning:
   | Name               | Type        | Params
-----------------------------------------------
0  | model              | Sequential  | 11 K  
1  | model.0            | Augmenter   | 0     
2  | model.1            | BatchNorm2d | 12    
3  | model.2            | NeuralDE    | 3 K   
4  | model.2.defunc     | DEFunc      | 3 K   
5  | model.2.defunc.m   | Sequential  | 3 K   
6  | model.2.defunc.m.0 | DepthCat    | 0     
7  | model.2.defunc.m.1 | GalConv2d   | 2 K   
8  | model.2.defunc.m.2 | Tanh        | 0     
9  | model.2.defunc.m.3 | Conv2d      | 330   
10 | model.2.defunc.m.4 | Tanh        | 0     
11 | model.2.defunc.m.5 | Conv2d      | 330   
12 | model.2.adjoint    | Adjoint     | 0     
13 | model.3            | Conv2d      | 55    
14 | model.4            | Flatten     | 0     
15 | model.5            | Linear      | 7 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …




1

In [71]:
trainer.test()



HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=157.0, style=Progr…

--------------------------------------------------------------------------------
TEST RESULTS
{'avg_test_accuracy': 91.18232727050781,
 'avg_test_loss': 0.3063780665397644,
 'test_loss': 0.3063780665397644}
--------------------------------------------------------------------------------



## MNIST (GalNODE, integral adjoint loss)

Next, we will train a MNIST classifier using an integral loss. This, as will be seen later, improves the rate of convergence of the model.

In [35]:
predictor = nn.Sequential(nn.Conv2d(6, 1, 3, padding=1),
                          nn.Flatten(),
                          nn.Linear(28*28, 10)).to(device)

In [36]:
class Cost(nn.Module):
    def __init__(self, criterion, predictor):
        super().__init__()
        # y needs to return targets at each value of `s`. Since for classification the target
        # is static, it simply returns the batch labels `y`.
        self.y = None
        self.criterion, self.predictor = criterion, predictor
    def forward(self, s, x):
        loss = self.criterion(self.predictor(x), self.y)
        return loss

In [37]:
c = Cost(nn.CrossEntropyLoss(), predictor)

In [38]:
settings = {'type':'classic', 'controlled':False, 'backprop_style':'integral_adjoint', 'cost':c,
            'solver':'dopri5', 'return_traj':True}

In [39]:
func = DEFunc(nn.Sequential(DepthCat(1),
                            GalConv2d(6, 6, 3, padding=1, expfunc=FourierExpansion, n_harmonics=4, n_eig=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1) 
                           )                       
             ).to(device)

In [40]:
neuralDE = NeuralDE(func, settings).to(device)


model = nn.Sequential(Augmenter(augment_dims=5),
                      neuralDE,
                      predictor).to(device)

Redefine Learner to account for the `integral loss` case

In [42]:
class Learner(pl.LightningModule):
    def __init__(self, model:nn.Module, lr):
        super().__init__()
        defaults.update(settings)
        self.lr = lr
        self.model = model
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch   
        x, y = x.to(device), y.to(device)
        # this line serves a specific purpose: set the integral loss target at `y` throughout the depth--flow
        c.y = y
        y_hat = self.model(x)   
        loss = nn.CrossEntropyLoss()(y_hat, y)
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}   
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.005)

    def train_dataloader(self):
        return trainloader
    
    def test_step(self, batch, batch_nb):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self(x)
        acc = accuracy(y_hat, y)
        return {'test_loss': nn.CrossEntropyLoss()(y_hat, y), 'test_accuracy': acc}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_accuracy'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'avg_test_accuracy': avg_acc,
                'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-5)

    def train_dataloader(self):
        return trainloader

    def test_dataloader(self):
        return testloader

In [43]:
logger = WandbLogger()
learn = Learner(model, lr=1e-3)
trainer = pl.Trainer(min_nb_epochs=1, max_nb_epochs=2)
trainer.fit(learn)

INFO:lightning:
   | Name               | Type       | Params
----------------------------------------------
0  | model              | Sequential | 11 K  
1  | model.0            | Augmenter  | 0     
2  | model.1            | NeuralDE   | 3 K   
3  | model.1.defunc     | DEFunc     | 3 K   
4  | model.1.defunc.m   | Sequential | 3 K   
5  | model.1.defunc.m.0 | DepthCat   | 0     
6  | model.1.defunc.m.1 | GalConv2d  | 2 K   
7  | model.1.defunc.m.2 | Tanh       | 0     
8  | model.1.defunc.m.3 | Conv2d     | 330   
9  | model.1.defunc.m.4 | Tanh       | 0     
10 | model.1.defunc.m.5 | Conv2d     | 330   
11 | model.1.adjoint    | Adjoint    | 0     
12 | model.2            | Sequential | 7 K   
13 | model.2.0          | Conv2d     | 55    
14 | model.2.1          | Flatten    | 0     
15 | model.2.2          | Linear     | 7 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …




1

In [116]:
trainer.test()



HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=157.0, style=Progr…

--------------------------------------------------------------------------------
TEST RESULTS
{'avg_test_accuracy': 89.79896545410156,
 'avg_test_loss': 0.3477240204811096,
 'test_loss': 0.3477240204811096}
--------------------------------------------------------------------------------



## CIFAR10 (Depth-Invariant Neural ODE)

In [131]:
batch_size=64
size=32
path_to_data='../data/cifar10_data'

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


train_data = datasets.CIFAR10(path_to_data, train=True, download=True,
                              transform=transform_train)
test_data = datasets.CIFAR10(path_to_data, train=False,
                             transform=transform_test)

trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


### Model

In [202]:
settings = {'type':'classic', 'backprop_style': 'autograd', 'return_traj':False, 'controlled':False, 'solver':'dopri5'}

In [204]:
func = DEFunc(nn.Sequential(nn.Conv2d(6, 6, 3, padding=1),
                            nn.Tanh(),
                            nn.Conv2d(6, 6, 3, padding=1)                       
                             )).to(device)

In [205]:
neuralDE = NeuralDE(func, settings).to(device)

model = nn.Sequential(Augmenter(augment_dims=3),
                      neuralDE,
                      nn.Conv2d(6, 1, 3, padding=1),
                      nn.Flatten(),
                      nn.Linear(1024, 10)).to(device)

In [206]:
logger = WandbLogger()
learn = Learner(model, lr=1e-3)
trainer = pl.Trainer(min_nb_epochs=1, max_nb_epochs=2)
trainer.fit(learn)

INFO:lightning:
   | Name               | Type       | Params
----------------------------------------------
0  | model              | Sequential | 10 K  
1  | model.0            | Augmenter  | 0     
2  | model.1            | NeuralDE   | 660   
3  | model.1.defunc     | DEFunc     | 660   
4  | model.1.defunc.m   | Sequential | 660   
5  | model.1.defunc.m.0 | Conv2d     | 330   
6  | model.1.defunc.m.1 | Tanh       | 0     
7  | model.1.defunc.m.2 | Conv2d     | 330   
8  | model.1.adjoint    | Adjoint    | 0     
9  | model.2            | Conv2d     | 55    
10 | model.3            | Flatten    | 0     
11 | model.4            | Linear     | 10 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...





1