# Assignment 6: Neural Ordinary Differential Equations

*Author:* Thomas Adler

*Copyright statement:* This  material,  no  matter  whether  in  printed  or  electronic  form,  may  be  used  for  personal  and non-commercial educational use only.  Any reproduction of this manuscript, no matter whether as a whole or in parts, no matter whether in printed or in electronic form, requires explicit prior acceptance of the authors. 

*This assignment discusses the following paper:* Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David K. Duvenaud. "Neural ordinary differential equations." Advances in neural information processing systems 31 (2018). https://arxiv.org/abs/1806.07366

## Exercise 1: Residual layer and Euler's method

Consider a fully connected layer with residual connection defined as
\begin{align*}
    h_{n+1} = h_n + f(h_n, \theta),
\end{align*}
where $f(h_n, \theta)$ is a neural network with parameters $\theta$. 
Show that this is equivalent to applying the Euler method to an ODE of the form
\begin{align*}
    \frac{d h(t)}{dt} = f(h(t), \theta). 
\end{align*}
Show both directions and discuss their . 
What step size do we have to use for Euler's method to establish the equivalence? 
What is the appropriate initial-value problem for this setting? 
Consider a ResNet with different parameters in each layer, i.e., $h_{n+1} = h_n + f(h_n, \theta_n)$. 
How could we modify the Neural ODE to reflect that property? 

########## YOUR SOLUTION HERE ##########

##### Equivalence between the residual layer and the Euler method

From the exercise description:
$$
h_{n+1} = h_n + f(h_n, \theta)
$$
where $f(h_n, \theta)$ is the output of a neural network using parameters $\theta$ and input $h_n$.

The Euler method for solving this ODE with a step size $\Delta t$ is defined as:

$$
h(t + \Delta t) = h(t) + \Delta t \cdot f(h(t), \theta)
$$
Comparing the equation, it is clear that with $\Delta t = 1$, the equivalence is prooved:

$$
h_{n+1} = h_n + 1 \cdot f(h_n, \theta) = h_n + f(h_n, \theta)
$$
Each step of the residual layer can then be considered as one Euler step for the ODE, with a step size of $\Delta t = 1$.

##### Step size and initial-value problem

Since we are considering the residual network layer equation as a discrete update using the Euler equivalence, we could use the same approach for Euler update:
$$
h(0) = h_0
$$
where $h_0$ is the initial value of $h$ at time $t = 0$. This corresponds to the input to the first layer of the neural network.

##### Modifications for varying parameters $\theta_n$

If different parameters are used in each layer with $\theta_n$ instead of a constant $\theta$, we have to consider it in the derivation as:
$$
\frac{d h(t)}{dt} = f(h(t), \theta(t))
$$
where $\theta(t)$ changes at discrete times corresponding to layer transitions.

## Exercise 2: A scalar linear ODE

In the simplest case, $f$ is scalar and linear, i.e., $f : \mathbb R \to \mathbb R$ has the form $f(h) = wh$. 
Solve the ODE 
\begin{align*}
    \frac{dh}{dt} = wh
\end{align*}
by separation of variables. 
Moreover, prove that the inhomogeneous ODE
\begin{align*}
    \frac{dh}{dt} = wh + b
\end{align*}
has the solution 
\begin{align*}
    h = \exp(w(t+c)) - \frac{b}{w}.
\end{align*}

########## YOUR SOLUTION HERE ##########

##### Homogeneous ODE Solution via Separation of Variables

Considering the ODE of the exercise:
$$
\frac{dh}{dt} = wh
$$
where $w$ is a constant.

To solve this using the method of separation of variables, we have to modify the equation as:
$$
\frac{dh}{h} = w dt
$$

Integrate both sides:
$$
\int \frac{1}{h} dh = \int w dt
$$

Which leads to:
$$
\ln |h| = wt + C
$$
where $C$ is the constant of integration.

Resolving the equation for $h$:
$$
|h| = e^{wt + C} = e^C e^{wt}
$$

At the end, we can consider $e^C$ as costant.

##### Inhomogeneous ODE Solution

To find a solution of the form $h = \exp(w(t+c)) - \frac{b}{w}$, let's substitute $h$ into the given equation:
$$
h = e^{w(t+c)} - \frac{b}{w}
$$

Taking the derivative $\frac{dh}{dt}$ gives:
$$
\frac{dh}{dt} = w e^{w(t+c)}
$$

Substituting $\frac{dh}{dt}$ and $h$ into the given ODE:
$$
w e^{w(t+c)} = w \left( e^{w(t+c)} - \frac{b}{w} \right) + b
$$

which leads to:
$$
w e^{w(t+c)} = w e^{w(t+c)} - b + b
$$

Ending with:
$$
w e^{w(t+c)} = w e^{w(t+c)}
$$

This proof that $h = \exp(w(t+c)) - \frac{b}{w}$ is indeed a solution to the inhomogeneous equation. The constant $c$ can be determined from initial conditions.

## Exercise 3: A scalar nonlinear ODE

Solve the nonlinear ODE
\begin{align*}
    \frac{dh}{dt} = e^{wh + b}
\end{align*}
by separation of variables. 

########## YOUR SOLUTION HERE ##########

##### Separating Variables

Rearranging the equation to isolate terms involving $h$ and $t$:
$$
\frac{dh}{e^{wh + b}} = dt
$$

##### Integrate the left side:

$$
\int \frac{1}{e^{wh + b}} dh
$$

Using the substitution $u = wh + b$, thus $du = w dh$ and $dh = \frac{du}{w}$. Substitute into the integral:
$$
\int \frac{1}{e^u} \frac{du}{w} = \frac{1}{w} \int e^{-u} du
$$
$$
= \frac{1}{w} \int e^{-u} du = \frac{-1}{w} e^{-u} + C
$$
Substitute back for $u$:
$$
\frac{-1}{w} e^{-(wh + b)} + C
$$

##### Integrate the right side:

$$
\int dt = t + D
$$

where $D$ is the constant of integration.

##### Solve for $h(t)$

After integration, equate the two expressions:
$$
\frac{-1}{w} e^{-(wh + b)} + C = t + D
$$

To solve for $h$:
$$
e^{-(wh + b)} = -w(t + D - C)
$$

Taking the natural logarithm of both sides (assuming  $-w(t + D - C) > 0$ to keep the logarithm real):

$$
-(wh + b) = \ln(-w(t + D - C))
$$

$$
wh + b = -\ln(-w(t + D - C))
$$

$$
h = \frac{-\ln(-w(t + D - C)) - b}{w}
$$

Considering that $D$ and $C$ are constants, they can be combined into $K$:

$$
h = \frac{-\ln(-w(t + K)) - b}{w}
$$

## Exercise 4: A multidimensional linear ODE

Consider the multidimensional case $f(h) = Wh$ where $h \in \mathbb R^d, W \in \mathbb R^{d \times d}$. 
Prove that the ODE
\begin{align*}
    \frac{dh}{dt} = Wh
\end{align*}
has the solution
\begin{align*}
    h = e^{tW} c
\end{align*}
where $c \in \mathbb R^d$ is some arbitrary constant vector and
\begin{align*}
    e^W = I + W + \frac12 W^2 + \cdots = \sum_{k=0}^\infty \frac{1}{k!} W^k
\end{align*}
is the matrix exponential function. 
Moreover, prove that the inhomogeneous ODE
\begin{align*}
    \frac{dh}{dt} = Wh + b
\end{align*}
has the solution 
\begin{align*}
    h = e^{tW} c - W^{-1} b. 
\end{align*}

########## YOUR SOLUTION HERE ##########

Differentiating $h = e^{tW}c$ with respect to $t$:
$$
\frac{dh}{dt} = \frac{d}{dt}(e^{tW}c)
$$

Using the derivative of the matrix exponential:
$$
\frac{d}{dt}(e^{tW}) = We^{tW}
$$

Thus:
$$
\frac{dh}{dt} = We^{tW}c
$$

Given the form of $h$:
$$
h = e^{tW}c
$$

Substituting back:
$$
\frac{dh}{dt} = Wh
$$
Therefore $h = e^{tW}c$ is a solution to the homogeneous ODE.

##### Solution to the Inhomogeneous ODE

Substituting $h = e^{tW}c - W^{-1}b$ into the ODE:
$$
\frac{dh}{dt} = \frac{d}{dt}(e^{tW}c - W^{-1}b)
$$
$$
\frac{dh}{dt} = We^{tW}c - 0
$$
Since $W^{-1}b$ is constant with respect to $t$.

This simplifies to:
$$
\frac{dh}{dt} = We^{tW}c
$$

Now, evaluating $Wh + b$:
$$
Wh + b = W(e^{tW}c - W^{-1}b) + b
$$
$$
Wh + b = We^{tW}c - WW^{-1}b + b
$$
$$
Wh + b = We^{tW}c + b - b
$$
$$
Wh + b = We^{tW}c
$$

Therefore:
$$
\frac{dh}{dt} = Wh + b
$$
Thus, $h = e^{tW}c - W^{-1}b$ is a solution to the inhomogeneous ODE, assuming $W$ is invertible.

## Exercise 5: The adjoint method

In the following, let $t \in [0, 1]$ and let $h(0) = x$ be the input to our neural ODE model, which is given by 
\begin{align*}
    \frac{d h(t)}{dt} = f(h(t), \theta). 
\end{align*}
We compute the network output as 
\begin{align*}
    \hat y = h(1) = h(0) + \int_0^1 f(h(t), \theta) dt. 
\end{align*}
As usual, we have a loss function $L(\hat y, y)$ and are interested in the gradients with respect to parameters $\theta$. As with conventional neural networks the key to computing these gradients are the deltas, i.e., the derivative of the loss with respect to the activations, i.e., $\partial L / \partial h(t)$ for $t \in [0, 1]$. 
Show that they follow an ODE given by
\begin{align*}
    \frac{d}{dt} \frac{\partial L}{\partial h(t)} = -\frac{\partial L}{\partial h(t)}\frac{\partial f(h(t), \theta)}{\partial h(t)}. 
\end{align*}
This ODE is called the adjoint equation. 
Which initial value problem do we have to solve to obtain the correct gradients? 

*Hint: For some $\varepsilon > 0$, use the chain rule*
\begin{align*}
    \frac{\partial L}{\partial h(t)} = \frac{\partial L}{\partial h(t+\varepsilon)}\frac{\partial h(t+\varepsilon)}{\partial h(t)}
\end{align*}
*and consider the limit as $\varepsilon \to 0$.*

*Note that with slight abuse of notation by $\frac{\partial f(g(t))}{\partial g(t)}$ we mean $\frac{\partial f(x)}{\partial x}|_{x=g(t)}$.*

########## YOUR SOLUTION HERE ##########

##### Computing the Adjoint Equation

Given the loss $L$, we define the adjoint variable:
$$
a(t) = \frac{\partial L}{\partial h(t)}
$$

This variable tracks the changes in $L$ w.r.t. $h$. The aim is to find an ODE for $a$.

From the chain rule and using a small increment $\varepsilon$:
$$
\frac{\partial L}{\partial h(t)} = \frac{\partial L}{\partial h(t+\varepsilon)}\frac{\partial h(t+\varepsilon)}{\partial h(t)}
$$

Now, considering the dynamics of $h$:
$$
h(t+\varepsilon) = h(t) + \varepsilon f(h(t), \theta) + o(\varepsilon)
$$
$$
\frac{\partial h(t+\varepsilon)}{\partial h(t)} = I + \varepsilon \frac{\partial f(h(t), \theta)}{\partial h(t)} + o(\varepsilon)
$$

Substituting and taking the limit as $\varepsilon \to 0$ gives:
$$
\frac{\partial L}{\partial h(t)} = \frac{\partial L}{\partial h(t)} + \varepsilon \frac{\partial L}{\partial h(t)} \frac{\partial f(h(t), \theta)}{\partial h(t)}
$$
$$
\frac{d}{dt} \frac{\partial L}{\partial h(t)} = -\frac{\partial L}{\partial h(t)}\frac{\partial f(h(t), \theta)}{\partial h(t)}
$$

This is the adjoint equation:
$$
\frac{d a(t)}{dt} = -a(t) \frac{\partial f(h(t), \theta)}{\partial h(t)}
$$

To compute the gradients w.r.t. $\theta$ we have4 to solve the equation backward from $t=1$ to $t=0$. Considering $h(1) = \hat{y}$ we have:

$$
a(1) = \frac{\partial L}{\partial h(1)} = \frac{\partial L}{\partial \hat{y}}
$$

## Exercise 6: Gradients with respect to parameters

Show that the gradients with respect to the parameters follow the ODE
\begin{align*}
    \frac{d}{dt} \frac{\partial L}{\partial \theta} &= -\frac{\partial L}{\partial h(t)} \frac{\partial f(h(t), \theta)}{\partial \theta}. 
\end{align*}

*Hint: The parameters are shared in time. To correctly account for the usage of $\theta$ along the trajectory, it is helpful to consider $\theta(t) = \theta$ as a constant function of time with $\dot \theta(t) = 0$.
In that light, for some $\varepsilon > 0$, use the chain rule*
\begin{align*}
    \frac{\partial L}{\partial \theta(t)} = \frac{\partial L}{\partial h(t+\varepsilon)}\frac{\partial h(t+\varepsilon)}{\partial \theta(t)} + \frac{\partial L}{\partial \theta(t+\varepsilon)}\frac{\partial \theta(t+\varepsilon)}{\partial \theta(t)}.
\end{align*}

What happens when $\theta(t)$ becomes an arbitrary differentiable function of time instead of a constant? 

########## YOUR SOLUTION HERE ##########

##### Apply the chain rule

Given the hint and using a small increment $\varepsilon > 0$:
$$
\frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial h(t+\varepsilon)}\frac{\partial h(t+\varepsilon)}{\partial \theta} + \frac{\partial L}{\partial \theta(t+\varepsilon)}
$$
Since $\theta(t+\varepsilon) = \theta$ and $\dot{\theta}(t) = 0$, the second term simplifies to:
$$
\frac{\partial \theta(t+\varepsilon)}{\partial \theta} = I \quad \text{(Identity matrix)}
$$
Then:
$$
\frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial h(t+\varepsilon)}\frac{\partial h(t+\varepsilon)}{\partial \theta} + \frac{\partial L}{\partial \theta}
$$
Since $h(t+\varepsilon) \approx h(t) + \varepsilon f(h(t), \theta)$, by differentiation we get:
$$
\frac{\partial h(t+\varepsilon)}{\partial \theta} \approx \varepsilon \frac{\partial f(h(t), \theta)}{\partial \theta}
$$

Differentiating both sides with respect to $t$, and having the limit $\varepsilon \to 0$:
$$
\frac{d}{dt} \frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial h(t)} \frac{\partial f(h(t), \theta)}{\partial \theta}
$$
Here we used the fact that:
$$
\frac{d}{dt} \left( \frac{\partial L}{\partial h(t)} \frac{\partial h(t)}{\partial \theta} \right) = \frac{\partial L}{\partial h(t)} \frac{d}{dt} \left(\frac{\partial h(t)}{\partial \theta}\right)
$$
Since $\frac{\partial h(t)}{\partial \theta} = \frac{\partial f(h(t), \theta)}{\partial \theta}$, we have:
$$
\frac{d}{dt} \frac{\partial L}{\partial \theta} = -\frac{\partial L}{\partial h(t)} \frac{\partial f(h(t), \theta)}{\partial \theta}
$$


#### What if $ \theta(t)$ varies with time?

If $\theta$ is now a function of time, $\theta(t)$, the above derivation would need additional terms:

$$
\frac{d}{dt} \frac{\partial L}{\partial \theta} = -\frac{\partial L}{\partial h(t)} \frac{\partial f(h(t), \theta(t))}{\partial \theta(t)} + \frac{\partial L}{\partial \theta(t)} \dot{\theta}(t)
$$

## Exercise 7: Implement a neural ODE

Using a deep learning framework like PyTorch, we can implement Neural ODEs in two different ways. 
One way is to explicitly implement the forward and backward ODEs as derived in the previous exercises. 
Alternatively, we can implmement an ODE solver using PyTorch operations. 
Then PyTorch will differentiate through the solver automatically. 
This sidesteps the need for an implementation of the adjoint method. 
The code below implements a Neural ODE module in PyTorch. 
Add a `forward` function to the class `ODEBlock` that implements the standard Runge-Kutta solver with fixed step size $\eta$ as defined by
\begin{align*}
    h_{n+1} &= h_n + \frac{\eta}{6}(k_1 + 2 k_2 + 2 k_3 + k_4) \\
    t_{n+1} &= t_n + \eta \\
    k_1 &= f(t_n, h_n) \\
    k_2 &= f(t_n + \frac{\eta}{2}, h_n + \eta \frac{k_1}{2}) \\
    k_3 &= f(t_n + \frac{\eta}{2}, h_n + \eta \frac{k_2}{2}) \\
    k_4 &= f(t_n + \eta, h_n + \eta k_3). 
\end{align*}
Moreover, add a training/validation loop for the CIFAR-10 dataset, train the model, visualize and discuss your results. 
The provided hyperparameters should work quite well but feel free to experiment with them. 
What are potential disadvantages of this approach?

*Note: The `Cat` module below needs its attribute `t` set accordingly before invoking its `forward` method. 
The reason for this is design is to enable its utilization with `nn.Sequential`.*


In [2]:
# Architecture based on the ODE-Only Net by
# Carrara, F., Amato, G., Falchi, F. and Gennaro, C., 2019, September. 
# Evaluation of Continuous Image Features Learned by ODE Nets. 
# In International Conference on Image Analysis and Processing (ICIAP '19) (pp. 432-442). Springer, Cham.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.nn.utils import parameters_to_vector as p2v
import matplotlib.pyplot as plt
%matplotlib inline

class Cat(nn.Module):
    """ Concatenate an image tensor x with a feature plane of constant value t. """
    def __init__(self, t=0):
        super().__init__()
        self.t = t

    def forward(self, x):
        t = torch.ones_like(x[:, :1, :, :]) * self.t
        return torch.cat([t, x], 1)

class ODEBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.module = nn.Sequential(
            nn.GroupNorm(32, dim),
            nn.ReLU(),
            Cat(),
            nn.Conv2d(dim+1, dim, 3, padding=1),
            nn.GroupNorm(32, dim),
            nn.ReLU(),
            Cat(),
            nn.Conv2d(dim+1, dim, 3, padding=1),
            nn.GroupNorm(32, dim)
        )

    def forward(self, x, t0=0, t1=1, step_size=0.1):
        # pass # IMPLEMENT ME!
        time = t0
        while time < t1:
            k1 = self.module(x)
            k2 = self.module(x + step_size * k1 / 2)
            k3 = self.module(x + step_size * k2 / 2)
            k4 = self.module(x + step_size * k3)
            x = x + step_size * (k1 + 2 * k2 + 2 * k3 + k4) / 6
            time += step_size
            
        return x


class ODENet(nn.Module):
    def __init__(self, in_dim, ode_dim, out_dim, dropout=0.):
        super().__init__()
        self.module = nn.Sequential(
            nn.Conv2d(in_dim, ode_dim, 4, stride=2, padding=1),
            ODEBlock(ode_dim),
            nn.GroupNorm(32, ode_dim),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Dropout(dropout),
            nn.Flatten(),
            nn.Linear(ode_dim, out_dim)
        )

    def forward(self, x):
        return self.module(x)

device = torch.device('cpu') # torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cifar = device.type == 'cuda'
batch_size = 256 if device=='gpu' else 128
epochs = 10 if device=='gpu' else 3

tfm = [
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(hue=.05, saturation=.05),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
] if cifar else [
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]

dataset = torchvision.datasets.CIFAR10 if cifar else torchvision.datasets.MNIST
trainset = dataset(root='./data', train=True, download=True, transform=transforms.Compose(tfm))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
valset = dataset(root='./data', train=False, download=True, transform=transforms.Compose(tfm[3:] if cifar else tfm))
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2)

model = ODENet(3 if cifar else 1, 128 if cifar else 64, 10, dropout=0.5).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=15 if cifar else 5)
# the scheduler should track validation accuracy, hence `mode='max'`


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [3]:
########## YOUR SOLUTION HERE ##########

criterion = nn.CrossEntropyLoss()

def train(epoch):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

def validate():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in valloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test: %d %%' % (100 * correct / total))

for epoch in range(epochs):
    train(epoch)
    validate()

Epoch 1, Loss: 1.2951067095753481
Accuracy of the network on the test images: 94 %
Epoch 2, Loss: 0.18243690947098518
Accuracy of the network on the test images: 98 %
Epoch 3, Loss: 0.09252552788005645
Accuracy of the network on the test images: 98 %


##### Potential Disadvantages
- Computational cost: adding RK4 involves multiple evaluations of the neural network per integration step
- Memory Usage: for the same reason, storing intermadiate status leads to increase of memory usage
- Step-size sensitivity: the choice of the step size and integration bounds can affect the model performance

## Exercise 8: Implement the adjoint method

In this exercise, we want to replace the fixed-step Runge-Kutta solver with an arbitrary solver from the `scipy` package. 
To this end, write a python class that derives from `torch.autograd.Function` and implements the adjoint method. 
You can read up [here](https://pytorch.org/docs/stable/notes/extending.html) on how to extend PyTorch. 
In the forward method, make a call to `scipy.integrate.solve_ivp` to compute the model output solving the forward ODE. 
In the backward method, make another call to `scipy.integrate.solve_ivp` to compute the gradients by the adjoint method solving the backward ODEs. 
Then modify the `ODEBlock` class from the previous exercise replacing the Runge-Kutta solver with a call to your `torch.autograd.Function`.  Finally, retrain the model using the adjoint method and visualize and discuss your results comparing them to those of the previous exercise. 
What are possible up-/downsides of the adjoint method?

Note that $\dot b(t)$ depends on both $a(t)$ and $h(t)$. 
Although we have computed a trajectory of $h(t)$ in the forward pass, we cannot reuse it in the backward pass since the solver might choose different steps in the backward ODE.
Therefore, we have to simultaneously solve 3 ODEs in the backward pass
 - the trajectory of the hiddens $\dot h = f(h, \theta)$
 - the adjoint $\dot a = -a \frac{\partial f(h, \theta)}{\partial h}$
 - the gradients $\dot b = -a \frac{\partial f(h, \theta)}{\partial \theta}$. 

In [7]:
import torch.autograd as ag
from torch.autograd.functional import vjp
from scipy.integrate import solve_ivp

########## YOUR SOLUTION HERE ##########

class ODEFuncAutograd(ag.Function):
    @staticmethod
    def forward(ctx, t0, t1, h0, func, *theta):
        # Flatten h0 for solving
        original_shape = h0.shape
        h0_flat = h0.detach().numpy().flatten()

        def rhs(t, h):
            h_torch = torch.tensor(h, dtype=torch.float32).view(original_shape)
            h_torch.requires_grad = True
            dh = func(t, h_torch, *theta).detach().numpy().flatten()
            return dh
        
        sol = solve_ivp(rhs, [t0, t1], h0_flat, method='RK45')
        h_end_flat = torch.tensor(sol.y[:, -1], dtype=torch.float32)
        h_end = h_end_flat.view(original_shape)

        ctx.save_for_backward(h0, h_end, *theta)
        ctx.func = func
        return h_end

    @staticmethod
    def backward(ctx, grad_output):
        h0, h_end, *theta = ctx.saved_tensors
        func = ctx.func
        original_shape = h0.shape
        grad_output_flat = grad_output.flatten()

        def rhs_augmented(t, aug):
            aug_torch = torch.tensor(aug, dtype=torch.float32)
            h, adj_h = aug_torch[:h0.numel()], aug_torch[h0.numel():2*h0.numel()]
            h = h.view(original_shape)
            adj_h = adj_h.view(original_shape)
            h.requires_grad = True
            dh = func(t, h, *theta)
            vhp = torch.autograd.grad(dh, (h, *theta), -adj_h)
            return np.concatenate([dh.detach().numpy().flatten(), vhp[0].detach().numpy().flatten()] + [v.detach().numpy().flatten() for v in vhp[1:]])

        aug0 = np.concatenate([h_end.detach().numpy().flatten(), grad_output_flat.numpy(), np.zeros_like(h0.detach().numpy()).flatten()])
        sol = solve_ivp(rhs_augmented, [t1, t0], aug0, method='RK45')
        
        grad_h0_flat = torch.tensor(sol.y[:h0.numel(), -1], dtype=torch.float32)
        grad_theta = tuple(torch.tensor(sol.y[2*h0.numel()+i*h0.numel():2*h0.numel()+(i+1)*h0.numel(), -1], dtype=torch.float32) for i in range(len(theta)))
        
        grad_h0 = grad_h0_flat.view(original_shape)
        return (None, None, grad_h0, *grad_theta)



# copy paste from previous with different forward
class ODEBlock_2(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.func = nn.Sequential(
            nn.GroupNorm(32, dim),
            nn.ReLU(),
            Cat(),
            nn.Conv2d(dim+1, dim, 3, padding=1),
            nn.GroupNorm(32, dim),
            nn.ReLU(),
            Cat(),
            nn.Conv2d(dim+1, dim, 3, padding=1),
            nn.GroupNorm(32, dim)
        )

    def forward(self, x, t0=0, t1=1):
        h0 = x
        result = self.func(h0) 
        return result

    

# copy paste from previous
class ODENet_2(nn.Module):
    def __init__(self, in_dim, ode_dim, out_dim, dropout=0.):
        super().__init__()
        self.module = nn.Sequential(
            nn.Conv2d(in_dim, ode_dim, 4, stride=2, padding=1),
            ODEBlock_2(ode_dim),
            nn.GroupNorm(32, ode_dim),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Dropout(dropout),
            nn.Flatten(),
            nn.Linear(ode_dim, out_dim)
        )

    def forward(self, x):
        return self.module(x)
    


# copy paste from before
model_2 = ODENet_2(3 if cifar else 1, 128 if cifar else 64, 10, dropout=0.5).to(device)


criterion = nn.CrossEntropyLoss()

def train_2(epoch):
    model_2.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model_2(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

def validate_2():
    model_2.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in valloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model_2(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test: %d %%' % (100 * correct / total))

for epoch in range(epochs):
    train_2(epoch)
    validate_2()


Epoch 1, Loss: 2.331999325803094
Accuracy on test: 9 %
Epoch 2, Loss: 2.332078274887508
Accuracy on test: 9 %
Epoch 3, Loss: 2.3328201897871264
Accuracy on test: 9 %


##### Upsides
- Memory Efficiency
##### Downsides
- Complexity