In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from ipywidgets import widgets
import pickle

# Confusion Matrix, Classification Report, and ROC Curve

In [2]:
# Load pickle function
def load_results(file_path, net, method):
    
    with open(file_path, 'rb') as file:
        data = pickle.load(file)
    
    globals()[f'predictions_{net}_{method}'] = data['predictions']
    globals()[f'labels_{net}_{method}'] = data['labels']
    globals()[f'probs_{net}_{method}'] = data['probs']
    globals()[f'fpr_{net}_{method}'] = data['fpr']
    globals()[f'tpr_{net}_{method}'] = data['tpr']
    globals()[f'roc_auc_{net}_{method}'] = data['roc_auc']
    #globals()[f'misclassified_test_paths_{net}_{method}'] = data['misclassified_test_paths']


In [3]:
# Plot interactive confusion matrix
def plot_confusion_matrix(probs, labels, threshold_w, cm_filename):
    adjusted_predictions = (np.array(probs) >= threshold_w).astype(int)
    # Calculate the confusion matrix by comparing true labels with predictions
    cm = confusion_matrix(labels, adjusted_predictions)
    # Calculate the percentage values for each cell in the confusion matrix
    cm_percentages = cm.astype('float') / cm.sum() * 100
    
    # Create a heatmap for the confusion matrix with percentage annotations
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=np.array([["{}\n({:.2f}%)".format(cm[i, j], cm_percentages[i, j]) 
                                     for j in range(cm.shape[1])] 
                                    for i in range(cm.shape[0])]), # Annotate each cell with value and percentage
                fmt='', cmap=plt.cm.Blues, 
                xticklabels=['AtmoNu', 'PDK'], yticklabels=['AtmoNu', 'PDK'],
                cbar=False, annot_kws={"size": 16})
    
    plt.xlabel('Predicted labels', fontsize=14)
    plt.ylabel('True labels', fontsize=14)
    plt.title(f"{NETWORK.capitalize()} - Confusion Matrix at Threshold = {threshold_w:.2f}", fontsize=16)
    plt.savefig(cm_filename)
    plt.show()

def interactive_plot_confusion_matrix(probs, labels, cm_filename):
    # Create a slider widget to select the threshold value for predictions
    threshold_slider = widgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05, description='Threshold:', continuous_update=False)
    
    # Define the layout of the widget interface with the threshold slider
    ui = widgets.HBox([threshold_slider])
    # Create an interactive output for the confusion matrix plot function
    out = widgets.interactive_output(plot_confusion_matrix, 
                                     {'probs': widgets.fixed(probs), 
                                      'labels': widgets.fixed(labels), 
                                      'threshold_w': threshold_slider, 
                                      'cm_filename': widgets.fixed(cm_filename)})
    
    display(ui, out)

In [4]:
def save_classification_report_as_image(report, filename):
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.axis('off')
    ax.text(0.01, 0.5, str(report), fontsize=12, va='center', ha='left', family='monospace')
    plt.savefig(filename, bbox_inches='tight')

In [5]:
class_names = ['AtmoNu', 'PDK']

In [6]:
threshold_w = 0.5

## Late Fusion

### Alexnet

In [7]:
NETWORK = 'alexnet'

In [8]:
# Load pickle
load_results("pkl/roc_results_alexnet_late_fusion.pkl", NETWORK, 'lf')

In [None]:
# Confusion matrix
interactive_plot_confusion_matrix(probs_alexnet_lf, labels_alexnet_lf, f"{NETWORK.capitalize()}_Late_Fusion_Confusion_Matrix_at_Threshold = {threshold_w:.2f}.png")

In [None]:
# Classification report
report = classification_report(labels_alexnet_lf, predictions_alexnet_lf, target_names=class_names)
print(f"{NETWORK.capitalize()} - Late Fusion - Classification Report:\n")
print(report)

In [None]:
save_classification_report_as_image(report, F"{NETWORK.capitalize()}_Late_Fusion_Classification_Report.png")

### Vgg11

### Vgg19

## Early Fusion

### Alexnet

### Vgg11

### Vgg19

## ROC Curve - Both Early and Late Fusion

In [None]:
# Plot ROC curves in a subplot
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Early Fusion
axes[0].plot(fpr_alexnet_ef, tpr_alexnet_ef, label=f'AlexNet (AUC = {roc_auc_alexnet_ef:.2f})')
axes[0].plot(fpr_vgg11_ef, tpr_vgg11_ef, label=f'VGG11 (AUC = {roc_auc_vgg11_ef:.2f})')
axes[0].plot(fpr_vgg19_ef, tpr_vgg19_ef, label=f'VGG19 (AUC = {roc_auc_vgg19_ef:.2f})')
axes[0].plot([0, 1], [0, 1], color='gray', linestyle='--')
axes[0].set_xlim([0.0, 1.0])
axes[0].set_ylim([0.0, 1.05])
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('Early Fusion')
axes[0].legend(loc='lower right')

# Late Fusion
axes[1].plot(fpr_alexnet_lf, tpr_alexnet_lf, label=f'AlexNet (AUC = {roc_auc_alexnet_lf:.2f})')
axes[1].plot(fpr_vgg11_lf, tpr_vgg11_lf, label=f'VGG11 (AUC = {roc_auc_vgg11_lf:.2f})')
axes[1].plot(fpr_vgg19_lf, tpr_vgg19_lf, label=f'VGG19 (AUC = {roc_auc_vgg19_lf:.2f})')
axes[1].plot([0, 1], [0, 1], color='gray', linestyle='--')
axes[1].set_xlim([0.0, 1.0])
axes[1].set_ylim([0.0, 1.05])
axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate')
axes[1].set_title('Late Fusion')
axes[1].legend(loc='lower right')

fig.suptitle('ROC Curve', fontsize=16)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig('ROC_Curve_EF&LF.png')
plt.show()