# CIFAR-10 Image Classification with CNNs

In [None]:
!pip install torchinfo

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchinfo import summary

# Transformations for CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Load CIFAR-10 dataset
batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

classes = trainset.classes

In [None]:
# cuda means NVIDIA GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
def show_images(img, mean, std):
    # Un-normalize the image and plot it
    mean = torch.tensor(mean, device=img.device)
    std = torch.tensor(std, device=img.device)
    img = img.permute(1, 2, 0)
    img = img * std + mean
    plt.imshow(img.numpy())


# fetch a batch from the train dataset
dataiter = iter(trainloader)
images, labels = next(dataiter)
nimages = min(batch_size, 4)
norm = transform.transforms[-1]
show_images(torchvision.utils.make_grid(images[:nimages]), norm.mean, norm.std)
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(nimages)))

## Fully-connected network

In [None]:
# TODO: class FullyConnectedNet(nn.Module):

## Train and test functions

In [None]:
def evaluate_model(model, dataloader):
	# TODO:
	return correct / total

def train_test(model, criterion, optimizer, epochs=5):
	train_accuracies = []
	test_accuracies = []
	# TODO:
	# print(f"Epoch {epoch + 1}: Training Loss = {running_loss / len(trainloader):.4f}, Training Accuracy = {train_accuracy * 100:.2f}%, Test Accuracy = {test_accuracy * 100:.2f}%")
	return np.array(train_accuracies), np.array(test_accuracies)

In [None]:
def plot_accuracy(train_accuracies, test_accuracies, epochs, title):
    plt.figure(figsize=(10, 6))
    plt.ylim(0, 100)
    plt.plot(range(1, epochs + 1), train_accuracies*100, label='Training Accuracy', marker='o')
    plt.plot(range(1, epochs + 1), test_accuracies*100, label='Test Accuracy', marker='o')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title(f'{title} - Training and Test Accuracy')
    plt.legend()
    plt.grid()
    plt.show()

## Train fully-connected network

In [None]:
fc_model = FullyConnectedNet()
# TODO: print model summary
# print(fc_model)
fc_model.to(device)

## Small Convolutional Neural Network

In [None]:
#TODO: class ConvNet3(nn.Module):

In [None]:
conv3_model = ConvNet3()
print(summary(conv3_model, input_size=(batch_size, 3, 32, 32)))
# print(conv3_model)
conv3_model.to(device)

## Fancy VGG-style CNN

In [None]:
#TODO: class VGGStyleNet(nn.Module):


In [None]:
vgg_model = VGGStyleNet()
print(summary(vgg_model, input_size=(batch_size, 3, 32, 32)))
# print(vgg_model)
vgg_model.to(device)

## Train models

In [None]:
models = [fc_model, conv3_model, vgg_model]
model_names = ["Fully-connected Model", "3-layer CNN", "VGG-style CNN"]
best_model = None
best_accuracy = 0.0
# TODO: train models and find best one

## Show predictions

In [None]:
print(f"The best model was: {best_model}\n with an accuracy of {best_accuracy*100:.2f}%")

In [None]:
dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
nimages = min(batch_size, 8)
norm = testset.transform.transforms[-1]
show_images(torchvision.utils.make_grid(images[:nimages]), norm.mean, norm.std)
print('GroundTruth: ', ' '.join('%5s' % testset.classes[labels[j]] for j in range(nimages)))

# TODO: get predictions of best model


# Print the predicted labels
# print('Predicted: ', ' '.join('%5s' % testset.classes[preds[j]] for j in range(nimages)))