In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score


# Setting

In [None]:
data_root_dir = "/nfs_share/students/jinhyun/TCGA/BRCA"
result_file_path = "results/InternVL2_5-78B-MPO_randompatch224_20_rep1.xlsx"
result_file_path = "results/InternVL2_5-78B-MPO_randompatch224_20_rep2.xlsx"
result_file_path = "results/InternVL2_5-78B-MPO_randompatch448_20.xlsx"
result_file_path = "results/InternVL2_5-78B-MPO_randompatch224_20_with_explain.xlsx"
result_file_path = "results/InternVL2_5-78B-MPO_multiscale_patch448_20.xlsx"


# Evaluate

In [None]:
def categorize_prognosis(value):
    """Maps prognosis values into year-based categories."""
    trunc_val = 7
    if value <= trunc_val-1:
        return value
    else:
        return trunc_val

def extract_numeric_part_and_to_year(label):
    """ Extract unique labels and sort them based on the number preceding '-'"""
    return round(float(label.split('-')[0])/373) + 1 if label.split('-')[0].isdigit() else np.nan

def prepare_df(result_file_path):
    df = pd.read_excel(result_file_path)
    df['Predicted Prognosis'] = df['Predicted Prognosis'].astype(str).str.split('\n').str[0]
    df_filtered = df[df['Predicted Prognosis'] != "Not supplied"].copy()
    print(f"Found {len(df_filtered)}/{len(df)} predictions from VLM")

    df_filtered['Actual Year'] = df_filtered['Actual Prognosis'].apply(extract_numeric_part_and_to_year)
    df_filtered['Predicted Year'] = df_filtered['Predicted Prognosis'].apply(extract_numeric_part_and_to_year)
    
    df_filtered['Actual Truncated'] = df_filtered['Actual Year'].apply(categorize_prognosis)
    df_filtered['Predicted Truncated'] = df_filtered['Predicted Year'].apply(categorize_prognosis)
    
    return df_filtered

def evaluate_model(df, tolerance_n = 1):
    target_value_name = "Truncated"
    # target_value_name = "Year"

    labels = df[["Actual Prognosis", 'Actual Year', 'Actual Truncated']].drop_duplicates().sort_values(by = ['Actual Year'])[f"Actual {target_value_name}"].drop_duplicates()
    y_true = df[f"Actual {target_value_name}"]
    y_pred = df[f"Predicted {target_value_name}"]
    # y_true = df["Actual Prognosis"]
    # y_pred = df["Predicted Prognosis"]

    # Compute accuracy and F1-score
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='weighted')  # Weighted to consider class imbalance

    # Compute Top-N Accuracy
    tolerance_n_correct = np.abs(y_pred - y_true) <= tolerance_n  # 373 days per year threshold
    tolerance_n_accuracy = np.mean(tolerance_n_correct)
    
    # Print metrics
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Tolerance-{tolerance_n} accuracy: {tolerance_n_accuracy}")
    print(f"F1 Score: {f1:.4f}")

    # Generate confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    print(cm)
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=45)
    plt.show()


In [None]:
df = prepare_df(result_file_path)

# print(df[["Actual Truncated"]].value_counts())
print(df[["Actual Prognosis", "Actual Year", "Actual Truncated"]].drop_duplicates())
print(df[["Predicted Prognosis", "Predicted Year", "Predicted Truncated"]].drop_duplicates())

In [None]:
evaluate_model(df)