In [18]:
%reload_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from adic_components.prototype2 import P2Encoder

# Define a simple classifier that maps embeddings to class scores
class Classifier(nn.Module):
    def __init__(self, d_model: int, input_width: int, input_height: int, num_classes: int):
        super(Classifier, self).__init__()
        self.h = input_height // 16
        self.w = input_width // 16
        self.fchead = nn.Linear(d_model * self.h * self.w, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        return self.fchead(x)

# Parameters
input_channels = 3
input_width = 32
input_height = 32
d_model = 128
num_classes = 10


In [19]:
batch_size = 256
epochs = 10

# Data transforms for CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Download CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Create model instances
encoder = P2Encoder(input_channels, input_width, input_height, d_model)
classifier = Classifier(d_model, input_width, input_height, num_classes)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
classifier.to(device)

Classifier(
  (fchead): Linear(in_features=512, out_features=10, bias=True)
)

In [20]:
print(device)

cuda


In [25]:
# Loss and optimizer
import tqdm
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.001)

# Training loop
encoder.train()
classifier.train()
for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in tqdm.tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        embeddings = encoder(images)
        outputs = classifier(embeddings)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}')

100%|██████████| 196/196 [01:24<00:00,  2.32it/s]


Epoch [1/10], Loss: 0.0570


100%|██████████| 196/196 [01:25<00:00,  2.30it/s]


Epoch [2/10], Loss: 0.0491


100%|██████████| 196/196 [01:25<00:00,  2.30it/s]


Epoch [3/10], Loss: 0.0437


100%|██████████| 196/196 [01:24<00:00,  2.31it/s]


Epoch [4/10], Loss: 0.0460


100%|██████████| 196/196 [01:25<00:00,  2.30it/s]


Epoch [5/10], Loss: 0.0519


100%|██████████| 196/196 [01:22<00:00,  2.37it/s]


Epoch [6/10], Loss: 0.0430


100%|██████████| 196/196 [01:22<00:00,  2.39it/s]


Epoch [7/10], Loss: 0.0312


100%|██████████| 196/196 [01:21<00:00,  2.39it/s]


Epoch [8/10], Loss: 0.0314


100%|██████████| 196/196 [01:21<00:00,  2.40it/s]


Epoch [9/10], Loss: 0.0304


100%|██████████| 196/196 [01:21<00:00,  2.39it/s]

Epoch [10/10], Loss: 0.0461





In [26]:
# Evaluation on the test set
encoder.eval()
classifier.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        embeddings = encoder(images)
        outputs = classifier(embeddings)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

Test Accuracy: 73.75%
