# Model Optimization

### Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Import required libraries

In [None]:
pip install --upgrade nni

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils.prune import identity
from datetime import datetime
from datetime import date
from itertools import product
import os
import torchvision.models as tmodels
from functools import partial
import collections
import nni
from nni.algorithms.compression.pytorch.quantization import NaiveQuantizer
from nni.algorithms.compression.pytorch.pruning import LevelPruner
import time
from nni.algorithms.compression.pytorch.pruning import AGPPruner

### Prepare data

In [None]:
norm_param_dataset_ref = 3
dataset_name = "TRG_3_FINAL"

In [None]:
# Retrieve normalisation parameters 

norm_param_df = pd.read_csv('/content/drive/MyDrive/KASHIKO/DATASET/TRG_DATASET_NORM_PARAM.csv')

meanR = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "meanR"].item()
meanG = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "meanG"].item()
meanB = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "meanB"].item()

stdR = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "stdR"].item()
stdG = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "stdG"].item()
stdB = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "stdB"].item()

In [None]:
# Prepare normalized dataset
dataset = datasets.ImageFolder(
    '/content/drive/MyDrive/KASHIKO/DATASET/' + dataset_name,
    transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((meanR, meanG, meanB), (stdR, stdG, stdB))
    ])
)

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 1000, 1000])
trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=100,
        shuffle=True,
        num_workers=2,
        drop_last=True)
valloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=2,
        drop_last=True)

### Load models

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 12, 5)
        self.bn1 = nn.BatchNorm2d(12)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(12, 24, 5)
        self.bn2 = nn.BatchNorm2d(24)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(24*53*53, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)
    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1,24*53*53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
net1 = Net()
state_dict1 = torch.load('/content/drive/MyDrive/KASHIKO/MODELS/model_2021-05-29_12:16:11_ trg_dataset1 batch_size=100 learning_rate=0.001 scheduler_step_size=5 scheduler_gamma=1 weight_decay=0 epoch_number=11 accuracy=97.6.pth')
net1.load_state_dict(state_dict1)

In [None]:
net = Net()
state_dict2 = torch.load('/content/drive/MyDrive/KASHIKO/MODELS/model_2021-05-29_18:41:53_ trg_dataset2 batch_size=100 learning_rate=0.001 scheduler_step_size=5 scheduler_gamma=1 weight_decay=0 epoch_number=16 accuracy=98.1.pth')
net.load_state_dict(state_dict2)

In [None]:
net = Net()
state_dict3 = torch.load('/content/drive/MyDrive/KASHIKO/MODELS/model_2021-05-30_08:24:13_ trg_dataset3 batch_size=100 learning_rate=0.001 scheduler_step_size=5 scheduler_gamma=1 weight_decay=0 epoch_number=19 accuracy=98.2.pth')
net.load_state_dict(state_dict3)


In [None]:
# Set temperature scaling coefficients (values computed previously)
temp_factor_net1 = 3.661
temp_factor_net2 = 4.289
temp_factor_net3 = 3.913

### Convert models to lower version of Pytorch for export

In [None]:
torch.save(net1.state_dict(),'/content/drive/MyDrive/KASHIKO/MODELS/best_model1.pth', _use_new_zipfile_serialization=False)
torch.save(net2.state_dict(),'/content/drive/MyDrive/KASHIKO/MODELS/best_model2.pth', _use_new_zipfile_serialization=False)
torch.save(net3.state_dict(),'/content/drive/MyDrive/KASHIKO/MODELS/best_model3.pth', _use_new_zipfile_serialization=False)

### Test perfomance (accuracy and speed) of pre-pruning model

In [None]:
start_time = time.clock()

# Set metrics to 0
total_all = 0
correct_all = 0
total_sure = 0
correct_sure = 0
total_sure_temp = 0
correct_sure_temp = 0
# Define softmax function
m = nn.Softmax(dim=1)

# Perform a fowrad pass of the dataset through the unpruned model
# and compute performance and time to complete processing
with torch.no_grad():
    for images, labels in valloader:
        net.eval()
        out = net(images)
        _, predicted = torch.max(out.data, 1)
        predicted_soft = m(out)
        predicted_soft_temp = m(out/temp_factor_net2)
        if np.amax(predicted_soft.numpy()) > 0.97:
            total_sure += labels.size(0)
            correct_sure += (predicted == labels).sum().item()
        if np.amax(predicted_soft_temp.numpy()) > 0.73:
            total_sure_temp += labels.size(0)
            correct_sure_temp += (predicted == labels).sum().item()
        total_all += labels.size(0)
        correct_all += (predicted == labels).sum().item()
        
test_accuracy_all = 100 * correct_all / total_all
test_accuracy_sure = 100 * correct_sure / total_sure
test_accuracy_sure_temp = 100 * correct_sure_temp / total_sure_temp

print(test_accuracy_all)
print(test_accuracy_sure)
print(100 * total_sure/total_all)
print(test_accuracy_sure_temp)
print(100 * total_sure_temp/total_all)

end_time = time.clock()
print(end_time - start_time)

### Apply LevelPruner

In [None]:
# Define pruning parameters
config_list = [{ 'sparsity': 0.3, 'op_types': ['default'] }]

# Apply pruning to model
pruner = LevelPruner(net, config_list)
pruner.compress()

### Test perfomance (accuracy and speed) of post-pruning pre-retraining model

In [None]:
start_time = time.clock()

# Set metrics to 0
total_all = 0
correct_all = 0
total_sure = 0
correct_sure = 0
total_sure_temp = 0
correct_sure_temp = 0
# Define softmax function
m = nn.Softmax(dim=1)

# Perform a fowrad pass of the dataset through the pruned (but not retrained) model
# and compute performance and time to complete processing
with torch.no_grad():
    for images, labels in valloader:
        net.eval()
        out = net(images)
        _, predicted = torch.max(out.data, 1)
        predicted_soft = m(out)
        predicted_soft_temp = m(out/temp_factor_net2)
        if np.amax(predicted_soft.numpy()) > 0.97:
            total_sure += labels.size(0)
            correct_sure += (predicted == labels).sum().item()
        if np.amax(predicted_soft_temp.numpy()) > 0.73:
            total_sure_temp += labels.size(0)
            correct_sure_temp += (predicted == labels).sum().item()
        total_all += labels.size(0)
        correct_all += (predicted == labels).sum().item()
        
test_accuracy_all = 100 * correct_all / total_all
test_accuracy_sure = 100 * correct_sure / total_sure
test_accuracy_sure_temp = 100 * correct_sure_temp / total_sure_temp

print(test_accuracy_all)
print(test_accuracy_sure)
print(100 * total_sure/total_all)
print(test_accuracy_sure_temp)
print(100 * total_sure_temp/total_all)

end_time = time.clock()
print(end_time - start_time)

### Retraining pruned model 

In [None]:
# Define optimizer
optimizer = torch.optim.Adam(net1.parameters(), lr=0.001)

In [None]:
# Load pruned model
net1_pruned = LevelPruner(net1, config_list)#, optimizer, pruning_algorithm='level')
net1_pruned.compress()

In [None]:
# Set training hyperparameters
parameters = dict(learning_rate = [0.001],
                  batch_size = [100],
                  weight_decay = [0],
                  epoch_number = [5],
                  scheduler_step_size = [5],
                  scheduler_gamma = [1]   )
param_values = [v for v in parameters.values()]
trg_dataset_ref = 3
valloader_size = 1000

In [None]:
# Perform model training
for learning_rate, batch_size, weight_decay, epoch_number, scheduler_step_size, scheduler_gamma in product(*param_values): 

    # Define Optimizer and scheduler
    optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay = weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=scheduler_gamma)
    criterion = nn.CrossEntropyLoss()

    # Initialize tensorboard SummaryWriter file/directory
    date_now = str(date.today())
    time_now = datetime.now().strftime("%H:%M:%S")    
    log_dir_root = os.path.join('/content/drive/My Drive/KASHIKO/RUNS', date_now + '_' + time_now + '_')
    comment = f' trg_dataset{trg_dataset_ref} batch_size={batch_size} learning_rate={learning_rate} scheduler_step_size={scheduler_step_size} scheduler_gamma={scheduler_gamma} weight_decay={weight_decay} epoch_number={epoch_number}'
    log_dir = log_dir_root + comment
    tb = SummaryWriter(log_dir)
    
    best_accuracy = 0.0

    for epoch in range(epoch_number):  # loop over the dataset multiple times

        net.train()
        net.requires_grad = True
        trg_running_loss = 0.0
        trg_epoch_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            trg_running_loss += loss.item()
            trg_epoch_loss += loss.item()
            if i % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, trg_running_loss / 10))
                trg_running_loss = 0.0

        # At the end of each epoch, check the performance of the network using the validation dataset
        correct = 0.0
        total = 0.0
        TP = 0.0
        TN = 0.0
        FP = 0.0
        FN = 0.0
        val_loss = 0.0
        with torch.no_grad():
            for data in valloader:
                net.eval()
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                TP += (predicted == labels & labels == 0).sum().item()
                TN += (predicted == labels & labels == 1).sum().item()
                FP += (predicted != labels & predicted == 0).sum().item()
                FN += (predicted != labels & predicted == 1).sum().item()
                val_loss += criterion(outputs, labels)
        val_accuracy = 100 * correct / total
        val_sensitivity_TPR = TP/(FN+TP) if (FN+TP)!=0.0 else -1.0
        val_specificity_TNR = TN/(TN+FP) if (TN+FP)!=0.0 else -1.0
        val_FPR = FP/(TN+FP) if (TN+FP)!=0.0 else -1.0
        val_FNR = FN/(FN+TP) if (FN+TP)!=0.0 else -1.0
        val_precision = TP/(TP+FP) if (TP+FP)!=0.0 else -1.0
        val_recall = TP/(FN+TP) if (FN+TP)!=0.0 else -1.0
        inv_val_recall = 1/val_recall if val_recall!=0.0 else -1.0
        inv_val_precision = 1/val_precision if val_precision!=0.0 else -1.0
        val_F1_score = 2/(inv_val_precision + inv_val_recall) if (inv_val_precision + inv_val_recall)!=0.0 else -1.0
        print(f'Accuracy of the network on the 1000 test images:{val_accuracy}')

        # Store metrics and other parameters in tensorboard SummaryWriter
        # Metrics
        tb.add_scalar('Training Loss', trg_epoch_loss/(int((len(dataset)-valloader_size)/batch_size) * batch_size), epoch+1)
        tb.add_scalar('Validation Loss', val_loss/valloader_size, epoch+1)
        tb.add_scalar('Accuracy', val_accuracy, epoch+1)
        tb.add_scalar('Sensitivity TPR', val_sensitivity_TPR, epoch+1)
        tb.add_scalar('Specificity TNR', val_specificity_TNR, epoch+1)        
        tb.add_scalar('FPR', val_FPR, epoch+1)
        tb.add_scalar('FNR', val_FNR, epoch+1)        
        tb.add_scalar('Precision', val_precision, epoch+1)
        tb.add_scalar('Recall', val_recall, epoch+1)        
        tb.add_scalar('F1 Score', val_F1_score, epoch+1)
        # DEBUG
        tb.add_scalar('False Positive', FP, epoch+1)        
        tb.add_scalar('False Negative', FN, epoch+1)
        tb.add_scalar('True Positive', TP, epoch+1)        
        tb.add_scalar('True Negative', TN, epoch+1)
        # Training parameters
        tb.add_scalar('Learning rate (scheduler)', optimizer.param_groups[0]["lr"], epoch+1)
        # NN Layers parameters
        #tb.add_histogram('conv1.bias', net.conv1.bias, epoch+1)
        #tb.add_histogram('conv1.weight', net.conv1.weight, epoch+1)
        #tb.add_histogram('conv1.weight.grad',net.conv1.weight.grad,epoch+1)
        #tb.add_histogram('bn1.bias', net.bn1.bias, epoch+1)
        #tb.add_histogram('bn1.weight', net.bn1.weight, epoch+1)
        #tb.add_histogram('bn1.weight.grad',net.bn1.weight.grad,epoch+1)      
        #tb.add_histogram('conv2.bias', net.conv2.bias, epoch+1)
        #tb.add_histogram('conv2.weight', net.conv2.weight, epoch+1)
        #tb.add_histogram('conv2.weight.grad',net.conv2.weight.grad,epoch+1)
        #tb.add_histogram('bn2.bias', net.bn2.bias, epoch+1)
        #tb.add_histogram('bn2.weight', net.bn2.weight, epoch+1)
        #tb.add_histogram('bn2.weight.grad',net.bn2.weight.grad,epoch+1)  
        #tb.add_histogram('fc1.bias', net.fc1.bias, epoch+1)
        #tb.add_histogram('fc1.weight', net.fc1.weight, epoch+1)
        #tb.add_histogram('fc1.weight.grad',net.fc1.weight.grad,epoch+1)
        #tb.add_histogram('fc2.bias', net.fc2.bias, epoch+1)
        #tb.add_histogram('fc2.weight', net.fc2.weight, epoch+1)
        #tb.add_histogram('fc2.weight.grad',net.fc2.weight.grad,epoch+1)
        #tb.add_histogram('fc3.bias', net.fc3.bias, epoch+1)
        #tb.add_histogram('fc3.weight', net.fc3.weight, epoch+1)
        #tb.add_histogram('fc3.weight.grad',net.fc3.weight.grad,epoch+1)
        
        
        # Save models
        comment = f' trg_dataset{trg_dataset_ref} batch_size={batch_size} learning_rate={learning_rate} scheduler_step_size={scheduler_step_size} scheduler_gamma={scheduler_gamma} weight_decay={weight_decay} epoch_number={epoch} accuracy={val_accuracy}'
        torch.save(net.state_dict(),'/content/drive/MyDrive/KASHIKO/MODELS/pruning_testmodel_' + date_now + '_' + time_now + '_' + comment + '_SAVE.pth')
        pruner.export_model(model_path='/content/drive/MyDrive/KASHIKO/MODELS/pruning_testmodel_' + date_now + '_' + time_now + '_' + comment + '_EXPORT.pth', mask_path='/content/drive/MyDrive/KASHIKO/MODELS/pruning_testmodel_' + date_now + '_' + time_now + '_' + comment + '_MASK.pth')
        
        # Update learning rate
        scheduler.step()
        
    # At the end of the training, close the tensorboard SummaryWriter and save the model to the drive
    tb.close()
    print('Training Completed')

In [None]:
pruner.export_model(model_path='/content/drive/MyDrive/KASHIKO/MODELS/last_model3_pruned.pth', mask_path='/content/drive/MyDrive/KASHIKO/MODELS/mask_last_model1_pruned.pth')

In [None]:
torch.save(net.state_dict(),'/content/drive/MyDrive/KASHIKO/MODELS/best_model3_pruned.pth', _use_new_zipfile_serialization=False)

### Test perfomance (accuracy and speed) of post-pruning post-retraining model

In [None]:
# Initialise model
net3_pruned = Net()

In [None]:
# Load pruned model
state_dict3_pruned = torch.load('/content/drive/MyDrive/KASHIKO/MODELS/last_model3_pruned.pth')
net3_pruned.load_state_dict(state_dict3_pruned)

In [None]:
start_time = time.clock()

# Set metrics to 0
total_all = 0
correct_all = 0
total_sure = 0
correct_sure = 0
total_sure_temp = 0
correct_sure_temp = 0
# Define softmax function
m = nn.Softmax(dim=1)

# Perform a fowrad pass of the dataset through the pruned and retrained model
# and compute performance and time to complete processing
with torch.no_grad():
    for images, labels in valloader:
        net3_pruned.eval()
        out = net3_pruned(images)
        _, predicted = torch.max(out.data, 1)
        predicted_soft = m(out)
        predicted_soft_temp = m(out/temp_factor_net1)
        if np.amax(predicted_soft.numpy()) > 0.97:
            total_sure += labels.size(0)
            correct_sure += (predicted == labels).sum().item()
        if np.amax(predicted_soft_temp.numpy()) > 0.73:
            total_sure_temp += labels.size(0)
            correct_sure_temp += (predicted == labels).sum().item()
        total_all += labels.size(0)
        correct_all += (predicted == labels).sum().item()
        
test_accuracy_all = 100 * correct_all / total_all
test_accuracy_sure = 100 * correct_sure / total_sure
test_accuracy_sure_temp = 100 * correct_sure_temp / total_sure_temp

print(test_accuracy_all)
print(test_accuracy_sure)
print(100 * total_sure/total_all)
print(test_accuracy_sure_temp)
print(100 * total_sure_temp/total_all)

end_time = time.clock()
print(end_time - start_time)