# Imports

In [1]:
import torch
import torch.optim as optim
from torch import nn
from torchdiffeq import odeint_adjoint
import numpy as np
import loader
import training
import metrics
import autotune
import config

# GPU

In [2]:
device = torch.device("cuda")

# Data Loading

In [3]:
mnist_data_path = './train-images-idx3-ubyte'
mnist_label_path = './train-labels-idx1-ubyte'

In [4]:
#5-fold generation
mnist_data = loader.MNIST(mnist_data_path, mnist_label_path, 5)
mnist_splits = mnist_data.splits

In [5]:
print("MNIST STATS")
print("Number of splits:", len(mnist_splits))
print("Number of segments per split (train, val, test):", len(mnist_splits[0]))
print("Info per segment (data, labels):", len(mnist_splits[0][0]))
print("Size of segement (num examples):", len(mnist_splits[0][0][0]))

MNIST STATS
Number of splits: 5
Number of segments per split (train, val, test): 3
Info per segment (data, labels): 2
Size of segement (num examples): 48000


# Model & Optimizer

In [6]:
MAX_NUM_STEPS = 1000

class NeuralODE(nn.Module):
    # Note certain parameters are constant throughout paper experiments and so are used directly, namely:
    # time_dependent = True
    # non_linearity = 'relu'
    # adjoint = True
    def __init__(self, in_channels, height, width, num_filters, 
                 out_dim=10, augmented_dim=0, tolerance=1e-3):
        super(NeuralODE, self).__init__()

        flattened_dim = (in_channels + augmented_dim) * height * width

        function = ODEConv(in_channels, num_filters, augmented_dim)

        self.block_ODE = ODEBlock(function, tolerance)
        self.block_linear = nn.Linear(augmented_dim + 1, out_dim)
        self.gap = nn.AdaptiveAvgPool2d(1)
        
    def forward(self, x):
        x = self.block_ODE(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.block_linear(x)

        return x


class ODEBlock(nn.Module):
    # is_conv = true
    # adjoint = False
    def __init__(self, function, tolerance):
        super(ODEBlock, self).__init__()
        self.function = function
        self.tolerance = tolerance

    # eval_times=None (since not plotting convolution trajectory)
    def forward(self, x):
        self.function.nfe = 0

        #Only need final result of convolution for plots
        integration_time = torch.tensor([0, 1]).float().type_as(x)

        #if ANODE
        if self.function.augmented_dim > 0:
            batch_size, channels, height, width = x.shape
            aug = torch.zeros(batch_size, self.function.augmented_dim,
                              height, width).to("cuda")
            x_aug = torch.cat([x, aug], 1)
        else:
            x_aug = x

        x = odeint_adjoint(self.function, x_aug, integration_time,
                           rtol=self.tolerance, atol=self.tolerance, method='dopri5',
                           options={'max_num_steps': MAX_NUM_STEPS})
        return x[1]

class ODEConv(nn.Module):  
    # time_dependent = True
    # non_linearity = 'relu'
    def __init__(self, in_channels, num_filters, augmented_dim): 
        super(ODEConv, self).__init__()
        self.nfe = 0  # Number of function evaluations
        self.augmented_dim = augmented_dim
        
        channels = in_channels + augmented_dim
       
        self.block_conv1 = Conv2dTime(channels, num_filters,
                                kernel_size=3, stride=1, padding=1)
        self.block_conv2 = Conv2dTime(num_filters, num_filters,
                                kernel_size=3, stride=1, padding=1)
        self.block_conv3 = Conv2dTime(num_filters, channels,
                                kernel_size=3, stride=1, padding=1)

        self.block_non_linear = nn.ReLU(inplace=True)

    def forward(self, t, x):
        self.nfe += 1

        x = self.block_conv1(t, x)
        x = self.block_non_linear(x)
        x = self.block_conv2(t, x)
        x = self.block_non_linear(x)
        x = self.block_conv3(t, x)
        return x

# (Dupont et al. [2019])
class Conv2dTime(nn.Conv2d):
    """
    Implements time dependent 2d convolutions, by appending the time variable as
    an extra channel.
    """
    def __init__(self, in_channels, *args, **kwargs):
        super(Conv2dTime, self).__init__(in_channels + 1, *args, **kwargs)

    def forward(self, t, x):
        # Shape (batch_size, 1, height, width)
        t_img = torch.ones_like(x[:, :1, :, :]) * t
        # Shape (batch_size, channels + 1, height, width)
        t_and_x = torch.cat([t_img, x], 1)
        return super(Conv2dTime, self).forward(t_and_x)


In [7]:
model = NeuralODE
optimizer = optim.Adam

# Training and Evaluation

## MNIST

### ANODE

In [8]:
model_params = config.config_mnist_modified['model']
lr, epochs, batch, workers = config.config_mnist_modified['train']

In [9]:
anode_mnist_trainer = training.Trainer(model, optimizer, mnist_data, device)

In [None]:
anode_mnist_trainer.train(model_params, lr, epochs, batch, workers, verbose=False)

[Fold 1] Epoch:1 Training Acc:0.2839583333333333
[Fold 1] Epoch:1 Validation Acc:0.45366666666666666
[Fold 1] Epoch:2 Training Acc:0.5578541666666667
[Fold 1] Epoch:2 Validation Acc:0.6576666666666666
[Fold 1] Epoch:3 Training Acc:0.7570208333333334
[Fold 1] Epoch:3 Validation Acc:0.8405
[Fold 1] Epoch:4 Training Acc:0.878625
[Fold 1] Epoch:4 Validation Acc:0.9275
[Fold 1] Epoch:5 Training Acc:0.9203125
[Fold 1] Epoch:5 Validation Acc:0.9268333333333333
[Fold 1] Epoch:6 Training Acc:0.9345625
[Fold 1] Epoch:6 Validation Acc:0.9376666666666666


In [None]:
anode_mnist_trainer.test(model_params, batch, workers)

# Plots

In [None]:
anode_mnist_trainer.val_metrics['legend'] = 'ANODE'
out_metrics = [anode_mnist_trainer.val_metrics]

In [None]:
plt = metrics.Plotter(out_metrics)

In [None]:
plt.plotLoss("Optimized Model Validation Loss Comparisons")

In [None]:
plt.plotAccuracy("Optimized Model Validation Accuracy Comparisons")

In [None]:
plt.plotNFE("Loss vs NFE", style='loss')

In [None]:
plt.plotNFE("Accuracy vs NFE", style='accuracy')