<h1>Results Script</h1>

In [None]:
#import libraries and setup
import torch
from torchvision import transforms
import torch.utils.data as data
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import pandas as pd
import numpy as np
from PIL import Image
from networks.dan import DAN
import os

%matplotlib inline

<h3>Functions Definitions</h3>

In [None]:
## create the tensor for the images as used in the fer2013train.py file 
data_transforms_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])])  

#create the data set as made in the fer2013train.py file 

fer2013_path = 'datasets/fer2013'


class RafDataSet(data.Dataset):
    def __init__(self, fer2013_path, phase, transform = None):
        self.phase = phase
        self.transform = transform
        self.fer2013_path = fer2013_path

        df = pd.read_csv(os.path.join(self.fer2013_path, 'EmoLabel/ferEmoLabellist.txt'), sep=' ', header=None,names=['name','label'])

        if phase == 'train':
            self.data = df[df['name'].str.startswith('Train')]
        else:
            self.data = df[df['name'].str.startswith('Test')]

        file_names = self.data.loc[:, 'name'].values
        self.label = self.data.loc[:, 'label'].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral

        _, self.sample_counts = np.unique(self.label, return_counts=True)
        # print(f' distribution of {phase} samples: {self.sample_counts}')

        self.file_paths = []
        for f in file_names:
            f = f.split(".")[0]
            f = f +".jpg"
            path = os.path.join(self.fer2013_path, 'images/', f)
            self.file_paths.append(path)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.label[idx]

        if self.transform is not None:
            image = self.transform(image)
        
        return image, label
    

val_dataset = RafDataSet(fer2013_path, phase = 'test', transform = data_transforms_val)   

print('Validation set size:', val_dataset.__len__())

workers = 4 
batch_size = 128

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size = batch_size,
                                            num_workers = workers,
                                            shuffle = False,  
                                            pin_memory = True)

In [None]:
def evaluate_model(model, test_loader):
    all_preds = []
    all_labels = []
    
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        for images, labels in test_loader:  # Loop through batches
            predictions = model(images)
            _, predicted_labels = torch.max(predictions, 1)
            
            all_preds.extend(predicted_labels.cpu().numpy())  # Store predictions
            all_labels.extend(labels.cpu().numpy())  # Store actual labels
    
    accuracy = accuracy_score(all_labels, all_preds)
    conf_mat = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, output_dict=True)
    
    return accuracy, conf_mat, class_report

In [None]:
model = DAN(num_head=4, num_class=7, pretrained=False)
checkpoint = torch.load('./results/fer2013-batch256/fer2013_epoch36_acc0.7079_bacc0.6905.pth')
model.load_state_dict(checkpoint['model_state_dict'],strict=True)

accuracy, conf_mat, class_report= evaluate_model(model, val_loader)

In [None]:
# 1. Plot model accuracies for comparison
plt.figure(figsize=(10, 6))
plt.bar(DAN, accuracy, color='skyblue')
plt.ylabel('Accuracy')
plt.title('Model Comparison: Accuracy on Test Data')
plt.xticks(rotation=45)
plt.show()

# 2. Plot Confusion Matrices

plt.figure(figsize=(8, 6))
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues')
plt.title(f'Confusion Matrix: Dan')
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.show()


class_report_df = pd.DataFrame(class_report).transpose()
print(f'Classification Report for dan:')
print(class_report_df)