In [1]:
import torch
import torch.nn as nn
from functools import partial

# Meta Learning in PyTorch

## General setting

In bi-level meta-learning, the meta-objective to be minimized is:


$$J(\theta) = \sum_{i=1}^b \overbrace{\mathcal{L} (\mathrm{Alg}_\theta(D^{\rm tr}_i), D^{\rm te}_i)}^{=J_i(\theta)},$$

i.e. we optimize the meta-parameters $\theta$ of the parameterized algorithm $\mathrm{Alg}$ such that $\mathrm{Alg}_\theta$, provides a model that, applied to a training dataset $D^{\rm tr}_i$, generalizes well to the corresponding test dataset $D^{\rm te}_i$, as measured by a loss $\mathcal{L}$.

* The loss $\mathcal{L}$ measures the discrepancy between a model with parameters $\phi = \mathrm{Alg}_\theta(D^{\rm tr}_i)$ and the test dataset $D^{\rm te}_i$
* The loss is computed over a meta-batch containing $b$ dataset parirs $(D^{\rm tr}_i, D^{\rm te}_i)$
* The datasets $D^{\rm tr}_i$ and $D^{\rm te}_i$ are generated by the same system
* The datasets $D^{\rm tr}_i$, $D^{\rm tr}_j$ are generated by different, yet *related* systems. They may be though as sampled from a probability $p(D)$
* We optimize with minibatch techniques and resample the datasets from $P(D)$ at each iteration.

## Static regression setting

To make things more concrete, let us consider a static regression problem.

* Each dataset $D$ contains $K$ input-output pairs $D = (\mathbf{x}, \mathbf{y}) = (x_1, y_1, \dots, x_K, y_K)$, $x_j \in \mathbb{R}^{n_x}$, $y_j \in \mathbb{R}^{n_y}$.
* We are given a model structure $\hat {\mathbf{y}} = f(\phi, \mathbf{x})$ that is suitable to describe the dependency $x \rightarrow y$ *for all possible systems* in $p(D)$, given a system-specific choice of the parameters $\phi$. 
* We aim at *learning an algorithm* $\mathrm{Alg}_\theta$ that provides such parameters $\phi$ from a training dataset:
$\phi(D) = \mathrm{Alg}_\theta(D)$ by minimizing $J(\theta).$


Having introduced $f$, we can write the loss $\mathcal{L}$ more explicitly as:
$$\mathcal{L}(\phi, D) = \ell(\overbrace{f(\phi, \mathbf{u})}^{=\hat {\mathbf{y}}}, \mathbf{y}),$$
where $\ell(\cdot, \cdot)$ is the measure of discrepancy, i.e. the mean squared error.

## Model-Agnostic Meta Learning (MAML)

The celebrated MAML algorithm is based on:

$$
\phi(\theta, D_i^{\rm tr}) = \mathrm{Alg}_\theta(D_i^{\rm tr}) = \theta - \alpha \nabla_1 \mathcal{L}(\theta, D_i^{\rm tr}),
$$

i.e. it learns model parameters $\phi$ as the result of a single gradient-descent step initialized at $\theta$. Note that $\nabla_1$ denotes here the gradient with respect to the first argument.

Let us drop the $i$ subscript and define $D^{\rm tr} = (\mathbf{x}^{\rm tr}, \mathbf{y}^{\rm tr})$, $D^{\rm te} = (\mathbf{x}^{\rm te}, \mathbf{y}^{\rm te})$. Overall, the MAML loss for a single step is:
$$
J_i(\theta) = \mathcal{L} \bigg(\theta - \alpha \nabla_1 \mathcal{L}(\theta, \mathbf{x}^{\rm tr}, \mathbf{y}^{\rm tr}\big), \mathbf{x}^{\rm te} \mathbf{y}^{\rm te} \bigg)
$$

We see that the loss $J_i$ is way more involved than the ones we have seen before. To evaluate it, we need to:

1. Compute $\mathcal{L}(\theta, \mathbf{x}^{\rm tr}, \mathbf{y}^{\rm tr}) = \ell(f(\theta, \mathbf{u}^{\rm tr}), \mathbf{y}^{\rm tr})$
2. Differentiate $\mathcal{L}(\theta, \mathbf{x}^{\rm tr}, \mathbf{y}^{\rm tr})$ w.r.t. $\theta$ to obtain $\nabla_1 \mathcal{L}(\theta, \mathbf{x}^{\rm tr}, \mathbf{y}^{\rm tr})$
3. Run a gradient descent step to compute $\phi = \theta - \alpha \nabla_1 \mathcal{L}(\theta, \mathbf{x}^{\rm tr}, \mathbf{y}^{\rm tr})$
4. Compute  $\mathcal{L}(\phi, \mathbf{x}^{\rm te}, \mathbf{y}^{\rm te}) = \ell(f(\phi, \mathbf{u}^{\rm te}), \mathbf{y}^{\rm te})$

Steps 2 and 3 require computing derivatives and updating model parameters *within the forward pass*. You will need some more advanced PyTorch to do that!

## MAML Implementation in PyTorch

### Meta dataset 

When working with synthetic data, it might be challenging to define a reasonable the dataset distribution. Here we just work with random data, for illustration.

In [2]:
batch_size = 32 # number of *datasets* in a batch
K = 5 # number of data points in each dataset
nx = 1 # number of inputs
ny = 1 # number of outputs

# The meta batch
batch_x = torch.randn(batch_size, 2*K, nx)
batch_y = torch.randn(batch_size, 2*K, ny)

# support set, aka context, training set
batch_x_tr = batch_x[:, :K//2]
batch_y_tr = batch_y[:, :K//2]
# query set, aka query, test set
batch_x_te = batch_x[:, K//2:]
batch_y_te = batch_y[:, K//2:]

# (batch_x_tr[i], batch_y_tr[i]) and (batch_x_te[i], batch_y_te[i]) are portions of the same dataset, hence generated by the same mechanism  
# (batch_x[i], batch_y[i]) and (batch_x[j], batch_y[j]) are generated by different mechanisms

In [3]:
# algorithm hyperparameters
alpha = 0.1 # inner loop learning rate
lr = 1e-3 # outer loop learning rate
iters = 1000 # outer loop iterations

### From modules to pure functions

Implementation of Meta Learning in PyTorch is more convenient following a purely *functional approach*. We need models as pure python functions, where the output depends *explicitly* on parameters and inputs, i.e. something like:

```python
def mlp_fn(params, x)
    ...
    return y
```

We have seen that ``torch.nn`` provides convenient tools to define neural networks as ``modules``, which are Python objects.

In [4]:
hidden_size = 40
input_size = 1 # number of features
output_size = 1

class MLP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(1, hidden_size)
        self.act1 = nn.Tanh()
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act1(x)
        y = self.fc2(x)
        return y

In [5]:
mlp = MLP(hidden_size)
x = torch.randn(32, 1)
y_hat = mlp(x) # mlp called with input x, parameters are hidden inside the module
y_hat.shape

torch.Size([32, 1])


If we do not want to write neural networks as functions from scratch (as we did in the first lecture), we can adapt an existing `nn.module` to behave "functionally".
We first neet to extract the model parameters as a dictionary of tensors:

In [6]:
params = dict(mlp.named_parameters()) # dictionary {parameter_name: parameter_tensor}

Then, we can call the model as a pure function:

In [7]:
def mlp_fn(params, x):
    return torch.func.functional_call(mlp, params, x)

mlp_fn(params, x).shape

torch.Size([32, 1])

### MAML Loss implementation

Let us define the loss as a function of parameters, inputs and outputs. Basically, we are defining $\mathcal{L}(\theta, \mathbf{x}, \mathbf{y})$ used in step 1 and 4 of the MAML loss computation.

In [8]:
def loss_fn(params, x, y):
    y_hat =  mlp_fn(params, x) # functional call to the mlp
    loss = torch.mean((y_hat - y) ** 2) # MSE loss
    return loss

loss_fn(params, batch_x_tr[0], batch_y_tr[0]) # loss for the first training dataset in the batch

tensor(0.2445, grad_fn=<MeanBackward0>)

Let us define the loss derivative with respect of the parameters as a function. Basically, we are defining $\nabla_1 \mathcal{L}(\theta, \mathbf{x}, \mathbf{y})$ needed in step 2.

In [9]:
grad_fn = torch.func.grad(loss_fn, argnums=0) # takes a function in and returns the function derivative out
params_grad = grad_fn(params, batch_x_tr[0],  batch_y_tr[0]) # a dict of gradients, one per parameter

Let us define the gradiend update $\phi = \theta - \alpha \nabla_1 \mathcal{L}(\theta, \mathbf{x}, \mathbf{y})$ needed in step 3.

In [10]:
def inner_update(p, x, y):
    g = torch.func.grad(loss_fn)(p, x, y)
    p = {k: p[k] - alpha * g[k] for k in p.keys()} # apply gd to each tensor in the dictionary
    #gd_fun = lambda z, dz: z - alpha * dz
    #p = torch.utils._pytree.tree_map(gd_fun, p, g) # apply one step of gradient descent
    return p

p_updated = inner_update(params, batch_x_tr[0], batch_y_tr[0]) # update the parameters of the first training dataset in the batch

We can finally define the MAML loss for a single dataset! This implements:

$$J(\theta, \mathbf{x}_1, \mathbf{y}_1, \mathbf{x}_2, \mathbf{y}_2) = \mathcal{L}(\theta - \alpha \nabla_1 \mathcal{L}(\theta, \mathbf{x}_1, \mathbf{y}_1), \mathbf{x}_2, \mathbf{y}_2)$$

In [11]:
def maml_loss(p, x1, y1, x2, y2):
    p2 = inner_update(p, x1, y1)
    return loss_fn(p2, x2, y2)

maml_loss(params, batch_x_tr[0], batch_y_tr[0], batch_x_te[0], batch_y_te[0]) # it works, and it is differentiable wrt p. We are almost there!

tensor(2.7155, grad_fn=<MeanBackward0>)

We now need to vectorize with respect to the datasets in the meta-batch. The ``vmap`` function does the job!

In [12]:
def batched_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
    maml_loss_batch = torch.func.vmap(maml_loss, in_dims=(None, 0, 0, 0, 0)) # vectorize wrt all but the first argument
    batch_losses = maml_loss_batch(p, x1_b, y1_b, x2_b, y2_b)
    return torch.mean(batch_losses)

batched_maml_loss(params, batch_x_tr, batch_y_tr, batch_x_te, batch_y_te)

tensor(2.2658, grad_fn=<MeanBackward0>)

### MAML Training loop

This part is pretty standard!

In [13]:
mlp = MLP(hidden_size)
params_maml = dict(mlp.named_parameters())

In [14]:
def sample_datasets(batch_size, K):
    # Dummy data here. In practice, this could call a simulator from a well-tuned distribution
    # or retrieve some real datasets of similar systems
    batch_x = torch.randn(batch_size, 2*K, nx)
    batch_y = torch.randn(batch_size, 2*K, ny)
    return batch_x, batch_y

In [15]:
opt = torch.optim.Adam(params_maml.values(), lr=lr)
losses = []

for i in range(iters):
    batch_x, batch_y = sample_datasets(batch_size=batch_size, K=2*K)
    batch_x_tr = batch_x[:, :K]
    batch_y_tr = batch_y[:, :K]
    batch_x_te = batch_x[:, K:]
    batch_y_te = batch_y[:, K:]


    opt.zero_grad()
    loss = batched_maml_loss(params_maml, batch_x_tr, batch_y_tr, batch_x_te, batch_y_te)
    loss.backward()
    opt.step()
    losses.append(loss.item())

## Training models starting from the MAML initialization

Given a *new* dataset, we would then execute one step of gradient descent starting from the MAML initialization. In other works, run MAML's ``inner_update`` starting from the learned MALM initialization, instead of training from scratch.

In [16]:
x_new = torch.randn(batch_size, K, nx)
y_new = torch.randn(batch_size, K, ny)

In [17]:
trained_params_maml = inner_update(params_maml, x_new, y_new) # update the parameters of the first training dataset in the batch

## Black-box meta learning with Hypernetworks

MAML is just one possible meta-learning algorithm. For instance, we could parameterize 
$$\phi(\theta, D_i^{\rm tr}) = \mathrm{Alg}_\theta(D_i^{\rm tr})$$
as a neural network with parameters $\theta$ that that processes a training dataset and returns the 
parameters of *another* neural network, which models the dependency $x \rightarrow y$. Such a network is called a Hyper-Nerwork!

To implement such a Hyper-Network, we need some functions to transform a flat tensor to model parameters

In [18]:
mlp = MLP(hidden_size)
flat_params_mlp = nn.utils.parameters_to_vector(mlp.parameters()) # mlp parameters flattened to a vector
n_params_mlp = flat_params_mlp.shape[0] 
n_params_mlp

121

This function does the job

In [19]:
def unflatten_like(flat_tensor: torch.Tensor, model: torch.nn.Module):
    param_shapes = [p.shape for p in model.parameters()]
    param_numels = [p.numel() for p in model.parameters()]
    param_names = [name for name, _ in model.named_parameters()]
    
    split_tensors = torch.split(flat_tensor, param_numels)
    params = {name: t.view(shape) for t, shape, name in zip(split_tensors, param_shapes, param_names)}
    return params

params_mlp_unf = unflatten_like(flat_params_mlp, mlp)  # check that the unflattening works

A Deep Set neural network that acts as a Hyper-Netwoks

In [20]:
# PyTorch implementation of DeepSet for generating MLP parameters
class DeepSet(nn.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        # Shared MLP for set elements
        self.fc1 = nn.Linear(2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        # MLP for aggregated representation
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, output_size)

    def forward(self, x, y):
        z = torch.cat([x, y], dim=-1)
        z = self.fc1(z)
        z = torch.relu(z)
        z = self.fc2(z)
        z = torch.relu(z)
        # Aggregate (sum) over set elements
        z = torch.sum(z, dim=-2)
        z = self.fc3(z)
        z = torch.relu(z)
        z = self.fc4(z) * 0.1  # scale for decent MLP initialization
        
        # Reshape as mlp parameters
        z = unflatten_like(z, mlp)
        return z

In [21]:
# Instantiate the hypernetwork
hypernet = DeepSet(hidden_size=128, output_size=n_params_mlp)

In [22]:
# Hypernet usage: dataset in, MLP parameters out
params_mlp_ = hypernet(batch_x_tr[0], batch_y_tr[0])  # Dataset in, MLP parameters out!

In [23]:
# We will need to make functional calls to the hypernetwork
params_hypernet = dict(hypernet.named_parameters())
params_mlp_ = torch.func.functional_call(hypernet, params_hypernet, args=(batch_x_tr[0], batch_y_tr[0]))  # Get the parameters for the MLP

Let us define and test the Hypernet loss for a single dataset

In [24]:
def hypernet_loss(ph, x1, y1, x2, y2):

    # Generate the weights using the hypernetwork on the support set
    pm = torch.func.functional_call(hypernet, ph, args=(x1, y1))
    # Compute the loss on the query set
    return loss_fn(pm, x2, y2)  # Loss for the second task

hypernet_loss(params_hypernet, batch_x_tr[0], batch_y_tr[0], batch_x_te[0], batch_y_te[0])

tensor(1.2821, grad_fn=<MeanBackward0>)

In [25]:
def batched_hypernet_loss(p, x1_b, y1_b, x2_b, y2_b):
    hypernet_loss_cfg = partial(hypernet_loss, p) # fix first argument
    hypernet_loss_batch = torch.func.vmap(hypernet_loss_cfg) # vmap over the rest
    batch_losses = hypernet_loss_batch(x1_b, y1_b, x2_b, y2_b)
    return torch.mean(batch_losses)

batched_hypernet_loss(params_hypernet, batch_x_tr, batch_y_tr, batch_x_te, batch_y_te)

tensor(0.9917, grad_fn=<MeanBackward0>)

The rest is equal to MAML!