In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm, trange
import numpy as np
import wandb
from torch.utils.data import DataLoader

from typing import List

from models import CNNSmall, CNNModerate, VGG, VGG_11, VGG_16

from utils import get_torch_device, get_torch_device_as_string


  from .autonotebook import tqdm as notebook_tqdm


In [23]:
DATASET_PATH: str = './data'
NUM_WORKERS_DATALOADER: int = 4

BATCH_SIZE: int = 64
LEARNING_RATE: float = 0.001
MOMENTUM: float = 0.9

NUM_EPOCHS: int = 1

DEVICE = get_torch_device(include_mps=False)

INCLUDE_WANDB: bool = True

SMALL_DATASET: bool = True

In [None]:
model = CNNSmall().to(DEVICE)
# model = CNNModerate().to(DEVICE)
# model = VGG(VGG_11).to(DEVICE)
# model = VGG(VGG_16).to(DEVICE)

In [24]:
if INCLUDE_WANDB:
    import wandb
    config_dict = {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'num_epochs': NUM_EPOCHS,
        'training_device': get_torch_device_as_string(),
        'num_workers_dataloader': NUM_WORKERS_DATALOADER,
        'dataset_size': 'small' if SMALL_DATASET else 'all',
        'model type': 'CNN Small'
    }
    wandb.init(project='CNN', config=config_dict)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmy-god-its-full-of-stars[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [25]:

image_classes: List[str] = ['plane', 'car', 'bird', 'cat',
                            'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

transform_images = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


In [26]:
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform_images)
validation_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform_images)
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform_images)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [27]:
if SMALL_DATASET:
    train_dataset.data = train_dataset.data[:1000]
    validation_dataset.data = validation_dataset.data[:1000]
    test_dataset.data = test_dataset.data[:250]

In [28]:
validation_split: int = round(len(train_dataset) * 0.8)

train_dataset.data = train_dataset.data[:validation_split]
validation_dataset.data = validation_dataset.data[validation_split:]


In [29]:

train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS_DATALOADER)

validation_dataloader = DataLoader(
    validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS_DATALOADER
)

test_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS_DATALOADER)


In [31]:
loss_function = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [32]:
def get_test_accuracy(data_loader: DataLoader, is_validation: bool = False):
    correct = 0
    total = 0
    validation_string = 'validation' if is_validation else 'test'
    # the gradients don't get calculated while testing
    with torch.no_grad():
        for (images, labels) in tqdm(data_loader, desc=f'{validation_string} loop'):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            # calculate outputs by running images through the network
            y_predictions = model(images)
            # the class with the highest probability (energy) is what we choose as prediction
            _, predicted = torch.max(y_predictions.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    test_accuracy = 100 * correct // total
    if not is_validation:
        print(f'Accuracy of the network on the {len(data_loader.dataset)} test images: {test_accuracy} %')

    if INCLUDE_WANDB:
        if is_validation:
            wandb.log({'validation accuracy': test_accuracy})
        wandb.run.summary['test dataset overall accuracy %'] = test_accuracy

In [33]:
def get_test_accuracy_of_each_class(data_loader: DataLoader, is_validation: bool = False):
    # prepare to count predictions for each class
    correct_pred = {classname: 0 for classname in image_classes}
    total_pred = {classname: 0 for classname in image_classes}
    
    validation_string = 'validation' if is_validation else 'test'

    # again no gradients needed
    with torch.no_grad():
        for (images, labels) in tqdm(data_loader, desc=f'{validation_string} loop'):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predictions = torch.max(outputs, 1)
            # collect the correct predictions for each class
            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred[image_classes[label]] += 1
                total_pred[image_classes[label]] += 1

    # print accuracy for each class
    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        if not is_validation:
            print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
            
        if INCLUDE_WANDB:
            key = f'{validation_string} {classname} accuracy'
            wandb.log({
                key: accuracy
            })
            wandb.run.summary[key] = accuracy
            

In [34]:
for epoch in range(NUM_EPOCHS):
    
    correct_predictions: int = 0
    total_predictions: int = 0
    
    for idx, (inputs, labels) in enumerate(tqdm(train_dataloader), 0):
        # zero the parameter gradients
        optimizer.zero_grad()
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        # forward pass
        y_hat = model(inputs)
        
        # calculate losses
        loss = loss_function(y_hat, labels)
        # backpropagate the new gradients
        loss.backward()
        # optimize the gradients
        optimizer.step()
        
        _, y_predictions = torch.max(y_hat, 1)
        correct_predictions += torch.sum(y_predictions == labels).item()
        total_predictions += labels.size(0)
        
    if INCLUDE_WANDB:
        train_accuracy = 100 * correct_predictions / total_predictions
        wandb.log({
            'loss': loss.item(), 
            'training accuracy in %': train_accuracy,
            })
        
    if epoch % 50 == 0:
        model.eval()
        get_test_accuracy(validation_dataloader, is_validation=True)
        get_test_accuracy_of_each_class(validation_dataloader, is_validation=True)
        model.train()
    

100%|██████████| 13/13 [00:22<00:00,  1.71s/it]
validation loop: 100%|██████████| 4/4 [00:21<00:00,  5.46s/it]
validation loop: 100%|██████████| 4/4 [00:21<00:00,  5.44s/it]

Accuracy for class: plane is 14.3 %
Accuracy for class: car   is 10.7 %
Accuracy for class: bird  is 0.0 %
Accuracy for class: cat   is 0.0 %
Accuracy for class: deer  is 0.0 %
Accuracy for class: dog   is 0.0 %
Accuracy for class: frog  is 33.3 %
Accuracy for class: horse is 9.5 %
Accuracy for class: ship  is 13.3 %
Accuracy for class: truck is 13.0 %





In [35]:
MODEL_PATH: str = f'./models/cifar_net.pth'
torch.save(model.state_dict(), MODEL_PATH)

In [36]:
get_test_accuracy(test_dataloader)

test loop: 100%|██████████| 4/4 [00:21<00:00,  5.47s/it]

Accuracy of the network on the 250 test images: 24 %





In [37]:
get_test_accuracy_of_each_class(test_dataloader)

test loop: 100%|██████████| 4/4 [00:21<00:00,  5.48s/it]

Accuracy for class: plane is 40.0 %
Accuracy for class: car   is 10.0 %
Accuracy for class: bird  is 0.0 %
Accuracy for class: cat   is 0.0 %
Accuracy for class: deer  is 10.5 %
Accuracy for class: dog   is 0.0 %
Accuracy for class: frog  is 65.5 %
Accuracy for class: horse is 8.3 %
Accuracy for class: ship  is 38.2 %
Accuracy for class: truck is 50.0 %



