In [7]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import shutil
import torch

from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

In [8]:

def load_pickle(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def load_modality_output(seed, projection_dim, batch_size, modality, task) : 
    result_dir = Path(f'Results/Linear_{modality}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
    output = load_pickle(result_dir / 'outputs' / 'best_epoch.pkl')
    etf = Path(f'Results/ETF/ETF_{projection_dim}_IN_2_OUT_Seed_{seed}.pt')
    etf = torch.load(etf)
    return output, etf

def load_fusion_ouptut(seed, projection_dim, batch_size, fusion_method, task) : 
    result_dir = Path(f'Results/Linear_MultiModal/Fusion_{fusion_method}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
    output = load_pickle(result_dir / 'outputs' / 'best_epoch.pkl')
    return output

def unimodal_process(output, key):
    data = output.get(key)
    features = data['features']
    return features

def bimodal_process(output1, output2, key, weight1=0.5, weight2=0.5) : 
    data1 = output1.get(key)
    data2 = output2.get(key)
    features1 = data1['features']
    features2 = data2['features']
    # combined_features = (features1 + features2) / 2
    combined_features = (features1 * weight1 + features2 * weight2) / (weight1 + weight2)
    return features1, features2, combined_features

def trimodal_process(output1, output2, output3, key, weight1=0.5, weight2=0.5, weight3=0.5) : 
    data1 = output1.get(key)
    data2 = output2.get(key)
    data3 = output3.get(key)
    features1 = data1['features']
    features2 = data2['features']
    features3 = data3['features']
    feature_12 = (features1 * weight1 + features2 * weight2) / (weight1 + weight2)
    feature_13 = (features1 * weight1 + features3 * weight3) / (weight1 + weight3)
    feature_23 = (features2 * weight2 + features3 * weight3) / (weight2 + weight3)
    combined_features = (features1 * weight1 + features2 * weight2 + features3 * weight3) / (weight1 + weight2 + weight3)
    return features1, features2, features3, feature_12, feature_13, feature_23, combined_features

def whole_process(output1, output2, output3, key) : 
    data1 = output1.get(key, None)
    data2 = output2.get(key, None)
    data3 = output3.get(key, None)
    # ignore none and agg the rest
    features_list = [data['features'] for data in [data1, data2, data3] if data is not None]
    features = np.mean(features_list, axis=0) if features_list else None
    return features

def get_output(modality, total_outputs, etf) : 
    labels, features = [], []
    labels = [values['labels'] for values in total_outputs.values() if values[modality] is not None]
    features = [[values[modality]] for values in total_outputs.values() if values[modality] is not None]
    labels = np.array(labels)
    features = np.concatenate(features, axis=0)
    probs = features @ etf.numpy()
    print(f"{modality} - Features shape: {features.shape}, Probs shape: {probs.shape}, Labels shape: {labels.shape}")
    return labels, probs

def get_performance(labels, probs):
    roc_auc = roc_auc_score(labels, probs)
    # pr_auc = average_precision_score(labels, probs)
    precision, recall, _ = precision_recall_curve(labels, probs)
    pr_auc = auc(recall, precision)
    return roc_auc, pr_auc, len(labels)

def get_score(modality, total_outputs, etf) : 
    labels, probs = get_output(modality, total_outputs, etf)
    roc_auc, pr_auc, n = get_performance(labels, probs[:,1])
    number_of_cases = (labels == 1).sum()
    return roc_auc, pr_auc, n, number_of_cases

def get_fusion_score(fusion_method, total_outputs) : 
    labels, probs = [], []
    labels = [values['labels'] for values in total_outputs.values()]
    probs = [values[fusion_method + '_prob'] for values in total_outputs.values()]
    labels = np.array(labels)
    probs = np.stack(probs, axis=0)
    print(f"{fusion_method} - Probs shape: {probs.shape}, Labels shape: {labels.shape}")
    roc_auc, pr_auc, n = get_performance(labels, probs[:,1])
    number_of_cases = (labels == 1).sum()
    return roc_auc, pr_auc, n, number_of_cases

In [9]:
projection_dim=128
batch_size=512
seed=2026
# task='mortality_90days'
task='readmission_15days'

total_scores = []
for seed in range(2026, 2029) : 

    tabular_outputs, proto = load_modality_output(seed, projection_dim, batch_size, 'tabular', task)
    lab_outputs, _ = load_modality_output(seed, projection_dim, batch_size, 'lab', task)
    note_outputs, _ = load_modality_output(seed, projection_dim, batch_size, 'note', task)
    sum_outputs = load_fusion_ouptut(seed, projection_dim, batch_size, 'Sum', task)
    weighted_sum_outputs = load_fusion_ouptut(seed, projection_dim, batch_size, 'WeightedFusion', task)
    attn_masked_outputs = load_fusion_ouptut(seed, projection_dim, batch_size, 'AttnMaskedFusion', task)
    split_mode = 'test'

    tab_output = tabular_outputs[split_mode]['outputs']
    lab_output = lab_outputs[split_mode]['outputs']
    note_output = note_outputs[split_mode]['outputs']
    sum_output = sum_outputs[split_mode]['outputs']
    weighted_sum_output = weighted_sum_outputs[split_mode]['outputs']
    attn_masked_output = attn_masked_outputs[split_mode]['outputs']

    tab_keys = set(tab_output.keys())
    lab_keys = set(lab_output.keys())
    note_keys = set(note_output.keys())

    # 3개의 조합이면, 나올 수 있는 경우의 수
    ## Tab > Lab & Tab > Note 포함관계
    # 1.tab_total
    tab_keys
    # 2.lab_total
    lab_keys
    # 3.note_total
    note_keys
    # 4. tab - lab - note
    tab_only_keys = tab_keys - lab_keys - note_keys
    # 5. lab - note
    lab_only_keys = lab_keys - note_keys
    # 6. note - lab
    note_only_keys = note_keys - lab_keys
    # 7. lab ∩ note
    lab_note_keys = lab_keys.intersection(note_keys)
    # 8. total
    total_keys = tab_keys.union(lab_keys).union(note_keys)

    print(f"1. tab total: {len(tab_keys)}")
    print(f"2. lab total: {len(lab_keys)}")
    print(f"3. note total: {len(note_keys)}")
    print(f"4. tab only: {len(tab_only_keys)}")
    print(f"5. lab only: {len(lab_only_keys)}")
    print(f"6. note only: {len(note_only_keys)}")
    print(f"7. lab & note: {len(lab_note_keys)}")
    print(f"8. total: {len(total_keys)}")

    total_outputs = {}
    for key in total_keys:
        ### Unimodal feature extraction
        # 1. tab total
        tab_features = unimodal_process(tab_output, key) if key in tab_keys else None
        # 2. lab total
        lab_features = unimodal_process(lab_output, key) if key in lab_keys else None
        # 3. note total
        note_features = unimodal_process(note_output, key) if key in note_keys else None    
        # 4. tab only
        tab_only_features = unimodal_process(tab_output, key) if key in tab_only_keys else None
        
        ### Bimodal feature extraction
        # 5. lab only : lab_only -> lab_only를 tab이 본 관점
        lab_only_lab_features, lab_only_tab_features, lab_only_features = bimodal_process(lab_output, tab_output, key, weight1=0.5, weight2=0.5) if key in lab_only_keys else (None, None, None)
        # lab_only_lab_features, lab_only_tab_features, lab_only_features = bimodal_process(lab_output, tab_output, key, weight1=lab_score, weight2=tab_score) if key in lab_only_keys else (None, None, None)
        
        # 6. note only : note_only -> note_only를 tab이 본 관점
        note_only_note_features, note_only_tab_features, note_only_features = bimodal_process(note_output, tab_output, key, weight1=0.5, weight2=0.5) if key in note_only_keys else (None, None, None)
        # note_only_note_features, note_only_tab_features, note_only_features = bimodal_process(note_output, tab_output, key, weight1=note_score, weight2=tab_score) if key in note_only_keys else (None, None, None)

        ### Trimodal feature extraction
        # 7. lab ∩ note
        mm_tab_features, mm_lab_features, mm_note_features, mm_tab_lab_features, mm_tab_note_features, mm_lab_note_features, mm_all_features = trimodal_process(tab_output, lab_output, note_output, key, weight1=0.5, weight2=0.5, weight3=0.5) if key in lab_note_keys else (None, None, None, None, None, None, None)
        # mm_tab_features, mm_lab_features, mm_note_features, mm_tab_lab_features, mm_tab_note_features, mm_lab_note_features, mm_all_features = trimodal_process(tab_output, lab_output, note_output, key, weight1=tab_score, weight2=lab_score, weight3=note_score) if key in lab_note_keys else (None, None, None, None, None, None, None)

        # 8. total
        total_features = whole_process(tab_output, lab_output, note_output, key)
        
        # 9. Fusion - Sum
        fusion_sum_prob = sum_output.get(int(key)).get('probs')
        
        # 10. Fusion - Weighted Sum
        fusion_weighted_sum_prob = weighted_sum_output.get(int(key)).get('probs')
        
        # 11. Fusion - Attn Masked
        fusion_attn_masked_prob = attn_masked_output.get(int(key)).get('probs')
        
        # Label
        label = tab_output.get(key)['labels']
        
        # Store results
        total_outputs[key] = {
            'tabular': tab_features,
            'lab': lab_features,
            'note': note_features,
            'tab_only': tab_only_features,
            'lab_only_lab': lab_only_lab_features, 'lab_only_tab': lab_only_tab_features, 'lab_only': lab_only_features,
            'note_only_note': note_only_note_features, 'note_only_tab': note_only_tab_features, 'note_only': note_only_features,
            'mm_tab': mm_tab_features, 'mm_lab': mm_lab_features, 'mm_note': mm_note_features,
            'mm_tab_lab': mm_tab_lab_features, 'mm_tab_note': mm_tab_note_features, 'mm_lab_note': mm_lab_note_features,
            'mm_all': mm_all_features,
            'total': total_features,
            'fusion_sum_prob': fusion_sum_prob,
            'fusion_weighted_sum_prob': fusion_weighted_sum_prob,
            'fusion_attn_masked_prob': fusion_attn_masked_prob,
            'labels' : label,
        }
    # Unimodal
    total_scores.append([seed, 'Tabular', *get_score('tabular', total_outputs, proto)])
    total_scores.append([seed, 'Lab', *get_score('lab', total_outputs, proto)])
    total_scores.append([seed, 'Note', *get_score('note', total_outputs, proto)])
    # Bimodal
    total_scores.append([seed, 'Tab Only', *get_score('tab_only', total_outputs, proto)])
    total_scores.append([seed, 'Lab Only', *get_score('lab_only', total_outputs, proto)])
    total_scores.append([seed, 'Note Only', *get_score('note_only', total_outputs, proto)])
    total_scores.append([seed, 'MM Tab', *get_score('mm_tab', total_outputs, proto)])
    total_scores.append([seed, 'MM Lab', *get_score('mm_lab', total_outputs, proto)])
    total_scores.append([seed, 'MM Note', *get_score('mm_note', total_outputs, proto)])
    # Trimodal
    total_scores.append([seed, 'MM Tab & Lab', *get_score('mm_tab_lab', total_outputs, proto)])
    total_scores.append([seed, 'MM Tab & Note', *get_score('mm_tab_note', total_outputs, proto)])
    total_scores.append([seed, 'MM Lab & Note', *get_score('mm_lab_note', total_outputs, proto)])
    total_scores.append([seed, 'MM All', *get_score('mm_all', total_outputs, proto)])
    # Total
    total_scores.append([seed, 'Simple Average', *get_score('total', total_outputs, proto)])
    total_scores.append([seed, 'Fusion - Sum', *get_fusion_score('fusion_sum', total_outputs)])
    total_scores.append([seed, 'Fusion - Weighted Sum', *get_fusion_score('fusion_weighted_sum', total_outputs)])
    total_scores.append([seed, 'Fusion - Attn Masked', *get_fusion_score('fusion_attn_masked', total_outputs)])

total_score_df = pd.DataFrame(total_scores, columns=['Seed', 'Modality', 'ROC-AUC', 'PR-AUC', 'N', 'Number of Cases'])
total_score_df_balanced = total_score_df.copy()

1. tab total: 46018
2. lab total: 36839
3. note total: 27202
4. tab only: 7530
5. lab only: 11286
6. note only: 1649
7. lab & note: 25553
8. total: 46018
tabular - Features shape: (46018, 128), Probs shape: (46018, 2), Labels shape: (46018,)
lab - Features shape: (36839, 128), Probs shape: (36839, 2), Labels shape: (36839,)
note - Features shape: (27202, 128), Probs shape: (27202, 2), Labels shape: (27202,)
tab_only - Features shape: (7530, 128), Probs shape: (7530, 2), Labels shape: (7530,)
lab_only - Features shape: (11286, 128), Probs shape: (11286, 2), Labels shape: (11286,)
note_only - Features shape: (1649, 128), Probs shape: (1649, 2), Labels shape: (1649,)
mm_tab - Features shape: (25553, 128), Probs shape: (25553, 2), Labels shape: (25553,)
mm_lab - Features shape: (25553, 128), Probs shape: (25553, 2), Labels shape: (25553,)
mm_note - Features shape: (25553, 128), Probs shape: (25553, 2), Labels shape: (25553,)
mm_tab_lab - Features shape: (25553, 128), Probs shape: (25553, 2

In [10]:
# avg
avg_scores = total_score_df_balanced.groupby('Modality').agg({'ROC-AUC': ['mean', 'std'], 'PR-AUC': ['mean', 'std'], 'N': 'mean', 'Number of Cases': 'mean'}).reset_index()
avg_scores.columns = ['Modality', 'ROC-AUC Mean', 'ROC-AUC Std', 'PR-AUC Mean', 'PR-AUC Std', 'N', 'Number of Cases']
avg_scores[['N', 'Number of Cases']] = avg_scores[['N', 'Number of Cases']].astype(int)
avg_scores['case_ratio'] = avg_scores['Number of Cases'] / avg_scores['N']
# avg_scores = avg_scores.sort_values(by='ROC-AUC Mean', ascending=False)
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']]

target_modality_and_order = ['Tabular', 'Lab', 'Note', # Unimodal
                             'MM All', # Trimodal
                             'Tab Only', 'Lab Only', 'Note Only', # Bimodal - only
                             'Simple Average', # Total
                             'Fusion - Sum', 'Fusion - Weighted Sum', 'Fusion - Attn Masked' # Fusion
                             ]
avg_scores = avg_scores.set_index('Modality').loc[target_modality_and_order].reset_index()
avg_scores[['ROC-AUC Mean', 'PR-AUC Mean']] = avg_scores[['ROC-AUC Mean', 'PR-AUC Mean']].apply(lambda x : round(x * 100, 2))
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']]
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']].to_csv(f'Results/Whole_{task.capitalize()}.csv', index=False)
avg_scores

Unnamed: 0,Modality,ROC-AUC Mean,ROC-AUC Std,PR-AUC Mean,PR-AUC Std,N,Number of Cases,case_ratio
0,Tabular,67.5,0.003535,43.21,0.005102,46018,11266,0.244817
1,Lab,59.97,0.003978,32.02,0.004675,36801,8665,0.235456
2,Note,63.24,0.005991,34.77,0.004372,27242,6311,0.231664
3,MM All,65.26,0.003228,37.32,0.000894,25577,5993,0.234312
4,Tab Only,76.77,0.007126,59.6,0.009962,7550,2283,0.302384
5,Lab Only,68.99,0.004362,45.03,0.010913,11224,2671,0.237972
6,Note Only,65.81,0.017141,34.37,0.027228,1665,318,0.190991
7,Simple Average,68.72,0.003384,44.0,0.0047,46018,11266,0.244817
8,Fusion - Sum,68.7,0.004586,44.26,0.004951,46018,11266,0.244817
9,Fusion - Weighted Sum,68.39,0.007125,43.99,0.005871,46018,11266,0.244817


In [11]:
# avg
avg_scores = total_score_df_balanced.groupby('Modality').agg({'ROC-AUC': ['mean', 'std'], 'PR-AUC': ['mean', 'std'], 'N': 'mean', 'Number of Cases': 'mean'}).reset_index()
avg_scores.columns = ['Modality', 'ROC-AUC Mean', 'ROC-AUC Std', 'PR-AUC Mean', 'PR-AUC Std', 'N', 'Number of Cases']
avg_scores[['N', 'Number of Cases']] = avg_scores[['N', 'Number of Cases']].astype(int)
avg_scores['case_ratio'] = avg_scores['Number of Cases'] / avg_scores['N']
# avg_scores = avg_scores.sort_values(by='ROC-AUC Mean', ascending=False)
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']]

target_modality_and_order = ['MM Tab', 'MM Lab', 'MM Note', # Bimodal - cross
                             'MM Tab & Lab', 'MM Tab & Note', 'MM Lab & Note', # Trimodal - pair
                            #  'Tabular', 'Lab', 'Note', # Unimodal
                             'MM All', # Trimodal - all
                             ]
avg_scores = avg_scores.set_index('Modality').loc[target_modality_and_order].reset_index()
avg_scores[['ROC-AUC Mean', 'PR-AUC Mean']] = avg_scores[['ROC-AUC Mean', 'PR-AUC Mean']].apply(lambda x : round(x * 100, 2))
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']]
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']].to_csv(f'Results/MM_{task.capitalize()}.csv', index=False)
avg_scores

Unnamed: 0,Modality,ROC-AUC Mean,ROC-AUC Std,PR-AUC Mean,PR-AUC Std,N,Number of Cases,case_ratio
0,MM Tab,62.8,0.002042,34.46,0.002475,25577,5993,0.234312
1,MM Lab,59.01,0.004327,30.84,0.003276,25577,5993,0.234312
2,MM Note,63.1,0.006077,34.89,0.003749,25577,5993,0.234312
3,MM Tab & Lab,63.58,0.002844,35.39,0.001022,25577,5993,0.234312
4,MM Tab & Note,64.86,0.004037,36.97,0.00166,25577,5993,0.234312
5,MM Lab & Note,63.96,0.004562,35.48,0.003814,25577,5993,0.234312
6,MM All,65.26,0.003228,37.32,0.000894,25577,5993,0.234312
