<a href="https://colab.research.google.com/github/GHes31415/Generative-Modeling/blob/main/NeuralODESMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchdiffeq

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdiffeq
  Downloading torchdiffeq-0.2.3-py3-none-any.whl (31 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.3


In [2]:
import torch 
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchdiffeq import odeint_adjoint as odeint

import matplotlib.pyplot as plt
import numpy as np
import time




In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(0.5,0.5)])




In [4]:
trainset = torchvision.datasets.MNIST(root = '../data',train = True,
                                      download = True, transform = transform)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 237292098.26it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 42459058.47it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 48215196.79it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 22307410.74it/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



In [5]:
trainloader = torch.utils.data.DataLoader(trainset,batch_size = 16,
                                          shuffle = True, num_workers=4)
testset = torchvision.datasets.MNIST(root = '../data', train = False,
                                     download = True, transform = transform)
testloader = torch.utils.data.DataLoader(testset,batch_size = 16,
                                         shuffle = True, num_workers = 4)



In [6]:
class MyNet(nn.Module):
	def __init__(self, path):
		super(MyNet, self).__init__()
		self.path = path

	def num_params(self):
		return sum(p.numel() for p in self.parameters() if p.requires_grad)

	def load(self):
		self.load_state_dict(torch.load('./' + self.path + '.pth'))

We consider an autonomous ODE with $$f(x,a) = gn\circ conv(a)\circ relu\circ gn$$

with gn the GroupNorm function, a are the set of parameters of the convolution.

In [7]:
class ODEFunc(nn.Module):
	def __init__(self, dim):
		super(ODEFunc, self).__init__()
		self.gn = nn.GroupNorm(min(32, dim), dim)
		self.conv = nn.Conv2d(dim, dim, 3, padding = 1)
		self.nfe = 0 # time counter

	def forward(self, t, x):
		self.nfe += 1
		x = self.gn(x)
		x = F.relu(x)
		x = self.conv(x)
		x = self.gn(x)
		return x

Integrator of neural ODE

In [8]:
class ODEBlock(nn.Module):
	def __init__(self, odefunc):
		super(ODEBlock, self).__init__()
		self.odefunc = odefunc
		self.integration_time = torch.tensor([0, 1]).float()

	def forward(self, x):
		out = odeint(self.odefunc, x, self.integration_time, rtol=1e-1, atol=1e-1) # high tolerances for speed

		# first dimension is x(0) and second is x(1),
		# so we just want the second
		return out[1]


Now we create ODENet with this block. There are thre parts to this ODENet.


1.   We take our 28-by-28 image and apply a  3-by-3 convolution without padding to it with 6 output channels. Then we apply GropuNorm and ReLu
2.   WE apply the ODEBlock
3.   WE apply a 2-by-2 average pool and one fully connected linear layer. 



In [11]:
class ODENet(MyNet):
  def __init__(self):
    super(ODENet, self).__init__('mnist_odenet')
    self.conv1 = nn.Conv2d(1, 2, 3) #( in channels, out channels ,filter size)
    self.gn = nn.GroupNorm(2, 2)
    self.odefunc = ODEFunc(2)
    self.odeblock = ODEBlock(self.odefunc)
    self.pool = nn.AvgPool2d(2)
    self.fc = nn.Linear(2* 13 * 13, 10)
  def forward(self,x):
    #26x26
    x = self.conv1(x)
    x = F.relu(self.gn(x))

    # stays 26x26
    x = self.odeblock(x)

    # 13x13
    x = self.pool(x)

    #fully connected layer
    x = x.view(-1,2*13*13)
    x = self.fc(x)

    return x

  # def forward(self, x):
  #   # 26 x 26
  #   x = self.conv1(x)
  #   x = F.relu(self.gn(x))

  #   # stays 26 x 26
  #   x = self.odeblock(x)

  #   # 13 x 13
  #   x = self.pool(x)

  #   # fully connected layer
  #   x = x.view(-1, 6*13*13)
  #   x = self.fc(x)

  #   return x

Now training and testing method

In [23]:
def train(net):
  n = 60000/(5*16) # batch size 16
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(net.parameters(),lr = 0.001, momentum = 0.9)

  for epoch in range(10):
    running_loss = 0.0
    for i,data in enumerate(trainloader,0):
      # get the inputs; data is a list of [inputs, labels]
      inputs, labels = data

      # zero the parameter gradients
      optimizer.zero_grad()

      #foward + backward+ optimize

      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      #print statistics

      running_loss += loss.item()
      if i%n == n-1:
        print('[%d,%5d] loss: %.3f'% (epoch +1, i+1, running_loss/n))
        running_loss = 0.0
        torch.save(net.state_dict(),'./'+ net.path +'.pth')
  print('Finished Training')
  torch.save(net.state_dict(),'./'+ net.path +'.pth')

def test(net):
  initial_time = time.time()
  correct = 0 
  total = 0
  with torch.no_grad():
    for data in testloader:
      images,labels = data
      batch_size = images.shape[0]
      outputs = net(images)
      _, predicted = torch.max(outputs.data,1)
      total += labels.size(0)
      correct += (predicted==labels).sum().item()
  final_time = time.time()
  print('Accuracy of the ' + net.path + ' network on the test set: %.2f %%' % (100 * correct / total))
  print('Time: %.2f seconds' % (final_time - initial_time))
  return(100 * correct / total)


In [19]:
odenet = ODENet()
train(odenet)

[1,  750] loss: 0.549
[1, 1500] loss: 0.370
[1, 2250] loss: 0.346
[1, 3000] loss: 0.363
[1, 3750] loss: 0.300
[2,  750] loss: 0.200
[2, 1500] loss: 0.168
[2, 2250] loss: 0.140
[2, 3000] loss: 0.141
[2, 3750] loss: 0.135
[3,  750] loss: 0.121
[3, 1500] loss: 0.115
[3, 2250] loss: 0.123
[3, 3000] loss: 0.117
[3, 3750] loss: 0.121
[4,  750] loss: 0.115
[4, 1500] loss: 0.111
[4, 2250] loss: 0.109
[4, 3000] loss: 0.109
[4, 3750] loss: 0.117
[5,  750] loss: 0.103
[5, 1500] loss: 0.107
[5, 2250] loss: 0.108
[5, 3000] loss: 0.105
[5, 3750] loss: 0.114
[6,  750] loss: 0.094
[6, 1500] loss: 0.106
[6, 2250] loss: 0.108
[6, 3000] loss: 0.111
[6, 3750] loss: 0.103
[7,  750] loss: 0.102
[7, 1500] loss: 0.101
[7, 2250] loss: 0.096
[7, 3000] loss: 0.098
[7, 3750] loss: 0.101
[8,  750] loss: 0.097
[8, 1500] loss: 0.110
[8, 2250] loss: 0.088
[8, 3000] loss: 0.102
[8, 3750] loss: 0.098
[9,  750] loss: 0.093
[9, 1500] loss: 0.103
[9, 2250] loss: 0.092
[9, 3000] loss: 0.102
[9, 3750] loss: 0.096
[10,  750]

In [24]:
test(odenet)

Accuracy of the mnist_odenet network on the test set: 96.91 %
Time: 16.66 seconds


96.91