In [None]:
# Only for Google Collab

# Set the runtime session to a GPU/TPU session first!
# Clone the repository
!git clone https://github.com/RobvanGastel/meta-in-context-learning.git

# Change directory to the cloned repository
%cd meta-in-context-learning

# Potentially this is the only dependency not supported yet
!pip install einops

### Outline of GPICL
Dataset $D = \{x_i, y_i\}$, linear projection $A \in \mathcal{R}^{N_x \times N_x}$ with $A_{ij} \sim N(0, 1/N_x)$ and final output permutation $\rho$ , $D = \{Ax_i, \rho(y_i)\}$.

this is done to reduce the amount of unique tasks necessary to train our meta-learned model. The loss used is cross-entropy loss, between the label y_j and prediction on the entire series except for the last label. Essentially the same the other notebook leaving out only the last label and adding the set of samples as context.

In [None]:
import jax
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from flax.training import train_state
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

from meta_icl.vision_transformer import ViT
from meta_icl.data import FewShotDataset, FewShotBatchSampler

# Parameters
num_epochs = 40000
n_way, k_shot = 3, 2
batch_size = 8
seq_length = 6 
seed = 42

v = ViT(
    image_size = 28,
    patch_size = (14, 14),
    num_classes = 10,
    emb_dim = 256,
    seq_length = seq_length,
    channels = 1,
    num_layers = 4,
    num_heads = 8,
    mlp_dim = 512
)

In [None]:

key = jax.random.key(seed)
init_rngs = {'params': jax.random.key(1)}

X = jax.random.normal(key, (batch_size, seq_length, 28*28))
y = jax.random.normal(key, (batch_size, seq_length-1))

params = v.init(init_rngs, X, y)
output = v.apply(params, X, y, rngs=init_rngs)

class TrainState(train_state.TrainState):
    pass

state = TrainState.create(
    apply_fn=v.apply,
    params=params,
    tx=optax.adamw(learning_rate=1e-4)
)

@jax.jit
def train_step(state, X, y):
    def loss_fn(params):
        batch_size, seq, _, _ = X.shape

        # Linear projection A, A_ij \in N(0, 1/Nx)
        # X_bar = X
        X_bar = jnp.reshape(X, (batch_size, seq, 28*28))
        A = (jax.random.normal(key, (batch_size, 28*28), dtype=jnp.float32) * jnp.array(1/28, dtype=jnp.float32))
        X_bar = jnp.einsum("bsj,bj->bsj", X_bar, A)

        # TODO: Should have permutation \rho(y)

        logits = state.apply_fn(params, X_bar, y[:, :-1])

        logits = jnp.expand_dims(logits, axis=0)
        y_hat = jnp.argmax(jax.nn.softmax(logits))
        y_one_hot = jax.nn.one_hot(y, 10)[:, 8:9]
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_one_hot))
        return loss

    loss_grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = loss_grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


train_dataset = FewShotDataset(dataset=MNIST, train=True)
data_loader = DataLoader(train_dataset, batch_sampler=FewShotBatchSampler(
    train_dataset.y, n_way, k_shot, batch_size=batch_size
    )
)

# Meta-training loop
losses = []
for epoch in range(num_epochs):
    cumulative_loss = 0

    for X, y in data_loader:
        state, loss = train_step(state, X.numpy(), y.numpy())
        cumulative_loss += loss.mean()

    losses.append(float(cumulative_loss))

    if epoch % 500 == 0:
        print(f"epoch {epoch}/{num_epochs}: loss: {cumulative_loss}")

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(range(num_epochs), losses, label='Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.legend()
plt.grid()
plt.show()