In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm, trange
import numpy as np
import wandb
from torch.utils.data import DataLoader
from time import perf_counter
import matplotlib.pyplot as plt

from typing import List

from models import CNNSmall, CNNModerate, CNNBatchnorm, CNNBig, VGG, VGG_11, VGG_16

from utils import get_torch_device, get_torch_device_as_string


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET_PATH: str = './data'
NUM_WORKERS_DATALOADER: int = 8

BATCH_SIZE: int = 100
LEARNING_RATE: float = 0.001
MOMENTUM: float = 0.9

NUM_EPOCHS: int = 100

DEVICE = get_torch_device(include_mps=False)

INCLUDE_WANDB: bool = True

SMALL_DATASET: bool = False

In [3]:

# model, model_type = CNNSmall().to(DEVICE), 'CNN Small'
# model, model_type = CNNModerate().to(DEVICE), 'CNN Moderate'
# model, model_type = CNNBatchnorm().to(DEVICE), 'CNN Batchnorm'
model, model_type = CNNBig().to(DEVICE), 'CNN Big'
# model, model_type = VGG(VGG_11).to(DEVICE), 'VGG 11'
# model, model_type = VGG(VGG_16).to(DEVICE), 'VGG 16'



In [4]:
if INCLUDE_WANDB:
    import wandb
    config_dict = {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'num_epochs': NUM_EPOCHS,
        'training_device': get_torch_device_as_string(),
        'num_workers_dataloader': NUM_WORKERS_DATALOADER,
        'dataset_size': 'small' if SMALL_DATASET else 'all',
        'model type': model_type
    }
    wandb.init(project='CNN', name=model_type, config=config_dict)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmy-god-its-full-of-stars[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
loss_function = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [6]:

image_classes: List[str] = ['plane', 'car', 'bird', 'cat',
                            'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

transform_images = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


In [7]:
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform_images)
validation_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform_images)
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform_images)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [8]:
if SMALL_DATASET:
    train_dataset.data = train_dataset.data[:1000]
    validation_dataset.data = validation_dataset.data[1000:2000]
    test_dataset.data = test_dataset.data[:250]

In [9]:
validation_split: int = round(len(train_dataset) * 0.8)

train_dataset.data = train_dataset.data[:validation_split]
validation_dataset.data = validation_dataset.data[validation_split:]

In [10]:

train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS_DATALOADER)

validation_dataloader = DataLoader(
    validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS_DATALOADER
)

test_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS_DATALOADER)


In [11]:
def get_test_accuracy(data_loader: DataLoader, is_validation: bool = False, epoch: int = -1):
    correct: int = 0
    total: int = 0
    validation_string: str = 'validation' if is_validation else 'test'

    # prepare to count predictions for each class
    correct_pred = {classname: 0 for classname in image_classes}
    total_pred = {classname: 0 for classname in image_classes}

    total_execution_time = perf_counter()

    # the gradients don't get calculated while testing
    with torch.no_grad():
        for (images, labels) in tqdm(data_loader, desc=f'{validation_string} loop'):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            # calculate outputs by running images through the network
            y_predictions = model(images)
            # the class with the highest probability (energy) is what we choose as prediction
            _, predicted = torch.max(y_predictions, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # collect the correct predictions for each class
            for label, prediction in zip(labels, predicted):
                if label == prediction:
                    correct_pred[image_classes[label]] += 1
                total_pred[image_classes[label]] += 1


    accuracy = round(100 * correct / total, 4)
    if not is_validation:
        print(
            f'Accuracy of the network on the {len(data_loader.dataset)} test images: {accuracy} %')

    if INCLUDE_WANDB:
        if is_validation:
            if epoch == -1:
                wandb.log({
                    f'{validation_string} accuracy': accuracy,
                    f'{validation_string} total execution time': perf_counter() - total_execution_time
                })
            else:
                wandb.log({
                    f'{validation_string} accuracy': accuracy,
                    f'{validation_string} total execution time': perf_counter() - total_execution_time
                }, step=epoch)

        wandb.run.summary[f'{validation_string} dataset overall accuracy %'] = accuracy

    # print accuracy for each class
    for classname, correct_count in correct_pred.items():
        accuracy = round(100 * correct_count / total_pred[classname], 4)
        if not is_validation:
            print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

        if INCLUDE_WANDB:
            key = f'{validation_string} {classname} accuracy'
            if epoch == -1:
                wandb.log({
                    key: accuracy
                })
            else:
                wandb.log({
                    key: accuracy
                }, step=epoch)
            wandb.run.summary[key] = accuracy


In [12]:
total_training_time = perf_counter()
for epoch in range(NUM_EPOCHS):
    
    correct_predictions: int = 0
    total_predictions: int = 0
    
    training_loop_execution_time = perf_counter()
    for idx, (inputs, labels) in enumerate(tqdm(train_dataloader, desc='train loop'), 0):
        # zero the parameter gradients
        optimizer.zero_grad()
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        # forward pass
        y_hat = model(inputs)
        
        # calculate losses
        loss = loss_function(y_hat, labels)
        # backpropagate the new gradients
        loss.backward()
        # optimize the gradients
        optimizer.step()
        
        _, y_predictions = torch.max(y_hat, 1)
        correct_predictions += torch.sum(y_predictions == labels).item()
        total_predictions += labels.size(0)
        
    if INCLUDE_WANDB:
        training_loop_execution_time = perf_counter() - training_loop_execution_time
        train_accuracy = 100 * correct_predictions / total_predictions
        wandb.log({
            'loss': loss.item(), 
            'training accuracy in %': train_accuracy,
            'epoch execution time': training_loop_execution_time,
            }, step=epoch)
        
    model.eval()
    
    get_test_accuracy(validation_dataloader, is_validation=True, epoch=epoch)
    # get_test_accuracy_of_each_class(validation_dataloader, is_validation=True)
    model.train()
    
if INCLUDE_WANDB:
    wandb.summary['total training time'] = round(perf_counter() - total_training_time, 4)

train loop: 100%|██████████| 400/400 [00:02<00:00, 139.32it/s]
validation loop: 100%|██████████| 100/100 [00:00<00:00, 144.75it/s]
train loop: 100%|██████████| 400/400 [00:01<00:00, 203.58it/s]
validation loop: 100%|██████████| 100/100 [00:00<00:00, 149.73it/s]
train loop: 100%|██████████| 400/400 [00:01<00:00, 200.38it/s]
validation loop: 100%|██████████| 100/100 [00:00<00:00, 139.15it/s]
train loop: 100%|██████████| 400/400 [00:02<00:00, 193.23it/s]
validation loop: 100%|██████████| 100/100 [00:00<00:00, 140.52it/s]
train loop: 100%|██████████| 400/400 [00:01<00:00, 204.16it/s]
validation loop: 100%|██████████| 100/100 [00:00<00:00, 138.11it/s]
train loop: 100%|██████████| 400/400 [00:02<00:00, 193.38it/s]
validation loop: 100%|██████████| 100/100 [00:00<00:00, 147.37it/s]
train loop: 100%|██████████| 400/400 [00:02<00:00, 195.44it/s]
validation loop: 100%|██████████| 100/100 [00:00<00:00, 133.50it/s]
train loop: 100%|██████████| 400/400 [00:02<00:00, 190.15it/s]
validation loop: 100

In [13]:
MODEL_PATH: str = f'./models/{model_type}_cifar_net.pth'
torch.save(model.state_dict(), MODEL_PATH)

In [14]:
get_test_accuracy(test_dataloader)

test loop: 100%|██████████| 100/100 [00:00<00:00, 134.74it/s]

Accuracy of the network on the 10000 test images: 59.8 %
Accuracy for class: plane is 64.5 %
Accuracy for class: car   is 72.1 %
Accuracy for class: bird  is 44.6 %
Accuracy for class: cat   is 38.3 %
Accuracy for class: deer  is 55.4 %
Accuracy for class: dog   is 48.3 %
Accuracy for class: frog  is 70.5 %
Accuracy for class: horse is 62.9 %
Accuracy for class: ship  is 71.1 %
Accuracy for class: truck is 70.3 %



