In [1]:
from torchvision import datasets
from torchvision import transforms

In [2]:
train_data = datasets.CIFAR10('CIFAR10', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))

100%|██████████| 170M/170M [00:01<00:00, 105MB/s]


In [3]:
test_data = datasets.CIFAR10('CIFAR10', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))

In [4]:
len(train_data), len(test_data)

(50000, 10000)

In [None]:

# --------------------------
# Imports
# --------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from torchvision import models
import numpy as np

# --------------------------
# Dataset splitting
validation_size = 0.2
training_size = len(train_data)
indices = list(range(training_size))
np.random.shuffle(indices)
index_split = int(np.floor(training_size * validation_size))

validation_indices, training_indices = indices[:index_split], indices[index_split:]

training_sample = SubsetRandomSampler(training_indices)
validation_sample = SubsetRandomSampler(validation_indices)

batch_size = 16
train_loader = DataLoader(train_data, batch_size=batch_size, sampler=training_sample)
valid_loader = DataLoader(train_data, batch_size=batch_size, sampler=validation_sample)
test_loader = DataLoader(train_data, batch_size=batch_size)

# --------------------------
# Transfer Learning with VGG16
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained VGG16 model
vgg16 = models.vgg16(pretrained=True)

# Freeze convolutional layers (only train classifier)
for param in vgg16.features.parameters():
    param.requires_grad = False

# Modify classifier for 10 classes
num_features = vgg16.classifier[6].in_features
vgg16.classifier[6] = nn.Linear(num_features, 10)

model = vgg16.to(device)

# --------------------------
# Loss and Optimizer
# --------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.classifier[6].parameters(), lr=0.001, momentum=0.9)
n_epochs = 10

# --------------------------
# Accuracy function
# --------------------------
def accuracy(preds, y):
    pred = preds.argmax(dim=1)
    correct = pred.eq(y)
    return correct.sum().item() / len(y)

# --------------------------
# Training + Validation Loop
# --------------------------
for epoch in range(1, n_epochs+1):
    train_loss, valid_loss = 0.0, 0.0

    # Training
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)

    # Validation
    model.eval()
    with torch.no_grad():
        for data, target in valid_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            valid_loss += loss.item() * data.size(0)

    train_loss = train_loss / len(train_loader.sampler)
    valid_loss = valid_loss / len(valid_loader.sampler)

    print(f'| Epoch: {epoch:02} | Train Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f} |')

# --------------------------
# Testing the model
# --------------------------
model.eval()
test_loss, test_acc = 0.0, 0.0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        test_loss += loss.item() * data.size(0)
        test_acc += accuracy(output, target) * data.size(0)

test_loss = test_loss / len(test_loader.sampler)
test_acc = test_acc / len(test_loader.sampler)

print(f'\nTest Loss: {test_loss:.3f} | Test Accuracy: {test_acc:.3f}')




Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


100%|██████████| 528M/528M [00:04<00:00, 122MB/s] 


| Epoch: 01 | Train Loss: 2.116 | Val. Loss: 1.571 |
