In [1]:
import torch
import torch.nn as nn
from functools import partial
import matplotlib.pyplot as plt

# Meta Learning


* In meta learning we have access to a *meta dataset*, namely a collection (possibly unlimited) of datasets

$$\mathcal{D} = \{D_1, D_2, \dots, \}$$

* In the case of static regression, each dataset $D_i$ is an unordered collection of $K$ input-output pairs.
$$D_i = \{(x_{i,1}, y_{i,1}), (x_{i,2}, y_{i,2}), \dots, (x_{i,K}, y_{i,K})\}, \qquad x_{i,j} \in \mathbb{R}^{n_x},\; y_{i,j} \in \mathbb{R}^{n_y}$$

* The datasets $D_i$ are assumed to be *similar* to each other. They are thought as realizations from a probability distribution $p(D)$.

Meta learning aims to improve our abilty to model the $x \rightarrow y$ relationship while observing more datasets from $p(D)$

## Dataset Spitting
Most meta learning algorithms require to split each dataset in the meta-dataset in training (support set, context) and test (query) portions:

$$\mathcal{D} = \{(D^{\rm tr}_1, D^{\rm te}_1), (D^{\rm tr}_2, D^{\rm te}_2), \dots \}$$

## In-context learning

In-context learning is a meta-learning approach that trains a single meta-model (in-context learner) for all datasets in the meta-dataset in $\mathcal{D}$ (equivalently, in the distribution $p(D)$.

In the case of static regression, an in-context learner processes:

* An input-output training dataset $D$
* An input value $x$

and produces output predictions $\hat y$ for the input value $x$. Formally, we have:

\begin{equation*}
\hat y = \mathcal{M}(D, x)
\end{equation*}

Training can be performed in a classic, standard supervised setting:


$$J(\theta) = \sum_{i=1}^b \sum_{j=1}^K \ell(y^{\rm te}_{i,j}, \mathcal{M}(D^{\rm tr}_i, x^{\rm te}_{i,j})).$$

The peculiar aspect of the in-context learner is that it digests whole training datasets, instead of just individial data points!

## Architecture for static regression

In practice, for in-context static regression, a deep set network may be used to process $D_{\rm tr}$ and produce a fixed-size embedding $c \in \mathbb{R}^{n_{ctx}}$. Then, a multi-layer perceptron can be used to process $c$ and $x$ and generate the output prediction $\hat y$. 

\begin{align*}
c &= \mathrm{DeepSet}(D_{\rm tr}) \\
\hat y &= \mathrm{MLP}(c, x).
\end{align*}

Overall, this may be seen as an encoder-decoder architecture (deep set encoder, MLP decoder). Note that the deep set architecture is a suitable encoder due to its *permutation-invariance* property.

For time series data (and thus, for system identification) encoder-decoder Transformer architectures have been used in [recent works](https://ieeexplore.ieee.org/abstract/document/10324309?casa_token=QUPQER5d90cAAAAA:lg7hDHBnE-0HXc5kizO96eJmpkQFYcB6Q9fcEeneJ3aTCBXyZ7quQ07ykWgGspiWF5XoMc7qUQ).

## PyTorch implementation

General meta-learning settings

In [2]:
batch_size = 32 # number of *datasets* in a batch
K = 5 # number of data points in each dataset
n_x = 1 # number of inputs
n_y = 1 # number of outputs

An crucial part of meta-learning is the code that generates data from a meaningful distribution. Here we just generate random data for illustration.

In [3]:
def sample_datasets(batch_size, K):
    # Dummy data here. In practice, this could call a simulator from a well-tuned distribution
    # or retrieve some real datasets of similar systems
    batch_x = torch.randn(batch_size, K, n_x)
    batch_y = torch.randn(batch_size, K, n_y)
    return batch_x, batch_y

In [4]:
batch_x, batch_y = sample_datasets(batch_size=batch_size, K=2*K)
# support set, aka context, training set
batch_x_tr = batch_x[:, :K]
batch_y_tr = batch_y[:, :K]
# query set, aka query, test set
batch_x_te = batch_x[:, K:]
batch_y_te = batch_y[:, K:]

In [5]:
batch_x_tr.shape

torch.Size([32, 5, 1])

In [6]:
batch_y_tr.shape

torch.Size([32, 5, 1])

In [7]:
n_ctx = 20 # size of the context embedding ctx = encoder(x_tr, y_tr)
hidden_size = 40 # size of all hidden layers

Let us define the encoder architecture as a deep set network

In [8]:
class DeepSet(nn.Module):
    def __init__(self, n_x, n_y, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        # Shared MLP for set elements
        self.fc1 = nn.Linear(n_x + n_y, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        # MLP for aggregated representation
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, output_size)

    def forward(self, x, y):
        z = torch.cat([x, y], dim=-1)
        z = self.fc1(z)
        z = torch.relu(z)
        z = self.fc2(z)
        z = torch.relu(z)
        # Aggregate (sum) over set elements
        z = torch.sum(z, dim=-2)
        z = self.fc3(z)
        z = torch.relu(z)
        z = self.fc4(z)
        
        return z

In [9]:
encoder = DeepSet(n_x, n_y, hidden_size, output_size=n_ctx)
ctx = encoder(batch_x_tr, batch_y_tr)
ctx.shape # describe each dataset as a vector of size n_ctx

torch.Size([32, 20])

Permutation-invariance illustration. The order of the $K$ observations in each dataset do not influence the encoder output thanks to its deep set architecture.

In [10]:
p = torch.randperm(K)
ctx_p = encoder(batch_x_tr[:, p, :], batch_y_tr[:, p, :])
torch.allclose(ctx, ctx_p, atol=1e-6) # 

True

Let us define the decoder architecture as a simple MLP

In [11]:
class MLP(nn.Module):
    def __init__(self, n_x, n_ctx, n_y, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(n_x + n_ctx, hidden_size)
        self.act1 = nn.Tanh()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.act2 = nn.Tanh()
        self.fc3 = nn.Linear(hidden_size, n_y)

    def forward(self, x, c):
        """
        x: input data (batch_size, K, n_x)
        c: context embedding (batch_size, n_ctx)
        """
        
        c_rep = c.unsqueeze(-2).repeat(1, x.shape[1], 1)
        xc = torch.cat((x, c_rep), dim=-1)
        z = self.fc1(xc)
        z = self.act1(z)
        z = self.fc2(z)
        z = self.act2(z)
        z = self.fc3(z)
        return z

In [12]:
decoder = MLP(n_x, n_ctx, n_y, hidden_size)
y_hat = decoder(batch_x_te, ctx)

Let us jointly train encoder and decoder networks on the dataset distribution

In [13]:
iters = 100
lr = 1e-3

In [14]:
opt = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)
losses = []


for itr in range(iters):
    
    batch_x, batch_y = sample_datasets(batch_size=batch_size, K=2*K)
    batch_x_tr = batch_x[:, :K]
    batch_y_tr = batch_y[:, :K]
    batch_x_te = batch_x[:, K:]
    batch_y_te = batch_y[:, K:]

    opt.zero_grad()
    ctx = encoder(batch_x_tr, batch_y_tr)
    batch_y_te_hat = decoder(batch_x_te, ctx)
    loss = torch.mean((batch_y_te_hat - batch_y_te) ** 2)
    loss.backward()
    opt.step()

    losses.append(loss.item())