# Confusion Matrix Gallery + Visual Cleanup (Cathy)



---

In [5]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

# ======================================================
# Confusion matrix dictionaries (your provided numbers)
# ======================================================
cms = {
    "lr_confusion_matrix": np.array([[39, 26],
                                     [27, 93]]),

    "rf_confusion_matrix": np.array([[38, 27],
                                     [19, 101]]),

    "xgb_baseline_cm": np.array([[39, 26],
                                 [19, 101]]),

    "xgb_tuned_cm": np.array([[34, 31],
                              [17, 103]])
}

# Make sure folder exists
save_path = "docs/final_visuals/confusion_matrices/"
os.makedirs(save_path, exist_ok=True)

# ======================================================
# Function to plot confusion matrices consistently
# ======================================================
def plot_confusion_matrix(cm, title, filename):
    plt.figure(figsize=(6, 5))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        cbar=False,
        linewidths=1,
        linecolor="white",
        annot_kws={"size": 14}
    )

    plt.title(title, fontsize=16, fontweight="bold")
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12, rotation=0)

    plt.tight_layout()
    plt.savefig(f"{save_path}{filename}.png", dpi=300)
    plt.close()


# ======================================================
# Generate all confusion matrices
# ======================================================
plot_confusion_matrix(cms["lr_confusion_matrix"], 
                      "Logistic Regression Confusion Matrix", 
                      "lr_confusion_matrix")

plot_confusion_matrix(cms["rf_confusion_matrix"], 
                      "Random Forest Confusion Matrix", 
                      "rf_confusion_matrix")

plot_confusion_matrix(cms["xgb_baseline_cm"], 
                      "XGBoost Baseline Confusion Matrix", 
                      "xgb_baseline_cm")

plot_confusion_matrix(cms["xgb_tuned_cm"], 
                      "XGBoost Tuned Confusion Matrix", 
                      "xgb_tuned_cm")
