# Neural ODEs from scratch

Integrators are the basis of every neural ODE and can be found under utils.integrators. The solver is used to integrate between two points in time: t and t + dt. Here are some basic examples:

In [19]:
# 1. The forward Euler method is the most simple first order method (comparable to ResNet step)
class Euler():
    name = "Euler"
    order = 1
    def step(f, t, dt, y):
        return y + dt * f(t, y)

# 2. The Modified Euler uses an intermediate step
class ModifiedEuler():
    name = "ModifiedEuler"
    order = 2
    def step(f, t, dt, y):
        k1 = f(t, y)
        k2 = f(t + dt, y + dt * k1)
        return y + dt * (k1 + k2) / 2

# 3. The famous Runge Kutta 4 method uses 4 estimates of the gradient
class RungeKutta4():
    name = "RungeKutta4"
    order = 4
    def step(f, t, dt, y):
        k1 = f(t, y)
        k2 = f(t + dt / 2,  y + dt * k1 / 2)
        k3 = f(t + dt / 2,  y + dt * k2 / 2)
        k4 = f(t + dt,      y + dt * k3)
        return y + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6

In turn, solvers are used to take multiple steps using integrators to integrate trajectories and solve the initial value problem. For our usecase, we will only use fixed step solvers, however they can be substituted using adaptive step solvers.

In [4]:
class FixedStepSolver():
    def __init__(self, f, t0, t1, method, steps=None, step_size=None, verbose=False):
        self.name = "FixedStepSolver"
        self.f = f
        self.t0 = t0
        self.t1 = t1
        self.method = method
        self.exp = None

        # Define the step sizes h to go from t0 to t1
        assert steps or step_size, "Either steps or step size should be defined!"
        if steps:
            self.hs = [(t1 - t0) / steps for s in range(steps)]
        else:
            assert step_size <= (t1 - t0), "Step size should be smaller than integration time!"
            self.hs = [step_size for _ in range(int((t1 - t0) / step_size))]
            # Add the residual in the last step, if required
            if (t1 - t0) % step_size != 0:
                self.hs.append((t1 - t0) - sum(self.hs))
        if verbose:
            print("This solver will be using the following time deltas:", self.hs)
            print("This solver will require", self.method.order * len(self.hs), "gradient evaluations")

    def integrate(self, y, reset=False):
        # For every step h, we integrate using the given method, starting from t0, y
        t = self.t0
        for h in self.hs:
            y = self.method.step(self.f, t, h, y)
            t += h
        return y

# Model

We can adapt an ResNet (image classifier) to a Neural ODE, by making some small adaptations:
1. We downsample the image before feeding it to the ODE Net, instead of within the Residual Blocks
2. Use the GradientNet to estimate the gradient at each timestep, an pass that to the solver function (f). The forward pass becomes: x = self.solver.integrate(x)
3. Concatenate time at each forward call


In [7]:
# Inspired by:
# https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnet18
# https://github.com/rtqichen/torchdiffeq/blob/master/examples/odenet_mnist.py

import sys, time
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.training.tensor_ops import Flatten

class GradientNet(nn.Module):
    def __init__(self, channels):
        super(GradientNet, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(channels + 1, channels, 3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(channels + 1, channels, 3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True)
        )

    def concat_time(self, t, x):
        return torch.cat([x, torch.ones_like(x[:, :1, :, :]) * t], 1)

    def forward(self, t, x):
        x = self.block1(self.concat_time(t, x))
        return self.block2(self.concat_time(t, x))

class RKNet(nn.Module):
    def __init__(self, solver, integrator, t0, t1, classes):
        super(RKNet, self).__init__()

        # First we want to downsample the image to have an acceptable size for the ODE solver (stride of 2 causes reduction of H/W)
        self.sampling = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=1),              # [batch_size, 1, 28, 28] -> [batch_size, 64, 28, 28]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, stride=2, padding=1),  # [batch_size, 64, 28, 28] -> [batch_size, 64, 13, 13]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, stride=2, padding=1),  # [batch_size, 64, 13, 13] -> [batch_size, 64, 6, 6]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )  
        
        # Then we want to apply the magic: model differences so we can use ode solvers
        self.gradient = GradientNet(channels=64)
        self.solver = solver(self.gradient, t0, t1, integrator, steps=1)

        # To make predictions we use a fully connected layer (so convenient :))
        self.output = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            Flatten(),
            nn.Linear(64, classes)
        )

        # Define an optimizer so we can call it from the outside
        params = list(self.sampling.parameters()) + list(self.gradient.parameters()) + list(self.output.parameters())
        self.optimizer = torch.optim.Adam(params, lr=0.1) #torch.optim.SGD(params, lr=0.1, momentum=0.9) #
        self.loss_module = nn.CrossEntropyLoss()

        # Check if GPU is available
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.cuda()
            self.device = 'cuda'

        # Print the amount of parameters
        print("This model is using %d parameters" % (sum(p.numel() for p in params if p.requires_grad)))

    def forward(self, x):
        x = self.sampling(x)
        x = self.solver.integrate(x)
        return self.output(x)

In [9]:
# Make the training / testing loaders
from utils.training.datasets import get_mnist
train_loader, test_loader = get_mnist(batch_size=128)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to .data/mnist\MNIST\raw\train-images-idx3-ubyte.gz


100.1%

Extracting .data/mnist\MNIST\raw\train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to .data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz


113.5%

Extracting .data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to .data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz


100.4%

Extracting .data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to .data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz


180.4%

Extracting .data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [14]:
# Define the model
from utils.solvers.fixed_step import FixedStepSolver
from utils.integrators.simple import RungeKutta4
model = RKNet(FixedStepSolver, RungeKutta4, 0.0, 1.0, 10)

This model is using 208138 parameters


In [16]:
# Define the convenient average calculators
from utils.training.training_ops import Average, accuracy
train_loss, test_loss, test_acc = Average(), Average(), Average()

In [18]:
epochs = 20
for e in range(epochs):
    model.train()
    for i, (data, target) in enumerate(train_loader):
        # Convert data proper device, forward pass and calculate loss
        data, target = data.to(model.device), target.to(model.device)
        pred = model(data)
        loss = model.loss_module(pred, target)
        
        #Take optimizer step
        model.optimizer.zero_grad()
        loss.backward()
        model.optimizer.step()
        train_loss.update(loss.item())

    model.eval()
    for i, (data, target) in enumerate(test_loader):
        data, target = data.to(model.device), target.to(model.device)
        pred = model(data)
        loss = model.loss_module(pred, target)
        test_loss.update(loss.item())
        test_acc.update(accuracy(pred, target))

    print('Epoch: %d / %d | train loss: %.3f | test loss: %.3f | test acc: %.3f' % (e + 1, epochs, train_loss.eval(), test_loss.eval(), 100 * test_acc.eval()))

    # Reset statistics each epoch:
    train_loss.reset(), test_loss.reset(), test_acc.reset()

Epoch: 1 / 20 | train loss: 0.218 | test loss: 0.199 | test acc: 93.750
Epoch: 2 / 20 | train loss: 0.060 | test loss: 0.077 | test acc: 97.776
Epoch: 3 / 20 | train loss: 0.053 | test loss: 0.053 | test acc: 98.287
Epoch: 4 / 20 | train loss: 0.044 | test loss: 0.048 | test acc: 98.618
Epoch: 5 / 20 | train loss: 0.046 | test loss: 0.150 | test acc: 95.523
Epoch: 6 / 20 | train loss: 0.039 | test loss: 0.040 | test acc: 98.778
Epoch: 7 / 20 | train loss: 0.038 | test loss: 0.047 | test acc: 98.598
Epoch: 8 / 20 | train loss: 0.035 | test loss: 0.045 | test acc: 98.678
Epoch: 9 / 20 | train loss: 0.037 | test loss: 0.047 | test acc: 98.538
Epoch: 10 / 20 | train loss: 0.030 | test loss: 0.046 | test acc: 98.728
Epoch: 11 / 20 | train loss: 0.031 | test loss: 0.103 | test acc: 96.955
Epoch: 12 / 20 | train loss: 0.029 | test loss: 0.050 | test acc: 98.478
Epoch: 13 / 20 | train loss: 0.032 | test loss: 0.103 | test acc: 97.316
Epoch: 14 / 20 | train loss: 0.030 | test loss: 0.049 | test