In [9]:
# Imports for Evaluation
import torch
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from torchmetrics.classification import MulticlassROC, MulticlassAUROC
from torchmetrics.classification import MulticlassConfusionMatrix

# Settings
MODEL_NAME = "" # Specify model before run
num_classes = 4
device = "cuda" if torch.cuda.is_available() else "cpu"
class_names = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']

# Save loss/acc graphs

In [None]:
visualisation_arrays = pd.read_csv("visualisation_arrays.csv")

# Convert to lists
train_loss_array = visualisation_arrays["Train loss"].tolist()
test_loss_array = visualisation_arrays["Test loss"].tolist()
train_acc_array = visualisation_arrays["Train acc"].tolist()
test_acc_array = visualisation_arrays["Test acc"].tolist()

In [None]:
# Save loss plot
plt.plot(train_loss_array, label="Train Loss", linestyle="-", color="blue", marker="o", markersize=6, linewidth=2)
plt.plot(test_loss_array, label="Test Loss", linestyle="--", color="orange", marker="s", markersize=6, linewidth=2)

plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title(f"Loss Graph: v{MODEL_NAME}", weight="bold")
plt.legend(loc="upper right", fontsize=12, frameon=True, shadow=True, fancybox=True, borderpad=1)
plt.axhline(y=0, color="gray", linestyle="-", linewidth=1, alpha=0.7)
plt.savefig(f"loss_graph_v{MODEL_NAME}.png")

In [None]:
# Save acc plot
plt.plot(train_acc_array, label="Train Accuracy", linestyle="-", color="blue", marker="o", markersize=6, linewidth=2)
plt.plot(test_acc_array, label="Test Accuracy", linestyle="--", color="orange", marker="s", markersize=6, linewidth=2)

plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title(f"Accuracy Graph: {MODEL_NAME}", weight="bold")
plt.legend(loc="lower right", fontsize=12, frameon=True, shadow=True, fancybox=True, borderpad=1)
plt.axhline(y=0, color="gray", linestyle="-", linewidth=1, alpha=0.7)
plt.savefig(f"acc_graph_{MODEL_NAME}.png")

# Other

In [None]:
# Load pt for visualisation
loaded_data = torch.load("visualisation_tensors.pt")

# Convert to tensors
all_probabilities = loaded_data["all_probabilities"]
all_predictions = loaded_data["all_predictions"]
all_labels = loaded_data["all_labels"]

In [None]:
# Save Roc curve
roc_metric = MulticlassROC(num_classes=num_classes).to(device)
auc_metric = MulticlassAUROC(num_classes=num_classes, average=None).to(device)

colors = ["blue", "yellow", "green", "red"]

fpr, tpr, _ = roc_metric(all_probabilities, all_labels)
auc_score = auc_metric(all_probabilities, all_labels)

for i, _ in enumerate(class_names):
    plt.plot(fpr[i], tpr[i], label=f"Class {class_names[i]} (AUC: {auc_score[i]})", color=colors[i], markersize=6, linewidth=2)

plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title(f"ROC Curve: {MODEL_NAME}", weight="bold")
plt.legend(loc="lower right", fontsize=12, frameon=True, shadow=True, fancybox=True, borderpad=1)
plt.axhline(y=0, color="gray", linestyle="-", linewidth=1, alpha=0.7)
plt.savefig(f"roc_curve_{MODEL_NAME}.png")

In [None]:
# Save Confusion Matrix
conf_matrix_metric = MulticlassConfusionMatrix(num_classes=num_classes).to(device)

conf_matrix = conf_matrix_metric(all_predictions, all_labels).cpu().numpy()

sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="mako_r", xticklabels=class_names, yticklabels=class_names)
plt.xticks(rotation=45)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title(f"Confusion Matrix: {MODEL_NAME}", weight="bold")
plt.savefig(f"conf_matrix_{MODEL_NAME}.png")

In [None]:
# Calculate precision for each class
true_positive = np.diag(conf_matrix)
false_positive = np.sum(conf_matrix, axis=0) - true_positive
precision_per_class = true_positive / (true_positive + false_positive)

# Macro Precision - average precision across all classes
macro_precision = np.mean(precision_per_class)

# Micro Precision - global precision across all classes
total_true_positive = np.sum(true_positive)
total_false_positive = np.sum(false_positive)
micro_precision = total_true_positive / (total_true_positive + total_false_positive)

precision_per_class, macro_precision, micro_precision


In [None]:
# Calculate TP, FP, FN, TN
true_prediction = np.diag(conf_matrix)
false_prediction = np.sum(conf_matrix, axis=0) - true_prediction
false_negative = np.sum(conf_matrix, axis=1) - true_prediction
true_negative = np.sum(conf_matrix) - (true_prediction + false_prediction + false_negative)

# Calculate Precision
precision_per_class = true_prediction / (true_prediction + false_prediction)
macro_precision = np.mean(precision_per_class)
micro_precision = np.sum(true_prediction) / (np.sum(true_prediction) + np.sum(false_prediction))

# Calculate Recall (Sensitivity)
recall_per_class = true_prediction / (true_prediction + false_negative)
macro_recall = np.mean(recall_per_class)
micro_recall = np.sum(true_prediction) / (np.sum(true_prediction) + np.sum(false_negative))

# Calculate Specificity
specificity_per_class = true_negative / (true_negative + false_prediction)
macro_specificity = np.mean(specificity_per_class)
micro_specificity = np.sum(true_negative) / (np.sum(true_negative) + np.sum(false_prediction))

# Calculate F1-score
f1_per_class = 2 * (precision_per_class * recall_per_class) / (precision_per_class + recall_per_class)
macro_f1 = np.mean(f1_per_class)
micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall)

# Display results
precision_per_class, macro_precision, micro_precision, recall_per_class, macro_recall, micro_recall, specificity_per_class, macro_specificity, micro_specificity, f1_per_class, macro_f1, micro_f1


In [None]:

classes = ["Healthy", "Stage 1", "Stage 2","Stage 3"]
metrics_df = pd.DataFrame({
    "Class": classes,
    "Precision": precision_per_class,
    "Recall": recall_per_class,
    "Specificity": specificity_per_class,
    "F1-Score": f1_per_class
})
print(metrics_df)

fig, ax = plt.subplots(figsize=(10, 6))
metrics_df.set_index("Class").T.plot(kind='bar', ax=ax)
plt.title("Classification Metrics per Class")
plt.ylabel("Score")
plt.ylim(0, 1)
plt.xticks(rotation=0)
plt.legend(title="Metrics", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()


In [None]:
# Define labels for the metrics
categories = ["Healthy", "Stage 1", "Stage 2","Stage 3"]
N = len(categories)

# Define values for each class
values = {
    "Class": classes,
    "Precision": precision_per_class,
    "Recall": recall_per_class,
    "Specificity": specificity_per_class,
    "F1-Score": f1_per_class
}
pd.DataFrame({
    
})

# Set up the radar chart
angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
angles += angles[:1]  # Close the radar chart loop

fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))

# Plot each class
for class_name, vals in values.items():
    vals += vals[:1]  # Close the loop for each metric
    ax.plot(angles, vals, label=class_name, linewidth=2)
    ax.fill(angles, vals, alpha=0.1)

# Format the chart
ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories, fontsize=12)
ax.set_yticklabels(["0.2", "0.4", "0.6", "0.8", "1.0"], fontsize=10)
ax.set_ylim(0, 1)

plt.title("Radar Chart of Classification Metrics", fontsize=14, fontweight='bold')
plt.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
plt.show()
