# Includes

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd gdrive/MyDrive/optml

In [1]:
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import random
from tqdm import tqdm_notebook as tqdm
import multiprocessing
import os.path
import csv
import copy
import joblib
from torchvision import datasets
import torchvision
import seaborn as sns; sns.set(color_codes=True)
import itertools
sns.set_style("white")
from pdb import set_trace as bp
from torch.optim.lr_scheduler import StepLR
from IPython.display import clear_output

In [2]:
USE_CUDA = torch.cuda.is_available()

def w(v):
    if USE_CUDA:
        return v.cuda()
    return v

In [3]:
from meta_module import *

In [4]:
!mkdir -p _cache
cache = joblib.Memory(location='_cache', verbose=0)

# Hamiltonian

In [5]:
! pip install git+https://github.com/rtqichen/torchdiffeq
from torchdiffeq import odeint_adjoint as odeint
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

Collecting git+https://github.com/rtqichen/torchdiffeq
  Cloning https://github.com/rtqichen/torchdiffeq to /tmp/pip-req-build-5yujgzzq
  Running command git clone -q https://github.com/rtqichen/torchdiffeq /tmp/pip-req-build-5yujgzzq


In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Hamiltonian(nn.Module):
    def __init__(self, input_dim:int):
        super().__init__()
        self.input_dim = input_dim
        #self.g = nn.Parameter(torch.rand((int(input_dim)), dtype=torch.float32, requires_grad=True))
        #self.v = nn.Parameter(torch.rand((1), dtype=torch.float32, requires_grad=True))
        self.g = nn.Sequential(nn.Linear(int(input_dim), int(input_dim)), nn.ReLU(), nn.Linear(int(input_dim), int(input_dim)))
        self.v = nn.Sequential(nn.Linear(int(input_dim), int(input_dim)), nn.ReLU(), nn.Linear(int(input_dim), int(input_dim)))
        self.D = nn.Linear(2*input_dim, 2*input_dim, bias=False)
        self.L = nn.Linear(input_dim, input_dim)
        #self.L.weight.data.fill_(0.0)
        #self.L.weight.data += torch.eye((int(self.input_dim)), dtype=torch.float32, requires_grad=True)
        #self.nfe = 0
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    def forward(self, t, x):
        with torch.enable_grad():
            #one = torch.tensor(1, dtype=torch.float32, device=self.device, requires_grad=True)
            #x = one * x
            #self.nfe += 1
            q, dev_q = torch.chunk(x, 2, dim=-1)

            g = self.g(q)
            #H = self.v(q) #+ torch.sum(torch.pow(torch.matmul(dev_q, self.L.weight.t()), 2)) / 2.0 
            #H = H.reshape(-1,)
            M = torch.matmul(self.L.weight, self.L.weight.t())# + 1e-12*torch.eye((int(self.input_dim)), dtype=torch.float32, device=self.device)

            #dH_q = torch.autograd.grad(H, q, grad_outputs=torch.ones_like(H), create_graph=True)[0]
            dH_q = self.v(q)
            D_q, D_p = torch.chunk(self.D(torch.cat((dH_q, dev_q), dim=-1)), 2, dim=-1)
            #print(D_q.shape, D_p.shape, torch.matmul(-dH_q-D_p+g, M).shape)
            #out = torch.cat((dev_q-D_q, torch.einsum('bi, bii->bi', -dH_q-D_p+g, M)), dim=-1).view_as(x)
            # print(out)
            out = torch.cat((dev_q-D_q, torch.matmul(-dH_q-D_p+g, M)), dim=-1).view_as(x)
        return out

# Funcs

In [7]:
def detach_var(v):
    var = w(Variable(v.data, requires_grad=True))
    var.retain_grad()
    return var

import functools

def rsetattr(obj, attr, val):
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)

# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427

def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))

def one_step_fit_LSTM(opt_net, meta_opt, target_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=True):
    if should_train:
        opt_net.train()
    else:
        opt_net.eval()
        unroll = 1
    
    target = target_cls(training=should_train)
    optimizee = w(target_to_opt())
    n_params = 0
    for name, p in optimizee.all_named_parameters():
        n_params += int(np.prod(p.size()))
    hidden_states = [w(Variable(torch.zeros(n_params, opt_net.hidden_sz))) for _ in range(2)]
    cell_states = [w(Variable(torch.zeros(n_params, opt_net.hidden_sz))) for _ in range(2)]
    all_losses_ever = []
    if should_train:
        meta_opt.zero_grad()
    all_losses = None
    for iteration in range(1, optim_it + 1):
        loss = optimizee(target)
                    
        if all_losses is None:
            all_losses = loss
        else:
            all_losses += loss
        
        all_losses_ever.append(loss.data.cpu().numpy())
        loss.backward(retain_graph=should_train)

        offset = 0
        result_params = {}
        hidden_states2 = [w(Variable(torch.zeros(n_params, opt_net.hidden_sz))) for _ in range(2)]
        cell_states2 = [w(Variable(torch.zeros(n_params, opt_net.hidden_sz))) for _ in range(2)]
        for name, p in optimizee.all_named_parameters():
            cur_sz = int(np.prod(p.size()))
            # We do this so the gradients are disconnected from the graph but we still get
            # gradients from the rest
            gradients = detach_var(p.grad.view(cur_sz, 1))
            updates, new_hidden, new_cell = opt_net(
                gradients,
                [h[offset:offset+cur_sz] for h in hidden_states],
                [c[offset:offset+cur_sz] for c in cell_states]
            )
            for i in range(len(new_hidden)):
                hidden_states2[i][offset:offset+cur_sz] = new_hidden[i]
                cell_states2[i][offset:offset+cur_sz] = new_cell[i]
            result_params[name] = p + updates.view(*p.size()) * out_mul
            result_params[name].retain_grad()
            
            offset += cur_sz
            
        if iteration % unroll == 0:
            if should_train:
                meta_opt.zero_grad()
                all_losses.backward()
                meta_opt.step()
                
            all_losses = None

            optimizee = w(target_to_opt())
            optimizee.load_state_dict(result_params)
            optimizee.zero_grad()
            hidden_states = [detach_var(v) for v in hidden_states2]
            cell_states = [detach_var(v) for v in cell_states2]
            
        else:
            for name, p in optimizee.all_named_parameters():
                rsetattr(optimizee, name, result_params[name])
            assert len(list(optimizee.all_named_parameters()))
            hidden_states = hidden_states2
            cell_states = cell_states2
    return all_losses_ever


@cache.cache
def fit_LSTM(target_cls, target_to_opt, preproc=False, unroll=20, optim_it=100, n_epochs=20, n_tests=100, lr=0.001, out_mul=1.0):
    opt_net = w(Optimizer(preproc=preproc))
    #opt_net = w(Optimizer_HNN(preproc=preproc))
    meta_opt = optim.Adam(opt_net.parameters(), lr=lr)
    #scheduler = StepLR(meta_opt, step_size=1, gamma=0.9)
    best_net = None
    best_loss = 100000000000000000
    
    for _ in range(n_epochs):
        for _ in range(20):
            one_step_fit_LSTM(opt_net, meta_opt, target_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=True)
        
        loss = (np.mean([
            np.mean(one_step_fit_LSTM(opt_net, meta_opt, target_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=False))
            for _ in tqdm(range(n_tests), 'tests')
        ]))
        if loss < best_loss:
            best_loss = loss
            best_net = copy.deepcopy(opt_net.state_dict())
        #scheduler.step() 
    return best_loss, best_net

In [8]:
def one_step_fit_HNN(opt_net, meta_opt, target_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=True):
    if should_train:
        opt_net.train()
    else:
        opt_net.eval()
        unroll = 1
    
    target = target_cls(training=should_train)
    optimizee = w(target_to_opt())
    n_params = 0
    for name, p in optimizee.all_named_parameters():
        n_params += int(np.prod(p.size()))
    #hidden_states = [w(Variable(torch.zeros(n_params, opt_net.hidden_sz))) for _ in range(2)]
    derivative_input = w(Variable(torch.zeros(n_params, 1)))
    all_losses_ever = []
    if should_train:
        meta_opt.zero_grad()
    all_losses = None
    for iteration in range(1, optim_it + 1):
        loss = optimizee(target)
        # print(iteration)      
        if all_losses is None:
            all_losses = loss
        else:
            all_losses += loss
        
        all_losses_ever.append(loss.data.cpu().numpy())
        loss.backward(retain_graph=should_train)

        offset = 0
        result_params = {}
        derivative_input2 = w(Variable(torch.zeros(n_params, 1)))
        for name, p in optimizee.all_named_parameters():
            cur_sz = int(np.prod(p.size()))
            # We do this so the gradients are disconnected from the graph but we still get
            # gradients from the rest
            gradients = detach_var(p.grad.view(cur_sz, 1))
            # print(p.shape, gradients.shape)
            inp = torch.cat((gradients, derivative_input[offset:offset+cur_sz]), dim=-1)
            updates = opt_net(inp)
            derivative_input2[offset:offset+cur_sz, 0] = updates.view(*p.size()).reshape(-1,) #* out_mul
            result_params[name] = p + updates.view(*p.size()) * out_mul
            result_params[name].retain_grad()
            offset += cur_sz
            
        if iteration % unroll == 0:
            if should_train:
                meta_opt.zero_grad()
                all_losses.backward()
                meta_opt.step()
                
            all_losses = None

            optimizee = w(target_to_opt())
            optimizee.load_state_dict(result_params)
            optimizee.zero_grad()
            derivative_input = detach_var(derivative_input2)
        else:
            for name, p in optimizee.all_named_parameters():
                rsetattr(optimizee, name, result_params[name])
            assert len(list(optimizee.all_named_parameters()))
            derivative_input = derivative_input2
    return all_losses_ever

@cache.cache
def fit_HNN(target_cls, target_to_opt, preproc=False, unroll=1, optim_it=1, n_epochs=20, n_tests=100, lr=0.001, out_mul=1.0):
    opt_net = w(Optimizer_HNN(preproc=preproc))
    meta_opt = optim.Adam(opt_net.parameters(), lr=lr)
    #scheduler = StepLR(meta_opt, step_size=1, gamma=0.9)
    best_net = None
    best_loss = 100000000000000000
    
    for _ in range(n_epochs):
        for _ in range(20):
            one_step_fit_HNN(opt_net, meta_opt, target_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=True)
        
        loss = (np.mean([
            np.mean(one_step_fit_HNN(opt_net, meta_opt, target_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=False))
            for _ in range(n_tests)
        ]))
        if loss < best_loss:
            best_loss = loss
            best_net = copy.deepcopy(opt_net.state_dict())
        #scheduler.step() 
    return best_loss, best_net


def find_best_parameters_HNN(target_cls, target_to_opt, preproc=False):
    best_loss = 1000000000000000.0
    best_lr = 0.0
    best_out_mul = 0.0
    lrs = [0.1, 0.01, 0.001, 0.003, 0.0001]
    out_muls= [0.1, 0.01, 0.001, 0.0001]    
    for lr, out_mul in itertools.product(lrs, out_muls):
        print('Trying:', lr, out_mul)
        loss = best_loss + 1.0
        loss = fit_HNN(target_cls, target_to_opt, preproc=preproc, unroll=20, optim_it=100,\
                            n_epochs=20, n_tests=10, lr=lr, out_mul=out_mul)[0]
        if loss < best_loss:
            best_loss = loss
            best_lr = lr
            best_out_mul = out_mul
        print(best_loss, best_lr, best_out_mul)
    clear_output()        
    return best_loss, best_lr, best_out_mul

In [9]:
@cache.cache
def fit_normal(target_cls, target_to_opt, opt_class, n_tests=100, n_epochs=100, **kwargs):
    results = []
    for i in range(n_tests):
        target = target_cls(training=False)
        optimizee = w(target_to_opt())
        optimizer = opt_class(optimizee.parameters(), **kwargs)
        total_loss = []
        for _ in range(n_epochs):
            loss = optimizee(target)
            
            total_loss.append(loss.data.cpu().numpy())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        results.append(total_loss)
    return results

In [10]:
class Optimizer(nn.Module):
    def __init__(self, preproc=False, hidden_sz=20, preproc_factor=10.0):
        super().__init__()
        self.hidden_sz = hidden_sz
        if preproc:
            self.recurs = nn.LSTMCell(2, hidden_sz)
        else:
            self.recurs = nn.LSTMCell(1, hidden_sz)
        self.recurs2 = nn.LSTMCell(hidden_sz, hidden_sz)
        self.output = nn.Linear(hidden_sz, 1)
        self.preproc = preproc
        self.preproc_factor = preproc_factor
        self.preproc_threshold = np.exp(-preproc_factor)
        
    def forward(self, inp, hidden, cell):
        if self.preproc:
            # Implement preproc described in Appendix A
            
            # Note: we do all this work on tensors, which means
            # the gradients won't propagate through inp. This
            # should be ok because the algorithm involves
            # making sure that inp is already detached.
            inp = inp.data
            inp2 = w(torch.zeros(inp.size()[0], 2))
            keep_grads = (torch.abs(inp) >= self.preproc_threshold).squeeze()
            inp2[:, 0][keep_grads] = (torch.log(torch.abs(inp[keep_grads]) + 1e-8) / self.preproc_factor).squeeze()
            inp2[:, 1][keep_grads] = torch.sign(inp[keep_grads]).squeeze()
            
            inp2[:, 0][~keep_grads] = -1
            inp2[:, 1][~keep_grads] = (float(np.exp(self.preproc_factor)) * inp[~keep_grads]).squeeze()
            inp = w(Variable(inp2))
        hidden0, cell0 = self.recurs(inp, (hidden[0], cell[0]))
        hidden1, cell1 = self.recurs2(hidden0, (hidden[1], cell[1]))
        return self.output(hidden1), (hidden0, hidden1), (cell0, cell1)

class Optimizer_HNN(nn.Module):
    def __init__(self, preproc=False, preproc_factor=10.0):
        super().__init__()
        if preproc:
            gdefunc = Hamiltonian(2)
            self.output = nn.Linear(4, 1)
        else:
            gdefunc = Hamiltonian(1)
            self.output = nn.Linear(2, 1)

        #self.gde = ODEBlock(odefunc=gdefunc)
        self.gde = gdefunc
        self.preproc = preproc
        self.preproc_factor = preproc_factor
        self.preproc_threshold = np.exp(-preproc_factor)
        
    def forward(self, inp):
        if self.preproc:
            # Implement preproc described in Appendix A
            inp_, dev_inp = torch.chunk(inp, 2, dim=-1)
            # Note: we do all this work on tensors, which means
            # the gradients won't propagate through inp. This
            # should be ok because the algorithm involves
            # making sure that inp is already detached.
            def preprocess(inp):
                inp = inp.data
                inp2 = w(torch.zeros(inp.size()[0], 2))
                keep_grads = (torch.abs(inp) >= self.preproc_threshold).squeeze()
                inp2[:, 0][keep_grads] = (torch.log(torch.abs(inp[keep_grads]) + 1e-8) / self.preproc_factor).squeeze()
                inp2[:, 1][keep_grads] = torch.sign(inp[keep_grads]).squeeze()
                
                inp2[:, 0][~keep_grads] = -1
                inp2[:, 1][~keep_grads] = (float(np.exp(self.preproc_factor)) * inp[~keep_grads]).squeeze()
                inp = w(Variable(inp2))
                return inp
            inp_, dev_inp = preprocess(inp_), preprocess(dev_inp)
            inp = torch.cat((inp_, dev_inp), dim=-1)
        #h = self.gde(inp)
        # print(h.shape)
        h = self.gde(t=1, x=inp)
        return self.output(h)

# Characteristics measurement

In [15]:
num_lstm_params = sum(p.numel() for p in Optimizer().parameters())
num_hnn_params = sum(p.numel() for p in Optimizer_HNN().parameters())
print('Number of LSTM parameters: {}\nNumber HNN parameters: {}'.format(num_lstm_params, num_hnn_params))

Number of LSTM parameters: 5221
Number HNN parameters: 17


In [17]:
num_lstm_params = sum(p.numel() for p in Optimizer(preproc=True).parameters())
num_hnn_params = sum(p.numel() for p in Optimizer_HNN(preproc=True).parameters())
print('Number of LSTM parameters: {}\nNumber HNN parameters: {}'.format(num_lstm_params, num_hnn_params))

Number of LSTM parameters: 5301
Number HNN parameters: 51


In [47]:
lstm = w(Optimizer(preproc=False))
hnn = w(Optimizer_HNN(preproc=False))

In [48]:
n_params = 1000
inputs = w(torch.ones((n_params, 1)))
hidden = [w(Variable(torch.zeros(n_params, lstm.hidden_sz))) for _ in range(2)]
cell = [w(Variable(torch.zeros(n_params, lstm.hidden_sz))) for _ in range(2)]
output = hnn(torch.cat((inputs, inputs), dim=-1))

In [49]:
%timeit output = lstm(inputs, hidden, cell)

259 µs ± 14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [50]:
hnn_input = torch.cat((inputs, inputs), dim=-1)

In [51]:
%timeit output = hnn(hnn_input)

576 µs ± 18.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Quadratic loss


In [11]:
class QuadraticLoss:
    def __init__(self, **kwargs):
        self.W = w(Variable(torch.randn(10, 10)))
        self.y = w(Variable(torch.randn(10)))
        
    def get_loss(self, theta):
        return torch.sum((self.W.matmul(theta) - self.y)**2)
    
class QuadOptimizee(MetaModule):
    def __init__(self, theta=None):
        super().__init__()
        self.register_buffer('theta', to_var(torch.zeros(10).cuda(), requires_grad=True))
        
    def forward(self, target):
        return target.get_loss(self.theta)
    
    def all_named_parameters(self):
        return [('theta', self.theta)]

In [None]:
best_loss, best_lr, best_out_mul = find_best_parameters_HNN(QuadraticLoss, QuadOptimizee)
print(best_loss, best_lr, best_out_mul)

Trying: 0.1 0.1
3.0375912 0.1 0.1
Trying: 0.1 0.01
0.63969946 0.1 0.01
Trying: 0.1 0.001
0.63969946 0.1 0.01
Trying: 0.1 0.0001


In [None]:
loss_HNN, quad_optimizer_HNN = fit_HNN(QuadraticLoss, QuadOptimizee, unroll=20, optim_it=100, lr=best_lr,\
                                       n_tests=10, n_epochs=100, out_mul=best_out_mul)

In [None]:
loss_LSTM, quad_optimizer_LSTM = fit_LSTM(QuadraticLoss, QuadOptimizee, unroll=20, optim_it=100, lr=0.003,\
                                          n_tests=10, n_epochs=100)

In [None]:
print('Best loss of LSTM = ', loss_LSTM)
print('Best loss of HNN = ', loss_HNN)

In [None]:
fit_data = np.zeros((100, 100, 6))
np.random.seed(0)

opt = w(Optimizer())
opt.load_state_dict(quad_optimizer_LSTM)
fit_data[:, :, 0] = np.array([one_step_fit_LSTM(opt, None, QuadraticLoss, QuadOptimizee, \
                                                1, 100, 100, out_mul=1.0, should_train=False) for _ in range(100)])
opt = w(Optimizer_HNN())
opt.load_state_dict(quad_optimizer_HNN)
fit_data[:, :, 1] = np.array([one_step_fit_HNN(opt, None, QuadraticLoss, QuadOptimizee, \
                                               1, 100, 100, out_mul=best_out_mul, should_train=False) for _ in range(100)])
QUAD_LRS = [0.1, 0.03, 0.01, 0.01]
NORMAL_OPTS = [(optim.Adam, {}), (optim.RMSprop, {}), (optim.SGD, {'momentum': 0.9}), (optim.SGD, {'nesterov': True, 'momentum': 0.9})]
for i, (opt, kwargs) in enumerate(NORMAL_OPTS):
    fit_data[:, :, 2 + i] = np.array(fit_normal(QuadraticLoss, QuadOptimizee, opt,  lr=QUAD_LRS[i], **extra_kwargs))

In [None]:
plt.figure(figsize=(8,6))
plt.plot(np.mean(fit_data[:,:,0], axis=0), color='b', label='LSTM')
plt.plot(np.mean(fit_data[:,:,1], axis=0), color='r', label='HNN')
OPT_NAMES = ['ADAM', 'RMSprop', 'SGD', 'NAG']
for i, opt in enumerate(OPT_NAMES):
    plt.plot(np.mean(fit_data[:,:,2+i], axis=0), label=opt)
    
plt.xlabel('steps')
plt.ylabel('loss')
plt.title('Quadratic functions')
plt.legend()
plt.grid()
plt.show()

# MNIST

In [None]:
try:
    os.mkdir('data')
except:
    pass

class MNISTLoss:
    def __init__(self, training=True):
        dataset = datasets.MNIST(
            'data', train=True, download=True,
            transform=torchvision.transforms.ToTensor()
        )
        indices = list(range(len(dataset)))
        np.random.RandomState(10).shuffle(indices)
        if training:
            indices = indices[:len(indices) // 2]
        else:
            indices = indices[len(indices) // 2:]

        self.loader = torch.utils.data.DataLoader(
            dataset, batch_size=128,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))

        self.batches = []
        self.cur_batch = 0
        
    def sample(self):
        if self.cur_batch >= len(self.batches):
            self.batches = []
            self.cur_batch = 0
            for b in self.loader:
                self.batches.append(b)
        batch = self.batches[self.cur_batch]
        self.cur_batch += 1
        return batch

class MNISTNet(MetaModule):
    def __init__(self, layer_size=20, n_layers=1, **kwargs):
        super().__init__()

        inp_size = 28*28
        self.layers = {}
        for i in range(n_layers):
            self.layers[f'mat_{i}'] = MetaLinear(inp_size, layer_size)
            inp_size = layer_size

        self.layers['final_mat'] = MetaLinear(inp_size, 10)
        self.layers = nn.ModuleDict(self.layers)

        self.activation = nn.Sigmoid()
        self.loss = nn.NLLLoss()

    def all_named_parameters(self):
        return [(k, v) for k, v in self.named_parameters()]
    
    def forward(self, loss):
        inp, out = loss.sample()
        inp = w(Variable(inp.view(inp.size()[0], 28*28)))
        out = w(Variable(out))

        cur_layer = 0
        while f'mat_{cur_layer}' in self.layers:
            inp = self.activation(self.layers[f'mat_{cur_layer}'](inp))
            cur_layer += 1

        inp = F.log_softmax(self.layers['final_mat'](inp), dim=1)
        l = self.loss(inp, out)
        return l

In [None]:
best_loss, best_lr, best_out_mul=find_best_parameters_HNN(MNISTLoss, MNISTNet, preproc=True)
print(best_loss, best_lr, best_out_mul)

In [None]:
loss_HNN, MNIST_optimizer_HNN = fit_HNN(MNISTLoss, MNISTNet, unroll=20, optim_it=100, lr=best_lr,\
                                        out_mul=best_out_mul, preproc=True, n_tests=10, n_epochs=50)

In [None]:
loss_LSTM, MNIST_optimizer_LSTM = fit_LSTM(MNISTLoss, MNISTNet, unroll=20, optim_it=100, lr=0.01,\
                                           out_mul=0.1, preproc=True, n_tests=10, n_epochs=50)

In [None]:
print('Best loss of LSTM = ', loss_LSTM)
print('Best loss of HNN = ', loss_HNN)

In [None]:
fit_data = np.zeros((100, 200, 6))
np.random.seed(0)

opt = w(Optimizer(preproc=True))
opt.load_state_dict(MNIST_optimizer_LSTM)
fit_data[:, :, 0] = np.array([one_step_fit_LSTM(opt, None, MNISTLoss, MNISTNet, 1, 100, 200, out_mul=1.0, should_train=False) for _ in range(100)])
opt = w(Optimizer_HNN(preproc=True))
opt.load_state_dict(MNIST_optimizer_HNN)
fit_data[:, :, 1] = np.array([one_step_fit_HNN(opt, None, MNISTLoss, MNISTNet, 1, 100, 200, out_mul=best_out_mul, should_train=False) for _ in range(100)])
QUAD_LRS = [0.03, 0.01, 1.0, 1.0]
NORMAL_OPTS = [(optim.Adam, {}), (optim.RMSprop, {}), (optim.SGD, {'momentum': 0.9}), (optim.SGD, {'nesterov': True, 'momentum': 0.9})]
for i, (opt, kwargs) in enumerate(NORMAL_OPTS):
    fit_data[:, :, 2 + i] = np.array(fit_normal(MNISTLoss, MNISTNet, opt, n_tests=100, n_epochs=200, lr=QUAD_LRS[i], **extra_kwargs))

In [None]:
plt.plot(np.mean(fit_data[:,:,0], axis=0), color='b', label='LSTM')
plt.plot(np.mean(fit_data[:,:,1], axis=0), color='r', label='HNN')
plt.xlabel('steps')
plt.ylabel('loss')
plt.title('MNIST')
plt.legend()
plt.grid()
plt.show()