In [None]:
import torch
import torch.nn
from torch.optim import Adam
from torch.optim import SGD
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [None]:
train_dataset = CIFAR100(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = CIFAR100(root='./data', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
print(f'{len(train_dataset)} training samples')
print(f'{len(test_dataset)} testing samples')

In [None]:
img, label = train_dataset[0]
print(f'Image shape: {img.shape}')
print(f'Label: {label}')
transforms.functional.to_pil_image(img)

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
])

train_dataset.transform = train_transforms
test_dataset.transform = test_transforms

In [None]:
batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
%run VGG16.ipynb

In [None]:
num_epochs = 12 # for shorter training time. Ideally, this should be way more.
lr = 1e-4

device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_losses = []
validation_losses = []

model = VGG16(100)

optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)
# optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-5)

criterion = nn.CrossEntropyLoss()

model.to(device)

In [None]:
def train(epoch):
    print("\nEpoch: %d" % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (data, targets) in enumerate(train_loader):

        data = data.to(device=device)
        targets = targets.to(device=device)

        optimizer.zero_grad()
        scores = model(data)
        loss = criterion(scores, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predictions = torch.max(scores, 1)
        total += targets.size(0)
        correct += (predictions == targets).sum()

    train_loss = train_loss / len(train_loader)  # Average loss per batch
    accuracy = correct.item() / total
    train_losses.append(train_loss)  # Append train loss for plotting

    print(f"Train Loss: {train_loss:.4f}, Percentage Train Acc: {100*accuracy:.4f}")        

In [None]:
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(test_loader):

            data = data.to(device=device)
            targets = targets.to(device=device)

            scores = model(data)
            loss = criterion(scores, targets)
            
            test_loss += loss.item()
            _, predictions = torch.max(scores, 1)
            total += targets.size(0)
            correct += (predictions == targets).sum()

    test_loss = test_loss / len(test_loader)
    accuracy = correct.item() / total
    validation_losses.append(test_loss)

    print(f"Test Loss: {test_loss:.4f}, Percentage Test Acc: {100*accuracy:.4f}")

In [None]:
for epoch in range(num_epochs):
    train(epoch)
    test(epoch)

In [None]:
plt.plot(train_losses, label='train loss')
plt.plot(validation_losses, label='validation loss')
plt.legend()
plt.show()

In [None]:
torch.save(model.state_dict(),'CIFAR100_VGG16.pt')

In [None]:
# load an empty model for 'final validation'
loaded_model = VGG16(100) # based on model
loaded_model.eval().to(device)

print('Before loading model')

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = loaded_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

loaded_model.load_state_dict(torch.load('CIFAR100_VGG16.pt')) # based on model
loaded_model.eval().to(device)

print('After loading model')

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = loaded_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))