In [1]:
import torchvision.models as models

alexnet = models.alexnet(pretrained=True)



Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


100%|██████████| 233M/233M [00:01<00:00, 153MB/s]


In [2]:
for param in alexnet.features.parameters():
  param.requires_grad = False

for param in alexnet.classifier[:-1].parameters():
  param.requires_grad = False

In [13]:
import torch.nn as nn

num_classes = 10
alexnet.classifier[6] = nn.Linear(4096, num_classes)

In [14]:
import torch.optim as optim
optimizer = optim.Adam(alexnet.classifier[6].parameters(), lr=0.001)

In [15]:
import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize(224),  # Resize to match AlexNet input
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [16]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(alexnet.parameters(), lr=0.01)

In [17]:
for epoch in range(10):
  alexnet.train()
  for xb, yb in train_loader:
    pred = alexnet(xb)
    loss = loss_fn(pred, yb)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

  alexnet.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    for xb, yb in test_loader:
      pred = alexnet(xb)
      _, predicted = torch.max(pred, 1)
      correct += (predicted==yb).sum().item()
      total += yb.size(0)
  accuracy = correct/total
  print(f'Epoch {epoch+1} accuracy {accuracy:.4f}')

Epoch 1 accuracy 0.6960
Epoch 2 accuracy 0.6958
Epoch 3 accuracy 0.7085
Epoch 4 accuracy 0.6984
Epoch 5 accuracy 0.7507
Epoch 6 accuracy 0.7578
Epoch 7 accuracy 0.7304
Epoch 8 accuracy 0.7208
Epoch 9 accuracy 0.7329
Epoch 10 accuracy 0.7348
