In [None]:
!pip install einops fancy_einsum torchvision



In [None]:
from __future__ import annotations

from dataclasses import dataclass
from collections import Counter
from functools import cache
import itertools as it
from operator import eq
import pathlib
import random

import einops
import pandas as pd
from plotly import express as px, graph_objects as go
import torch as t
import torch.nn.functional as F
from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from torchvision.datasets import MNIST
from tqdm import tqdm

## Model

In [None]:
class ContrastiveCNN(t.nn.Module):
    """Simple 2-layer CNN for MNIST"""

    def __init__(self, d_embed: int = 64) -> None:
        super().__init__()
        self.d_embed = d_embed

        CONV_KERNEL_SIZE = 5
        CONV_STRIDE = 1
        POOL_KERNEL_SIZE = 2
        POOL_STRIDE = 2

        self.conv1 = t.nn.Conv2d(
            in_channels=1,
            out_channels=32,
            kernel_size=CONV_KERNEL_SIZE,
            stride=CONV_STRIDE,
            padding="same",
        )
        self.pool1 = t.nn.MaxPool2d(kernel_size=POOL_KERNEL_SIZE, stride=POOL_STRIDE)
        self.conv2 = t.nn.Conv2d(
            in_channels=32,
            out_channels=64,
            kernel_size=CONV_KERNEL_SIZE,
            stride=CONV_STRIDE,
            padding="same",
        )
        self.pool2 = t.nn.MaxPool2d(kernel_size=POOL_KERNEL_SIZE, stride=POOL_STRIDE)
        self.fc = t.nn.Linear(in_features=7 * 7 * 64, out_features=d_embed)

    def forward(
            self,
            x: t.Tensor, # [batch 1 28 28]
        ) -> t.Tensor:   # [batch d_embed]
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.flatten(1)
        x = self.fc(x)
        return x

## Dataset

In [None]:
BatchX = BatchY = DiffPairInds = SamePairInds = t.Tensor
TrainBatch = tuple[BatchX, BatchY, DiffPairInds, SamePairInds]

@dataclass(frozen=True, slots=True)
class Dataset:
    train_x: t.Tensor
    train_y: t.Tensor
    test_x: t.Tensor
    test_y: t.Tensor

    @classmethod
    def load(cls) -> Dataset:
        DATA_PATH = pathlib.Path("data")
        if not DATA_PATH.exists():
            DATA_PATH.mkdir()
        train = MNIST(str(DATA_PATH), train=True, download=True)
        test = MNIST(str(DATA_PATH), train=False, download=True)
        train_x = _preprocess_batch(train.data)
        train_y = train.targets
        test_x = _preprocess_batch(test.data)
        test_y = test.targets
        return cls(
            train_x=train_x,
            train_y=train_y,
            test_x=test_x,
            test_y=test_y,
        )

    @cache
    def get_train_batches(self, batch_size: int = 32) -> list[TrainBatch]:
        """`TrainBatch = tuple[BatchX, BatchY, DiffPairInds, SamePairInds]`"""
        n_batches = len(self.train_x) // batch_size
        batches = [
            (
                self.train_x[batch_i * batch_size : (batch_i + 1) * batch_size],
                batch_y := self.train_y[
                    batch_i * batch_size : (batch_i + 1) * batch_size
                ],
                *get_diff_and_sim_pair_inds(batch_y),
            )
            for batch_i in range(n_batches)#, desc=f"Generating train batches for {batch_size=}"
        ]
        return batches



def get_diff_and_sim_pair_inds(y: t.Tensor) -> tuple[t.Tensor, t.Tensor]:
    assert y.ndim == 1
    assert y.dtype == t.int64
    y_rep = einops.repeat(y, "b -> b rep", rep=len(y))
    diff_pair_inds = (y_rep != y_rep.T).nonzero().T # [2 _n]
    same_pair_inds = (y_rep == y_rep.T).nonzero().T # [2 _n]
    assert diff_pair_inds.ndim == 2 and len(diff_pair_inds) == 2
    assert same_pair_inds.ndim == 2 and len(same_pair_inds) == 2
    return diff_pair_inds, same_pair_inds


def _preprocess_batch(batch: t.Tensor) -> t.Tensor:
    assert batch.ndim == 3
    assert eq(*batch.shape[1:])
    batch_dim, im_dim = batch.shape[:2]
    processed_batch = (
        batch.to(dtype=t.float32).unsqueeze(-1).reshape(batch_dim, 1, im_dim, im_dim)
    )
    return (processed_batch - processed_batch.mean()) / processed_batch.std()

ds = Dataset.load()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 156677223.10it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 113849336.30it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 41698168.27it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7611078.21it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [None]:
def contrastive_loss(
    embeds: t.Tensor,  # [batch d_embed]
    diff_pair_inds: t.Tensor, # [2 _n]
    same_pair_inds: t.Tensor, # [2 _n]
    *,
    epsilon: float = .5,
) -> t.Tensor:
    diff_pair_diffs = embeds.index_select(0, diff_pair_inds[0]) - embeds.index_select(0, diff_pair_inds[1])
    same_pair_diffs = embeds.index_select(0, same_pair_inds[0]) - embeds.index_select(0, same_pair_inds[1])
    diff_term_loss = t.max(t.tensor(0), epsilon - diff_pair_diffs.pow(2).mean().sqrt()).pow(2)
    same_term_loss = same_pair_diffs.pow(2).mean()
    loss = diff_term_loss + same_term_loss
    return loss

In [None]:
def train(
    model: ContrastiveCNN,
    ds: Dataset,
    optimizer: t.optim.Optimizer,
    *,
    n_epochs: int = 10,
    batch_size: int = 128,
    verbose: bool = True,
    scheduler: LRScheduler | ReduceLROnPlateau | None = None,
) -> None:  # TrainingHistory?:
    batches = ds.get_train_batches(batch_size)
    batch_log_and_sched_freq = len(batches) // 10

    for epoch_i in range(1, n_epochs + 1):
        print(f"Epoch {epoch_i} / {n_epochs}")
        batch_losses: list[float] = []

        for batch_i, (batch_x, batch_y, diff_pair_inds, same_pair_inds) in enumerate(batches):
            optimizer.zero_grad()
            batch_embeds = model(batch_x)
            batch_loss = contrastive_loss(batch_embeds, diff_pair_inds, same_pair_inds)
            batch_loss.backward()
            optimizer.step()
            batch_losses.append(batch_loss.item())
            if batch_i % batch_log_and_sched_freq == 0:
                running_loss = t.tensor(batch_losses[-batch_log_and_sched_freq:]).mean().item()
                print(f"\t[Batch {batch_i}]: {running_loss=}")
                if isinstance(scheduler, LRScheduler):
                    scheduler.step()
                elif isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(running_loss)


        epoch_loss = t.tensor(batch_losses).mean().item()
        if verbose:
            print(f"\tEpoch {epoch_i} mean loss = {epoch_loss}\n\n---\n")
            for name, p in model.named_parameters():
                if n_nans := p.isnan().sum():
                    print(f"{name}: {n_nans} NaNs")
            if epoch_loss == 0:
                print(
                    f"Achieved zero loss, terminating training after {epoch_i} epochs"
                )
                break

In [None]:
batch_size = 512
n_epochs = 4
lr = 1e-4

model = ContrastiveCNN()
ds = Dataset.load()
optimizer = t.optim.AdamW(model.parameters(), lr)
scheduler = ReduceLROnPlateau(optimizer, patience=4)

train(
    model,
    ds,
    optimizer,
    scheduler=scheduler,
    n_epochs=n_epochs,
    batch_size=batch_size,
)


Epoch 1 / 4
	[Batch 0]: running_loss=0.15125998854637146
	[Batch 11]: running_loss=0.09931919723749161
	[Batch 22]: running_loss=0.06754554808139801
	[Batch 33]: running_loss=0.056445930153131485
	[Batch 44]: running_loss=0.0471959225833416
	[Batch 55]: running_loss=0.04287444055080414
	[Batch 66]: running_loss=0.039449773728847504
	[Batch 77]: running_loss=0.03676426783204079
	[Batch 88]: running_loss=0.03443353250622749
	[Batch 99]: running_loss=0.03285261243581772
	[Batch 110]: running_loss=0.027982724830508232
	Epoch 1 mean loss = 0.04815696179866791

---

Epoch 2 / 4
	[Batch 0]: running_loss=0.0273450817912817
	[Batch 11]: running_loss=0.02706645056605339
	[Batch 22]: running_loss=0.026015568524599075
	[Batch 33]: running_loss=0.027934378013014793
	[Batch 44]: running_loss=0.025032956153154373
	[Batch 55]: running_loss=0.023871885612607002
	[Batch 66]: running_loss=0.022820716723799706
	[Batch 77]: running_loss=0.022454949095845222
	[Batch 88]: running_loss=0.021436573937535286
	[

In [None]:
for n, p in model.named_parameters():
    if n_nans := p.isnan().sum():
        print(n, n_nans)
    else:
        print(n, "no nans")

conv1.weight no nans
conv1.bias no nans
conv2.weight no nans
conv2.bias no nans
fc.weight no nans
fc.bias no nans


In [None]:
SEED = 42
k = 4000
random.seed(SEED)
t.manual_seed(SEED)

def get_k_label_embeds(model: ContrastiveCNN, ds: Dataset, k: int) -> t.Tensor:
    repr_x = t.cat(
        tuple(
            t.index_select(
                ds.train_x,
                dim=0,
                index=t.tensor(random.sample((ds.train_y == label).nonzero().tolist(), k=k)).squeeze()
            ) for label in range(10)
        )

    )
    with t.no_grad():
        repr_embeds = model(repr_x)
    label_embeds = t.stack(
        [
            repr_embeds[label * k : (label + 1) * k].mean(0)
            for label in range(10)
        ]
    )
    return label_embeds

def get_all_label_embeds(model: ContrastiveCNN, ds: Dataset) -> t.Tensor:
    label_embeds = t.zeros(10, model.d_embed, requires_grad=False)
    for label in tqdm(range(10)):
        repr_embeds = model(ds.train_x[ds.train_y == label])
        label_embeds[label] = repr_embeds.mean(0)
    return label_embeds

In [None]:
label_embeds = get_k_label_embeds(model, ds, k)

In [None]:
model.label_embeds = label_embeds

In [None]:
def predict(model: ContrastiveCNN, x: t.Tensor) -> t.Tensor:
    assert hasattr(model, "label_embeds")
    assert isinstance(model.label_embeds, t.Tensor)
    with t.no_grad():
        x_embeds = model(x)
    # TODO: optimize, use tensors
        embed_label_diffs = t.stack([
            ((model.label_embeds - x_embed) ** 2).mean(1)
            for x_embed in x_embeds
        ])
    return embed_label_diffs.argmin(1)

predict(model, ds.test_x[:10]),ds.test_y[:10]

(tensor([7, 8, 1, 0, 7, 1, 9, 9, 5, 4]),
 tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9]))

In [None]:
x_6 = ds.train_x[ds.train_y == 6]
assert len(x_6) == Counter(ds.train_y.tolist())[6]
predict(model, x_6)

tensor([6, 6, 6,  ..., 0, 0, 0])

In [None]:
train_pred_counts: list[dict[int, float]] = []

for label in tqdm(range(10)):
    x = ds.train_x[ds.train_y == label]
    preds = predict(model, x)
    train_pred_counts.append({l: c / len(x) for l, c in sorted(Counter(preds.tolist()).items())})

100%|██████████| 10/10 [00:56<00:00,  5.66s/it]


In [None]:
train_pred_df = pd.DataFrame(train_pred_counts)[list(range(10))].fillna(0)
px.imshow(train_pred_df)

In [None]:
train_pred_acc = train_pred_df.values.diagonal().mean()
print(f"{train_pred_acc = :.2%}")

train_pred_acc = 62.60%


In [None]:
test_pred_counts: list[dict[int, float]] = []

for label in tqdm(range(10)):
    x = ds.test_x[ds.test_y == label]
    preds = predict(model, x)
    test_pred_counts.append({l: c / len(x) for l, c in sorted(Counter(preds.tolist()).items())})

100%|██████████| 10/10 [00:13<00:00,  1.31s/it]


In [None]:
test_pred_df = pd.DataFrame(test_pred_counts)[list(range(10))].fillna(0)
px.imshow(test_pred_df)

In [None]:
test_pred_acc = test_pred_df.values.diagonal().mean()
print(f"{test_pred_acc = :.2%}")

test_pred_acc = 63.36%
