# Compare your Network with State-Of-the-Art Networks

### Import Libraries

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
from torchvision.models import resnet50, ResNet50_Weights, vgg19, VGG19_Weights
from torch.utils.data import random_split, ConcatDataset

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torcheval.metrics.functional import (multiclass_accuracy, 
                                          multiclass_confusion_matrix, 
                                          multiclass_precision, 
                                          multiclass_recall)
from sklearn.metrics import ConfusionMatrixDisplay, recall_score, precision_score, accuracy_score

### Load the Data

In [None]:
# Batch size for training, validation and testing datasets
batch_size = 32

# Percentages for training, validation and training sets
train_split = 0.6
valid_split = 0.2
test_split = 0.2

In [None]:
transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                ])

In [None]:
train_data = datasets.CIFAR10('./data', train=True, transform=transform, download=True)
test_data = datasets.CIFAR10('./data', train=False, transform=transform, download=True)

In [None]:
# Concatenate the datasets
full_dataset = ConcatDataset([train_data, test_data])

len_full_dataset = len(full_dataset)
print("Full dataset length", len_full_dataset)

In [None]:
# Split data into training, validation and test datasets

# Seed the generator to achieve the same splits everytime
split_generator = torch.Generator().manual_seed(42)

train_size = int(np.floor(train_split * len_full_dataset))
valid_size = int(np.floor(valid_split * len_full_dataset))
test_size = int(np.floor(test_split * len_full_dataset))

train_dataset, valid_dataset, test_dataset = random_split(full_dataset, 
                                                               [train_size, valid_size, test_size], 
                                                               split_generator)

In [None]:
print("Train dataset length: ", len(train_dataset))
print("Validation dataset length: ", len(valid_dataset))
print("Test dataset length: ", len(test_dataset))

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=True)

In [None]:
# Image classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

---
### ResNet50 Model

In [None]:
# Load the pretrained model
resnet50 = resnet50(weights=ResNet50_Weights.DEFAULT)

In [None]:
# Print the model structure
print(resnet50)

In [None]:
print(resnet50.fc)

In [None]:
# Freeze parameters
# for param in resnet50.parameters():
#     param.requires_grad = False

### ResNet50 Custom Classifier

In [None]:
n_inputs = resnet50.fc.in_features

fc = nn.Sequential(nn.Linear(n_inputs, 1024),
                   nn.ReLU(),
#                    nn.Dropout(0.5),
                   nn.Linear(1024, 512),
                   nn.ReLU(),
#                    nn.Dropout(0.5),
                   nn.Linear(512, 10))

resnet50.fc = fc

In [None]:
# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('Training on CPU ...')
else:
    print('Training on GPU ...')

In [None]:
if train_on_gpu:
    resnet50.cuda()

In [None]:
print(resnet50.fc)

### Loss Function & Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()

# optimizer = optim.Adam(resnet50.fc.parameters(), lr=0.001)
optimizer = optim.Adam(resnet50.parameters(), lr=0.001)

### Train the Network

In [None]:
n_epochs = 10

valid_loss_min = np.Inf

train_losses, valid_losses = [], []

for epoch in range(1, n_epochs+1):
    train_loss = 0.0
    valid_loss = 0.0
    
    ############# Train the model #############
    resnet50.train()
    for images, labels in train_loader:
        if train_on_gpu:
            images, labels = images.cuda(), labels.cuda()
            
        optimizer.zero_grad()
        
        output = resnet50(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
        
    ############# Train the model #############
    with torch.no_grad():
        resnet50.eval()
        for images, labels in valid_loader:
            if train_on_gpu:
                images, labels = images.cuda(), labels.cuda()
                
            output = resnet50(images)
            loss = criterion(output, labels)
            valid_loss += loss.item() * images.size(0)
            
    # Calculate average losses
    train_loss = train_loss/len(train_loader.sampler)
    valid_loss = valid_loss/len(valid_loader.sampler)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss, valid_loss))
    
    # Save model when validation loss decreases
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(valid_loss_min,valid_loss))
        
        torch.save(resnet50.state_dict(), 'model_resnet50.pt')
        valid_loss_min = valid_loss

In [None]:
plt.plot(train_losses, label='Training loss')
plt.plot(valid_losses, label='Validation loss')
plt.legend(frameon=False)
plt.xticks(range(10))
plt.show()

### Load the Model with Lowest Validation Loss

In [None]:
resnet50.load_state_dict(torch.load('model_resnet50.pt'))