In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import shutil
import torch

from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def load_pickle(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def load_modality_output(seed, projection_dim, batch_size, modality, task) : 
    result_dir = Path(f'Results/Linear_{modality}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
    output = load_pickle(result_dir / 'outputs' / 'best_epoch.pkl')
    etf = Path(f'Results/ETF/ETF_{projection_dim}_IN_2_OUT_Seed_{seed}.pt')
    etf = torch.load(etf)
    return output, etf

def unimodal_process(output, key):
    data = output.get(key)
    features = data['features']
    return features

def bimodal_process(output1, output2, key, weight1=0.5, weight2=0.5) : 
    data1 = output1.get(key)
    data2 = output2.get(key)
    features1 = data1['features']
    features2 = data2['features']
    # combined_features = (features1 + features2) / 2
    combined_features = (features1 * weight1 + features2 * weight2) / (weight1 + weight2)
    combined_features = combined_features.flatten()
    return features1, features2, combined_features

def trimodal_process(output1, output2, output3, key, weight1=0.5, weight2=0.5, weight3=0.5) : 
    data1 = output1.get(key)
    data2 = output2.get(key)
    data3 = output3.get(key)
    features1 = data1['features']
    features2 = data2['features']
    features3 = data3['features']
    feature_12 = (features1 * weight1 + features2 * weight2) / (weight1 + weight2)
    feature_13 = (features1 * weight1 + features3 * weight3) / (weight1 + weight3)
    feature_23 = (features2 * weight2 + features3 * weight3) / (weight2 + weight3)
    combined_features = (features1 * weight1 + features2 * weight2 + features3 * weight3) / (weight1 + weight2 + weight3)
    
    feature_12 = feature_12.flatten()
    feature_13 = feature_13.flatten()
    feature_23 = feature_23.flatten()
    combined_features = combined_features.flatten()
    return features1, features2, features3, feature_12, feature_13, feature_23, combined_features

def whole_process(output1, output2, output3, key, weight1, weight2, weight3) : 
    data1 = output1.get(key, None)
    data2 = output2.get(key, None)
    data3 = output3.get(key, None)

    features = []
    for feature, weight in zip([data1, data2, data3], [weight1, weight2, weight3]):
        if feature is not None:
            feat = feature['features']
            feat_weighted = feat * weight
            features.append(feat_weighted)
    
    if features:
        features = np.sum(features, axis=0)

    features = features.flatten()
    return features

def get_output(modality, total_outputs, etf) : 
    labels, features = [], []
    labels = [values['labels'] for values in total_outputs.values() if values[modality] is not None]
    features = [[values[modality]] for values in total_outputs.values() if values[modality] is not None]
    labels = np.array(labels)
    features = np.concatenate(features, axis=0)
    probs = features @ etf.numpy()
    print(f"{modality} - Features shape: {features.shape}, Probs shape: {probs.shape}, Labels shape: {labels.shape}")
    return labels, probs

def get_performance(labels, probs):
    roc_auc = roc_auc_score(labels, probs)
    # pr_auc = average_precision_score(labels, probs)
    precision, recall, _ = precision_recall_curve(labels, probs)
    pr_auc = auc(recall, precision)
    return roc_auc, pr_auc, len(labels)

def get_score(modality, total_outputs, etf) : 
    labels, probs = get_output(modality, total_outputs, etf)
    roc_auc, pr_auc, n = get_performance(labels, probs[:,1])
    number_of_cases = (labels == 1).sum()
    return roc_auc, pr_auc, n, number_of_cases


def get_conf(feature, etf) : 
    feature = np.expand_dims(feature, axis=0)
    feature = feature / np.linalg.norm(feature, axis=1, keepdims=True)
    prob = feature @ etf.numpy()
    energy = np.log(np.sum(np.exp(prob), axis=1))
    conf = energy / 10
    conf = np.reshape(conf ,(-1, 1))
    return conf


: 

In [None]:
projection_dim=128
batch_size=512
seed=2026
task='mortality_90days'
# task='readmission_15days'



total_scores = []
for seed in range(2026, 2029) : 
    
    tabular_outputs, proto = load_modality_output(seed, projection_dim, batch_size, 'tabular', task)
    lab_outputs, _ = load_modality_output(seed, projection_dim, batch_size, 'lab', task)
    note_outputs, _ = load_modality_output(seed, projection_dim, batch_size, 'note', task)
    split_mode = 'test'

    tab_output = tabular_outputs[split_mode]['outputs']
    lab_output = lab_outputs[split_mode]['outputs']
    note_output = note_outputs[split_mode]['outputs']

    tab_keys = set(tab_output.keys())
    lab_keys = set(lab_output.keys())
    note_keys = set(note_output.keys())

    # 3개의 조합이면, 나올 수 있는 경우의 수
    ## Tab > Lab & Tab > Note 포함관계
    # 1.tab_total
    tab_keys
    # 2.lab_total
    lab_keys
    # 3.note_total
    note_keys
    # 4. tab - lab - note
    tab_only_keys = tab_keys - lab_keys - note_keys
    # 5. lab - note
    lab_only_keys = lab_keys - note_keys
    # 6. note - lab
    note_only_keys = note_keys - lab_keys
    # 7. lab ∩ note
    lab_note_keys = lab_keys.intersection(note_keys)
    # 8. total
    total_keys = tab_keys.union(lab_keys).union(note_keys)

    print(f"1. tab total: {len(tab_keys)}")
    print(f"2. lab total: {len(lab_keys)}")
    print(f"3. note total: {len(note_keys)}")
    print(f"4. tab only: {len(tab_only_keys)}")
    print(f"5. lab only: {len(lab_only_keys)}")
    print(f"6. note only: {len(note_only_keys)}")
    print(f"7. lab & note: {len(lab_note_keys)}")
    print(f"8. total: {len(total_keys)}")

    tab_confs = {key: get_conf(unimodal_process(tab_output, key), proto) for key in tab_keys}
    lab_confs = {key: get_conf(unimodal_process(lab_output, key), proto) for key in lab_keys}
    note_confs = {key: get_conf(unimodal_process(note_output, key), proto) for key in note_keys}

    total_outputs = {}
    for key in total_keys:
        ### Unimodal feature extraction
        # 1. tab total
        tab_features = unimodal_process(tab_output, key) if key in tab_keys else None
        # 2. lab total
        lab_features = unimodal_process(lab_output, key) if key in lab_keys else None
        # 3. note total
        note_features = unimodal_process(note_output, key) if key in note_keys else None    
        # 4. tab only
        tab_only_features = unimodal_process(tab_output, key) if key in tab_only_keys else None
        
        ### Bimodal feature extraction
        # 5. lab only : lab_only -> lab_only를 tab이 본 관점
        weight1 = lab_confs[key] if key in lab_confs else 0.5
        weight2 = tab_confs[key] if key in tab_confs else 0.5
        lab_only_lab_features, lab_only_tab_features, lab_only_features = bimodal_process(lab_output, tab_output, key, weight1=weight1, weight2=weight2) if key in lab_only_keys else (None, None, None)
        # lab_only_lab_features, lab_only_tab_features, lab_only_features = bimodal_process(lab_output, tab_output, key, weight1=lab_score, weight2=tab_score) if key in lab_only_keys else (None, None, None)
        
        # 6. note only : note_only -> note_only를 tab이 본 관점
        weight1 = note_confs[key] if key in note_confs else 0.5
        weight2 = tab_confs[key] if key in tab_confs else 0.5
        note_only_note_features, note_only_tab_features, note_only_features = bimodal_process(note_output, tab_output, key, weight1=weight1, weight2=weight2) if key in note_only_keys else (None, None, None)
        # note_only_note_features, note_only_tab_features, note_only_features = bimodal_process(note_output, tab_output, key, weight1=note_score, weight2=tab_score) if key in note_only_keys else (None, None, None)

        ### Trimodal feature extraction
        # 7. lab ∩ note
        weight1 = tab_confs[key] if key in tab_confs else 0.5
        weight2 = lab_confs[key] if key in lab_confs else 0.5
        weight3 = note_confs[key] if key in note_confs else 0.5
        mm_tab_features, mm_lab_features, mm_note_features, mm_tab_lab_features, mm_tab_note_features, mm_lab_note_features, mm_all_features = trimodal_process(tab_output, lab_output, note_output, key, weight1=weight1, weight2=weight2, weight3=weight3) if key in lab_note_keys else (None, None, None, None, None, None, None)
        # mm_tab_features, mm_lab_features, mm_note_features, mm_tab_lab_features, mm_tab_note_features, mm_lab_note_features, mm_all_features = trimodal_process(tab_output, lab_output, note_output, key, weight1=tab_score, weight2=lab_score, weight3=note_score) if key in lab_note_keys else (None, None, None, None, None, None, None)

        # 8. total
        total_features = whole_process(tab_output, lab_output, note_output, key, weight1=weight1, weight2=weight2, weight3=weight3)
        # _, _, _, _, _, _, total_features = trimodal_process(tab_output, lab_output, note_output, key, weight1=weight1, weight2=weight2, weight3=weight3)
        
        # Label
        label = tab_output.get(key)['labels']
        
        # Store results
        total_outputs[key] = {
            'tabular': tab_features,
            'lab': lab_features,
            'note': note_features,
            'tab_only': tab_only_features,
            'lab_only_lab': lab_only_lab_features, 'lab_only_tab': lab_only_tab_features, 'lab_only': lab_only_features,
            'note_only_note': note_only_note_features, 'note_only_tab': note_only_tab_features, 'note_only': note_only_features,
            'mm_tab': mm_tab_features, 'mm_lab': mm_lab_features, 'mm_note': mm_note_features,
            'mm_tab_lab': mm_tab_lab_features, 'mm_tab_note': mm_tab_note_features, 'mm_lab_note': mm_lab_note_features,
            'mm_all': mm_all_features,
            'total': total_features,
            'labels' : label,
        }
        

    # modality keys : 
    ## 1. unimodal : tabular, lab, note
    ## 2. bimodal : tab_only, lab_only, note_only, mm_tab, mm_lab, mm_note
    ## 3. trimodal : mm_tab_lab, mm_tab_note, mm_lab_note, mm_all
    ## 4. total : total


    # Unimodal
    total_scores.append([seed, 'Tabular', *get_score('tabular', total_outputs, proto)])
    total_scores.append([seed, 'Lab', *get_score('lab', total_outputs, proto)])
    total_scores.append([seed, 'Note', *get_score('note', total_outputs, proto)])
    # Bimodal
    total_scores.append([seed, 'Tab Only', *get_score('tab_only', total_outputs, proto)])
    total_scores.append([seed, 'Lab Only', *get_score('lab_only', total_outputs, proto)])
    total_scores.append([seed, 'Note Only', *get_score('note_only', total_outputs, proto)])
    total_scores.append([seed, 'MM Tab', *get_score('mm_tab', total_outputs, proto)])
    total_scores.append([seed, 'MM Lab', *get_score('mm_lab', total_outputs, proto)])
    total_scores.append([seed, 'MM Note', *get_score('mm_note', total_outputs, proto)])
    # Trimodal
    total_scores.append([seed, 'MM Tab & Lab', *get_score('mm_tab_lab', total_outputs, proto)])
    total_scores.append([seed, 'MM Tab & Note', *get_score('mm_tab_note', total_outputs, proto)])
    total_scores.append([seed, 'MM Lab & Note', *get_score('mm_lab_note', total_outputs, proto)])
    total_scores.append([seed, 'MM All', *get_score('mm_all', total_outputs, proto)])
    # Total
    total_scores.append([seed, 'Total', *get_score('total', total_outputs, proto)])

total_score_df = pd.DataFrame(total_scores, columns=['Seed', 'Modality', 'ROC-AUC', 'PR-AUC', 'N', 'Number of Cases'])
total_score_df_balanced = total_score_df.copy()

1. tab total: 46018
2. lab total: 36839
3. note total: 27202
4. tab only: 7530
5. lab only: 11286
6. note only: 1649
7. lab & note: 25553
8. total: 46018
tabular - Features shape: (46018, 128), Probs shape: (46018, 2), Labels shape: (46018,)
lab - Features shape: (36839, 128), Probs shape: (36839, 2), Labels shape: (36839,)
note - Features shape: (27202, 128), Probs shape: (27202, 2), Labels shape: (27202,)
tab_only - Features shape: (7530, 128), Probs shape: (7530, 2), Labels shape: (7530,)
lab_only - Features shape: (11286, 128), Probs shape: (11286, 2), Labels shape: (11286,)
note_only - Features shape: (1649, 128), Probs shape: (1649, 2), Labels shape: (1649,)
mm_tab - Features shape: (25553, 128), Probs shape: (25553, 2), Labels shape: (25553,)
mm_lab - Features shape: (25553, 128), Probs shape: (25553, 2), Labels shape: (25553,)
mm_note - Features shape: (25553, 128), Probs shape: (25553, 2), Labels shape: (25553,)
mm_tab_lab - Features shape: (25553, 128), Probs shape: (25553, 2

: 

In [70]:
# avg
avg_scores = total_score_df_balanced.groupby('Modality').agg({'ROC-AUC': ['mean', 'std'], 'PR-AUC': ['mean', 'std'], 'N': 'mean', 'Number of Cases': 'mean'}).reset_index()
avg_scores.columns = ['Modality', 'ROC-AUC Mean', 'ROC-AUC Std', 'PR-AUC Mean', 'PR-AUC Std', 'N', 'Number of Cases']
avg_scores[['N', 'Number of Cases']] = avg_scores[['N', 'Number of Cases']].astype(int)
avg_scores['case_ratio'] = avg_scores['Number of Cases'] / avg_scores['N']
# avg_scores = avg_scores.sort_values(by='ROC-AUC Mean', ascending=False)
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']]

target_modality_and_order = ['Tabular', 'Lab', 'Note', # Unimodal
                             'MM All', # Trimodal
                             'Tab Only', 'Lab Only', 'Note Only', # Bimodal - only
                             'Total', # Total
                             ]
avg_scores = avg_scores.set_index('Modality').loc[target_modality_and_order].reset_index()
avg_scores[['ROC-AUC Mean', 'PR-AUC Mean']] = avg_scores[['ROC-AUC Mean', 'PR-AUC Mean']].apply(lambda x : round(x * 100, 2))
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']]
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']].to_csv(f'Results/Whole_{task.capitalize()}.csv', index=False)
avg_scores

Unnamed: 0,Modality,ROC-AUC Mean,ROC-AUC Std,PR-AUC Mean,PR-AUC Std,N,Number of Cases,case_ratio
0,Tabular,67.6,,43.19,,46018,11166,0.242644
1,Lab,59.88,,31.89,,36839,8610,0.23372
2,Note,63.06,,34.34,,27202,6222,0.228733
3,MM All,65.07,,36.95,,25553,5927,0.231949
4,Tab Only,76.4,,58.92,,7530,2261,0.300266
5,Lab Only,69.19,,44.72,,11286,2683,0.237728
6,Note Only,65.64,,31.35,,1649,295,0.178896
7,Total,65.92,,41.79,,46018,11166,0.242644


In [34]:
# avg
avg_scores = total_score_df_balanced.groupby('Modality').agg({'ROC-AUC': ['mean', 'std'], 'PR-AUC': ['mean', 'std'], 'N': 'mean', 'Number of Cases': 'mean'}).reset_index()
avg_scores.columns = ['Modality', 'ROC-AUC Mean', 'ROC-AUC Std', 'PR-AUC Mean', 'PR-AUC Std', 'N', 'Number of Cases']
avg_scores[['N', 'Number of Cases']] = avg_scores[['N', 'Number of Cases']].astype(int)
avg_scores['case_ratio'] = avg_scores['Number of Cases'] / avg_scores['N']
# avg_scores = avg_scores.sort_values(by='ROC-AUC Mean', ascending=False)
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']]

target_modality_and_order = ['MM Tab', 'MM Lab', 'MM Note', # Bimodal - cross
                             'MM Tab & Lab', 'MM Tab & Note', 'MM Lab & Note', # Trimodal - pair
                             'Tabular', 'Lab', 'Note', # Unimodal
                             'MM All', # Trimodal - all
                             ]
avg_scores = avg_scores.set_index('Modality').loc[target_modality_and_order].reset_index()
avg_scores[['ROC-AUC Mean', 'PR-AUC Mean']] = avg_scores[['ROC-AUC Mean', 'PR-AUC Mean']].apply(lambda x : round(x * 100, 2))
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']]
avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC Mean', 'PR-AUC Mean']].to_csv(f'Results/MM_{task.capitalize()}.csv', index=False)
avg_scores

Unnamed: 0,Modality,ROC-AUC Mean,ROC-AUC Std,PR-AUC Mean,PR-AUC Std,N,Number of Cases,case_ratio
0,MM Tab,63.08,0.001304,34.99,0.003504,25577,5993,0.234312
1,MM Lab,59.15,0.002796,30.84,0.004195,25577,5993,0.234312
2,MM Note,63.1,0.006077,34.89,0.003749,25577,5993,0.234312
3,MM Tab & Lab,63.8,0.002406,35.66,0.003493,25577,5993,0.234312
4,MM Tab & Note,65.06,0.001727,37.16,0.001695,25577,5993,0.234312
5,MM Lab & Note,63.98,0.004199,35.56,0.003691,25577,5993,0.234312
6,Tabular,67.65,0.000919,43.57,0.003401,46018,11266,0.244817
7,Lab,60.07,0.001742,32.23,0.003371,36801,8665,0.235456
8,Note,63.24,0.005991,34.77,0.004372,27242,6311,0.231664
9,MM All,65.42,0.00244,37.41,0.002314,25577,5993,0.234312
