# Setup

In [1]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 12873304705088487595, name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 14638920512
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 8589819103854549468
 physical_device_desc: "device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5"]

# misc

In [2]:
import torch
import torch.nn as nn
!pip install torchdiffeq
from torchdiffeq import odeint_adjoint as odeint
import numpy as np
!pip install einops
from einops import rearrange, repeat
import time
import torch.optim as optim
import glob
import imageio
import numpy as np
import torch
from math import pi
from random import random
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Normal
from torchvision import datasets, transforms
tol = 1e-3
gpu = 0
niters = 10
lr = 1e-3
# Format [time, batch, diff, vector]

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class ArgumentParser:
    def add_argument(self, str, type, default):
        setattr(self, str[2:], default)

    def parse_args(self):
        return self

def str_rec (names, data, unit=None, sep=', ', presets='{}'):
    if unit is None:
        unit = [''] * len(names)
    data = [str(i)[:6] for i in data]
    out_str = "{}: {{}} {{{{}}}}" + sep
    out_str *= len(names)
    out_str = out_str.format(*names)
    out_str = out_str.format(*data)
    out_str = out_str.format(*unit)
    out_str = presets.format(out_str)
    return out_str

rec_names = ["iter", "loss", "nfe", "time/iter", "time"]
rec_unit = ["","","","s","min"]

Collecting torchdiffeq
  Downloading https://files.pythonhosted.org/packages/90/e4/5e483dc28a0a520e403f4dade7ad120d739471693afe83eaf36c9cc09cb0/torchdiffeq-0.2.0-py3-none-any.whl
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.0
Collecting einops
  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl
Installing collected packages: einops
Successfully installed einops-0.3.0


# base

In [3]:


class Example_df(nn.Module):
    def __init__(self, input_dim, output_dim, nhid=20):
        super(Example_df, self).__init__()
        self.dense1 = nn.Linear(input_dim, nhid)
        self.dense2 = nn.Linear(nhid, output_dim)
        self.lrelu = nn.LeakyReLU(0.2)

    def forward(self, t, x):
        x = self.dense1(x)
        x = self.lrelu(x)
        x = self.dense2(x)
        return x


class NODEintegrate(nn.Module):

    def __init__(self, df=None, x0=None):
        """
        Create an OdeRnnBase model
            x' = df(x)
            x(t0) = x0
        :param df: a function that computes derivative. input & output shape [batch, channel, feature]
        :param x0: initial condition.
            - if x0 is set to be nn.parameter then it can be trained.
            - if x0 is set to be nn.Module then it can be computed through some network.
        """
        super().__init__()
        self.df = df
        self.x0 = x0

    def forward(self, initial_condition, evaluation_times, x0stats=None):
        """
        Evaluate odefunc at given evaluation time
        :param initial_condition: shape [batch, channel, feature]. Set to None while training.
        :param evaluation_times: time stamps where method evaluates, shape [time]
        :param x0stats: statistics to compute x0 when self.x0 is a nn.Module, shape required by self.x0
        :return: prediction by ode at evaluation_times, shape [time, batch, channel, feature]
        """
        if initial_condition is None:
            initial_condition = self.x0
        if x0stats is not None:
            initial_condition = self.x0(x0stats)
        out = odeint(self.df, initial_condition, evaluation_times, rtol=tol, atol=tol)
        return out

    @property
    def nfe(self):
        return self.df.nfe


class NODElayer(nn.Module):
    def __init__(self, df, evaluation_times=(0.0, 1.0)):
        super(NODElayer, self).__init__()
        self.df = df
        self.evaluation_times = torch.as_tensor(evaluation_times)

    def forward(self, x0):
        out = odeint(self.df, x0, self.evaluation_times, rtol=tol, atol=tol)
        return out[1]

    def to(self, device, *args, **kwargs):
        super().to(device, *args, **kwargs)
        self.evaluation_times.to(device)


class NODE(nn.Module):
    def __init__(self, df=None, **kwargs):
        super(NODE, self).__init__()
        self.__dict__.update(kwargs)
        self.df = df
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        return self.df(t, x)


class SONODE(NODE):
    def forward(self, t, x):
        """
        Compute [y y']' = [y' y''] = [y' df(t, y, y')]
        :param t: time, shape [1]
        :param x: [y y'], shape [batch, 2, vec]
        :return: [y y']', shape [batch, 2, vec]
        """
        self.nfe += 1
        v = x[:, 1:, :]
        out = self.df(t, x)
        return torch.cat((v, out), dim=1)


class HeavyBallODE(NODE):
    def __init__(self, df, gamma=None):
        super().__init__(df)
        if gamma is None:
            self.gamma = nn.Parameter(torch.Tensor([-4.0]))
        else:
            self.gamma = gamma

    def forward(self, t, x):
        """
        Compute [theta' m' v'] with heavy ball parametrization in
        $$ theta' = -m / sqrt(v + eps) $$
        $$ m' = h f'(theta) - rm $$
        $$ v' = p (f'(theta))^2 - qv $$
        https://www.jmlr.org/papers/volume21/18-808/18-808.pdf
        because v is constant, we change c -> 1/sqrt(v)
        c has to be positive
        :param t: time, shape [1]
        :param x: [theta m v], shape [batch, 3, dim]
        :return: [theta' m' v'], shape [batch, 3, dim]
        """
        self.nfe += 1
        theta, m = torch.split(x, 1, dim=1)
        dtheta = - m
        dm = self.df(t, theta) - torch.sigmoid(self.gamma) * m
        return torch.cat((dtheta, dm), dim=1)

# train

In [4]:
def train(model, optimizer, trdat, tsdat):
    epoch = 0
    itrcnt = 0
    loss_func = nn.CrossEntropyLoss()
    itr_arr = np.zeros(args.niters)
    loss_arr = np.zeros(args.niters)
    nfe_arr = np.zeros(args.niters)
    time_arr = np.zeros(args.niters)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)

    # training
    start_time = time.time()
    while epoch < args.niters:
        epoch += 1
        iter_start_time = time.time()
        for x, y in trdat:
            itrcnt += 1
            model[1].df.nfe = 0
            optimizer.zero_grad()
            # forward in time and solve ode
            pred_y = model(x.to(device=args.gpu))
            # compute loss
            loss = loss_func(pred_y, y.to(device=args.gpu))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            # make arrays
            itr_arr[epoch - 1] = epoch
            loss_arr[epoch - 1] += loss.detach()
            nfe_arr[epoch - 1] += model[1].df.nfe
        iter_end_time = time.time()
        time_arr[epoch - 1] = iter_end_time - iter_start_time
        loss_arr[epoch - 1] *= 1.0 * epoch / itrcnt
        nfe_arr[epoch - 1] *= 1.0 * epoch / itrcnt
        printouts = [epoch, loss_arr[epoch-1], nfe_arr[epoch-1], time_arr[epoch-1], (time.time()-start_time)/60]
        print(str_rec(rec_names, printouts, rec_unit, presets="Train|| {}"))
        if epoch % 2 == 0:
            model[1].df.nfe = 0
            end_time = time.time()
            loss = 0
            acc = 0
            dsize = 0
            bcnt = 0
            for x, y in tsdat:
                # forward in time and solve ode
                dsize += y.shape[0]
                y = y.to(device=args.gpu)
                pred_y = model(x.to(device=args.gpu))
                pred_l = torch.argmax(pred_y, dim=1)
                acc += torch.sum((pred_l == y).float())
                bcnt += 1
                # compute loss
                loss += loss_func(pred_y, y).detach() * y.shape[0]

            loss /= dsize
            acc /= dsize
            printouts = [epoch, loss.detach().cpu().numpy(), acc.detach().cpu().numpy(), str(model[1].df.nfe / bcnt), str(count_parameters(model))]
            names = ["iter", "loss", "acc", "nfe", "param cnt"]
            print(str_rec(names, printouts, presets="Test || {}"))

# anode_data

In [5]:
def cifar(batch_size=64, size=32, path_to_data='../cifar_data'):
    """CIFAR dataloader with (3, 32, 32) images.
    Parameters
    ----------
    batch_size : int
    size : int
        Size (height and width) of each image. Default is 28 for no resizing.
    path_to_data : string
        Path to CIFAR data files.
    """
    all_transforms = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor()
    ])

    train_data = datasets.CIFAR10(path_to_data, train=True, download=True,
                                transform=all_transforms)
    test_data = datasets.CIFAR10(path_to_data, train=False,
                               transform=all_transforms)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader

# anode_cifar

In [6]:
parser = ArgumentParser()
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=False)
parser.add_argument('--visualize', type=eval, default=True)
parser.add_argument('--niters', type=int, default=40)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()

# shape: [time, batch, derivatives, channel, x, y]
trdat, tsdat = cifar(batch_size=256)


class anode_initial_velocity(nn.Module):

    def __init__(self, in_channels, aug):
        super(anode_initial_velocity, self).__init__()
        self.aug = aug
        self.in_channels = in_channels

    def forward(self, x0):
        x0 = rearrange(x0.float(), 'b c x y -> b 1 c x y')
        outshape = list(x0.shape)
        outshape[2] = self.aug
        out = torch.zeros(outshape).to(args.gpu)
        out[:, :, :3] += x0
        return out


class DF(nn.Module):

    def __init__(self, in_channels, nhid, out_channels=None):
        super(DF, self).__init__()
        if out_channels is None:
            out_channels = in_channels
        self.activation = nn.ReLU(inplace=True)
        self.fc1 = nn.Conv2d(in_channels + 1, nhid, kernel_size=1, padding=0)
        self.fc2 = nn.Conv2d(nhid + 1, nhid, kernel_size=3, padding=1)
        self.fc3 = nn.Conv2d(nhid + 1, out_channels, kernel_size=1, padding=0)

    def forward(self, t, x0):
        x0 = rearrange(x0, 'b d c x y -> b (d c) x y')
        t_img = torch.ones_like(x0[:, :1, :, :]).to(device=args.gpu) * t
        out = torch.cat([x0, t_img], dim=1)
        out = self.fc1(out)
        out = self.activation(out)
        out = torch.cat([out, t_img], dim=1)
        out = self.fc2(out)
        out = self.activation(out)
        out = torch.cat([out, t_img], dim=1)
        out = self.fc3(out)
        out = rearrange(out, 'b c x y -> b 1 c x y')
        return out

class predictionlayer(nn.Module):
    def __init__(self, in_channels, truncate=False):
        super(predictionlayer, self).__init__()
        self.dense = nn.Linear(in_channels * 32 * 32, 10)
        self.truncate = truncate

    def forward(self, x):
        if self.truncate:
            x = rearrange(x[:,0], 'b ... -> b (...)')
        else:
            x = rearrange(x, 'b ... -> b (...)')
        x = self.dense(x)
        return x

dim = 13
nhid = 64
layer = NODElayer(NODE(DF(dim, nhid)))
model = nn.Sequential(anode_initial_velocity(3, dim), layer, predictionlayer(dim)).to(device=args.gpu)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.00)
print(count_parameters(model))
train(model, optimizer, trdat, tsdat)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../cifar_data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../cifar_data/cifar-10-python.tar.gz to ../cifar_data
172452




Train|| iter: 1 , loss: 1.7850 , nfe: 40.244 , time/iter: 259.20 s, time: 4.3200 min, 
Train|| iter: 2 , loss: 1.5362 , nfe: 41.653 , time/iter: 272.70 s, time: 8.8651 min, 
Test || iter: 2 , loss: 1.5090 , acc: 0.4669 , nfe: 20.0 , param cnt: 172452 , 
Train|| iter: 3 , loss: 1.4264 , nfe: 45.999 , time/iter: 316.10 s, time: 14.326 min, 
Train|| iter: 4 , loss: 1.3294 , nfe: 46.642 , time/iter: 322.72 s, time: 19.705 min, 
Test || iter: 4 , loss: 1.3508 , acc: 0.5173 , nfe: 20.0 , param cnt: 172452 , 
Train|| iter: 5 , loss: 1.2471 , nfe: 47.316 , time/iter: 329.67 s, time: 25.393 min, 
Train|| iter: 6 , loss: 1.1437 , nfe: 47.102 , time/iter: 327.98 s, time: 30.860 min, 
Test || iter: 6 , loss: 1.2091 , acc: 0.5712 , nfe: 20.0 , param cnt: 172452 , 
Train|| iter: 7 , loss: 1.0713 , nfe: 47.499 , time/iter: 331.64 s, time: 36.580 min, 
Train|| iter: 8 , loss: 1.0072 , nfe: 50.346 , time/iter: 360.10 s, time: 42.581 min, 
Test || iter: 8 , loss: 1.1581 , acc: 0.5931 , nfe: 20.0 , param