# Representation Classifier

dataset: `imagenet-1k`

models based on AvgPool2d + Linear:
* conv_in: 321k params
* down_blocks[0]: 321k params
* down_blocks[1]: 641k params
* down_blocks[2]: 1281k params
* down_blocks[3]: 1281k params
* mid_block: 1281k params
* up_blocks[0]: 1281k params
* up_blocks[1]: 1281k params
* up_blocks[2]: 641k params
* up_blocks[3]: 321k params
* conv_out: 5k params

In [None]:
from datasets import load_from_disk
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from trainplot import plot, TrainPlot
from trainplot.trainplot import TrainPlotPlotlyExperimental
from collections import defaultdict
import numpy as np

In [None]:
# Load the dataset
dataset = load_from_disk("...").with_format("torch").select(range(10_000)).train_test_split(test_size=0.1, seed=42)
train_loader = DataLoader(dataset['train'], batch_size=64, shuffle=True)
test_loader = DataLoader(dataset['test'], batch_size=64, shuffle=False)

In [None]:
class SimpleClassifier(nn.Module):
    def __init__(self, channel_size=1280, spatial_size=8, num_classes=1000):
        super().__init__()
        self.fc = nn.Linear(channel_size, num_classes)
        self.pool = nn.AvgPool2d(spatial_size)

    def forward(self, x):
        x = torch.flatten(self.pool(x), start_dim=1)
        x = self.fc(x)
        return x


class SimpleCNNClassifier(nn.Module):
    def __init__(self, channel_size=1280, spatial_size=8, num_classes=1000):
        super().__init__()
        self.conv = nn.Conv2d(channel_size, num_classes, kernel_size=1, padding=0)
        self.pool = nn.MaxPool2d(spatial_size)

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(self.pool(x), start_dim=1)
        return x


models = {
    name: SimpleClassifier(channel_size=value.shape[1], spatial_size=value.shape[-1]).cuda()
    for name, value in next(iter(train_loader)).items()
    if name != 'label'
}
print(f'Created {len(models)} models:')
for name, model in models.items():
    print(f'  {name}: {sum(p.numel() for p in model.parameters() if p.requires_grad)//1000}k params')
criterion = nn.CrossEntropyLoss()
optimizers = {name: optim.Adam(model.parameters(), lr=0.001) for name, model in models.items()}

In [None]:
tp_train = TrainPlotPlotlyExperimental()
tp_test = TrainPlotPlotlyExperimental()

In [None]:
epoch_progress = tqdm(range(20), desc="Epoch")
step_progress = tqdm(total=len(train_loader), desc="Step")
model_progress = tqdm(models.keys(), desc="Model")
total_train_step = 0
accuracies_ema = defaultdict(float)
for epoch in epoch_progress:
    # train
    step_progress.total = len(train_loader)
    step_progress.desc = "Train"
    for model in models.values():
        model.train()
    accuracies = defaultdict(list)
    step_progress.reset()
    for batch in train_loader:
        model_progress.reset()
        for name in models:
            model_progress.set_postfix_str(name)
            optimizers[name].zero_grad()
            x = batch[name].cuda()
            y = batch["label"].cuda()
            y_pred = models[name](x)
            loss = criterion(y_pred, y)
            accuracy = (y_pred.argmax(dim=1) == y).float().mean().cpu().item()
            accuracies[name] += accuracy,
            accuracies_ema[name] = 0.8 * accuracies_ema[name] + 0.2 * accuracy
            loss.backward()
            optimizers[name].step()
            model_progress.update()
        tp_train(step=total_train_step, **accuracies_ema)
        total_train_step += 1
        step_progress.update()
    # accuracies = {name: accuracy/len(train_loader) for name, accuracy in accuracies.items()}
    # tp_train(**accuracies)

    # test
    test_accuracies = defaultdict(float)
    for model in models.values():
        model.eval()
    step_progress.reset()
    step_progress.total = len(test_loader)
    step_progress.desc = "Test"
    for batch in test_loader:
        model_progress.reset()
        for name in models:
            model_progress.set_postfix_str(name)
            x = batch[name].cuda()
            y = batch["label"].cuda()
            y_pred = models[name](x)
            test_accuracies[name] += (y_pred.argmax(dim=1) == y).float().mean().cpu().item()
            model_progress.update()
        step_progress.update()
    test_accuracy = {name: accuracy/len(test_loader) for name, accuracy in test_accuracies.items()}
    tp_test(**test_accuracy)
    # print(f"Epoch {epoch}: accuracy {accuracy.item():.2%}, test accuracy {test_accuracy.item():.2%}")

In [None]:
# for name, model in models.items():
#     torch.save(model.state_dict(), f"../classifier-models/{name}-10k-20ep.pth")

In [None]:
import matplotlib.pyplot as plt
# plot barplot of test accuracies for each model
plt.figure(figsize=(10, 5))
plt.bar(test_accuracy.keys(), test_accuracy.values())
plt.title("Test accuracy after 20 epochs")
plt.xlabel("Model")
plt.ylabel("Accuracy")
plt.xticks(rotation=45)
plt.show()