# Representation Classifier

dataset: `imagenet-1k`

1 model based on AvgPool2d + Linear or 1x1 Conv2d + Pooling

In [None]:
from datasets import load_from_disk
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from trainplot import plot, TrainPlot
from trainplot.trainplot import TrainPlotPlotlyExperimental

In [None]:
# Load the dataset
dataset = load_from_disk("...").with_format("torch")
train_loader = DataLoader(dataset['train'], batch_size=32, shuffle=True)
test_loader = DataLoader(dataset['test'], batch_size=32, shuffle=False)

In [None]:
class SimpleClassifier(nn.Module):
    def __init__(self, channel_size=1280, spatial_size=8, num_classes=1000):
        super().__init__()
        self.fc = nn.Linear(channel_size, num_classes)
        self.pool = nn.AvgPool2d(spatial_size)

    def forward(self, x):
        x = torch.flatten(self.pool(x), start_dim=1)
        x = self.fc(x)
        return x


class SimpleCNNClassifier(nn.Module):
    def __init__(self, channel_size=1280, spatial_size=8, num_classes=1000):
        super().__init__()
        self.conv = nn.Conv2d(channel_size, num_classes, kernel_size=1, padding=0)
        self.pool = nn.MaxPool2d(spatial_size)

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(self.pool(x), start_dim=1)
        return x


model = SimpleCNNClassifier().cuda()
print('Parameter count:', sum(p.numel() for p in model.parameters() if p.requires_grad))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
tp = TrainPlotPlotlyExperimental()

In [None]:
for epoch in tqdm(range(20)):
    model.train()
    accuracy = 0
    for batch in train_loader:
        optimizer.zero_grad()
        x = batch["mid_block"].cuda()
        y = batch["label"].cuda()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        accuracy += (y_pred.argmax(dim=1) == y).float().mean()
        loss.backward()
        optimizer.step()
    accuracy /= len(train_loader)
    test_accuracy = 0
    model.eval()
    for batch in test_loader:
        x = batch["mid_block"].cuda()
        y = batch["label"].cuda()
        y_pred = model(x)
        test_accuracy += (y_pred.argmax(dim=1) == y).float().mean()
    test_accuracy /= len(test_loader)
    tp(accuracy = accuracy.item(), test_accuracy = test_accuracy.item())
    print(f"Epoch {epoch}: accuracy {accuracy.item():.2%}, test accuracy {test_accuracy.item():.2%}")