In [1]:
# MultiClass Classification
# To get ROC curve and confusion matrix when we already have Predicted Labels and Actual Labels
# Ashok K Sharma

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix

# Example data (replace these with your actual data)
actual_labels = BNN_pred_three_Class['Train_actual_label']
predicted_labels = BNN_pred_three_Class['Train_pred_label']

# Calculate Sensitivity, Specificity, and Accuracy
# For multi-class data, sensitivity, specificity, and accuracy are calculated for each class separately.
n_classes = actual_labels.nunique()
sensitivity = np.zeros(n_classes)
specificity = np.zeros(n_classes)
accuracy = np.zeros(n_classes)

for i in range(n_classes):
    tn, fp, fn, tp = confusion_matrix(actual_labels == i, predicted_labels == i).ravel()
    sensitivity[i] = tp / (tp + fn)
    specificity[i] = tn / (tn + fp)
    accuracy[i] = (tp + tn) / (tp + tn + fp + fn)

# Plot ROC Curve and Confusion Matrix Heatmap side by side
fig, axes = plt.subplots(1, 2, figsize=(12, 6))  # Set same size for both plots

# Calculate ROC Curve
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(actual_labels == i, predicted_labels == i)
    roc_auc[i] = roc_auc_score(actual_labels == i, predicted_labels == i)

roc_ax = axes[0]
for i in range(n_classes):
    roc_ax.plot(fpr[i], tpr[i], label=f'ROC curve class {i} (AUC = {roc_auc[i]:.2f})')

roc_ax.plot([0, 1], [0, 1], color='red', linestyle='--')
roc_ax.set_xlabel('False Positive Rate', fontsize=14, fontweight='bold')
roc_ax.set_ylabel('True Positive Rate', fontsize=14, fontweight='bold')
roc_ax.set_title('ROC Curve', fontsize=16, fontweight='bold')
roc_ax.legend(fontsize=14)

# Add table for Sensitivity, Specificity, and Accuracy
table_data = [
    ["Metric", *[f"Class {i}" for i in range(n_classes)]],
    ["Sensitivity", *[f"{s:.2f}" for s in sensitivity]],
    ["Specificity", *[f"{s:.2f}" for s in specificity]],
    ["Accuracy", *[f"{s:.2f}" for s in accuracy]]
]

# Increase table height
table_height = 0.4  # Adjust this value according to your preference
table = roc_ax.table(cellText=table_data, loc='upper right', cellLoc='center', colWidths=[0.2] * (n_classes + 1))
table.auto_set_font_size(False)
table.set_fontsize(12)

# Confusion Matrix Heatmap
conf_matrix = confusion_matrix(actual_labels, predicted_labels)
im = axes[1].imshow(conf_matrix, interpolation='nearest', cmap='YlGn')  # lighter colormap 'YlGn'
axes[1].set_title('Confusion Matrix', fontsize=16, fontweight='bold')

# Add annotations
for i in range(n_classes):
    for j in range(n_classes):
        axes[1].text(j, i, str(conf_matrix[i, j]), horizontalalignment='center', verticalalignment='center', color='red', fontsize=18)

fig.colorbar(im, ax=axes[1])
axes[1].set_xlabel('Predicted Label', fontsize=14, fontweight='bold')
axes[1].set_ylabel('True Label', fontsize=14, fontweight='bold')
axes[1].set_xticks(np.arange(n_classes))
axes[1].set_yticks(np.arange(n_classes))
axes[1].set_xticklabels([f"Class {i}" for i in range(n_classes)], fontsize=14)
axes[1].set_yticklabels([f"Class {i}" for i in range(n_classes)], fontsize=14)

plt.tight_layout()
plt.show()