# Model-Agnostic Meta Learning - Pytorch Implementation

This is an implementation of the paper "Model-Agnostic Meta Learning for Adaptation of Deep Networks"

Paper link: https://arxiv.org/pdf/1703.03400.pdf

## Import Libraries

In [1]:
import torch

In [2]:
import math
import random
import torch # v0.4.1
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

def net(x, params):
    x = F.linear(x, params[0], params[1])
    x = F.relu(x)

    x = F.linear(x, params[2], params[3])
    x = F.relu(x)

    x = F.linear(x, params[4], params[5])
    return x

params = [
    torch.Tensor(32, 1).uniform_(-1., 1.).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(32, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(1, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_(),
]

opt = torch.optim.SGD(params, lr=1e-2)
n_inner_loop = 5
alpha = 3e-2

for it in range(100000):
    b = 0 if random.choice([True, False]) else math.pi

    x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    y = torch.sin(x + b)

    v_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    v_y = torch.sin(v_x + b)

    opt.zero_grad()

    new_params = params
    for k in range(n_inner_loop):
        f = net(x, new_params)
        loss = F.l1_loss(f, y)

        # create_graph=True because computing grads here is part of the forward pass.
        # We want to differentiate through the SGD update steps and get higher order
        # derivatives in the backward pass.
        grads = torch.autograd.grad(loss, new_params, create_graph=True)
        print(len(grads))
        new_params = [(new_params[i] - alpha*grads[i]) for i in range(len(params))]

        if it % 100 == 0: print('Iteration %d -- Inner loop %d -- Loss: %.4f' % (it, k, loss))

    v_f = net(v_x, new_params)
    loss2 = F.l1_loss(v_f, v_y)
    loss2.backward()

    opt.step()

    if it % 100 == 0: print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))

t_b = math.pi #0

t_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
t_y = torch.sin(t_x + t_b)

opt.zero_grad()

t_params = params
for k in range(n_inner_loop):
    t_f = net(t_x, t_params)
    t_loss = F.l1_loss(t_f, t_y)

    grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
    t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]


test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
test_y = torch.sin(test_x + t_b)

test_f = net(test_x, t_params)

plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.legend()
plt.savefig('maml-sine.png')


6
Iteration 0 -- Inner loop 0 -- Loss: 0.6935
6
Iteration 0 -- Inner loop 1 -- Loss: 0.5947
6
Iteration 0 -- Inner loop 2 -- Loss: 0.6119
6
Iteration 0 -- Inner loop 3 -- Loss: 0.5867
6
Iteration 0 -- Inner loop 4 -- Loss: 0.5635
Iteration 0 -- Outer Loss: 0.5566
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6


6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
Iteration 700 -- Inner loop 0 -- Loss: 0.5854
6
Iteration 700 -- Inner loop 1 -- Loss: 0.5136
6
Iteration 700 -- Inner loop 2 -- Loss: 0.4386
6
Iteration 700 -- Inner loop 3 -- Loss: 0.3822
6
Iteration 700 -- Inner loop 4 -- Loss: 0.3748
Iteration 700 -- Outer Loss: 0.3258
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6


KeyboardInterrupt: 

In [None]:
params

In [3]:
import numpy as np
import torch
import torch.nn as nn


def idcg(n_rel):
    # Assuming binary relevance.
    nums = np.ones(n_rel)
    denoms = np.log2(np.arange(n_rel) + 1 + 1)
    return (nums / denoms).sum()


# Data.
input_dim = 50
n_docs = 20
n_rel = 5
n_irr = n_docs - n_rel

doc_features = np.random.randn(n_docs, input_dim)

# Model.
model = torch.nn.Sequential(
    nn.Linear(input_dim, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Document scores.
docs = torch.from_numpy(np.array(doc_features, dtype="float32"))
print(docs)
docs = docs.to(device)
doc_scores = model(docs)

# Document ranks.
(sorted_scores, sorted_idxs) = doc_scores.sort(dim=0, descending=True)
doc_ranks = torch.zeros(n_docs).to(device)
doc_ranks[sorted_idxs] = 1 + torch.arange(n_docs).view((n_docs, 1)).to(device).float()
doc_ranks = doc_ranks.view((n_docs, 1))
print(doc_ranks)

# Compute lambdas.
# See equation (6) in [2] and equation (9) in [1].
score_diffs = doc_scores[:n_rel] - doc_scores[n_rel:].view(n_irr)
exped = score_diffs.exp()
N = 1 / idcg(n_rel)
dcg_diffs = 1 / (1 + doc_ranks[:n_rel]).log2() - (1 / (1 + doc_ranks[n_rel:]).log2()).view(n_irr)
lamb_updates = 1 / (1 + exped) * N * dcg_diffs.abs()
lambs = torch.zeros((n_docs, 1)).to(device)
lambs[:n_rel] += lamb_updates.sum(dim=1, keepdim=True)
lambs[n_rel:] -= lamb_updates.sum(dim=0, keepdim=True).t()

# Accumulate lambda scaled gradients.
model.zero_grad()
doc_scores.backward(lambs)

# Update model weights.
lr = 0.00001
with torch.no_grad():
    for param in model.parameters():
        param += lr * param.grad

tensor([[-8.7096e-01,  1.4488e+00,  5.4451e-01, -1.4318e+00, -7.9269e-01,
          1.0654e+00,  1.3484e+00,  1.3097e+00,  4.4482e-03,  2.2589e+00,
         -1.5502e+00,  9.9209e-01, -1.6088e+00,  1.0710e+00, -7.1378e-01,
         -1.1627e+00,  5.7741e-02,  9.3084e-01,  1.4069e+00,  1.7520e-01,
          8.9120e-01, -1.4027e+00, -2.9129e-01, -2.9930e-01, -9.2963e-01,
          5.6893e-01,  2.8202e+00, -1.0937e+00,  8.7788e-01,  7.4125e-01,
         -5.0514e-01, -6.5769e-01,  1.4799e-01, -6.7905e-01,  7.5637e-01,
          1.3447e+00,  5.9508e-02, -3.0331e-02,  7.0334e-01, -5.3269e-01,
          9.4237e-01, -1.0163e-01,  4.3839e-01, -3.3120e-01,  2.3810e+00,
          1.6967e+00,  8.7406e-01,  2.3900e+00,  1.6642e+00, -2.3882e-01],
        [-1.3552e+00, -1.0470e+00, -1.7552e+00,  7.9643e-01, -1.7400e+00,
          6.5742e-01,  7.5502e-01,  2.8012e-01,  9.7678e-01,  6.0380e-01,
         -1.2133e+00,  6.6936e-02, -6.1286e-01,  1.0915e-01, -1.2869e+00,
         -5.1862e-01, -4.6349e-01, -6