In [2]:
import torch
from torch import nn
from torch import Tensor
from torch.nn  import functional as F 
from torch.autograd import Variable
import torchvision

from tensorboardX import SummaryWriter

import numpy as np
import math

import time
import zipfile
import os
from os.path import join, exists
from os import makedirs
import re
import time
import logging
from tqdm import tnrange, tqdm_notebook

###### Todo add formulas
Solver of the intital value problem of the differential equation

In [3]:
#TODO use more sophisticated method
def ode_solve(z0, t0, t1, f):
    '''
    Euler's ODE solver
    '''
    h_max = 0.05#fix step size
    n_steps = math.ceil((abs(t1-t0)/h_max).max().item()) #calculate number of steps
    
    h = (t1-t0)/n_steps
    t = t0
    z = z0
    
    for i_step in range(n_steps):
        z = z+h*f(z, t)
        t = t+h
    return z

#### adfdz
![adfdz](./images/dl_dh.png)
#### adfdt
![adfdt](./images/dl_dt.png)
#### adfdp
![adfdp](./images/dl_dp.png)

In [4]:
class ODEF(nn.Module):
    def forward_with_grad(self, z, t, grad_outputs):
        batch_size = z.shape[0]
        
        out= self.forward(z, t)
        
        a = grad_outputs
        
        '''
        autograd - automatic differentiation for all operations on Tensors. Define-by-run=>
        backprop depends on how the code is run=>every single iteration can be different
        
        requires_grad = True => track all the operations on it => backward() has 
        all the gradients computed automatically
        '''
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            outputs=(out,), inputs=(z, t) + tuple(self.parameters()), grad_outputs=(a),
            allow_unused=True, retain_graph=True
        )
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
            adfdp = adfdp.expand(batch_size, -1) / batch_size
        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        return out, adfdz, adfdt, adfdp
    def flatten_parameters(self):
        p_shapes = []
        flat_parameters = []
        for p in self.parameters():
            p_shapes.append(p.size())
            flat_parameters.append(p.flatten())
        return torch.cat(flat_parameters)

In [5]:
import numpy as np
import torch

![loss](./images/loss.png)
This explains the part of the backward function, when computing direct gradient:
    ```
    for i_t in range(time_len-1, 0, -1):
                z_i = z[i_t]
                t_i = t[i_t]
                f_i = func(z_i, t_i).view(bs, n_dim)
                
                # Compute direct gradients
                dLdz_i = dLdz[i_t]
    ```

In [6]:
v = torch.Tensor(np.array([[[1, 3, 4], [3, 5, 6]]]))
t, b, *sh = v.size()
print(t, b, *sh)

1 2 3


In [8]:
class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, flat_parameters, func):
        assert isinstance(func, ODEF)
        bs, *z_shape = z0.size()
        time_len = t.size(0)

        with torch.no_grad():
            z = torch.zeros(time_len, bs, *z_shape).to(z0)
            z[0] = z0
            for i_t in range(time_len - 1):
                z0 = ode_solve(z0, t[i_t], t[i_t+1], func)
                z[i_t+1] = z0

        ctx.func = func
        ctx.save_for_backward(t, z.clone(), flat_parameters)
        return z

    @staticmethod
    def backward(ctx, dLdz):
        """
        dLdz shape: time_len, batch_size, *z_shape
        """
        func = ctx.func
        t, z, flat_parameters = ctx.saved_tensors
        time_len, bs, *z_shape = z.size()
        n_dim = np.prod(z_shape)
        n_params = flat_parameters.size(0)

        # Dynamics of augmented system to be calculated backwards in time
        def augmented_dynamics(aug_z_i, t_i):
            """
            tensors here are temporal slices
            t_i - is tensor with size: bs, 1
            aug_z_i - is tensor with size: bs, n_dim*2 + n_params + 1
            """
            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  # ignore parameters and time

            # Unflatten z and a
            z_i = z_i.view(bs, *z_shape)
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                t_i = t_i.detach().requires_grad_(True)
                z_i = z_i.detach().requires_grad_(True)
                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, t_i, grad_outputs=a)  # bs, *z_shape
                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)
                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)
                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)

            # Flatten f and adfdz
            func_eval = func_eval.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim) 
            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)

        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz for convenience
        with torch.no_grad():
            ## Create placeholders for output gradients
            # Prev computed backwards adjoints to be adjusted by direct gradients
            adj_z = torch.zeros(bs, n_dim).to(dLdz)
            adj_p = torch.zeros(bs, n_params).to(dLdz)
            # In contrast to z and p we need to return gradients for all times
            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)

            for i_t in range(time_len-1, 0, -1):
                z_i = z[i_t]
                t_i = t[i_t]
                f_i = func(z_i, t_i).view(bs, n_dim)

                # Compute direct gradients
                dLdz_i = dLdz[i_t]
                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

                # Adjusting adjoints with direct gradients
                adj_z += dLdz_i
                adj_t[i_t] = adj_t[i_t] - dLdt_i

                # Pack augmented variable
                aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)

                # Solve augmented system backwards
                aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)

                # Unpack solved backwards augmented system
                adj_z[:] = aug_ans[:, n_dim:2*n_dim]
                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]

                del aug_z, aug_ans

            ## Adjust 0 time adjoint with direct gradients
            # Compute direct gradients 
            dLdz_0 = dLdz[0]
            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

            # Adjust adjoints
            adj_z += dLdz_0
            adj_t[0] = adj_t[0] - dLdt_0
        return adj_z.view(bs, *z_shape), adj_t, adj_p, None

In [9]:
class NeuralODE(nn.Module):
    def __init__(self, func):
        super(NeuralODE, self).__init__()
        assert isinstance(func, ODEF)
        self.func = func

    def forward(self, z0, t=torch.Tensor([0., 1.]), return_whole_sequence=False):
        t = t.to(z0)
        z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)
        if return_whole_sequence:
            return z
        else:
            return z[-1]

### Application

#### MNIST

In [10]:
def create_logdirs(modeldir):
    cur_time_str = time.ctime().replace(' ', '_').replace(':', '-')
    tb_dir = join(modeldir, 'logs', cur_time_str)
    runs_dir = join(modeldir, 'runs', cur_time_str)
    if not exists(tb_dir):
        makedirs(tb_dir)
    if not exists(runs_dir):
        makedirs(runs_dir)
    return tb_dir, runs_dir

In [11]:
def add_time(in_tensor, t):
    bs, c, w, h = in_tensor.shape
    return torch.cat((in_tensor, t.expand(bs, 1, w, h)), dim=1)

In [12]:
class ConvODEF(ODEF):
    def __init__(self, dim):
        super(ConvODEF, self).__init__()
        self.conv1 = nn.Conv2d(dim + 1, dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm1 = nn.BatchNorm2d(dim)
        self.conv2 = nn.Conv2d(dim + 1, dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm2 = nn.BatchNorm2d(dim)

    def forward(self, x, t):
        xt = add_time(x, t)
        h = self.norm1(torch.relu(self.conv1(xt)))
        ht = add_time(h, t)
        dxdt = self.norm2(torch.relu(self.conv2(ht)))
        return dxdt

In [44]:
class ContinuousNeuralMNISTClassifier(nn.Module):
    def __init__(self, ode):
        super(ContinuousNeuralMNISTClassifier, self).__init__()
        self.downsampling = nn.Sequential(
            nn.Conv2d(1, 64, 5, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 5, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        )
        self.feature = ode
        self.norm = nn.BatchNorm2d(64)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = self.downsampling(x)
        x = self.feature(x)
        x = self.norm(x)
        x = self.avg_pool(x)
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        x = x.view(-1, shape)
        out = self.fc(x)
        return out

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [45]:
class DiscreteNeuralMNISTClassifier(nn.Module):
    def __init__(self):
        super(DiscreteNeuralMNISTClassifier, self).__init__()
        self.downsampling = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        )
        dim = 64
        self.conv1 = nn.Conv2d(dim + 1, dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm1 = nn.BatchNorm2d(dim)
        self.conv2 = nn.Conv2d(dim + 1, dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm2 = nn.BatchNorm2d(dim)
        
        self.norm = nn.BatchNorm2d(64)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = self.downsampling(x)
        x = self.norm1(self.conv1(x))
        x = self.norm2(self.conv1(x))
        x = self.avg_pool(x)
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        x = x.view(-1, shape)
        out = self.fc(x)
        return out

In [46]:
def train(epoch):
    num_items = 0
    train_losses = []

    model.train()
    criterion = nn.CrossEntropyLoss()
    print(f"Training Epoch {epoch}...")
    for batch_idx, (data, target) in tqdm_notebook(enumerate(train_loader), total=len(train_loader)):
        data = data.cuda()
        target = target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target) 
        loss.backward()
        optimizer.step()

        train_losses += [loss.item()]
        num_items += data.shape[0]
    print('Train loss: {:.5f}'.format(np.mean(train_losses)))
    return np.mean(train_losses)

In [47]:
def test():
    accuracy = 0.0
    num_items = 0

    model.eval()
    criterion = nn.CrossEntropyLoss()
    print(f"Testing...")
    with torch.no_grad():
        for batch_idx, (data, target) in tqdm_notebook(enumerate(test_loader),  total=len(test_loader)):
            data = data.cuda()
            target = target.cuda()
            output = model(data)
            accuracy += torch.sum(torch.argmax(output, dim=1) == target).item()
            num_items += data.shape[0]
    accuracy = accuracy * 100 / num_items
    print("Test Accuracy: {:.3f}%".format(accuracy))
    return accuracy

###### get the data

In [48]:
img_std = 0.3081
img_mean = 0.1307


batch_size = 32
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=True, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize((img_mean,), (img_std,))
                             ])
    ),
    batch_size=batch_size, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=False, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize((img_mean,), (img_std,))
                             ])
    ),
    batch_size=128, shuffle=True
)

In [49]:
func = ConvODEF(64)
ode = NeuralODE(func)
model = ContinuousNeuralMNISTClassifier(ode)
model = model.cuda()

In [50]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
n_epochs = 5
test()
train_losses = []
tb_dir, runs_dir = create_logdirs('./neural_ode/logs/')
writer = SummaryWriter(tb_dir)
for epoch in range(1, n_epochs + 1):
    train_loss = train(epoch)
    train_losses+=train_loss
    writer.add_scalar('Train/loss', train_loss, epoch)
    accuracy = test()
    writer.add_scalar('Test/accuracy', accuracy, epoch)

Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 9.730%
Training Epoch 1...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


Train loss: 0.17801
Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 98.330%
Training Epoch 2...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


Train loss: 0.04941
Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 98.860%
Training Epoch 3...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


Train loss: 0.03738
Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 98.840%
Training Epoch 4...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

In [41]:
model = Net()
model.cuda()

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [42]:
optimizer = torch.optim.Adam(model.parameters())

In [43]:
n_epochs = 5
test()
train_losses = []
tb_dir, runs_dir = create_logdirs('./neural_ode/logs/')
writer = SummaryWriter(tb_dir)
for epoch in range(1, n_epochs + 1):
    train_loss = train(epoch)
    train_losses+=train_loss
    writer.add_scalar('Ordinary NN Train/loss', train_loss, epoch)
    accuracy = test()
    writer.add_scalar('Ordinary NN Test/accuracy', accuracy, epoch)

Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 14.250%
Training Epoch 1...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


Train loss: 0.10759
Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 98.400%
Training Epoch 2...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


Train loss: 0.03996
Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 99.090%
Training Epoch 3...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


Train loss: 0.02748
Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 98.930%
Training Epoch 4...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


Train loss: 0.02204
Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 98.840%
Training Epoch 5...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))


Train loss: 0.01650
Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))


Test Accuracy: 98.850%
