In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

import random
from copy import deepcopy

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np

from imblearn.metrics import (geometric_mean_score, sensitivity_score, 
                              specificity_score)
from sklearn.metrics import (balanced_accuracy_score, precision_score, 
                             recall_score, f1_score)


import scipy as sc
import matplotlib.style

params = {'legend.fontsize': 14,
          'axes.labelsize': 14,
          'axes.titlesize': 14,
          'xtick.labelsize' :14,
          'ytick.labelsize': 13,
          'grid.color': 'k',
          'grid.linestyle': ':',
          'grid.linewidth': 0.8,
          'mathtext.fontset' : 'stix',
          'mathtext.rm'      : 'serif',
          'font.family'      : 'serif',
          'font.serif'       : "Times New Roman", # or "Times"          
         }
matplotlib.rcParams.update(params)

In [None]:
if torch.cuda.is_available():
    print("life is good")
    
device = torch.device("cuda")

In [None]:
dataset = 'CIFAR-10'
batch_size = 32
val_size = 0.2
epochs=200

In [None]:
if dataset == 'CIFAR-10':
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        numpy.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(1)

    train_transform = transforms.Compose([
            util.Cutout(num_cutouts=2, size=8, p=0.8),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    
    test_transform = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                             ])
    
    print('Datasets are being downloaded...')

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

    train_indices, val_indices, _, _ = train_test_split(
    range(len(trainset)),
    trainset.targets,
    stratify=trainset.targets,
    test_size=val_size
    )

    train_split = Subset(trainset, train_indices)
    val_split = Subset(trainset, val_indices)

    trainloader = DataLoader(train_split, batch_size=batch_size, shuffle=True, worker_init_fn=seed_worker, generator=g)
    validloader = DataLoader(val_split, batch_size=batch_size, shuffle=True, worker_init_fn=seed_worker, generator=g)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g)

    classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    output_size_network = len(classes)

    print('Download finished!')

elif dataset == 'CIFAR-100':
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        numpy.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(2)
    
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    print('Datasets are being downloaded...')

    trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

    train_indices, val_indices, _, _ = train_test_split(
    range(len(trainset)),
    trainset.targets,
    stratify=trainset.targets,
    test_size=val_size
    )

    train_split = Subset(trainset, train_indices)
    val_split = Subset(trainset, val_indices)

    trainloader = DataLoader(train_split, batch_size=batch_size, shuffle=True, worker_init_fn=seed_worker, generator=g)
    validloader = DataLoader(val_split, batch_size=batch_size, shuffle=True, worker_init_fn=seed_worker, generator=g)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g)

    classes = ('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 
                  'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 
                  'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 
                  'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 
                  'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 
                  'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 
                  'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm')
    
    output_size_network = len(classes)
    
    print(f'Download finished!')

In [None]:
class ResidualBlock(nn.Module):
    """
    A residual block as defined by He et al.
    """

    def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
        super(ResidualBlock, self).__init__()
        self.conv_res1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   padding=padding, stride=stride, bias=False)
        self.conv_res1_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9)
        self.conv_res2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   padding=padding, bias=False)
        self.conv_res2_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9)

        if stride != 1:
            # in case stride is not set to 1, we need to downsample the residual so that
            # the dimensions are the same when we add them together
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_features=out_channels, momentum=0.9)
            )
        else:
            self.downsample = None

        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        residual = x

        out = self.relu(self.conv_res1_bn(self.conv_res1(x)))
        out = self.conv_res2_bn(self.conv_res2(out))

        if self.downsample is not None:
            residual = self.downsample(residual)

        out = self.relu(out)
        out += residual
        return out


class Net(nn.Module):
    """
    A Residual network.
    """
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ResidualBlock(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ResidualBlock(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.fc = nn.Linear(in_features=1024, out_features=output_size_network, bias=True)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(-1, out.shape[1] * out.shape[2] * out.shape[3])
        out = self.fc(out)
        return out

model = Net()
model = nn.DataParallel(model)
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
val_loss_min1 = np.Inf

In [None]:
train_loss_hist = []
val_loss_hist = []

    
print('-------------------')
print('MODEL: ResNet9')
print('-------------------')

np.random.seed(1)
torch.manual_seed(1)
random.seed(1)
torch.cuda.manual_seed(1)

for epoch in range(1, epochs+1):  
    train_loss = 0.0
    val_loss = 0.0

    model.train()
    for data, labels in trainloader:
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    model.eval()

    with torch.no_grad():
        for data, labels in validloader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()


    train_loss = train_loss/len(trainloader)
    val_loss = val_loss/len(validloader)
    train_loss_hist.append(train_loss)
    val_loss_hist.append(val_loss)

    print('Epoch: {} \tTraining Loss: {:.3f} \tValidation Loss: {:.3f}'.format( 
        epoch, train_loss, val_loss))
 

    if val_loss <= val_loss_min[network][seed-1]:
        print('Validation loss decreased ({:.3f} --> {:.3f}).  Saving model ...'.format(
        val_loss_min1,
        val_loss))
        torch.save({
            'model'+'_state_dict': deepcopy(model.state_dict()),
            'optimizer'+'_state_dict': deepcopy(optimizer.state_dict()),
            'epoch': epoch
                }, 'balance_cifar'+str('.pt'))
        val_loss_min1 = val_loss

In [None]:
device = torch.device("cuda")

checkpoint1 = torch.load('balance_cifar.pt')
model1.load_state_dict(checkpoint1['model_state_dict'])
optimizer1.load_state_dict(checkpoint1['optimizer_state_dict'])
epoch1 = checkpoint1['epoch']
model1.to(device)

In [None]:
model_updated = model1

In [None]:
global_accuracy = []
test_loss_hist = []
labels_list = []
pred_list= []
class_accuracy_model =[]
        
test_loss = 0.0
correct = 0
total = 0
    
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
        
model_updated.eval()
with torch.no_grad():
    for data, labels in testloader:
        images, labels = data.to(device), labels.to(device)
        output = model_updated(images)
    
        loss = criterion(output, labels)
        test_loss += loss.item()

        _, pred = torch.max(output, 1)    

        total += labels.size(0)
        correct += (pred == labels).sum().item()
                
        for label, p in zip(labels, pred):
            if label == p:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

        pred = pred.cpu().detach().numpy()
        labels = labels.cpu().detach().numpy()

        pred_list.append(pred)
        labels_list.append(labels)

test_loss = test_loss/len(testloader)
test_loss_hist.append(test_loss)
        
accuracy = 100 * correct / total
global_accuracy.append(accuracy)
        
for classname, correct_count in correct_pred.items():
    model_seed_accuracy = 100 * float(correct_count) / total_pred[classname]
    class_accuracy_model.append(model_seed_accuracy)

In [None]:
print('Test Loss for ResNet9: {:.3f}'.format(test_loss_hist))
print('Accuracy for ResNet9: {:.3f}%'.format(global_accuracy))

In [None]:
pred_list = [item for sublist in pred_list for item in sublist]
labels_list = [item for sublist in labels_list for item in sublist]

pred_model1 = pred_list.tolist()
labels_list = labels_list[0].tolist()

In [None]:
f1_micro = f1_score(labels_list, pred_model1, average='micro')
f1_macro = f1_score(labels_list, pred_model1, average='macro')
        
gmean_micro = geometric_mean_score(labels_list, pred_model1, average='micro')
gmean_macro = geometric_mean_score(labels_list, pred_model1, average='macro')
        
bac = balanced_accuracy_score(labels_list, pred_model1)
bac_adj = balanced_accuracy_score(labels_list, pred_model1, adjusted=True)
        
sens_micro = sensitivity_score(labels_list, pred_model1, average='micro')
sens_macro = sensitivity_score(labels_list, pred_model1, average='macro')
        
spec_micro = specificity_score(labels_list, pred_model1, average='micro')
spec_macro = specificity_score(labels_list, pred_model1, average='macro')
                                       
prec_micro = precision_score(labels_list, pred_model1, average='micro')
prec_macro = precision_score(labels_list, pred_model1, average='macro')
                                       
rec_micro = recall_score(labels_list, pred_model1, average='micro')
rec_macro = recall_score(labels_list, pred_model1, average='macro')

In [None]:
metrics_list = [f1_micro, f1_macro, gmean_micro, gmean_macro, bac, bac_adj, sens_micro, sens_macro, spec_micro, spec_macro,
                prec_micro, prec_macro, rec_micro, rec_macro]

names = ["F1 Micro", "F1 Macro", "GMean Micro", "GMean Macro", "Balanced Accuracy", "Adjusted Balanced Accuracy", "Sensitivity Micro", "Sensitivity Macro", "Specificity Micro", 
            "Specificity Macro", "Precision Micro", "Precision Macro", "Recall Micro", "Recall Macro"]

for metric, name in zip(metrics_list, names): 
    print('---------------')
    print('MODEL: ResNet9')
    print('METRIC:', name)
    print('---------------')
    print(' {:.2f}'.format(metric))