### Imports

In [1]:
import src.simclr as simclr
from src.data_loading import (
    get_data_loader,
    get_image_dataset,
    subset_classes,
    transforms_image_net,
)
from src.models import ResNet
from src.train import training_loop
from src.utils import accuracy, show, update_ewma

In [8]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


def train_simclr(
    data: DataLoader,
    model: nn.Module,
    lr: float,
    decay: float,
    n_epochs: int = 100,
    plot_every: int = 10,
    print_every: int = 10,
    write_every: int = 10,
):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=decay)
    transform = transforms_image_net(is_image=False, crop=True, crop_size=28)
    writer = SummaryWriter("data/logs/simclr")
    running_acc = None
    running_loss = None
    for e_ix in range(n_epochs):
        for (step_ix, batch) in enumerate(data):
            x = batch[0]
            package, loss, logits_labels = simclr.step(x, model, transform)
            (t1, t2, s1, s2) = package
            logits, labels = logits_labels
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = accuracy(logits, labels)
            running_acc = update_ewma(acc.item(), running_acc, 0.9)
            running_loss = update_ewma(loss.item(), running_loss, 0.9)

        if write_every and (e_ix + 1) % write_every == 0:
            writer.add_scalar("Accuracy", running_acc, e_ix)
            writer.add_scalar("Loss", running_loss, e_ix)
        if plot_every and ((e_ix == 1) or (e_ix + 1) % plot_every == 0):
            show(torch.cat([x[:5], t1[:5], t2[:5]]))
        if print_every and (e_ix + 1) % print_every == 0:
            print(f"loss: {loss.item()}")
            print(f"contrastive accuracy: {acc}")

### Get MNIST data

In [9]:
train_dataset = get_image_dataset("mnist", train=True,)
train_two_balanced = subset_classes(
    train_dataset, classes_retained=[0, 1], class_probas=[0.1, 0.1]
)
train_two_imbalanced = subset_classes(
    train_dataset, classes_retained=[0, 1], class_probas=[0.5, 0.05]
)

### Params

In [10]:
b_size = 128
n_epochs = 100
lr = 0.1

In [11]:
train_loader, _ = get_data_loader(
    train_two_imbalanced, val_share=0.0, batch_size=b_size, single_batch=False
)
eval_loader, _ = get_data_loader(
    train_two_balanced, val_share=0.0, batch_size=b_size, single_batch=False
)

Dataset lengths: train-3274, val-0
Dataset lengths: train-1288, val-0


In [12]:
encoder = nn.Sequential(
    nn.Conv2d(1, 16, kernel_size=3, padding=1),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.ReLU(),
    # nn.LayerNorm((16, 14, 14)),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.ReLU(),
    # nn.LayerNorm((32, 7, 7)),
    nn.Conv2d(32, 64, kernel_size=1, padding=0),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
)
head = nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 64))
model = simclr.ContrastiveLearner(encoder=encoder, projection=head)

In [None]:
train_simclr(
    train_loader,
    model,
    n_epochs=n_epochs,
    plot_every=0,
    print_every=10,
    write_every=1,
    lr=lr,
    decay=1e-6,
)

loss: 4.828437805175781
contrastive accuracy: 0.013513513840734959
loss: 4.7747721672058105
contrastive accuracy: 0.013513513840734959
loss: 5.1745076179504395
contrastive accuracy: 0.006756756920367479


In [None]:
encoder = nn.Sequential(
    nn.Conv2d(1, 16, kernel_size=3, padding=1),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.ReLU(),
    # nn.LayerNorm((16, 14, 14)),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.ReLU(),
    # nn.LayerNorm((32, 7, 7)),
    nn.Conv2d(32, 64, kernel_size=1, padding=0),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
)
training_loop(
    "test-mnist",
    model=encoder,
    opt=torch.optim.SGD(encoder.parameters(), lr=lr, weight_decay=1e-6),
    scheduler=None,
    train_loader=train_loader,
    eval_loader=eval_loader,
    loss_fn=nn.CrossEntropyLoss(),
    device="cpu",
    n_epochs=100,
    print_every=10,
    write_every=10,
    plot_every=0,
    check_every=0,
    metric_fn=accuracy,
)