## Learning to learn by gradient descent by gradient descent

In this notebook, we try Self-Attention instead of LSTM netowrks as optimizers

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import glob
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets
import torchvision
from torchsummary import summary

import matplotlib.pyplot as plt
import random
from tqdm import tqdm_notebook as tqdm
import multiprocessing
import os.path
import csv
import copy
import joblib
import seaborn as sns; sns.set(color_codes=True)
sns.set_style("white")
from pdb import set_trace as bp

In [None]:
USE_CUDA = torch.cuda.is_available()

def w(v):
    if USE_CUDA:
        return v.cuda()
    return v

In [None]:
!mkdir cache
cache = joblib.Memory(location='cache', verbose=0)

In [None]:
from meta_module import *

## Gradient detach

As we perform operations, Pytorch builds the computational graph of the operations we perform. However, there are some variables that we want to detach from the graph at various points, specifically we want to pretend that the **gradients are inputs** (as specified in the previous image) that come from nowhere, instead of coming from the rest of the computational graph as they really do: this means we want to **detach** the gradients from the graph. Likewise, when every 20 steps we perform backpropagation on the optimizer network, we want the current hidden states and cell states, as well as the parameters of the optimizee to "forget" that they are dependent on previous steps in the graph. For all of this, I created a function called `detach_var` which creates a new Variable from the current variable's data, and makes sure that its gradients are still kept. This is different from the `.detach()` function in Pytorch which does not quite forget the original graph and also does not guarantee that the gradients will be there.

In [None]:
def detach_var(v):
    var = w(Variable(v.data, requires_grad=True))
    var.retain_grad()
    return var

import functools

def rsetattr(obj, attr, val):
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)

# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427

def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))

In [None]:
def do_fit(optimizer_net, meta_opt, cost_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, preproc = True, should_train=True):

    r""""
    
    Input variables:

    optimizer_net:      optimizer network class
    meta_opt:           optimizer method for the optimizer, in this notebook we use ADAM
    cost_cls:           cost/loss function
    target_to_opt:      optimizee class
    optim_it, unroll:   Each epoch is made up of trying to optimize a new random function for 'optim_it' steps, 
                        but we are doing an update of the optimizer every 'unroll' steps.                       
    n_epochs:           number of epochs
    out_mul:            learning rate for update parameters in the optimizee network?
    params:             number of parameters of the omptimizee network
    """
    
    if should_train:
        optimizer_net.train()
    else:
        optimizer_net.eval()
        unroll = 1
    
    target = cost_cls(training=should_train)
    optimizee = w(target_to_opt())
    n_params = 0
    
    for name, p in optimizee.all_named_parameters():
        n_params += int(np.prod(p.size()))
        
    all_losses_ever = []

    if should_train:
        meta_opt.zero_grad()

    # Compute the loss of the optimizee and compute the cumulative loss over all iterations.
    all_losses = None
    for iteration in range(1, optim_it + 1):
        loss = optimizee(target)
                    
        if all_losses is None:
            all_losses = loss
        else:
            all_losses += loss
            
        # Appends the current individual loss to a file
        all_losses_ever.append(loss.data.cpu().numpy())
        
        # Compute optimizee's backward propagation of the loss and retain_graph to be used when optimizing the optimizer.
        loss.backward(retain_graph=should_train)

        # Update each parameters and the cell and hidden states by iterating through the optimizee's "all_named_parameters".
        
        result_params = {}
        for name, p in optimizee.all_named_parameters():
            cur_sz = int(np.prod(p.size()))

            # We do this so the gradients are disconnected from the graph but we still get
            # gradients from the rest
            gradients = detach_var(p.grad.view(cur_sz, 1))
            if preproc == True:
                gradients = preprocess_gradient(gradients)
            
            # The gradients are fed to the optimizer network as a flatenned layer (1D)
            updates = optimizer_net(gradients)
                
            # Updated parameters of the optimizee function    
            result_params[name] = p + updates.view(*p.size()) * out_mul
            
            # The resulting variable isn't a leaf, which means it won't retain grads by default.
            result_params[name].retain_grad()
            
            
        # Update the optimizer parameters if    
        if iteration % unroll == 0:
            if should_train:
                meta_opt.zero_grad()
                all_losses.backward()
                meta_opt.step()

            # Restart the losses    
            all_losses = None

            # Train a new random network with the last parameters obtained and reinitialize the grad
            optimizee = w(target_to_opt())
            optimizee.load_state_dict(result_params)
            optimizee.zero_grad()
            
        else:
            for name, p in optimizee.all_named_parameters():
                rsetattr(optimizee, name, result_params[name])
            assert len(list(optimizee.all_named_parameters()))
            
    return all_losses_ever

def preprocess_gradient(gradients):
    """
    Args:
        gradients: `Tensor` of gradients with shape `[d_1, ..., d_n]`.
        p       : `p` > 0 is a parameter controlling how small gradients are disregarded 
    Returns:
       `Tensor` with shape `[d_1, ..., d_n-1, 2 * d_n]`. The first `d_n` elements
       along the nth dimension correspond to the `log output` \in [-1,1] and the remaining
       `d_n` elements to the `sign output`.
    """
    p_threshold  = 10
    log = torch.log(torch.abs(gradients))
    clamp_log = torch.clamp(log/p_threshold , min = -1.0,max = 1.0)
    clamp_sign = torch.clamp(torch.exp(w(torch.Tensor(p_threshold)))*gradients, min = -1.0, max =1.0)
    return torch.cat((clamp_log,clamp_sign), dim = -1)


@cache.cache
def fit_optimizer(cost_cls, target_to_opt, preproc=False, unroll=20, optim_it=100, n_epochs=20, n_tests=100, lr=0.001, out_mul=1.0):

    # Call the Transformer
    # We need to change this if we are other code.
    # It does not work yet.
    encoder_layer = w(nn.TransformerEncoderLayer(d_model=512, nhead=1))
    optimizer_net = w(nn.TransformerEncoder(encoder_layer, num_layers=2))

    # Choose the optimizer that will optimize the Transformer network
    # i.e.: the meta-optimizer
    meta_opt = optim.Adam(optimizer_net.parameters(), lr=lr)
    
    best_net = None
    best_loss = 100000000000000000
    
    for _ in tqdm(range(n_epochs), 'epochs'):
        for _ in tqdm(range(20), 'iterations'):
            do_fit(optimizer_net, meta_opt, cost_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=True)
        
        loss = (np.mean([
            np.sum(do_fit(optimizer_net, meta_opt, cost_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=False))
            for _ in tqdm(range(n_tests), 'tests')
        ]))
        
        print(loss)
        if loss < best_loss:
            print(best_loss, loss)
            best_loss = loss
            best_net = copy.deepcopy(optimizer_net.state_dict())
            
    return best_loss, best_net

## Optimizer network: PyTorch Transformer

In this new re-implementation we will use a Transformer as Optimizer instead of a LSTM.
As of now, I try to use PyTorch Transformer module. However it is not working.

What are the advantages of a Transformer over LSTM networks?
The transformers are non sequential and use positional embeddings to replace recurrence. On the other hand, LSTM are sequential and depend on the previous hidden and cell states, and therefore not possible to parallelize. Our hypothesis is that transformers can speed up the process.

In [None]:
Optimizer = torch.nn.Transformer(
    d_model=512, 
    nhead=8, 
    dim_feedforward=2048, 
    dropout=0.1, 
    activation='relu')

r""""
    d_model – the number of expected features in the encoder/decoder inputs (default=512).
    nhead – the number of heads in the multiheadattention models (default=8).
    num_encoder_layers – the number of sub-encoder-layers in the encoder (default=6).
    num_decoder_layers – the number of sub-decoder-layers in the decoder (default=6).
    dim_feedforward – the dimension of the feedforward network model (default=2048).
    dropout – the dropout value (default=0.1).
    activation – the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
    custom_encoder – custom encoder (default=None).
    custom_decoder – custom decoder (default=None).
"""


print(Optimizer)

# Optimizee Network: Quadratic functions

The optimizer is supposed to find a 10-element vector called $\theta$ that, when multiplied by a 10x10 matrix called $W$, is as close as possible to a 10-element vector called $y$. Both $y$ and $W$ are generated randomly. The error is simply the squared error.

## Class and function definitions

In [None]:
class QuadraticLoss:
    def __init__(self, **kwargs):
        self.W = w(Variable(torch.randn(10, 10)))
        self.y = w(Variable(torch.randn(10)))
        
    def get_loss(self, theta):
        return torch.sum((self.W.matmul(theta) - self.y)**2)
    
class QuadOptimizee(MetaModule):
    def __init__(self, theta=None):
        super().__init__()
        self.register_buffer('theta', to_var(torch.zeros(10).cuda(), requires_grad=True))
        
    def forward(self, target):
        return target.get_loss(self.theta)
    
    def parameters(self):
        return [self.theta]
    
    def all_named_parameters(self):
        return [('theta', self.theta)]

## Find best learning rate for meta_optimizer
The experiment below fits various learning rates that are used in the meta_optimizer (ADAM).

In [None]:
for lr in tqdm([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001], 'all'):
    print('Learning rate:', lr)
    print(fit_optimizer(QuadraticLoss, QuadOptimizee, lr=lr)[0])

The experiment shows that 0.003 is a promising learning rate. It is not the lowest, but we are training on only 20 epochs by default and will then retrain with 100 epochs, so it is good to have a slightly lower learning rate for training for longer.

Next, the final model is trained with the learning rate (lr = 0.003) found in the previous block and the number of epochs are increased to 100.

In [None]:
loss, quad_optimizer = fit_optimizer(QuadraticLoss, QuadOptimizee, lr=0.003, n_epochs=100)
print('The transformer model loss with the best found learning rate is: ', loss)

## Find best learning rate for conventional optimizers

The following two functions are used to find the best learning rate for conventional optimizers: ADAM, RMSProp, SGD and NAG. 

In [None]:
@cache.cache
def fit_normal(target_cls, target_to_opt, opt_class, n_tests=100, n_epochs=100, **kwargs):
    results = []
    for i in tqdm(range(n_tests), 'tests'):
        target = target_cls(training=False)
        optimizee = w(target_to_opt())
        optimizer = opt_class(optimizee.parameters(), **kwargs)
        total_loss = []
        for _ in range(n_epochs):
            loss = optimizee(target)
            
            total_loss.append(loss.data.cpu().numpy())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        results.append(total_loss)
    return results

def find_best_lr_normal(target_cls, target_to_opt, opt_class, **extra_kwargs):
    best_loss = 1000000000000000.0
    best_lr = 0.0
    for lr in tqdm([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001], 'Learning rates'):
        try:
            loss = best_loss + 1.0
            loss = np.mean([np.sum(s) for s in fit_normal(target_cls, target_to_opt, opt_class, lr=lr, **extra_kwargs)])
        except RuntimeError:
            pass
        if loss < best_loss:
            best_loss = loss
            best_lr = lr
    return best_loss, best_lr

In [None]:
NORMAL_OPTS = [(optim.Adam, {}), (optim.RMSprop, {}), (optim.SGD, {'momentum': 0.9}), (optim.SGD, {'nesterov': True, 'momentum': 0.9})]
OPT_NAMES = ['ADAM', 'RMSprop', 'SGD', 'NAG']

In [None]:
# NB: the momentum parameter for nesterov was found from the following file: 
#https://github.com/torch/optim/blob/master/nag.lua  since it is mentioned
# n the paper that "When an optimizer has more parameters than just a learning rate (e.g. decay coefficients for ADAM) 
#we use the default values from the optim package in Torch7."

for opt, kwargs in NORMAL_OPTS:
    print(find_best_lr_normal(QuadraticLoss, QuadOptimizee, opt, **kwargs))

In the cell below:

- QUAD_LRS are the best learning rates obtained for the conventional optimizers
- fit_data is initialized to 0 and the third dimension has length equal to all conventional opt. + Transformer
- The data is fitted with the best learning rate for conventional optimizers.
- The state_dict of the optimizer network found in the previous section is loaded
- Why is it fitted again I don't know    

In [None]:
QUAD_LRS = [0.1, 0.03, 0.01, 0.01]
fit_data = np.zeros((100, 100, len(OPT_NAMES) + 1))
for i, ((opt, extra_kwargs), lr) in enumerate(zip(NORMAL_OPTS, QUAD_LRS)):
    np.random.seed(0)
    fit_data[:, :, i] = np.array(fit_normal(QuadraticLoss, QuadOptimizee, opt, lr=lr, **extra_kwargs))

opt = w(Optimizer())
opt.load_state_dict(quad_optimizer)
np.random.seed(0)
fit_data[:, :, len(OPT_NAMES)] = np.array([do_fit(opt, None, QuadraticLoss, QuadOptimizee, 1, 100, 100, out_mul=1.0, should_train=False) for _ in range(100)])

## Graphical results

Here, our results are shown:

In [None]:
ax = sns.tsplot(data=fit_data[:, :, :], condition=OPT_NAMES + ['LSTM'], linestyle='--', color=['r', 'b', 'g', 'k', 'y'])
ax.lines[-1].set_linestyle('-')
ax.legend()
plt.yscale('log')
plt.xlabel('steps')
plt.ylabel('loss')
plt.title('Quadratic functions')
plt.show()

plt.savefig('quadratic_results.png')

The results obtain in the paper with the LSTM optimizer are:

![image.png](attachment:image.png)



# Optimizee network: MNIST

In [None]:
class MNISTLoss:
    def __init__(self, training=True):
        dataset = datasets.MNIST(
            '/home/chenwy/mnist', train=True, download=True,
            transform=torchvision.transforms.ToTensor()
        )
        indices = list(range(len(dataset)))
        np.random.RandomState(10).shuffle(indices)
        if training:
            indices = indices[:len(indices) // 2]
        else:
            indices = indices[len(indices) // 2:]

        self.loader = torch.utils.data.DataLoader(
            dataset, batch_size=128,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))

        self.batches = []
        self.cur_batch = 0 #current batch
        
    def sample(self):
        if self.cur_batch >= len(self.batches):
            self.batches = []
            self.cur_batch = 0
            for b in self.loader:
                self.batches.append(b)
        batch = self.batches[self.cur_batch]
        self.cur_batch += 1
        return batch

class MNISTNet(MetaModule):
    def __init__(self, layer_size=20, n_layers=1, **kwargs):
        super().__init__()

        inp_size = 28*28
        self.layers = {}
        for i in range(n_layers):
            self.layers[f'mat_{i}'] = MetaLinear(inp_size, layer_size)
            inp_size = layer_size

        self.layers['final_mat'] = MetaLinear(inp_size, 10)
        self.layers = nn.ModuleDict(self.layers)

        self.activation = nn.Sigmoid()
        self.loss = nn.NLLLoss()

    def all_named_parameters(self):
        return [(k, v) for k, v in self.named_parameters()]
    
    def forward(self, loss):
        inp, out = loss.sample()
        inp = w(Variable(inp.view(inp.size()[0], 28*28)))
        out = w(Variable(out))

        cur_layer = 0 #current layer
        while f'mat_{cur_layer}' in self.layers:
            inp = self.activation(self.layers[f'mat_{cur_layer}'](inp))
            cur_layer += 1

        inp = F.log_softmax(self.layers['final_mat'](inp), dim=1)
        l = self.loss(inp, out)
        return l