# MAML - MODEL-AGNOSTIC META-LEARNING

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')
#%cd drive/MyDrive/'Colab Notebooks'
#%cd meta-learning-course-notebooks/1_MAML/
#!ls

In [None]:
#!pip install import_ipynb --quiet

In [None]:
#!pip install learn2learn --quiet

In [None]:
import import_ipynb
import utils
import models
utils.hide_toggle('Imports 1')

In [None]:
from IPython import display
import torch
import torch.nn as nn
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from l2lutils import KShotLoader
from IPython import display
utils.hide_toggle('Imports 2')

# Pre-trained Models

In [None]:
#Generate data - euclidean
meta_train_ds, meta_test_ds, full_loader = utils.euclideanDataset(n_samples=10000,n_features=20,n_classes=10,batch_size=32)

In [None]:
# Define an MLP network. Note that input dimension has to be data dimension. For classification
# final dimension has to be number of classes; for regression one.
#torch.manual_seed(10)
net = models.MLP(dims=[20,32,32,10])

In [None]:
# Train the network; note that network is trained in place so repeated calls further train it.
net,loss,accs=models.Train(net,full_loader,lr=1e-2,epochs=5,verbose=True)

In [None]:
#Training accuracy.
models.accuracy(net,meta_train_ds.samples,meta_train_ds.labels,verbose=True)

In [None]:
# Test accuracy.
models.accuracy(net,meta_test_ds.samples,meta_test_ds.labels)

# Second-order Differentiation using Autograd

Second-order derivatives as needed for MAML

In [None]:
network = (lambda x,w: x@w)
loss = torch.nn.MSELoss()

In [None]:
Z=(torch.ones(3,1)).float()
z=(torch.ones(3,1)*2).float()

In [None]:
Zt=(torch.ones(3,1)*1.5).float()
zt=(torch.ones(3,1)*2*1.5).float()

In [None]:
w0=(torch.ones(1,1,requires_grad=True)).float()

In [None]:
w1=w0.clone()

In [None]:
L=loss(network(Z,w1),z)

In [None]:
#g=torch.autograd.grad(L,w0)[0]
g=torch.autograd.grad(L,w1,create_graph=True)[0]
#L.backward(create_graph=True)# Not good

In [None]:
w1.grad, w0.grad, L, w0, w1,w1.requires_grad,g

In [None]:
w1 = w1 - 0.1*g

In [None]:
L1=loss(network(Zt,w1),zt)
#L1=loss(net(Zt,w0-0.1*(2.0*(w0-2.0))),zt)

In [None]:
# Both OK - latter used with optimizer.step()
g1=torch.autograd.grad(L1,w0)[0]
#L1.backward()

In [None]:
g1

Working this out manually:

$w_0=1, L=(w_0-2)^2, dL=2\times(w_0-2)=-2,w_1=w_0-0.1\times(-2)=1.2$

$L_1=(w_1\times1.5-3)^2 = (w_0-0.1\times(2\times(w_0-2))\times1.5-3)^2 = (-1.2)^2$

$dL_1 = 2 \times (-1.2) \times (1.5 \times (1-.2)$

In [None]:
2*(-1.2)*(1.5*(1-.2))

In [None]:
w0.grad,w1.grad

# Meta-Learning: Tasks

Generate a k-shot n-way loader using the meta-training dataset

In [None]:
classes_train = [i for i in range(5)]
classes_test = [i+5 for i in range(5)]
classes_train, classes_test

In [None]:
meta_train_kloader=KShotLoader(meta_train_ds,shots=1,ways=5)

Sample a task - each task has a k-shot n-way training set and a similar test set

In [None]:
d_train,d_test=meta_train_kloader.get_task()

Let's try directly learning using the task training set albeit its small size: create a dataset and loader and train it with the earlier network and Train function.

In [None]:
taskds = utils.MyDS(d_train[0],d_train[1])

In [None]:
d_train_loader = torch.utils.data.DataLoader(dataset=taskds,batch_size=1,shuffle=True)

In [None]:
net,losses,accs=models.Train(net,d_train_loader,lr=1e-1,epochs=10,verbose=True)

How does it do on the test set of the sampled task?

In [None]:
models.accuracy(net,d_test[0],d_test[1])

# MAML - Model-Agnostic Meta-Learning

In [None]:
import learn2learn as l2l
import torch.optim as optim

In [None]:
maml = l2l.algorithms.MAML(net, lr=1e-1)
optimizer = optim.Adam(net.parameters(),lr=1e-3)
lossfn = torch.nn.NLLLoss()

The MAML class above wraps our nn.Module class for parameter cloning and other purposes as below. One iteration of the MAML algorithm proceeds by first sampling a training task: Note that each of d_train and d_test is a tuple comprising of a training set, and labels.

In [None]:
d_train,d_test=meta_train_kloader.get_task()

In [None]:
learner = maml.clone()

The learner class above is a 'clone' of our network with copies of parameters so that we can change these without changing the parameters of the network. We apply the learner on training data of d_train and compute TRAINING loss w.r.t the training data of the task, i.e., d_train.

In [None]:
train_preds = learner(d_train[0])
train_loss = lossfn(train_preds,d_train[1])

In [None]:
net.layers[0].weight

In [None]:
learner.layers[0].weight

Note that at this point both the learner and original net have the same parameters. Lets see what the gradients w.r.t the TRAINING loss are: (We use pytorch's autograd functions directly.)

In [None]:
from torch.autograd import grad

In [None]:
train_grad=grad(train_loss,learner.layers[0].weight,retain_graph=True,
                                 create_graph=True,
                                 allow_unused=True)
train_grad[0]

Next we ADAPT the learner by taking one step on the CLONED parameters in direction of the gradient of the TRAINING loss above. This is the part that the l2l libarary does for us as per the MAML algorithm.

In [None]:
learner.adapt(train_loss)

We can check what has happended:

In [None]:
learner.layers[0].weight

In [None]:
(net.layers[0].weight - learner.layers[0].weight)/train_grad[0]

So one step in the diretion of the gradient (w.r.t train_loss) has been taken. Next we compute the loss of this ADAPTED learner w.r.t. the TEST data of the task, i.e., d_test:

In [None]:
test_preds = learner(d_test[0])
adapt_loss = lossfn(test_preds,d_test[1])

The main MAML update to the original network net takes place now, by back-propagating through the (cumulative) adaptation loss (across possibly many tasks, here there was just one):

In [None]:
task_count = 1
optimizer.zero_grad()
total_loss = adapt_loss/task_count
total_loss.backward()

In [None]:
net.layers[0].weight

In [None]:
optimizer.step()

In [None]:
net.layers[0].weight

So, the original parameters have been updated by a gradient step using on all the task adaptation losses. 

# Putting it all together: MAML Algorithm
Now let's put all of the above in a loop - the MAML algorithm:

In [None]:
import learn2learn as l2l
import torch.optim as optim
shots,ways = 5,2
net = models.MLP(dims=[20,64,32,ways])
#net = models.RNN(n_classes=3,dim=10,n_layers=2)
maml = l2l.algorithms.MAML(net, lr=1e-2)
optimizer = optim.Adam(maml.parameters(),lr=5e-3)
lossfn = torch.nn.NLLLoss()
meta_train_kloader=KShotLoader(meta_train_ds,shots=shots,ways=ways,num_tasks=1000,classes=classes_train)

In [None]:
# Number of epochs, tasks per step and number of fast_adaptation steps 
n_epochs=10
task_count=32
fas = 5

Note: In practice we use more than one gradient step for adpation, this is called 'fast adaptation'.

In [None]:
epoch=0
while epoch<n_epochs:
    adapt_loss = 0.0
    test_acc = 0.0
    # Sample and train on a task
    for task in range(task_count):
        d_train,d_test=meta_train_kloader.get_task()
        learner = maml.clone()
        for fas_step in range(fas):
            train_preds = learner(d_train[0])
            train_loss = lossfn(train_preds,d_train[1])
            learner.adapt(train_loss)
        test_preds = learner(d_test[0])
        adapt_loss += lossfn(test_preds,d_test[1])
        learner.eval()
        test_acc += models.accuracy(learner,d_test[0],d_test[1],verbose=False)
        learner.train()
        # Done with a task
    # Update main network
    print('Epoch  % 2d Loss: %2.5e Avg Acc: %2.5f'%(epoch,adapt_loss/task_count,test_acc/task_count))
    display.clear_output(wait=True)
    optimizer.zero_grad()
    total_loss = adapt_loss
    total_loss.backward()
    optimizer.step()
    epoch+=1
    

Now test the trained maml network and applying the adaption step to tasks sampled from the meta_test_ds dataset:

In [None]:
shots
meta_test_kloader=KShotLoader(meta_test_ds,shots=shots,ways=ways,classes=classes_test)
test_acc = 0.0
task_count = 20
adapt_steps = 5
maml.eval()
# Sample and train on a task
for task in range(task_count):
    d_train,d_test=meta_test_kloader.get_task()
    learner = maml.clone()
    learner.eval()
    for adapt_step in range(adapt_steps):
        train_preds = learner(d_train[0])
        train_loss = lossfn(train_preds,d_train[1])
        learner.adapt(train_loss)
    test_preds = learner(d_test[0])
    test_acc += models.accuracy(learner,d_test[0],d_test[1],verbose=False)
    # Done with a task
learner.train()
print('Avg Acc: %2.5f'%(test_acc/task_count))