In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import time
from timm.data.mixup import Mixup

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "CPU")

In [25]:
# Define data augmentation and normalization (using ImageNet statistics)
transform_train = transforms.Compose([
    transforms.Resize(224),                     # Resize CIFAR-10 images to 224x224
    transforms.RandomCrop(224, padding=4),      # Random crop with padding
    transforms.RandomHorizontalFlip(),          # Random horizontal flip
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet mean
                         std=[0.229, 0.224, 0.225])   # ImageNet std
])

transform_test = transforms.Compose([
    transforms.Resize(224),                     # Resize for consistency
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [26]:
train_data = torchvision.datasets.CIFAR10(root = './data', download = True, train = True, transform = transform_train)
test_data = torchvision.datasets.CIFAR10(root = './data', download = True, train = False, transform = transform_test)

In [27]:
train_data_batches = DataLoader(train_data, shuffle = True, batch_size = 8, pin_memory=True, num_workers = 2)
test_data_batches = DataLoader(test_data, shuffle = True, batch_size = 8, pin_memory=True, num_workers = 2)

In [28]:
model = torchvision.models.resnet18(pretrained = True)

In [29]:
num_features = model.fc.in_features
model.fc

Linear(in_features=512, out_features=1000, bias=True)

In [30]:
model.fc = nn.Linear(num_features, 10)

In [31]:
# Optionally, freeze all layers except the final fully connected layer:
for name, param in model.named_parameters():
    # Freeze everything except layers in the last block (or the final two layers)
    if "layer4" not in name and "fc" not in name:
        param.requires_grad = False

In [32]:
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()

epochs = 30

# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

optimizer = torch.optim.AdamW([
    {'params': model.fc.parameters(), 'lr': 3e-4},
    {'params': [param for name, param in model.named_parameters() if "layer4" in name], 'lr': 1e-4}
], weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=3e-4, steps_per_epoch=len(train_data_batches), epochs=epochs
)

# Training loop
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_data_batches):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    avg_loss = running_loss / len(train_data_batches)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

Epoch [1/30], Loss: 0.7214
Epoch [2/30], Loss: 0.4353
Epoch [3/30], Loss: 0.3639
Epoch [4/30], Loss: 0.3168
Epoch [5/30], Loss: 0.2773
Epoch [6/30], Loss: 0.2448
Epoch [7/30], Loss: 0.2220
Epoch [8/30], Loss: 0.1991
Epoch [9/30], Loss: 0.1816
Epoch [10/30], Loss: 0.1626
Epoch [11/30], Loss: 0.1519
Epoch [12/30], Loss: 0.1357
Epoch [13/30], Loss: 0.1192
Epoch [14/30], Loss: 0.1146
Epoch [15/30], Loss: 0.1035
Epoch [16/30], Loss: 0.0941
Epoch [17/30], Loss: 0.0854
Epoch [18/30], Loss: 0.0790
Epoch [19/30], Loss: 0.0745
Epoch [20/30], Loss: 0.0700
Epoch [21/30], Loss: 0.0630
Epoch [22/30], Loss: 0.0612
Epoch [23/30], Loss: 0.0562
Epoch [24/30], Loss: 0.0545
Epoch [25/30], Loss: 0.0505
Epoch [26/30], Loss: 0.0479
Epoch [27/30], Loss: 0.0442
Epoch [28/30], Loss: 0.0425
Epoch [29/30], Loss: 0.0416
Epoch [30/30], Loss: 0.0361


In [33]:
# Evaluate on test set
def check_accuracy(model, batch):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in batch:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_acc = 100 * correct / total
    print(f"Test Accuracy: {test_acc:.2f}%")
    return test_acc

check_accuracy(model, test_data_batches)

Test Accuracy: 93.81%


93.81

In [35]:
list(model.fc.parameters())

[Parameter containing:
 tensor([[-3.8974e-02,  2.9298e-02, -1.3958e-02,  ..., -4.6387e-02,
          -1.2887e-02,  1.2016e-02],
         [-1.0419e-01,  3.4756e-02,  6.4834e-02,  ...,  3.3782e-02,
          -6.7517e-02,  2.9924e-02],
         [ 3.4154e-02,  5.3276e-03,  5.3004e-02,  ..., -7.9814e-02,
           3.5800e-02,  5.7835e-02],
         ...,
         [-9.1951e-03, -8.7283e-02,  2.6646e-02,  ...,  2.1281e-02,
           4.0279e-03, -6.6018e-02],
         [-1.2005e-01, -5.0391e-02,  5.4703e-02,  ...,  6.6629e-02,
           6.4773e-02,  3.9974e-02],
         [-8.2902e-02,  6.2400e-03, -9.1404e-02,  ..., -3.7850e-03,
          -3.6153e-02, -1.1456e-04]], device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([ 0.0363,  0.0100,  0.0260,  0.0085,  0.0251,  0.0221, -0.0104,  0.0187,
         -0.0086, -0.0113], device='cuda:0', requires_grad=True)]

In [38]:
for name, param in model.named_parameters():
    print(name)

conv1.weight
bn1.weight
bn1.bias
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer2.1.conv1.weight
layer2.1.bn1.weight
layer2.1.bn1.bias
layer2.1.conv2.weight
layer2.1.bn2.weight
layer2.1.bn2.bias
layer3.0.conv1.weight
layer3.0.bn1.weight
layer3.0.bn1.bias
layer3.0.conv2.weight
layer3.0.bn2.weight
layer3.0.bn2.bias
layer3.0.downsample.0.weight
layer3.0.downsample.1.weight
layer3.0.downsample.1.bias
layer3.1.conv1.weight
layer3.1.bn1.weight
layer3.1.bn1.bias
layer3.1.conv2.weight
layer3.1.bn2.weight
layer3.1.bn2.bias
layer4.0.conv1.weight
layer4.0.bn1.weight
layer4.0.bn1.bias
layer4.0.conv2.we