In [None]:
import pandas as pd
import numpy as np 
import glob
import matplotlib.pyplot as plt 
import seaborn as sns

In [None]:
def plot_confusion_matrix(
    confusion_matrix,
    class_names,
    normalize=True,
    cmap="Blues",
    fontsize=12
):
    """
    Plot a confusion matrix with percentage and raw count annotations.

    Parameters
    ----------
    confusion_matrix : np.ndarray
        Square confusion matrix of shape (N, N).
    class_names : list of str
        Class labels (length N).
    normalize : bool, optional
        If True, row-normalize and show percentages. Default is True.
    cmap : str, optional
        Colormap for heatmap. Default is "Blues".
    fontsize : int, optional
        Font size for annotations. Default is 12.
    """

    confusion_matrix = np.asarray(confusion_matrix)
    n_classes = len(class_names)

    if normalize:
        cm_percent = (
            confusion_matrix
            / confusion_matrix.sum(axis=1, keepdims=True)
            * 100
        )
    else:
        cm_percent = confusion_matrix.copy()

    fig, ax = plt.subplots()

    heatmap = sns.heatmap(
        cm_percent,
        annot=False,
        cmap=cmap,
        xticklabels=class_names,
        yticklabels=class_names,
        linecolor="black",
        linewidths=0.5,
        cbar=False,
        ax=ax
    )

    # Add annotations: percentage + raw count
    for i in range(n_classes):
        for j in range(n_classes):
            if normalize:
                text = f"{cm_percent[i, j]:.1f}%\n{confusion_matrix[i, j]}"
            else:
                text = f"{confusion_matrix[i, j]}"

            cell_color = heatmap.collections[0].get_facecolors()[
                i * n_classes + j
            ]
            r, g, b, _ = cell_color
            brightness = 0.299 * r + 0.587 * g + 0.114 * b
            text_color = "white" if brightness < 0.65 else "black"

            ax.text(
                j + 0.5,
                i + 0.5,
                text,
                ha="center",
                va="center",
                fontsize=fontsize,
                color=text_color
            )

    ax.set_xlabel("Predicted Class", fontsize=fontsize)
    ax.set_ylabel("True Class", fontsize=fontsize)

    ax.tick_params(axis="x", labelsize=fontsize)
    ax.tick_params(axis="y", labelsize=fontsize)

    plt.tight_layout()
    plt.show()


In [None]:
filename = glob.glob('all_scores.csv')

In [None]:
score_columns = [col for col in data.columns if '_score' in col]
data['pred_type'] = data[score_columns].idxmax( axis = 1 ).str.extract(r'([IV]+)')

In [None]:
# Calculate the confusion matrix

confusion_matrix = np.zeros( (5,5) , dtype=int  )
type_mapping = { 'I': 0, 'II': 1, 'III': 2, 'IV': 3, 'V': 4 }


for _, row in data.iterrows():
    actual_type = row['actual_type']
    pred_type = row['pred_type']

    if actual_type in type_mapping and pred_type in type_mapping:
        actual_index = type_mapping[actual_type]
        pred_index = type_mapping[pred_type]
        confusion_matrix[actual_index][pred_index] += 1 

In [None]:
plot_confusion_matrix(confusion_matrix,  ['I' , 'II' , "III", "IV", "V"])

In [None]:
plot_confusion_matrix(confusion_matrix,  ['I' , "III", "IV", "V"])

In [None]:
csv_path = "predictions_for_chime_bursts.csv"

# Read the CSV
df = pd.read_csv(csv_path)

# Quick sanity check
print("Columns:", list(df.columns))
print("Number of rows:", len(df))

# Display first few rows
df.head()

In [None]:
# Calculate the confusion matrix excluding Type I
#Here B corresponds to type II, C to type III, C1 to type IV and C2 to type V

confusion_matrix = np.zeros((4,4) , dtype=int)
type_mapping = { 'B': 0 , 'C': 1 , 'C1': 2, 'C2': 3 }

for _, row in df.iterrows():
    actual_type = row['correct_type']
    pred_type = row['pred_type_w_A']

    if actual_type in type_mapping and pred_type in type_mapping:
        actual_index = type_mapping[actual_type]
        pred_index = type_mapping[pred_type]
        confusion_matrix[actual_index][pred_index] += 1 


In [None]:
plot_confusion_matrix(confusion_matrix,  ['II' , "III", "IV", "V"])

In [None]:
# Calculate the confusion matrix excluding Type II
#Here A corresponds to type I, C to type III, C1 to type IV and C2 to type V

confusion_matrix = np.zeros((4,4) , dtype=int)
type_mapping = { 'A': 0 , 'C': 1 , 'C1': 2, 'C2': 3 }

for _, row in df.iterrows():
    actual_type = row['correct_type']
    pred_type = row['pred_type_w_B']

    if actual_type in type_mapping and pred_type in type_mapping:
        actual_index = type_mapping[actual_type]
        pred_index = type_mapping[pred_type]
        confusion_matrix[actual_index][pred_index] += 1 

In [None]:
plot_confusion_matrix(confusion_matrix,  ['I' , "III", "IV", "V"])

In [None]:
# Calculate the confusion matrix excluding Type I and type II
#Here C corresponds to type III, C1 to type IV and C2 to type V

confusion_matrix = np.zeros((3,3) , dtype=int)
type_mapping = { 'C': 0 , 'C1': 1, 'C2': 2 }

for _, row in df.iterrows():
    actual_type = row['correct_type']
    pred_type = row['pred_type_w_AB']

    if actual_type in type_mapping and pred_type in type_mapping:
        actual_index = type_mapping[actual_type]
        pred_index = type_mapping[pred_type]
        confusion_matrix[actual_index][pred_index] += 1 

In [None]:
plot_confusion_matrix(confusion_matrix,  ['III', "IV", "V"])