# Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import warnings
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import numpy as np
import torchvision.transforms.functional as TF
import random 
import torch.nn.functional as F
import math 
import functools

plt.style.use('default')
warnings.filterwarnings("ignore", category=UserWarning) 

softplus = torch.nn.Softplus()

# seed = 888
# torch.manual_seed(seed)

# helper functions

mnist_trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((1.0,), (0.5,))])

# if not exist, download mnist dataset
root = './data'

mnist_train_set = datasets.MNIST(root=root, train=True, transform=mnist_trans, download=True)
mnist_test_set = datasets.MNIST(root=root, train=False, transform=mnist_trans, download=True)

def rsetattr(obj, attr, val):
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)

# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427

def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Meta model stuff

In [None]:
class MetaLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        # Initialize weights and biases to zero
        # The line below is nearly identical to "self.weight = ...", but we get all of the added PyTorch features.
        self.register_buffer('weight', torch.zeros(out_features, in_features, requires_grad=True))
        if bias:
            self.register_buffer('bias', torch.zeros(out_features, requires_grad=True))
        else:
            self.bias = None
        
        # Fancy initialization from https://discuss.pytorch.org/t/how-are-layer-weights-and-biases-initialized-by-default/13073
        stdv = 2. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
            
    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

# Simple Neural Network

In [None]:
layer_1_neurons = 64
layer_2_neurons = 64
layer_3_neurons = 64
layer_4_neurons = 10

biases = True
class Simple_nn(nn.Module):
    def __init__(self):
        super(Simple_nn, self).__init__()

        self.fc1 = MetaLinear(28*28*1,layer_1_neurons, bias=biases)
        self.fc2 = MetaLinear(layer_1_neurons,layer_2_neurons, bias=biases)
        self.fc3 = MetaLinear(layer_2_neurons,layer_3_neurons, bias=biases)
        self.fc4 = MetaLinear(layer_3_neurons,layer_4_neurons, bias=biases)


    def forward(self, x):
        x = x.view(-1, 28*28*1)
        x = softplus(self.fc1(x))
        x = softplus(self.fc2(x))
        x = softplus(self.fc3(x))
        x = softplus(self.fc4(x))
        return x


# Batch and loaders

In [None]:
batch_size = 600

mnist_train_loader = torch.utils.data.DataLoader(
                 dataset=mnist_train_set,
                 batch_size=batch_size,
                 shuffle=True)
mnist_test_loader = torch.utils.data.DataLoader(
                dataset=mnist_test_set,
                batch_size=batch_size,
                shuffle=False)                

# Meta LR model

In [None]:
class Meta_LR_Model(nn.Module):
    def __init__(self):
        super(Meta_LR_Model, self).__init__()

        self.fc1 = nn.Linear(1,128)
        self.fc2 = nn.Linear(128,128)
        self.fc3 = nn.Linear(128,128)
        self.fc4 = nn.Linear(128,1)
        
    def forward(self, x):
        x = softplus(self.fc1(x))
        x = softplus(self.fc2(x))
        x = softplus(self.fc3(x))
        x = softplus(self.fc4(x)) * 1e-3
        return x 

meta_lr_model = Meta_LR_Model().cuda()
meta_model_opt = optim.Adam(meta_lr_model.parameters(), lr=1e-2)

# Epochs

In [None]:
meta_num_train = 30 # Number of steps we will train the meta learner 
mnist_num_train = 3 # We will train on MNIST for this many steps in the inner loop

# Loss function

In [None]:
criterion = nn.CrossEntropyLoss()

# Custom Optimizer

In [None]:
class CustomeOptimizer():
  def __init__(self, model):
    self.named_buffers = model.named_buffers()
    self.model = model

  def zero_grad(self):
      for name, param in self.named_buffers:
          if param.grad:
              param.grad.zero_()
  
  def step(self, meta_output_lr): 
    for name, param in self.model.named_buffers():
      clipping_value = 1e-2
      clipped_gradient = torch.clip(param.grad.detach().clone(), min = -clipping_value, max = clipping_value)
      # clipped_gradient = param.grad.detach().clone()
      
      new_param = (param.clone() - meta_output_lr.to(device) * clipped_gradient)
      new_param.retain_grad()
      rsetattr(self.model, name, new_param)      

''' def step(self, layer): 
    for name, param in self.model.named_buffers():
      if layer in name: # layer = 'spec'/'fc'. Basically, only update the parameters for either the specs or the fcs (2 optimizers)
        layer_name = name.split('.')[0]
        new_param = (param.clone() - learning_rates_dictionary[str(layer_name)].to(device) * param.grad.detach().clone())
        new_param.retain_grad()
        rsetattr(self.model, name, new_param)

The original param is a leaf node, so if we call .backward it will have a gradient.
so we create new_param in order to replace that value and still propagate gradients through the new param (otherwise we'll get an error that we changed a leaf node that required grad). 
We use param.clone() to propagate the gradients from the previous new_param or param if it's the first iteration 
(but the main reason I used ".clone()" is because I'm trying to avoid performing in-place operations on leaf nodes). 
We're doing new_param.retain_grad() because otherwise pytorch will remove the intermediate parameter gradients when we call .backward(). 
The gradient of the meta learning model will be passed mainly through learning_rates_dictionary. 
Though, it will can also pass through the gradient of param.grad ( but the paper was saying 
that this signal is weak so not really needed). 
The buffers are just "tensor holders". 
That is, we use them because otherwise pytorch will give us an error that we are trying to change inplace an nn.Module. 
So we create a new set of parameters/tensors using a custom layer that are not nn.Module'''      
      

In [None]:
%%time

test_losses = []
meta_losses = []
predicted_lrs = []

# META training loop 
for outer_loop_epoch in range(meta_num_train):

    # reset losses
    meta_loss = 0 # meta loss for the meta model

    # reset networks
    simple_nn = Simple_nn().to(device)

    # maintain grad on the parameters of the simple nn
    for name, param in simple_nn.named_buffers():
        param.retain_grad() 

    # reset optimizer
    opt = CustomeOptimizer(simple_nn) 

    # MNIST training loop (simple_nn)
    for inner_loop_epoch in range(mnist_num_train):        

                    # get a batch
        for batch_idx, (x, target) in enumerate(mnist_train_loader):
            x, target = x.to(device), target.to(device)

                        # send inputs to the simple nn
            out = simple_nn(x)
            fc_loss = criterion(out, target)

            # add losses to meta loss for stronger signal
            meta_loss += fc_loss 
        
            # get an input to the meta model. We will use the actual loss as the input
            meta_model_input = torch.tensor([fc_loss.item()]).to(device) # note that we want to remove the gradient from this as this is just an input

                        # get new lrs based on the activation percentages
            meta_output_lr = meta_lr_model(meta_model_input)
            
            opt.zero_grad()
            fc_loss.backward(retain_graph=True) 
            opt.step(meta_output_lr)

                        # print loss and batch
            if (batch_idx) % 200 == 0 or (batch_idx) == len(mnist_train_loader):
                print ('==>>> outer loop epoch: {} , inner loop epoch: {} batch index: {}, train loss: {:.6f}'.format(outer_loop_epoch+1,inner_loop_epoch+1, batch_idx, fc_loss.item()))

        
    ######################################################################################################################################################################################################
    # Compute meta loss
    ######################################################################################################################################################################################################\

    forgetfulness_loss = meta_loss

    meta_model_opt.zero_grad()
    forgetfulness_loss.backward() 
    torch.nn.utils.clip_grad_norm_(meta_lr_model.parameters(), 1e-1)
    meta_model_opt.step()

    ######################################################################################################################################################################################################
    # test
    ######################################################################################################################################################################################################\
    with torch.no_grad():
        test_loss = []
        test_accuracy = []
            
        for batch_idx, (x, target) in enumerate(mnist_test_loader):
            x, target = x.to(device), target.to(device)
            outputs = simple_nn(x)
            _, predicted = torch.max(outputs.data, 1)
            batch_test_loss = criterion(outputs, target)
            test_loss.append(batch_test_loss.item())
            test_accuracy.append((predicted == target).sum().item() / predicted.size(0))

    ######################################################################################################################################################################################################
    # appending stuff (last batch)
    ######################################################################################################################################################################################################\
    print('test loss: {}, test accuracy: {}'.format(np.mean(test_loss), np.mean(test_accuracy)))
    print('meta_loss', meta_loss.item())
    print(f'current predicted learning rate is: {meta_output_lr.item()}')

    test_losses.append(np.mean(test_loss))
    meta_losses.append(meta_loss.item())
    predicted_lrs.append(meta_output_lr.item())
    print()

# meta losses

In [None]:
plt.plot(meta_losses, c='teal', label='meta_losses')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()

# test losses

In [None]:
plt.plot(test_losses, c='teal', label='test_losses')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()

# meta model learning rates predictions per loss 

In [None]:
meta_function = []

for i in torch.range(0, 1, 1e-3): # this is the range for the losses
    meta_function.append(meta_lr_model(torch.tensor([i]).cuda()).item())

In [None]:
# plt.xticks([])
plt.xlabel('loss steps (0 to 1 in 1e-3 step size)')
plt.ylabel('learning rate values')
plt.plot(meta_function)
plt.show()

# Save meta model weights

In [None]:
# weights and optimizer
torch.save({'model': meta_lr_model.state_dict(), 'optimizer_state_dict': meta_model_opt.state_dict()}, 'models_weights.pt')

# Compare the fixed learning rates and the one we just learned

In [None]:
# train simple neural network using the fixed meta model (not training the meta model)

# set seed
seed = 888
torch.manual_seed(seed)

# epochs
mnist_num_train = 10

# lists to hold losses
test_losses_meta_trained = []
meta_losses = []

# set network
simple_nn = Simple_nn().to(device)

# set optimizer
opt = CustomeOptimizer(simple_nn) 

# MNIST training loop (simple_nn)
for inner_loop_epoch in range(mnist_num_train):        

    # maintain grad on the parameters of the simple nn
    for name, param in simple_nn.named_buffers():
        param.retain_grad() 
        
                # get a batch
    for batch_idx, (x, target) in enumerate(mnist_train_loader):
        x, target = x.to(device), target.to(device)

                    # send inputs to the simple nn
        out = simple_nn(x)
        fc_loss = criterion(out, target)
    
        # get an input to the meta model. We will use the actual loss as the input
        meta_model_input = torch.tensor([fc_loss.item()]).to(device) # note that we want to remove the gradient from this as this is just an input

                    # get new lrs based on the activation percentages
        meta_output_lr = meta_lr_model(meta_model_input)
        
        # overwrite existing learning rate
        simple_neural_network_lr = meta_output_lr.item() # remove the gradient

        opt.zero_grad()
        fc_loss.backward(retain_graph=False) 
        opt.step(torch.tensor(simple_neural_network_lr))

                    # print loss and batch
        if (batch_idx) % 200 == 0 or (batch_idx) == len(mnist_train_loader):
            print ('==>>> inner loop epoch: {} batch index: {}, train loss: {:.6f}'.format(inner_loop_epoch+1, batch_idx, fc_loss.item()))

    with torch.no_grad():
        test_loss = []
        test_accuracy = []
            
        for batch_idx, (x, target) in enumerate(mnist_test_loader):
            x, target = x.to(device), target.to(device)
            outputs = simple_nn(x)
            _, predicted = torch.max(outputs.data, 1)
            batch_test_loss = criterion(outputs, target)
            test_loss.append(batch_test_loss.item())
            test_accuracy.append((predicted == target).sum().item() / predicted.size(0))

        print('test loss: {}, test accuracy: {}'.format(np.mean(test_loss), np.mean(test_accuracy)))

        test_losses_meta_trained.append(np.mean(test_accuracy))

# Load previous accuracies
#### You need to run the simple_neural_network.py 3 times for this!! Once with each unique learning rate.
#### That is, run it once with lr 1e-3, once with 1e-2, and once with 1e-1. Change the lr on line 60 in the simple_neural_network.py file.

In [None]:
simple_neural_network_lr = 1e-3
lr_1e_neg_3_accuracy = torch.load(f'test_accuracies_lr_{simple_neural_network_lr}.pt')

simple_neural_network_lr = 1e-2
lr_1e_neg_2_accuracy = torch.load(f'test_accuracies_lr_{simple_neural_network_lr}.pt')

simple_neural_network_lr = 1e-1
lr_1e_neg_1_accuracy = torch.load(f'test_accuracies_lr_{simple_neural_network_lr}.pt')

In [None]:
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy')
plt.plot(lr_1e_neg_3_accuracy, c='purple', label='1e-3')
plt.plot(lr_1e_neg_2_accuracy, c='red', label='1e-2')
plt.plot(lr_1e_neg_1_accuracy, c='black', label='1e-1')
plt.plot(test_losses_meta_trained, c='teal', label='meta_lr')
plt.legend()
plt.show()

In [None]:
torch.__version__