In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc, average_precision_score
import pandas as pd
import os

In [6]:
def compute_and_plot_auprc(y_true, y_pred_proba, save_path, average='macro'):
    """
    Computes and plots AUPRC for each class in a multiclass classification task.

    Args:
        y_true (array-like): True class labels (n_samples,)
        y_pred_proba (array-like): Predicted probabilities (n_samples, n_classes)
        average (str): 'macro', 'weighted', or None
    """
    y_true = np.asarray(y_true)
    y_pred_proba = np.asarray(y_pred_proba)
    n_classes = y_pred_proba.shape[1]

    # One-hot encode true labels
    y_true_bin = np.eye(n_classes)[y_true]

    # Plot setup
    plt.figure(figsize=(8, 6))
    colors = plt.cm.tab10.colors  # color palette for up to 10 classes

    auprc_scores = {}
    diagnosis = ['Control','MCI','ADRD','Other']

    for i in range(n_classes):
        precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_pred_proba[:, i])
        auprc = auc(recall, precision)
        auprc_scores[f'class_{diagnosis[i]}'] = auprc

        # Plot PR curve
        plt.plot(recall, precision, color=colors[i % len(colors)],
                 label=f"{diagnosis[i]} (AUPRC = {auprc:.3f})")

    # Overall average score
    average_score = average_precision_score(y_true_bin, y_pred_proba, average=average)
    auprc_scores['average'] = average_score

    # Plot formatting
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve (Average AUPRC: {average_score:.3f})")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    plt.savefig(save_path, dpi=300)
    plt.close()

    return auprc_scores

In [3]:
column_map = {'diagnosis_control':0,'diagnosis_mci':1,'diagnosis_adrd':2,'diagnosis_other':3}

# labels = pd.read_csv('../data/post_data/Test_labels.csv')
labels = pd.read_csv('../data/post_data/All_labels.csv')
labels['unique_id'] = labels['uid'].apply(lambda x: str(x).split('_')[0])
labels = labels.drop(columns='uid')
labels = labels.groupby('unique_id').first()
labels['diagnosis'] = labels.idxmax(axis=1).map(column_map)

In [4]:
def calculate_auprc(base_path):
    results = pd.read_csv(os.path.join(base_path,'submission_whisper_final.csv'))
    results['unique_id'] = results['uid'].apply(lambda x: str(x).split('_')[0])
    results = results.drop(columns='uid')
    results = results.groupby('unique_id').mean()
    results['diagnosis'] = results.idxmax(axis=1).map(column_map)
    filtered_labels = labels[labels.index.isin(results.index)]
    compute_and_plot_auprc(
        filtered_labels['diagnosis'],
        results[['diagnosis_control','diagnosis_mci','diagnosis_adrd','diagnosis_other']],
        os.path.join(base_path,'auprc.png')
    )

In [7]:
calculate_auprc('../results/whisper_metadata_age_multiplier')
calculate_auprc('../results/whisper_metadata_age_fullaudio')
calculate_auprc('../results/whisper_metadata')
calculate_auprc('../results/whisper-large_fullaudio')
calculate_auprc('../results/distilled_whisper_base_v0_2')