# Imports

In [1]:
# tiatoolbox-simple environment
import os
import math

import numpy as np
import pandas as pd
from pandas.plotting import parallel_coordinates
import matplotlib.pyplot as plt
import seaborn as sns

from source.constants import ORIGINAL_2_PRETTY_MODEL_NAMES

In [2]:
num_original_patches = 50
num_augs = 10
num_patches = num_original_patches * num_augs
num_connections = num_patches * (num_patches - 1) // 2

# dataset = 'TCIA-CPTAC'
lung_datasets = [
    'TCIA-CPTAC',
    'TCGA-lung',
    'ouh_batch1_20x', 'ouh_batch1_40x', 'ouh_batch2_20x', 'ouh_batch3_40x',
    'DART_001', 'DART_002', 'DART_003', 'DART_004',
    # 'DHMC_20x', 'DHMC_40x',
]
all_datasets = lung_datasets + ['CAMELYON16']

original_2_shorter_metric_names = {
    'Adjusted Rand Index (ARI)': 'Adjusted Rand Index',
    'Normalized Mutual Information (NMI)': 'Normalized Mutual Info',
}
OPTIMIZING_METRIC = 'Fowlkes-Mallows Index'

model_2_color = {
    'Phikon-v2': 'orange',
    'Prov-GigaPath': 'blue',
    'ResNet18-SimCLR': 'black',
    'ResNet18-camelyon16-20x': 'black',
    'ResNet18-lung-10x': 'black',
    'UNI': 'red',
    'Virchow-v1-Concat': 'yellow',
}

In [None]:
def get_eval_df(dataset: str):
    eval_csv_path = f'eval_results/dataset={dataset}#extractor_name=all#img_norm=all#distance_metric=cosine#dimensionality_reduction=none#clustering=kmeans.csv'
    eval_df = pd.read_csv(eval_csv_path)
    
    # initial transformations
    eval_df['dataset'] = dataset
    eval_df = eval_df.rename(columns={eval_df.columns[0]: 'model#img_norm#wsi_id'})
    eval_df.rename(columns=original_2_shorter_metric_names, inplace=True)

    # info columns
    eval_df['model'] = eval_df['model#img_norm#wsi_id'].apply(lambda x: x.split('#')[0])
    eval_df['model'] = eval_df['model'].apply(lambda x: ORIGINAL_2_PRETTY_MODEL_NAMES[x])
    eval_df['img_norm'] = eval_df['model#img_norm#wsi_id'].apply(lambda x: x.split('#')[1])
    eval_df['wsi_id'] = eval_df['model#img_norm#wsi_id'].apply(lambda x: x.split('#')[2])
    eval_df.drop(columns=['model#img_norm#wsi_id'], inplace=True)
    info_columns = ['model', 'img_norm', 'wsi_id', 'dataset']

    # drop ResNet18-imagenet combination
    remove_condition = (eval_df['model'].str.startswith('ResNet18')) & (eval_df['img_norm'] == 'imagenet')
    eval_df = eval_df[~remove_condition]
    
    # confusion matrix columns
    conf_matrix_columns = ['TP', 'FP', 'FN', 'TN']
    eval_df['total_connections'] = eval_df[conf_matrix_columns].sum(axis=1)
    conf_matrix_columns.append('total_connections')

    # metric columns
    
    metric_columns = [
        col for col in eval_df.columns
        if col not in set(info_columns).union(set(conf_matrix_columns))
    ]
    
    # metric_columns_wo_precision
    metric_columns_wo_precision = [
        col for col in metric_columns
        if 'precision' not in col
    ]

    # reorder
    eval_df = eval_df[info_columns + conf_matrix_columns + metric_columns]

    return {
        'df':eval_df,
        'info_columns':info_columns,
        'conf_matrix_columns':conf_matrix_columns,
        'metric_columns':metric_columns,
        'metric_columns_wo_precision':metric_columns_wo_precision,
    }

def get_eval_df_from_list(datasets: list[str]):
    eval_dfs = [get_eval_df(dataset) for dataset in datasets]
    result = {
        'df': pd.concat([eval_df['df'] for eval_df in eval_dfs]),
        'info_columns': eval_dfs[0]['info_columns'],
        'conf_matrix_columns': eval_dfs[0]['conf_matrix_columns'],
        'metric_columns': eval_dfs[0]['metric_columns'],
        'metric_columns_wo_precision': eval_dfs[0]['metric_columns_wo_precision'],
    }
    return result

def get_full_slides_condition(eval_df: pd.DataFrame):
    full_slides_condition = (eval_df['total_connections'] == num_connections)
    return full_slides_condition

def get_agg_df(eval_df: pd.DataFrame, condition: pd.Series, agg: str):
    return eval_df[condition].drop(columns=['wsi_id', 'dataset']).groupby(['model', 'img_norm']).agg(agg).sort_values('model').reset_index()

def get_agg_df_per_dataset(eval_df: pd.DataFrame, condition: pd.Series, agg: str):
    return eval_df[condition].drop(columns=['wsi_id']).groupby(['model', 'img_norm', 'dataset']).agg(agg).sort_values('model').reset_index()


def get_mean_and_std_dfs(dataset: str):
    eval_df = get_eval_df(dataset)['df']
    full_slides_condition = get_full_slides_condition(eval_df)
    mean_agg_df = get_agg_df(eval_df, full_slides_condition, 'mean')
    std_agg_df = get_agg_df(eval_df, full_slides_condition, 'std')
    return mean_agg_df, std_agg_df

# for dataset in all_datasets:
#     mean_agg_df, std_agg_df = get_mean_and_std_dfs(dataset)
#     display(mean_agg_df)
#     display(std_agg_df)
#     print('='*80)

def get_mean_and_std_dfs_from_list(datasets_list: list[str]):
    print(datasets_list)
    eval_dfs = [get_eval_df(dataset)['df'] for dataset in datasets_list]
    combined_eval_df = pd.concat(eval_dfs)
    full_slides_condition = get_full_slides_condition(combined_eval_df)
    print(f"full_slides_condition.mean() {full_slides_condition.mean()}")
    mean_agg_df = get_agg_df(combined_eval_df, full_slides_condition, 'mean')
    std_agg_df = get_agg_df(combined_eval_df, full_slides_condition, 'std')
    return mean_agg_df, std_agg_df

def get_mean_and_std_dfs_per_dataset_from_list(datasets_list: list[str]):
    print(datasets_list)
    eval_dfs = [get_eval_df(dataset)['df'] for dataset in datasets_list]
    combined_eval_df = pd.concat(eval_dfs)
    full_slides_condition = get_full_slides_condition(combined_eval_df)
    print(f"full_slides_condition.mean() {full_slides_condition.mean()}")
    mean_agg_df = get_agg_df_per_dataset(combined_eval_df, full_slides_condition, 'mean')
    std_agg_df = get_agg_df_per_dataset(combined_eval_df, full_slides_condition, 'std')
    return mean_agg_df, std_agg_df

mean_tcga_lung_df, std_tcga_lung_df = get_mean_and_std_dfs_from_list(['TCGA-lung'])
mean_tcia_cptac_df, std_tcia_cptac_df = get_mean_and_std_dfs_from_list(['TCIA-CPTAC'])
mean_ouh_df, std_ouh_df = get_mean_and_std_dfs_from_list([dataset for dataset in all_datasets if dataset.startswith('ouh')])
mean_dart_df, std_dart_df = get_mean_and_std_dfs_from_list([dataset for dataset in all_datasets if dataset.startswith('DART')])
mean_camelyon16_df, std_camelyon16_df = get_mean_and_std_dfs_from_list(['CAMELYON16'])

mean_tcga_lung_df

In [None]:
def plot_parallel_coordinates(mean_df: pd.DataFrame, metric_columns: list[str],title: str='Clustering Metrics', baseline=0.5):
    plt.figure(figsize=(7.5, 3))
    parallel_coordinates(
        mean_df,
        'model',
        cols=metric_columns,
        color=plt.cm.Set2.colors,
    )
    plt.title(title)
    plt.xticks(rotation=30, ha='right')
    plt.legend(
        # bbox_to_anchor=(0.01, 0.01, 1.05, 1.05),
        loc='lower right',
        ncol=3, borderaxespad=0.)
    plt.ylim(baseline, 1.0)
    plt.savefig(f"./figures/parallel_coordinates_{title.replace(' ', '_')}.png", bbox_inches='tight')
    plt.show()
    

# very good
plot_parallel_coordinates(
    mean_ouh_df,
    get_eval_df_from_list([dataset for dataset in all_datasets if dataset.startswith('ouh')])['metric_columns'],
    'OUH Lung Clustering Metrics',
    baseline=0.6,
)
plot_parallel_coordinates(
    mean_dart_df,
    get_eval_df_from_list([dataset for dataset in all_datasets if dataset.startswith('DART')])['metric_columns'],
    'DART Lung Clustering Metrics',
    baseline=0.6,
)
# ok
plot_parallel_coordinates(mean_tcga_lung_df, get_eval_df('TCGA-lung')['metric_columns'], 'TCGA Lung Clustering Metrics', baseline=0.5)
# pretty bad
plot_parallel_coordinates(mean_tcia_cptac_df, get_eval_df('TCIA-CPTAC')['metric_columns'], 'TCIA-CPTAC Lung Clustering Metrics', baseline=0.1)
plot_parallel_coordinates(mean_camelyon16_df, get_eval_df('CAMELYON16')['metric_columns'], 'CAMELYON16 Clustering Metrics', baseline=0.1)

In [None]:
radar_mean_df, radar_std_df = get_mean_and_std_dfs_per_dataset_from_list(all_datasets)
# drop rows that have (model="ResNet18..." and img_norm="imagenet") # done in get_eval_df()
# remove_condition = (mean_df['model'].str.startswith('ResNet18')) & (mean_df['img_norm'] == 'imagenet') #
# radar_mean_df = mean_df[~remove_condition].drop(columns=['img_norm'])
# radar_std_df = std_df[~remove_condition].drop(columns=['img_norm'])
# subset columns
subset_columns = ['dataset', 'model', OPTIMIZING_METRIC]
radar_mean_df = radar_mean_df[subset_columns]
radar_std_df = radar_std_df[subset_columns]
# if model name starts with "ResNet18", change it to "ResNet18-SimCLR"
radar_mean_df['model'] = radar_mean_df['model'].apply(lambda x: 'ResNet18-SimCLR' if x.startswith('ResNet18') else x)
radar_std_df['model'] = radar_std_df['model'].apply(lambda x: 'ResNet18-SimCLR' if x.startswith('ResNet18') else x)
radar_mean_df

In [None]:
def plot_radar_chart_complete_data(radar_mean_df, optimizing_metric, baseline_value):
    radar_models = radar_mean_df['model'].unique()
    print(radar_models)

    radar_datasets = radar_mean_df['dataset'].unique()
    print(radar_datasets)

    df_list = [radar_mean_df[radar_mean_df['dataset'] == dataset] for dataset in radar_datasets]

    # Prepare radar plot
    num_vars = len(radar_datasets)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]  # Close the plot

    fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))

    # Plot baseline values
    baseline_values = [baseline_value] * num_vars
    baseline_values += baseline_values[:1]  # Close the line

    # Loop through models to plot their performance
    for model in radar_models:
        model_values = [
            df[df['model'] == model][optimizing_metric].values[0]
            for df in df_list
        ]
        model_values += model_values[:1]  # Close the line
        ax.plot(angles, model_values, label=model, linewidth=2)
        ax.fill(angles, model_values, alpha=0.1)

    # Add labels and legend
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(radar_datasets)
    ax.set_title(optimizing_metric, size=20, pad=20)
    ax.legend(loc='upper right', bbox_to_anchor=(1.4, 1.2))

    # Set the y-axis limit to start from the baseline value
    ax.set_ylim(baseline_value, 1)

    plt.savefig(f"./figures/radar_chart_baseline_{optimizing_metric.replace(' ', '_')}.png", bbox_inches='tight')
    plt.show()

# Example usage
baseline_value = 0.3  # Replace with your actual baseline value
plot_radar_chart_complete_data(radar_mean_df, OPTIMIZING_METRIC, baseline_value)