In [2]:
import torch
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader

# Step 1: Load the pretrained model
model = models.alexnet(pretrained=True)
model.eval()  # Set the model to evaluation mode

# Step 2: Prepare the CIFAR10 dataset
transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Load CIFAR10
train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


# Step 3: Define the evaluation metric
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, 1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


# Step 4: Evaluate the model on the test set
total_accuracy = 0
total_samples = 0

with torch.no_grad():  # Disable gradient computation
    for images, labels in test_loader:
        outputs = model(images)
        total_accuracy += accuracy(outputs, labels) * images.size(0)
        total_samples += images.size(0)

print(f"Accuracy: {total_accuracy/total_samples:.2f}")

Files already downloaded and verified
Shape of images: torch.Size([10, 3, 224, 224])
Labels: tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1])
Model outputs: tensor([[ 0.6175,  1.2068,  0.4165, -0.0327,  1.0109, -0.4699, -0.0486,  0.7053,
          0.5276,  0.8022],
        [ 0.0945,  1.4550,  0.2541, -1.1134,  0.7000, -0.1805, -0.6558,  0.4922,
          0.5821,  0.4114],
        [ 1.2710,  0.4897,  0.0044, -0.5899,  1.3828, -0.0667, -0.6573,  1.0354,
          0.8861,  0.7321],
        [-0.0291,  0.9328,  0.6547, -0.8127,  1.1128, -0.0328, -0.3471,  0.5694,
          0.1890,  0.6997],
        [-0.3130,  1.1678,  0.7404, -0.0439,  0.6904,  0.4394, -0.4430,  0.8653,
          0.2046,  1.4180],
        [ 0.0153,  0.6478,  0.3468,  0.7277,  1.3743,  0.3285,  0.3513,  0.2923,
          0.5868,  1.1714],
        [ 0.3656,  0.8359,  0.8290,  0.3245,  1.1494, -0.0671,  0.1716, -0.2312,
          0.9066,  1.3271],
        [-0.1869,  0.9709,  0.4063, -0.1205,  1.0908, -0.0941, -0.7757,  0.8036,
         