In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from adic_components.prototype2 import P2Encoder

# Define a simple classifier that maps embeddings to class scores
class Classifier(nn.Module):
    def __init__(self, d_model: int, num_classes: int):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        return self.fc(x)

# Parameters
input_channels = 3
input_width = 32
input_height = 32
d_model = 128
num_classes = 10


100%|██████████| 170M/170M [00:17<00:00, 9.77MB/s] 


Classifier(
  (fc): Linear(in_features=128, out_features=10, bias=True)
)

In [10]:
batch_size = 256
epochs = 10

# Data transforms for CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Download CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Create model instances
encoder = P2Encoder(input_channels, input_width, input_height, d_model)
classifier = Classifier(d_model, num_classes)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
classifier.to(device)

Classifier(
  (fc): Linear(in_features=128, out_features=10, bias=True)
)

In [11]:
print(device)

cuda


In [None]:
# Loss and optimizer
import tqdm
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.001)

# Training loop
encoder.train()
classifier.train()
for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in tqdm.tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        embeddings = encoder(images)
        outputs = classifier(embeddings)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}')

100%|██████████| 196/196 [01:13<00:00,  2.68it/s]


Epoch [1/10], Loss: 1.3925


100%|██████████| 196/196 [01:14<00:00,  2.63it/s]


Epoch [2/10], Loss: 0.9097


100%|██████████| 196/196 [01:14<00:00,  2.62it/s]


Epoch [3/10], Loss: 0.7255


100%|██████████| 196/196 [01:12<00:00,  2.70it/s]


Epoch [4/10], Loss: 0.5898


100%|██████████| 196/196 [01:11<00:00,  2.74it/s]


Epoch [5/10], Loss: 0.4787


100%|██████████| 196/196 [01:11<00:00,  2.75it/s]


Epoch [6/10], Loss: 0.3678


100%|██████████| 196/196 [01:11<00:00,  2.74it/s]


Epoch [7/10], Loss: 0.2727


100%|██████████| 196/196 [01:12<00:00,  2.71it/s]


Epoch [8/10], Loss: 0.1976


100%|██████████| 196/196 [01:14<00:00,  2.64it/s]


Epoch [9/10], Loss: 0.1227


100%|██████████| 196/196 [01:13<00:00,  2.67it/s]

Epoch [10/10], Loss: 0.1342





: 

In [None]:
# Evaluation on the test set
encoder.eval()
classifier.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        embeddings = encoder(images)
        outputs = classifier(embeddings)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

Test Accuracy: 76.81%
