# Neural ODEs

Name: []  
Student ID: []

In [None]:
"""
Before starting, you will need to install the following libraries (on a conda environment with python==3.10):

pip install jupyter
pip install deepchem[torch]
pip install matplotlib
pip install torchvision

"""

In [None]:
### Import Libraries

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchdiffeq import odeint
import deepchem as dc
import matplotlib.pyplot as plt
import numpy as np
import gc

## What is a Neural ODE? Part 1: The ODE

This notebook will help you to understand the above question. First, let's start with understanding what an ODE (ordinary differential equation) is. An ODE is an equation that defines a relationship between a function and its derivative. A simple example is the following:

$$
\frac{dg(x)}{dx} = f(x)
$$
 
We want to know the function $g(x)$, given some initial condition $g(0)$. Let's do a simple example. Suppose $f(x) = 2x$ and $g(0) = 0$, then we can solve this ODE fairly simply by integration.

### Q1: Solve the above equation analytically and write the equation for g(x) below:

Great job! But the above equation was very simple. Not all equations are so easy to solve. In fact, many cannot be solved analytically. However, there are still ways to approximate them numerically, and that is where an ODE solver comes into play. In this notebook, we will use the odeint function from the torchdiffeq library to solve a simple ODE. There are other ODE solvers available, but we will eventually move to neural ODEs, and since we will be using the pytorch framework, it makes sense to stick with the torchdiffeq library.

So now, let's take that same equation as before, and attempt to solve it numerically. We are not going to worry too much about the details of the odeint itself (if you'd like to learn more, see  [here](https://en.wikipedia.org/wiki/Numerical_methods_for_ordinary_differential_equations) to learn more), but instead focus on how we can use it to solve our equation. To use it, we will need to define our function $f(x)$, define our initial value/parameter vector $g(0)$, and a range of points we want to solve for. Then, we need to call the odeint function on the three parameters in the order (initial value, function, points).  

### Q2: Implement the function f(x) = 2x, and define a variable that calls the odeint function on this, the initial value g0, and the range of points x.  

In [None]:
def f(x, g):
    ### Start your code

    ### End your code
    return equation


# Define the initial value
g0 = torch.Tensor([0])

# Define the range of points to solve for
x = torch.linspace(-10, 10, 1000)

sol = #Your code here

And now, let's take a look at our g(x) on this interval! 

In [None]:
plt.plot(x, sol)
plt.show()

Well, doesn't that look nice! We form a nice parabola, which is exactly what we should have. 

## Resnet

Before we move into incorporating ODEs into neural network, let's first take a look at what inspired this. Resnet ([Residual neural networks](https://en.wikipedia.org/wiki/Residual_neural_network)) add in the concept of a "skip connection". What this means, is that the network will pass inputs forwards past another layer, and combine it afterwards. This helped with the issue of training stability, and allowed for networks of hundreds of layers to be trained. Let's load up some data, and then we'll take a look at the specifics of Resnets.

First, this is going to be a very intensive operation, so using a GPU is recommended.

In [None]:
#Set the device to train the NN on. This tries Nvidia CUDA (GPU) first, then CPU. If you have an AMD GPU, there may be a way, but I am not certain, sorry.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
print(device) #See which device we are using

Now let's load up the dataset we will train our Resnet on

In [None]:
###Load the Data

shuffle=True

bsize = 512 #Set to a smaller batch size if you are running out of memory

indices = list(range(50000))
split = int(np.floor(0.1 * 50000))

if shuffle:
    np.random.seed(2025)
    np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)


use_cuda = True

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=bsize, sampler=train_sampler, **kwargs)

valid_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=bsize, sampler=valid_sampler, **kwargs)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=bsize, shuffle=True, **kwargs)

Now that we have our dataset loaded, it's time to start building a Resnet. A Resnet is made up of Residual blocks, which have a skip connection that takes the data input to the block, runs it through some layers, then at the end, adds the input (with no modifications) to the end. Visually, it looks like the following picture (image from [Wikipedia](https://en.wikipedia.org/wiki/Residual_neural_network)):

<img src="images/ResBlock.png" style="width:300px;height:300px;">

### Q3: Write the forward pass for the basic Resnet Block

The task is to write the forward pass for the above Resnet block. In this case, we will have a 2-layer Resnet block that accepts an input $x$ ($x$ in this case being an image from our dataset). The output $F(x)$ of the layer will be of the following form:

$$

h(x) = \sigma(F(x) + x) = \sigma_2(w_2(\sigma_1(w_1x+b))+b+x)
$$

A few things to keep in mind here. First, both activation functions $\sigma$ will be ReLU units. Second, keep in mind that the second ReLU activation will come after you add back in the input using the skip connection.

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, units, skip_connection=None):
        super(BasicBlock, self).__init__()
        self.l1 = nn.Sequential(nn.LazyLinear(units), nn.ReLU())
        self.l2 = nn.Sequential(nn.LazyLinear(units))
        self.relu = nn.ReLU()

    def forward(self, x):
        ### Begin your code here

        ### End your code here
        return output



And now we put this into a Resnet network. This Resnet network will utilize four sets of layers, with each layer having 2 sets of Resnet blocks inside. In total, there are 8 Resnet blocks in this network. 

In [None]:
class resnet_model(nn.Module):
    def __init__(self, resnet_block, layers, num_outputs = 10):
        super(resnet_model, self).__init__()
        self.flatten = nn.Flatten()
        self.layer0 = nn.LazyLinear(64)
        self.layer1 = self._make_layer(resnet_block, 64, layers[0])
        self.layer2 = self._make_layer(resnet_block, 64, layers[1])
        self.layer3 = self._make_layer(resnet_block, 64, layers[2])
        self.layer4 = self._make_layer(resnet_block, 64, layers[3])
        self.classifier_layer = nn.LazyLinear(num_outputs)

    def _make_layer(self, resnet_block, units, resblock_layers):
        layers = []
        layers.append(resnet_block(units))
        for i in range(1, resblock_layers):
            layers.append(resnet_block(units))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.classifier_layer(x)

        return x


In [None]:
num_classes = 10
num_epochs = 20
batch_size = 512

model = resnet_model(BasicBlock, [2, 2, 2, 2], num_outputs=num_classes).to(device)

#Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())  

#Train the model
total_step = len(train_loader)

And now we train the network. Let's see how it does! This could take a while, so feel free to go get a cup of coffee, do some other homework, or elsewise, while this is running. 

In [None]:
total_step = len(train_loader)

for epoch in range(num_epochs):
    # print('Epoch: {}'.format(epoch))
    for i, (images, labels) in enumerate(train_loader): 
        print('Batch: {}'.format(i), end="\r") 
    
        #Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)

        #Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        #Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print ('Epoch [{}/{}], Loss: {:.4f}' 
                    .format(epoch+1, num_epochs, loss.item()))

#Validation
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in valid_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del images, labels, outputs

    print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))

Not bad! But we took 8 Resnet blocks to do this. Trying this on larger datasets, and you could quickly find yourself running out of time and memory. To solve this, let's take a look at a nerual ODE, and bridge the gap. First, recall that when making the Resent, we could write the output of a Resnet block in the form of:

$$
F(x) + x
$$

We have a hidden state in the network that is being passed along. Moreover, if you are familiar with numerical ODE solvers, this looks very much like using the Euler method to update. As such, we can think of neural ODEs as Resnets with an infinite (continuous) number of layers. Now lets go back to the very beginning of this notebook, and recall our simple example of an ODE:

$$
\frac{dg(x)}{dx} = f(x)
$$

In this case, $f(x)$ is a function. Neural networks happen to have a nice property that they can approximate any function to an arbitrary degree ([universal function approximation](https://en.wikipedia.org/wiki/Universal_approximation_theorem)). Ergo, we can replace $f(x)$ with a neural network approximation thereof. In this case, our simple example would become:

$$
\frac{dg(x)}{dx} = f_\theta(h(x), x)
$$

Where $f_\theta(h(x), x)$ is a neural network approximation of our original function $f(x)$. 

Using this idea, let's build a neural ODE network. First, we will need to define the function. The following does so, defining our parameterized neural network $f_\theta(h(x),x)$

In [None]:
class ODENN_Function(nn.Module):
    def __init__(self, units):
        super(ODENN_Function, self).__init__()
        self.model = nn.Sequential(
            nn.LazyLinear(units), 
            nn.ReLU(), 
            nn.LazyLinear(units), 
            nn.ReLU())

    def forward(self, t, x):
        return self.model(x)

Now that we have that, we need to create the actual ODE layer, which is similar to a Resnet block.

### Q4: Create the forward pass for the ODE layer

The forward pass is just the output of the ODE numerical solver. The forward pass will take in input $x$. This input, along with and the "range of points to solve for" (integration_time) and the function (f_x) will be sent into an ODE solver. The output of this solver will be the output of the ODE layer. 

In [None]:
class ODEBlock(nn.Module):
    def __init__(self, f_x):
        super(ODEBlock, self).__init__()
        self.f_x = f_x
        self.integration_time = torch.Tensor([0,1]).float()

    def forward(self, x):
        ### Begin your code here

        ### End your code here
        return out[1]

### Q5: Make the forward pass for the neural ODE

The neural ODE is structured much like the Resnet, but because of the continuous layer structure, you don't need to manually define each layer as you do in Resnet. Using the Resnet as a template, create the forward pass for the neural ODE. Rememebr that we don't have a discrete number of Resnet blocks, but rather a continuous ODE block. 

In [None]:
class ODENet(nn.Module):
    def __init__(self, units, num_outputs):
        super(ODENet, self).__init__()
        fx = ODENN_Function(units)
        self.flatten = nn.Flatten()
        self.l1 = nn.LazyLinear(units)
        self.ode_block = ODEBlock(fx)
        self.classifier_layer = nn.LazyLinear(num_outputs)


    def forward(self, x):
        ### Begin your code here

        ### End your code here
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
num_classes = 10
num_epochs = 20
batch_size = 512

model = ODENet(units=64, num_outputs=num_classes).to(device)

#Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())  

#Train the model
total_step = len(train_loader)

And now we will train it:

In [None]:
total_step = len(train_loader)

for epoch in range(num_epochs):
    # print('Epoch: {}'.format(epoch))
    for i, (images, labels) in enumerate(train_loader): 
        print('Batch: {}'.format(i), end="\r") 
    
        #Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)

        #Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        #Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print ('Epoch [{}/{}], Loss: {:.4f}' 
                    .format(epoch+1, num_epochs, loss.item()))

#Validation
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in valid_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del images, labels, outputs

    print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))

Quite similar to the Resnet! This example was fairly quick, but depending on your dataset, this could take much longer, or shorter, than a given Resnet model to train.  

Now let's look at one last thing. Because neural ODEs have a layer based upon an ODE solver, they are particularly good where there are dynamic equations (regular ODEs and the like) involved in the data. This is especially prevalent in physical processes, such as fluid dynamics, glacial accumulation, and other phenomena. Now we will take a quick look at an example on real world physical data. 

## Bonus Section: Physical Data

We will use a chemistry dataset, the Delaney solubility dataset, provided by the [DeepChem](https://github.com/deepchem/deepchem/tree/master) library. The task will be to predict the solubility of a molecule based upon its extended-connectivity fingerprint (ECFP). 

An ECFP is a topological fingerprint for molecular characterization, meaning it contains information about the structure of various molecules [[1]](https://pubs.acs.org/doi/10.1021/ci100050t). The idea behind this task is that we can make use of this information in order to, based upon the features that can be extracted from the ECFP (features such as molecule length, etc), we may find a dependence between these structural features and the solubility.

But why use a neural ODE? Well, while we can't say for certain it is the case here, many phenomena in chemistry are governed by ODEs and PDEs. Some examples of such behaviour in solubility and related is the dependence of solubility on pressure, which is defined by: $\frac{\partial \ln N_i}{\partial P}_T = -\frac{V_{i, q} - V_{i, cr}}{RT}$, or the rate of dissolution, which is defined by: $\frac{dm}{dt} = A \frac{D}{d}(C_s - C_n)$ (see [here](en.wikipedia.org/wiki/Solubility) for more information on solubility). Owing to this, a good hunch is that the relationship between the structure and the solubility may also have an ODE relation, hence, a neural ODE may be a good method to use. 

Now that we've got a little of the theory out of the way, let's move into the practical. First, let's load the dataset, and then define the loss function to use. DeepChem recommends using an $L_2$ Loss, and using the Pearson R2 score as a metric:

In [None]:
tasks, datasets, transformers = dc.molnet.load_delaney(featurizer='ECFP', splitter='random')
train_dataset, valid_dataset, test_dataset = datasets
metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)

Now that we have the dataset, we need to train our neural ODE to predict the solubility given an ECFP. DeepChem comes with a very nice wrapper for training models, and so all we need to do is create a model and load it into the wrapper, and we can go and train.

### Q6 (BONUS): Create a Neural ODE model with 512 neurons per layer, and a single output neuron

In [None]:

ODE_model = ### Your code here

deepchem_model = dc.models.TorchModel(ODE_model, dc.models.losses.L2Loss())

deepchem_model.fit(train_dataset, nb_epoch=50)
print('training set score:', deepchem_model.evaluate(train_dataset, [metric]))
print('test set score:', deepchem_model.evaluate(test_dataset, [metric]))

Hooray! It did decently well. While we can always do better, I hope the simplicity with which this could be done will inspire you to use neural ODEs in more problems!

## End Checklist

There are a total of 5 mandatory questions in HW assignment, with 1 bonus question.

This list is provided to ensure you don't forget any of them:

Q1          []  

Q2          []  

Q3          []  

Q4          []  

Q5          []  

Q6 (BONUS)  []  