# CNN for Image Classification

### Import Libraries

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import random_split, ConcatDataset

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torcheval.metrics.functional import (multiclass_accuracy, 
                                          multiclass_confusion_matrix, 
                                          multiclass_precision, 
                                          multiclass_recall)
from sklearn.metrics import ConfusionMatrixDisplay, recall_score, precision_score, accuracy_score

### Load the data

In [None]:
# Batch size for training, validation and testing datasets
batch_size = 64

# Percentages for training, validation and training sets
train_split = 0.6
valid_split = 0.2
test_split = 0.2

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                ])

In [None]:
train_data = datasets.CIFAR10('./data', train=True, transform=transform, download=True)
test_data = datasets.CIFAR10('./data', train=False, transform=transform, download=True)

In [None]:
# Concatenate the datasets
full_dataset = ConcatDataset([train_data, test_data])

len_full_dataset = len(full_dataset)
print("Full dataset length", len_full_dataset)

In [None]:
# Split data into training, validation and test datasets

# Seed the generator to achieve the same splits everytime
split_generator = torch.Generator().manual_seed(42) 

train_size = int(np.floor(train_split * len_full_dataset))
valid_size = int(np.floor(valid_split * len_full_dataset))
test_size = int(np.floor(test_split * len_full_dataset))

train_dataset, valid_dataset, test_dataset = random_split(full_dataset, 
                                                               [train_size, valid_size, test_size], 
                                                               split_generator)

In [None]:
print("Train dataset length: ", len(train_dataset))
print("Validation dataset length: ", len(valid_dataset))
print("Test dataset length: ", len(test_dataset))

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=True)

In [None]:
# Image classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']