In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm, trange
import numpy as np
import wandb
from torch.utils.data import DataLoader
from time import perf_counter
import matplotlib.pyplot as plt

from typing import List

from models import CNNSmall, CNNModerate, VGG, VGG_11, VGG_16

from utils import get_torch_device, get_torch_device_as_string


In [None]:
DATASET_PATH: str = './data'
NUM_WORKERS_DATALOADER: int = 4

BATCH_SIZE: int = 64
LEARNING_RATE: float = 0.001
MOMENTUM: float = 0.9

NUM_EPOCHS: int = 3

DEVICE = get_torch_device(include_mps=False)

INCLUDE_WANDB: bool = True

SMALL_DATASET: bool = True

In [None]:
if INCLUDE_WANDB:
    import wandb
    config_dict = {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'num_epochs': NUM_EPOCHS,
        'training_device': get_torch_device_as_string(),
        'num_workers_dataloader': NUM_WORKERS_DATALOADER,
        'dataset_size': 'small' if SMALL_DATASET else 'all',
        'model type': 'CNN Small'
    }
    wandb.init(project='CNN', name='test', config=config_dict)

In [None]:
model = CNNSmall().to(DEVICE)
# model = CNNModerate().to(DEVICE)
# model = VGG(VGG_11).to(DEVICE)
# model = VGG(VGG_16).to(DEVICE)

In [None]:
loss_function = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:

image_classes: List[str] = ['plane', 'car', 'bird', 'cat',
                            'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

transform_images = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


In [None]:
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform_images)
validation_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform_images)
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform_images)

In [None]:
if SMALL_DATASET:
    train_dataset.data = train_dataset.data[:1000]
    validation_dataset.data = validation_dataset.data[1000:2000]
    test_dataset.data = test_dataset.data[:250]

In [None]:
validation_split: int = round(len(train_dataset) * 0.8)

train_dataset.data = train_dataset.data[:validation_split]
validation_dataset.data = validation_dataset.data[validation_split:]


In [None]:

train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS_DATALOADER)

validation_dataloader = DataLoader(
    validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS_DATALOADER
)

test_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS_DATALOADER)


In [None]:
def get_test_accuracy(data_loader: DataLoader, is_validation: bool = False):
    correct: int = 0
    total: int = 0
    validation_string: str = 'validation' if is_validation else 'test'

    # prepare to count predictions for each class
    correct_pred = {classname: 0 for classname in image_classes}
    total_pred = {classname: 0 for classname in image_classes}

    total_execution_time = perf_counter()

    # the gradients don't get calculated while testing
    with torch.no_grad():
        for (images, labels) in tqdm(data_loader, desc=f'{validation_string} loop'):
            batch_execution_time = perf_counter()
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            # calculate outputs by running images through the network
            y_predictions = model(images)
            # the class with the highest probability (energy) is what we choose as prediction
            _, predicted = torch.max(y_predictions, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # collect the correct predictions for each class
            for label, prediction in zip(labels, predicted):
                if label == prediction:
                    correct_pred[image_classes[label]] += 1
                total_pred[image_classes[label]] += 1

            if INCLUDE_WANDB:
                wandb.log(
                    {f'{validation_string} batch execution time': perf_counter() - batch_execution_time})

    accuracy = round(100 * correct / total, 4)
    if not is_validation:
        print(
            f'Accuracy of the network on the {len(data_loader.dataset)} test images: {accuracy} %')

    if INCLUDE_WANDB:
        if is_validation:
            wandb.log({
                f'{validation_string} accuracy': accuracy,
                f'{validation_string} total execution time': perf_counter() - total_execution_time
            })
        wandb.run.summary[f'{validation_string} dataset overall accuracy %'] = accuracy

    # print accuracy for each class
    for classname, correct_count in correct_pred.items():
        accuracy = round(100 * correct_count / total_pred[classname], 4)
        if not is_validation:
            print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

        if INCLUDE_WANDB:
            key = f'{validation_string} {classname} accuracy'
            wandb.log({
                key: accuracy
            })
            wandb.run.summary[key] = accuracy


In [None]:
total_training_time = perf_counter()
for epoch in range(NUM_EPOCHS):
    
    correct_predictions: int = 0
    total_predictions: int = 0
    
    training_loop_execution_time = perf_counter()
    for idx, (inputs, labels) in enumerate(tqdm(train_dataloader, desc='train loop'), 0):
        batch_execution_time = perf_counter()
        # zero the parameter gradients
        optimizer.zero_grad()
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        # forward pass
        y_hat = model(inputs)
        
        # calculate losses
        loss = loss_function(y_hat, labels)
        # backpropagate the new gradients
        loss.backward()
        # optimize the gradients
        optimizer.step()
        
        _, y_predictions = torch.max(y_hat, 1)
        correct_predictions += torch.sum(y_predictions == labels).item()
        total_predictions += labels.size(0)
        
        if INCLUDE_WANDB:
            wandb.log({'batch execution time': perf_counter() - batch_execution_time})
        
    if INCLUDE_WANDB:
        training_loop_execution_time = perf_counter() - training_loop_execution_time
        train_accuracy = 100 * correct_predictions / total_predictions
        wandb.log({
            'loss': loss.item(), 
            'training accuracy in %': train_accuracy,
            'epoch execution time': training_loop_execution_time,
            })
        
    model.eval()
    
    get_test_accuracy(validation_dataloader, is_validation=True)
    # get_test_accuracy_of_each_class(validation_dataloader, is_validation=True)
    model.train()
    
if INCLUDE_WANDB:
    wandb.summary['total training time'] = round(perf_counter() - total_training_time, 4)

In [None]:
MODEL_PATH: str = f'./models/cifar_net.pth'
torch.save(model.state_dict(), MODEL_PATH)

In [None]:
get_test_accuracy(test_dataloader)