# Classifier

## Import libraries

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from tqdm.notebook import tqdm, trange
from time import sleep
from torchvision.io import read_image
from torchvision.transforms import ToTensor
from torchvision import transforms
from torchsummary import summary
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from collections import Counter
import numpy as np
from sklearn.metrics import roc_curve, auc
from sklearn import datasets
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import StratifiedKFold


In [29]:
seed = 23
save_model = True
load_model = False

models_directory = "models"
metrics_directory = "metrics"
metrics_name = "metrics.csv"
dataset_name = "Fer2013_merge_selected_uniform"
train_df_path = f"../datasets/{dataset_name}/train"
test_df_path = f"../datasets/{dataset_name}/test"

batch_size = 64

## Functions definition

In [30]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [31]:
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, scheduler_func=None):
    history = []
    optimizer = opt_func(model.parameters(), lr, betas=(0.5,0.999))
    scheduler = scheduler_func(optimizer, mode='min', patience=3)
    for epoch in tqdm(range(epochs), desc = "Current Epoch"):
        # Training Phase 
        model.train()
        train_losses = []
        for batch in tqdm(train_loader, desc = f"Epoch: {epoch}", leave= False):
            optimizer.zero_grad()
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()

        scheduler.step(result['val_loss'])
        result['lrate'] = get_lr(optimizer)

        model.epoch_end(epoch, result)
        history.append(result)
    return history

## Class definition

In [32]:
class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [33]:
class Metric():
    def __init__(self, name):
        self.name = name
    
    def eval(self, outputs, labels):
        pass

In [34]:
class Accuracy(Metric):
    
    def eval(self, outputs, labels):
        _, preds = torch.max(outputs, dim=1)
        return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [35]:
class ImageClassificationBase(nn.Module):
    
    def __init__(self, loss_function, metrics):
        super().__init__()
        self.loss_function = loss_function
        self.metrics = metrics
    
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = self.loss_function(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = self.loss_function(out, labels)   # Calculate loss
        result = {'val_loss': loss.detach()}
        
        for m in self.metrics:
            result[m.name] = m.eval(out, labels)           # Calculate metrics
            
        return result
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        
        result = {'val_loss': epoch_loss.item()}
        
        for m in self.metrics:
            batch = [x[m.name] for x in outputs]
            epoch = torch.stack(batch).mean()      # Combine metrics
            result[m.name] = epoch.item()
            
        return result
    
    def epoch_end(self, epoch, result):
        out = f"Epoch [{epoch}]"
        vals = list(result.keys())
        for v in vals:
            out += f", {v}: {result[v]:.3e}"
        print(out)
        

In [36]:
class Net(ImageClassificationBase):
    
    def __init__(self, loss_function, metrics, out_size):
        super().__init__(loss_function, metrics)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=3,stride=2, padding=1)
        self.norm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=3,stride=2, padding=1)
        self.norm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.norm3 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(in_features=128*5*5, out_features=256)
        #self.fc2 = nn.Linear(in_features=256, out_features=256)
        self.fc = nn.Linear(256, out_size)

    def forward(self, input):
        output = self.norm1(self.pool1(F.relu(self.conv1(input))))
        output = self.norm2(self.pool2(F.relu(self.conv2((output)))))
        output = self.norm3(self.pool3(F.relu(self.conv3((output)))))
        output = output.view(-1, 128*5*5)
        output = F.relu(self.fc1(output))
        #output = F.relu(self.fc2(output))
        output = self.fc(output)
        output = F.softmax(input = output, dim=-1)
        return output

In [37]:
device = get_default_device()

## Dataset loading

In [38]:
df_train = ImageFolder(root=train_df_path, transform=transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                     transforms.ToTensor()]))
df_test = ImageFolder(root=test_df_path, transform=transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                     transforms.ToTensor()]))

In [39]:
train_dl = DeviceDataLoader(DataLoader(df_train, batch_size=batch_size),device)
test_dl = DeviceDataLoader(DataLoader(df_test, batch_size=batch_size, shuffle=False),device)

In [40]:
classes = df_train.classes
classes

['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']

In [41]:
counts = dict(Counter(df_train.targets))
print(counts)
weights = np.array(list(counts.values()))
weights = torch.Tensor( min(weights)/weights)
print(weights)

{0: 2376, 1: 173, 2: 2460, 3: 4394, 4: 2999, 5: 2938, 6: 1901}
tensor([0.0728, 1.0000, 0.0703, 0.0394, 0.0577, 0.0589, 0.0910])


## Model design and Training

In [42]:
loss_function = nn.CrossEntropyLoss(weight = weights)
metrics = []
optimizer = optim.Adam
scheduler = optim.lr_scheduler.ReduceLROnPlateau

lr = 0.0001
num_epochs = 20

In [43]:
splits = 4
skf = StratifiedKFold(n_splits=splits, random_state=seed, shuffle=True)
histories = []

for fold,(train_idx,val_idx) in enumerate(skf.split(df_train,train_dl.dl.dataset.targets)):
    print('------------fold no---------{}----------------------'.format(fold))
    
    net = Net(loss_function, metrics, len(classes))
    net.to(device)
    
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

    train_fold_dl = DeviceDataLoader(DataLoader(df_train, batch_size=batch_size,  sampler=train_subsampler),device)
    val_fold_dl = DeviceDataLoader(DataLoader(df_train, batch_size=batch_size,  sampler=val_subsampler),device)
    #print(np.unique(np.array(train_fold_dl.dl.dataset.targets)[train_subsampler.indices],return_counts=True))
    #print(np.unique(np.array(train_fold_dl.dl.dataset.targets)[test_subsampler.indices],return_counts=True))
    #print(train_dl[train_subsampler.indices])
    history = dict()
    history["losses"] = fit(num_epochs, lr, net, train_fold_dl, val_fold_dl, optimizer, scheduler)
    history["model"] = net
    histories.append(history)

summary(net,(1,48,48))

------------fold no---------0----------------------


Current Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 0:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [0], val_loss: 1.711e+00, train_loss: 1.792e+00, lrate: 1.000e-04


Epoch: 1:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [1], val_loss: 1.661e+00, train_loss: 1.653e+00, lrate: 1.000e-04


Epoch: 2:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [2], val_loss: 1.621e+00, train_loss: 1.575e+00, lrate: 1.000e-04


Epoch: 3:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [3], val_loss: 1.577e+00, train_loss: 1.525e+00, lrate: 1.000e-04


Epoch: 4:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [4], val_loss: 1.597e+00, train_loss: 1.486e+00, lrate: 1.000e-04


Epoch: 5:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [5], val_loss: 1.557e+00, train_loss: 1.464e+00, lrate: 1.000e-04


Epoch: 6:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [6], val_loss: 1.567e+00, train_loss: 1.435e+00, lrate: 1.000e-04


Epoch: 7:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [7], val_loss: 1.588e+00, train_loss: 1.415e+00, lrate: 1.000e-04


Epoch: 8:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [8], val_loss: 1.542e+00, train_loss: 1.397e+00, lrate: 1.000e-04


Epoch: 9:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [9], val_loss: 1.564e+00, train_loss: 1.370e+00, lrate: 1.000e-04


Epoch: 10:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [10], val_loss: 1.536e+00, train_loss: 1.362e+00, lrate: 1.000e-04


Epoch: 11:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [11], val_loss: 1.532e+00, train_loss: 1.347e+00, lrate: 1.000e-04


Epoch: 12:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [12], val_loss: 1.527e+00, train_loss: 1.329e+00, lrate: 1.000e-04


Epoch: 13:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [13], val_loss: 1.529e+00, train_loss: 1.318e+00, lrate: 1.000e-04


Epoch: 14:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [14], val_loss: 1.534e+00, train_loss: 1.307e+00, lrate: 1.000e-04


Epoch: 15:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [15], val_loss: 1.526e+00, train_loss: 1.301e+00, lrate: 1.000e-04


Epoch: 16:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [16], val_loss: 1.729e+00, train_loss: 1.292e+00, lrate: 1.000e-04


Epoch: 17:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [17], val_loss: 1.527e+00, train_loss: 1.338e+00, lrate: 1.000e-04


Epoch: 18:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [18], val_loss: 1.602e+00, train_loss: 1.286e+00, lrate: 1.000e-04


Epoch: 19:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [19], val_loss: 1.526e+00, train_loss: 1.287e+00, lrate: 1.000e-05
------------fold no---------1----------------------


Current Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 0:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [0], val_loss: 1.697e+00, train_loss: 1.781e+00, lrate: 1.000e-04


Epoch: 1:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [1], val_loss: 1.626e+00, train_loss: 1.635e+00, lrate: 1.000e-04


Epoch: 2:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [2], val_loss: 1.599e+00, train_loss: 1.566e+00, lrate: 1.000e-04


Epoch: 3:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [3], val_loss: 1.580e+00, train_loss: 1.521e+00, lrate: 1.000e-04


Epoch: 4:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [4], val_loss: 1.614e+00, train_loss: 1.485e+00, lrate: 1.000e-04


Epoch: 5:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [5], val_loss: 1.558e+00, train_loss: 1.465e+00, lrate: 1.000e-04


Epoch: 6:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [6], val_loss: 1.606e+00, train_loss: 1.433e+00, lrate: 1.000e-04


Epoch: 7:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [7], val_loss: 1.612e+00, train_loss: 1.416e+00, lrate: 1.000e-04


Epoch: 8:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [8], val_loss: 1.607e+00, train_loss: 1.403e+00, lrate: 1.000e-04


Epoch: 9:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [9], val_loss: 1.552e+00, train_loss: 1.386e+00, lrate: 1.000e-04


Epoch: 10:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [10], val_loss: 1.594e+00, train_loss: 1.365e+00, lrate: 1.000e-04


Epoch: 11:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [11], val_loss: 1.542e+00, train_loss: 1.362e+00, lrate: 1.000e-04


Epoch: 12:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [12], val_loss: 1.582e+00, train_loss: 1.336e+00, lrate: 1.000e-04


Epoch: 13:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [13], val_loss: 1.544e+00, train_loss: 1.328e+00, lrate: 1.000e-04


Epoch: 14:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [14], val_loss: 1.554e+00, train_loss: 1.318e+00, lrate: 1.000e-04


Epoch: 15:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [15], val_loss: 1.537e+00, train_loss: 1.310e+00, lrate: 1.000e-04


Epoch: 16:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [16], val_loss: 1.580e+00, train_loss: 1.300e+00, lrate: 1.000e-04


Epoch: 17:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [17], val_loss: 1.531e+00, train_loss: 1.303e+00, lrate: 1.000e-04


Epoch: 18:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [18], val_loss: 1.533e+00, train_loss: 1.288e+00, lrate: 1.000e-04


Epoch: 19:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [19], val_loss: 1.537e+00, train_loss: 1.281e+00, lrate: 1.000e-04
------------fold no---------2----------------------


Current Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 0:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [0], val_loss: 1.724e+00, train_loss: 1.809e+00, lrate: 1.000e-04


Epoch: 1:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [1], val_loss: 1.662e+00, train_loss: 1.658e+00, lrate: 1.000e-04


Epoch: 2:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [2], val_loss: 1.676e+00, train_loss: 1.577e+00, lrate: 1.000e-04


Epoch: 3:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [3], val_loss: 1.621e+00, train_loss: 1.534e+00, lrate: 1.000e-04


Epoch: 4:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [4], val_loss: 1.617e+00, train_loss: 1.493e+00, lrate: 1.000e-04


Epoch: 5:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [5], val_loss: 1.575e+00, train_loss: 1.463e+00, lrate: 1.000e-04


Epoch: 6:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [6], val_loss: 1.582e+00, train_loss: 1.435e+00, lrate: 1.000e-04


Epoch: 7:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [7], val_loss: 1.560e+00, train_loss: 1.413e+00, lrate: 1.000e-04


Epoch: 8:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [8], val_loss: 1.555e+00, train_loss: 1.392e+00, lrate: 1.000e-04


Epoch: 9:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [9], val_loss: 1.623e+00, train_loss: 1.373e+00, lrate: 1.000e-04


Epoch: 10:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [10], val_loss: 1.551e+00, train_loss: 1.370e+00, lrate: 1.000e-04


Epoch: 11:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [11], val_loss: 1.561e+00, train_loss: 1.342e+00, lrate: 1.000e-04


Epoch: 12:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [12], val_loss: 1.543e+00, train_loss: 1.333e+00, lrate: 1.000e-04


Epoch: 13:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [13], val_loss: 1.620e+00, train_loss: 1.320e+00, lrate: 1.000e-04


Epoch: 14:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [14], val_loss: 1.586e+00, train_loss: 1.320e+00, lrate: 1.000e-04


Epoch: 15:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [15], val_loss: 1.650e+00, train_loss: 1.311e+00, lrate: 1.000e-04


Epoch: 16:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [16], val_loss: 1.545e+00, train_loss: 1.315e+00, lrate: 1.000e-05


Epoch: 17:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [17], val_loss: 1.545e+00, train_loss: 1.284e+00, lrate: 1.000e-05


Epoch: 18:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [18], val_loss: 1.541e+00, train_loss: 1.283e+00, lrate: 1.000e-05


Epoch: 19:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [19], val_loss: 1.540e+00, train_loss: 1.284e+00, lrate: 1.000e-05
------------fold no---------3----------------------


Current Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 0:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [0], val_loss: 1.735e+00, train_loss: 1.792e+00, lrate: 1.000e-04


Epoch: 1:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [1], val_loss: 1.666e+00, train_loss: 1.639e+00, lrate: 1.000e-04


Epoch: 2:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [2], val_loss: 1.619e+00, train_loss: 1.572e+00, lrate: 1.000e-04


Epoch: 3:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [3], val_loss: 1.599e+00, train_loss: 1.523e+00, lrate: 1.000e-04


Epoch: 4:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [4], val_loss: 1.654e+00, train_loss: 1.488e+00, lrate: 1.000e-04


Epoch: 5:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [5], val_loss: 1.594e+00, train_loss: 1.471e+00, lrate: 1.000e-04


Epoch: 6:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [6], val_loss: 1.585e+00, train_loss: 1.438e+00, lrate: 1.000e-04


Epoch: 7:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [7], val_loss: 1.621e+00, train_loss: 1.418e+00, lrate: 1.000e-04


Epoch: 8:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [8], val_loss: 1.580e+00, train_loss: 1.409e+00, lrate: 1.000e-04


Epoch: 9:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [9], val_loss: 1.605e+00, train_loss: 1.389e+00, lrate: 1.000e-04


Epoch: 10:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [10], val_loss: 1.561e+00, train_loss: 1.381e+00, lrate: 1.000e-04


Epoch: 11:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [11], val_loss: 1.561e+00, train_loss: 1.358e+00, lrate: 1.000e-04


Epoch: 12:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [12], val_loss: 1.581e+00, train_loss: 1.346e+00, lrate: 1.000e-04


Epoch: 13:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [13], val_loss: 1.556e+00, train_loss: 1.336e+00, lrate: 1.000e-04


Epoch: 14:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [14], val_loss: 1.555e+00, train_loss: 1.324e+00, lrate: 1.000e-04


Epoch: 15:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [15], val_loss: 1.562e+00, train_loss: 1.313e+00, lrate: 1.000e-04


Epoch: 16:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [16], val_loss: 1.567e+00, train_loss: 1.311e+00, lrate: 1.000e-04


Epoch: 17:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [17], val_loss: 1.556e+00, train_loss: 1.309e+00, lrate: 1.000e-04


Epoch: 18:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [18], val_loss: 1.562e+00, train_loss: 1.295e+00, lrate: 1.000e-05


Epoch: 19:   0%|          | 0/203 [00:00<?, ?it/s]

Epoch [19], val_loss: 1.550e+00, train_loss: 1.287e+00, lrate: 1.000e-05
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 46, 46]             320
         MaxPool2d-2           [-1, 32, 23, 23]               0
       BatchNorm2d-3           [-1, 32, 23, 23]              64
            Conv2d-4           [-1, 64, 21, 21]          18,496
         MaxPool2d-5           [-1, 64, 11, 11]               0
       BatchNorm2d-6           [-1, 64, 11, 11]             128
            Conv2d-7            [-1, 128, 9, 9]          73,856
         MaxPool2d-8            [-1, 128, 5, 5]               0
       BatchNorm2d-9            [-1, 128, 5, 5]             256
           Linear-10                  [-1, 256]         819,456
           Linear-11                    [-1, 7]           1,799
Total params: 914,375
Trainable params: 914,375
Non-trainable params: 0
----------------------

## Save/Load model

In [44]:
if not os.path.isdir(models_directory):
    os.makedirs(name = models_directory)
models_directory_data = os.path.join(models_directory,dataset_name)
if not os.path.isdir(models_directory_data):
    os.makedirs(name = models_directory_data)

In [45]:
if save_model:
    for fold,h in enumerate(histories):
        model_path = os.path.join(models_directory_data,f"classifier_fold_{fold}.pt")
        torch.save(h["model"].state_dict(), model_path)

In [46]:
if load_model:
    histories = []

    for fold, model in enumerate(os.listdir(models_directory_data)):
        net = Net(loss_function, metrics, len(classes))
        net.load_state_dict(torch.load(os.path.join(models_directory_data,model)))
        net.eval()
        net.cuda()
        histories.append( {"model": net } )

## Model testing

In [47]:
def get_model_predicitons(model, dataset):
    y_test = []
    y_scores = []

    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in tqdm(dataset):
            images, labels = data[0].to(device), data[1].to(device)
            # calculate outputs by running images through the network
            outputs = model(images)

            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)

            y_test.extend(labels.tolist())
            y_scores.extend(outputs.tolist())

    y_test = np.array(y_test)
    y_scores = np.array(y_scores)

    return y_test, y_scores


### Network Accuracy

In [48]:
def evaluate_accuracy(y_test, y_scores, classes, verbose=False):
    # Overall accuracy

    n_test = y_test.shape[0]
    predicted = np.argmax(y_scores,axis=1)
    mean_acc = np.sum(predicted == y_test) / n_test
    class_acc = []
    
    # Accuracy per class

    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}

    # collect the correct predictions for each class
    for label, prediction in zip(y_test, predicted):
        if label == prediction:
            correct_pred[classes[label]] += 1
        total_pred[classes[label]] += 1

    for classname, correct_count in correct_pred.items():
        accuracy = float(correct_count) / total_pred[classname]
        class_acc.append(accuracy)

    if verbose:
        print(f'Overall accuracy: {100 * mean_acc } %')

        print("Accuracy per class:")
        for i,c in enumerate(classes):
            print(f'{c:9s} : {class_acc[i] * 100:.1f} %')
        
    return mean_acc, class_acc

# ROC/AUC

In [49]:

def evaluate_roc_auc(y_test, y_scores, classes, plot=False, verbose=False):
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    ths = dict()
    roc_auc = dict()
    for i,_ in enumerate(classes):
        fpr[i], tpr[i], ths[i] = roc_curve(y_test == i, y_scores[:,i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    if plot:
        # Plot of a ROC curve for a specific class
        for i,c in enumerate(classes):
            plt.figure()
            plt.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f)' % roc_auc[i])
            plt.plot([0, 1], [0, 1], 'k--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title(f'ROC {c}')
            plt.legend(loc="lower right")
            plt.show()

    mean_auc = np.mean(list(roc_auc.values()))

    if verbose:
        print(f"Mean AUC: { mean_auc * 100} %")

        print("Per classs AUC:")
        for i,c in enumerate(classes):
            print(f'{c:9s} : {roc_auc[i]*100:.2f} %')

    return mean_auc, roc_auc

In [50]:
y_test = dict()
y_scores = dict()

for fold,h in enumerate(histories):
    y_test[fold], y_scores[fold] = get_model_predicitons(h["model"],test_dl)

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

In [51]:
columns=["mean AUC", "angry AUC", "disgust AUC", "fear AUC", "happy AUC", "neautral AUC", "sad AUC", "surprise AUC", "mean ACC", "angry ACC", "disgust ACC", "fear ACC", "happy ACC", "neautral ACC", "sad ACC", "surprise ACC"]
df = pd.DataFrame(columns = columns)

for fold,_ in enumerate(histories):
    mean_acc,class_acc = evaluate_accuracy(y_test[fold], y_scores[fold], classes)
    mean_auc,roc_auc = evaluate_roc_auc(y_test[fold], y_scores[fold], classes)

    exp = [mean_auc]
    exp += list(roc_auc.values())
    exp += [mean_acc]
    exp += class_acc

    df_new = pd.DataFrame( np.array([exp]),[f"{dataset_name} fold {fold}"], columns=columns)
    df = pd.concat([df,df_new])
    df = df.astype('float')

mean_df = df.mean(0).to_frame().transpose().set_index(pd.Index([f"{dataset_name} mean"]))
df = pd.concat([df,mean_df])
df


Unnamed: 0,mean AUC,angry AUC,disgust AUC,fear AUC,happy AUC,neautral AUC,sad AUC,surprise AUC,mean ACC,angry ACC,disgust ACC,fear ACC,happy ACC,neautral ACC,sad ACC,surprise ACC
Fer2013_merge_selected_uniform fold 0,0.882229,0.843333,0.925,0.78235,0.9534,0.8904,0.84335,0.937767,0.612857,0.56,0.58,0.47,0.83,0.6,0.45,0.8
Fer2013_merge_selected_uniform fold 1,0.871894,0.82995,0.893383,0.773133,0.964908,0.871783,0.818133,0.951967,0.591429,0.52,0.55,0.43,0.84,0.52,0.49,0.79
Fer2013_merge_selected_uniform fold 2,0.882504,0.850033,0.914767,0.789017,0.967658,0.887017,0.8265,0.942533,0.627143,0.58,0.61,0.42,0.81,0.64,0.54,0.79
Fer2013_merge_selected_uniform fold 3,0.879571,0.867983,0.8953,0.786083,0.961667,0.8941,0.809567,0.9423,0.62,0.61,0.57,0.46,0.79,0.63,0.51,0.77
Fer2013_merge_selected_uniform mean,0.879049,0.847825,0.907112,0.782646,0.961908,0.885825,0.824387,0.943642,0.612857,0.5675,0.5775,0.445,0.8175,0.5975,0.4975,0.7875


## Save metrics

In [52]:
try:
    stored_df = pd.read_csv(os.path.join(metrics_directory, metrics_name),index_col=0)
except FileNotFoundError:
    stored_df = pd.DataFrame(columns = columns)
stored_df

Unnamed: 0,mean AUC,angry AUC,disgust AUC,fear AUC,happy AUC,neautral AUC,sad AUC,surprise AUC,mean ACC,angry ACC,disgust ACC,fear ACC,happy ACC,neautral ACC,sad ACC,surprise ACC
Fer2013_uniform xval-4,0.845617,0.7855,0.907362,0.730771,0.923177,0.822044,0.810358,0.940106,0.548571,0.475,0.4925,0.3375,0.725,0.5375,0.52,0.7525
Fer2013_uniform xval-4,0.846888,0.773717,0.911138,0.737354,0.920221,0.829625,0.81185,0.94431,0.547857,0.475,0.465,0.375,0.735,0.5275,0.5,0.7575
Fer2013_uniform xval-4,0.847053,0.782483,0.909817,0.731229,0.925108,0.839454,0.799546,0.941735,0.543214,0.475,0.4825,0.37,0.7175,0.5375,0.455,0.765
Fer2013_uniform xval-4,0.849133,0.775371,0.909379,0.739429,0.917608,0.848187,0.81475,0.939204,0.544286,0.455,0.45,0.3675,0.73,0.55,0.495,0.7625
Fer2013_Aug_Disgust_uniform_e_03_identity xval-4,0.846339,0.776829,0.874633,0.745138,0.919633,0.842625,0.821512,0.944002,0.542143,0.44,0.4825,0.3575,0.715,0.5525,0.485,0.7625
Fer2013_uniform_Aug_filtered xval-4,0.853964,0.804608,0.916829,0.731154,0.918488,0.842688,0.822029,0.94195,0.556429,0.49,0.51,0.3725,0.72,0.5725,0.48,0.75


In [53]:
stored_df = pd.concat([stored_df,mean_df.set_index(pd.Index([f"{dataset_name} xval-{splits}"]))])
stored_df

Unnamed: 0,mean AUC,angry AUC,disgust AUC,fear AUC,happy AUC,neautral AUC,sad AUC,surprise AUC,mean ACC,angry ACC,disgust ACC,fear ACC,happy ACC,neautral ACC,sad ACC,surprise ACC
Fer2013_uniform xval-4,0.845617,0.7855,0.907362,0.730771,0.923177,0.822044,0.810358,0.940106,0.548571,0.475,0.4925,0.3375,0.725,0.5375,0.52,0.7525
Fer2013_uniform xval-4,0.846888,0.773717,0.911138,0.737354,0.920221,0.829625,0.81185,0.94431,0.547857,0.475,0.465,0.375,0.735,0.5275,0.5,0.7575
Fer2013_uniform xval-4,0.847053,0.782483,0.909817,0.731229,0.925108,0.839454,0.799546,0.941735,0.543214,0.475,0.4825,0.37,0.7175,0.5375,0.455,0.765
Fer2013_uniform xval-4,0.849133,0.775371,0.909379,0.739429,0.917608,0.848187,0.81475,0.939204,0.544286,0.455,0.45,0.3675,0.73,0.55,0.495,0.7625
Fer2013_Aug_Disgust_uniform_e_03_identity xval-4,0.846339,0.776829,0.874633,0.745138,0.919633,0.842625,0.821512,0.944002,0.542143,0.44,0.4825,0.3575,0.715,0.5525,0.485,0.7625
Fer2013_uniform_Aug_filtered xval-4,0.853964,0.804608,0.916829,0.731154,0.918488,0.842688,0.822029,0.94195,0.556429,0.49,0.51,0.3725,0.72,0.5725,0.48,0.75
Fer2013_merge_selected_uniform xval-4,0.879049,0.847825,0.907112,0.782646,0.961908,0.885825,0.824387,0.943642,0.612857,0.5675,0.5775,0.445,0.8175,0.5975,0.4975,0.7875


In [54]:
if not os.path.isdir(metrics_directory):
    os.makedirs(name = metrics_directory)
stored_df.to_csv(os.path.join(metrics_directory, metrics_name))