In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch import nn, einsum
from einops import rearrange

from utils.data import DATA_PATH,get_mnist_data_loaders, get_emnist_data_loaders, randomize_targets, select_from_classes
from utils.visualization import show_imgs, get_model_dot, LivePlot
from utils.others import measure_alloc_mem, count_parameters
from utils.timing import func_timer
from utils.metrics import get_accuracy

import wandb
from IPython.display import clear_output
import tqdm
from livelossplot import PlotLosses
import lovely_tensors as lt

lt.monkey_patch()
torch.set_printoptions(precision=3, linewidth=180)
%env "WANDB_NOTEBOOK_NAME" "main.ipynb"
wandb.login()

np.random.seed(0)
torch.manual_seed(0)

In [None]:
config = {
    "batch_size": 18,
    "seq_len": 40,
    "num_of_tasks": 2**8,
    "permuted_labels_frac": 0.1,
    "whole_seq_prediction": True,
    "lr": 3e-4,
    "eps": 1e-16,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

print(f"... Running on {config['device']} ...")

# Data

In [None]:
def create_tasks(num_of_tasks):
    # proj_matrices = torch.randn(num_of_tasks, 1, 784, 784)
    proj_matrices = torch.distributions.Normal(0, 1/784).sample((num_of_tasks, 1, 784, 784))
    proj_matrices /= torch.norm(proj_matrices, dim=(2, 3), keepdim=True)
    label_perms = torch.cat([torch.randperm(10).unsqueeze(0) for _ in range(num_of_tasks)], dim=0)
    return proj_matrices, label_perms

proj_matrices, label_perms = create_tasks(num_of_tasks=config["num_of_tasks"])
proj_matrices, label_perms

In [None]:
def get_context_seqs(X, y, proj_matrices=None, label_perms=None, seq_len=32, labels_shifted_by_one=False):
    # flatten the images + split into sequences
    _X = X.view(X.shape[0] // seq_len, seq_len, 1, -1) # (batch_size, seq_len, 1, 784)
    
    # apply projection matrices
    if proj_matrices is not None:
        _X = _X @ proj_matrices.transpose(-1, -2)
        # rescale the projected image values (needed?)
        _X = (_X - _X.min()) / (_X.max() - _X.min())

    _y = y.view(y.shape[0] // seq_len, seq_len) # (batch_size, seq_len)
    # apply label permutations
    if label_perms is not None:
        _y = torch.gather(label_perms, dim=1, index=_y)

    if labels_shifted_by_one:
        seqs_y = _y.clone() # target labels
        # append labels to images - labels shifted by one to the right
        _y = F.one_hot(_y[:, :-1])
        _y = torch.cat([torch.zeros(size=(_y.shape[0], 1, _y.shape[-1]), device=_y.device), _y], dim=1)
        _X = torch.concat((_X.squeeze(2), _y), dim=-1)
    else:
        # get the target label for each sequence (last label in the sequence)
        seqs_y = _y[:, -1]
        # append labels to images ((x1,y1), (x2,y2), ..., (xn-1, yn-1), (xn, 0)) - all except the last one
        _y = F.one_hot(_y)
        _y[:, -1, :] = 0 # remove the last label from the sequence (to be predicted)
        _X = torch.concat((_X.squeeze(2), _y), dim=-1)

    return _X, seqs_y


In [None]:
train_loader, test_loader, classes = get_mnist_data_loaders(
    batch_size=config["batch_size"] * config["seq_len"], flatten=False, only_classes=None, img_size=28
)

# show sample images
X, y = next(iter(train_loader))
X_rand, y_rand = get_context_seqs(X, y, proj_matrices=proj_matrices[:config["batch_size"]], label_perms=label_perms[:config["batch_size"]],
    seq_len=config["seq_len"], labels_shifted_by_one=False)

show_imgs(
    imgs=torch.cat([X[-5:], X_rand[-1,-5:,:784].view(-1, 1, 28, 28)], dim=0),
    titles=torch.cat([y[-5:], X_rand[-1,-5:-1,784:].argmax(-1), y_rand[-1].unsqueeze(0)], dim=0).tolist()
)

# Model

In [None]:
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads=2,
        dim_head=16,
        dropout=0.,
        causal=False,
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.causal = causal
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v))
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        if self.causal:
            # apply causal mask
            mask = torch.ones(size=sim.shape[-2:], device=sim.device).triu_(1).bool()
            sim.masked_fill_(mask, float("-inf"))

        attn = sim.softmax(dim=-1) # (batch, heads, query, key)
        attn = self.dropout(attn)

        out = einsum("b h i j, b h j d -> b h i d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)", h=self.heads) # merge heads
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads,
        dim_head,
        token_dim=784 + 10,
        inner_dim=None,
        dropout=0.,
        causal=False,
    ):
        super().__init__()
        self.embed_proj = nn.Linear(token_dim, dim)
        self.layers = nn.ModuleList([])
        inner_dim = inner_dim or 4 * dim
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, causal=causal),
                nn.LayerNorm(dim),
                nn.Sequential(
                    nn.Linear(dim, inner_dim),
                    nn.GELU(),
                    nn.Linear(inner_dim, dim),
                    nn.Dropout(dropout)
                ),
                nn.LayerNorm(dim)
            ]))

    def forward(self, x):
        x = self.embed_proj(x)
        for attn, ln1, mlp, ln2 in self.layers:
            x = x + attn(x)
            x = x + mlp(ln1(x))
        return x

class InContextLearner(nn.Module):
    def __init__(
        self,
        dim,
        depth=2,
        heads=4,
        dim_head=16,
        inner_dim=None,
        dropout=0.1,
        whole_seq_prediction=False,
    ):
        super().__init__()
        inner_dim = inner_dim or 4 * dim
        self.whole_seq_prediction = whole_seq_prediction
        self.transformer = Transformer(
            dim=dim,
            depth=depth,
            heads=heads,
            dim_head=dim_head,
            inner_dim=inner_dim,
            dropout=dropout,
            causal=whole_seq_prediction,
        )
        self.final_classifier = nn.Linear(dim, 10)

    def forward(self, x):
        x = self.transformer(x) # (batch, seq_len, dim)
        if self.whole_seq_prediction:
            return self.final_classifier(x)
        else:
            return self.final_classifier(x[:,-1,:])

# model_dim = 28 * 28 + 10
model_dim = 256
model = InContextLearner(
    dim=model_dim,
    depth=4,
    heads=6,
    dim_head=32,
    inner_dim=4 * model_dim,
    dropout=0.1,
    whole_seq_prediction=config["whole_seq_prediction"]
).to(config["device"])
print(model)
print(f"{count_parameters(model)} trainable parameters")

# Training

In [None]:
# eval model
def eval(model, test_loader, apply_proj=False):
    model.eval()
    with torch.no_grad():
        loss, acc, acc_max_improvement_within_seq = 0, 0, 0
        acc_over_seq = np.array([0.] * config["seq_len"])
        for X, y in test_loader:
            X, y = X.to(config["device"]), y.to(config["device"])

            curr_proj_matrices, curr_label_perms = None, None
            if apply_proj:
                # randomly sample tasks (one task for each sequence/context)
                task_idxs = np.random.randint(0, config["num_of_tasks"], size=config["batch_size"])
                curr_proj_matrices, curr_label_perms = proj_matrices[task_idxs].to(config["device"]), label_perms[task_idxs].to(config["device"])
            X, y = get_context_seqs(X, y, proj_matrices=curr_proj_matrices, label_perms=curr_label_perms,
                seq_len=config["seq_len"], labels_shifted_by_one=config["whole_seq_prediction"]) # (batch, seq_len, dim)

            y_hat = model(X)
            if config["whole_seq_prediction"]:
                loss += F.cross_entropy(y_hat.view(-1, 10), y.view(-1)).item()
                acc_over_seq += (y_hat.argmax(dim=-1) == y).float().mean(dim=0).cpu().numpy() # (seq_len,)
                acc_max_improvement_within_seq += \
                    ((y_hat[:,1:,:].argmax(dim=-1) == y[:,1:]).float().max(dim=-1).values \
                    - (y_hat[:,0,:].argmax(dim=-1) == y[:,0]).float()).mean().item()
            else:
                loss += F.cross_entropy(y_hat, y).item()
            acc += (y_hat.argmax(dim=-1) == y).float().mean().item()
        loss /= len(test_loader)
        acc /= len(test_loader)
        acc_over_seq = list(acc_over_seq / len(test_loader))
        acc_max_improvement_within_seq /= len(test_loader)
        print(f"loss: {loss:.4f}, acc: {acc:.4f}")
    return loss, acc, acc_over_seq, acc_max_improvement_within_seq


In [None]:
model_optim = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=config["eps"])

# logging
groups = ["train_loss", "train_acc", "eval_loss", "eval_acc"]
if config["whole_seq_prediction"]:
    groups.extend(["train_acc_over_seq", "train_acc_max_improvement_within_seq", "eval_acc_over_seq", "eval_acc_max_improvement_within_seq"])
live_plot = LivePlot(figsize=(26, 24) if config["whole_seq_prediction"] else (26, 14), use_seaborn=False, groups=groups)

In [None]:
# proj_matrices, label_perms = proj_matrices.to(config["device"]), label_perms.to(config["device"]) # save gpu mem

for epoch in tqdm.tqdm(range(200)):
    model.train()
    for i, (X, y) in enumerate(train_loader):
        X, y = X.to(config["device"]), y.to(config["device"])
        
        # randomly sample tasks (one task for each sequence/context)
        task_idxs = np.random.randint(0, config["num_of_tasks"], size=config["batch_size"])
        curr_proj_matrices, curr_label_perms = proj_matrices[task_idxs].to(config["device"]), label_perms[task_idxs].to(config["device"])
        curr_label_perms[int(config["permuted_labels_frac"] * config["batch_size"]):] = torch.tensor(np.arange(10)).to(config["device"]) # no permutation for some of the sequences
        X, y = get_context_seqs(X, y, proj_matrices=curr_proj_matrices, label_perms=curr_label_perms,
            seq_len=config["seq_len"], labels_shifted_by_one=config["whole_seq_prediction"]) # (batch, seq_len, dim)

        y_hat = model(X)
        if config["whole_seq_prediction"]:
            loss = F.cross_entropy(y_hat.view(-1, 10), y.view(-1))
        else:
            loss = F.cross_entropy(y_hat, y)
        loss.backward()
        model_optim.step()
        model_optim.zero_grad()

        # update the plot
        if i % 20 == 19:
            if config["whole_seq_prediction"]:
                acc_over_seq = (y_hat.argmax(dim=-1) == y).float().mean(dim=0) # (seq_len,)
                acc_max_improvement_within_seq = \
                    ((y_hat[:,1:,:].argmax(dim=-1) == y[:,1:]).float().max(dim=-1).values \
                    - (y_hat[:,0,:].argmax(dim=-1) == y[:,0]).float()).mean().item()
                live_plot.update({"train_acc_over_seq": acc_over_seq.tolist()}, reset=True)
                live_plot.update({"train_acc_max_improvement_within_seq": acc_max_improvement_within_seq})
            live_plot.update({"train_loss": loss.item(), "train_acc": (y_hat.argmax(dim=-1) == y).float().mean().item()})
            live_plot.draw()
    #     break
    loss, acc, acc_over_seq, acc_max_improvement_within_seq = eval(model, test_loader, apply_proj=False)
    live_plot.update({"eval_loss": loss, "eval_acc": acc})
    if config["whole_seq_prediction"]:
        live_plot.update({"eval_acc_max_improvement_within_seq": acc_max_improvement_within_seq})
        live_plot.update({"eval_acc_over_seq": acc_over_seq}, reset=True)
    live_plot.draw()
    # break