In [13]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import shutil
import torch

from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, auc

In [14]:

def load_pickle(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def load_modality_output(seed, projection_dim, batch_size, modality, task) : 
    result_dir = Path(f'Results/Linear_{modality}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
    output = load_pickle(result_dir / 'outputs' / 'best_epoch.pkl')
    etf = Path(f'Results/ETF/ETF_{projection_dim}_IN_2_OUT_Seed_{seed}.pt')
    etf = torch.load(etf)
    return output, etf

def unimodal_process(output, key):
    data = output.get(key)
    features = data['features']
    return features

def bimodal_process(output1, output2, key, weight1=0.5, weight2=0.5) : 
    data1 = output1.get(key)
    data2 = output2.get(key)
    features1 = data1['features']
    features2 = data2['features']
    # combined_features = (features1 + features2) / 2
    combined_features = (features1 * weight1 + features2 * weight2) / (weight1 + weight2)
    return features1, features2, combined_features

def trimodal_process(output1, output2, output3, key, weight1=0.5, weight2=0.5, weight3=0.5) : 
    data1 = output1.get(key)
    data2 = output2.get(key)
    data3 = output3.get(key)
    features1 = data1['features']
    features2 = data2['features']
    features3 = data3['features']
    feature_12 = (features1 * weight1 + features2 * weight2) / (weight1 + weight2)
    feature_13 = (features1 * weight1 + features3 * weight3) / (weight1 + weight3)
    feature_23 = (features2 * weight2 + features3 * weight3) / (weight2 + weight3)
    combined_features = (features1 * weight1 + features2 * weight2 + features3 * weight3) / (weight1 + weight2 + weight3)
    return features1, features2, features3, feature_12, feature_13, feature_23, combined_features

def quadromodal_process(output1, output2, output3, output4, key, weight1=0.5, weight2=0.5, weight3=0.5, weight4=0.5) :
    data1 = output1.get(key)
    data2 = output2.get(key)
    data3 = output3.get(key)
    data4 = output4.get(key)
    features1 = data1['features']
    features2 = data2['features']
    features3 = data3['features']
    features4 = data4['features']
    combined_features = (features1 * weight1 + features2 * weight2 + features3 * weight3 + features4 * weight4) / (weight1 + weight2 + weight3 + weight4)
    return combined_features

def whole_process(output1, output2, output3, output4, key) : 
    data1 = output1.get(key, None)
    data2 = output2.get(key, None)
    data3 = output3.get(key, None)
    # ignore none and agg the rest
    features_list = [data['features'] for data in [data1, data2, data3] if data is not None]
    features = np.mean(features_list, axis=0) if features_list else None
    return features

def get_output(modality, total_outputs, etf) : 
    labels, features = [], []
    labels = [values['labels'] for values in total_outputs.values() if values[modality] is not None]
    features = [[values[modality]] for values in total_outputs.values() if values[modality] is not None]
    labels = np.array(labels)
    features = np.concatenate(features, axis=0)
    probs = features @ etf.numpy()
    print(f"{modality} - Features shape: {features.shape}, Probs shape: {probs.shape}, Labels shape: {labels.shape}")
    return labels, probs

def get_performance(labels, probs):
    roc_auc = roc_auc_score(labels, probs)
    # pr_auc = average_precision_score(labels, probs)
    precision, recall, _ = precision_recall_curve(labels, probs)
    pr_auc = auc(recall, precision)
    return roc_auc, pr_auc, len(labels)

def get_score(modality, total_outputs, etf) : 
    labels, probs = get_output(modality, total_outputs, etf)
    roc_auc, pr_auc, n = get_performance(labels, probs[:,1])
    number_of_cases = (labels == 1).sum()
    return roc_auc, pr_auc, n, number_of_cases


In [18]:
projection_dim=128
batch_size=512
seed=2026
task='mortality_90days' # 'mortality_90days', 'readmission_15days'
# task='readmission_15days'



total_scores = []
for seed in range(2026, 2029) : 
    tabular_outputs, proto = load_modality_output(seed, projection_dim, batch_size, 'tabular', task)
    lab_outputs, _ = load_modality_output(seed, projection_dim, batch_size, 'lab', task)
    note_outputs, _ = load_modality_output(seed, projection_dim, batch_size, 'note', task)
    recon_outputs = load_modality_output(seed, projection_dim, batch_size, 'Multimodal', task)[0]
    split_mode = 'test'

    tab_output = tabular_outputs[split_mode]['outputs']
    lab_output = lab_outputs[split_mode]['outputs']
    note_output = note_outputs[split_mode]['outputs']
    recon_output = recon_outputs[split_mode]
    del recon_output['auroc'], recon_output['auprc'], recon_output['accuracy']

    tab_keys = set(tab_output.keys())
    lab_keys = set(lab_output.keys())
    note_keys = set(note_output.keys())
    recon_keys = set(recon_output.keys())
    print(f"Tabular keys: {len(tab_keys)}, Lab keys: {len(lab_keys)}, Note keys: {len(note_keys)}, Recon keys: {len(recon_keys)}")
    # 3개의 조합이면, 나올 수 있는 경우의 수
    ## Tab > Lab & Tab > Note 포함관계
    # 1.tab_total
    tab_keys
    # 2.lab_total
    lab_keys
    # 3.note_total
    note_keys
    # 4. tab - lab - note
    tab_only_keys = tab_keys - lab_keys - note_keys
    # 5. lab - note
    lab_only_keys = lab_keys - note_keys
    # 6. note - lab
    note_only_keys = note_keys - lab_keys
    # 7. lab ∩ note
    lab_note_keys = lab_keys.intersection(note_keys)
    # 8. total
    total_keys = tab_keys.union(lab_keys).union(note_keys)

    print(f"1. tab total: {len(tab_keys)}")
    print(f"2. lab total: {len(lab_keys)}")
    print(f"3. note total: {len(note_keys)}")
    print(f"4. tab only: {len(tab_only_keys)}")
    print(f"5. lab only: {len(lab_only_keys)}")
    print(f"6. note only: {len(note_only_keys)}")
    print(f"7. lab & note: {len(lab_note_keys)}")
    print(f"8. total: {len(total_keys)}")

    total_outputs = {}
    for key in total_keys:
        ### Unimodal feature extraction
        # 1. tab total
        tab_features = unimodal_process(tab_output, key) if key in tab_keys else None
        # 2. lab total
        lab_features = unimodal_process(lab_output, key) if key in lab_keys else None
        # 3. note total
        note_features = unimodal_process(note_output, key) if key in note_keys else None    
        # 4. tab only
        tab_only_features = unimodal_process(tab_output, key) if key in tab_only_keys else None
        
        ### Bimodal feature extraction
        # 5. lab only : lab_only -> lab_only를 tab이 본 관점
        lab_only_lab_features, lab_only_tab_features, lab_only_features = bimodal_process(lab_output, tab_output, key, weight1=0.5, weight2=0.5) if key in lab_only_keys else (None, None, None)
        # lab_only_lab_features, lab_only_tab_features, lab_only_features = bimodal_process(lab_output, tab_output, key, weight1=lab_score, weight2=tab_score) if key in lab_only_keys else (None, None, None)
        
        # 6. note only : note_only -> note_only를 tab이 본 관점
        note_only_note_features, note_only_tab_features, note_only_features = bimodal_process(note_output, tab_output, key, weight1=0.5, weight2=0.5) if key in note_only_keys else (None, None, None)
        # note_only_note_features, note_only_tab_features, note_only_features = bimodal_process(note_output, tab_output, key, weight1=note_score, weight2=tab_score) if key in note_only_keys else (None, None, None)

        ### Trimodal feature extraction
        # 7. lab ∩ note
        mm_tab_features, mm_lab_features, mm_note_features, mm_tab_lab_features, mm_tab_note_features, mm_lab_note_features, mm_all_features = trimodal_process(tab_output, lab_output, note_output, key, weight1=0.5, weight2=0.5, weight3=0.5) if key in lab_note_keys else (None, None, None, None, None, None, None)
        # mm_tab_features, mm_lab_features, mm_note_features, mm_tab_lab_features, mm_tab_note_features, mm_lab_note_features, mm_all_features = trimodal_process(tab_output, lab_output, note_output, key, weight1=tab_score, weight2=lab_score, weight3=note_score) if key in lab_note_keys else (None, None, None, None, None, None, None)

        recon_total_features = quadromodal_process(tab_output, lab_output, note_output, recon_output, key, weight1=0.5, weight2=0.5, weight3=0.5, weight4=0.5) if key in lab_note_keys else None
        
        # 8. total
        total_features = whole_process(tab_output, lab_output, note_output, recon_output, key)
        
        # Label
        label = tab_output.get(key)['labels']
        
        # Store results
        total_outputs[key] = {
            'tabular': tab_features,
            'lab': lab_features,
            'note': note_features,
            'tab_only': tab_only_features,
            'lab_only_lab': lab_only_lab_features, 'lab_only_tab': lab_only_tab_features, 'lab_only': lab_only_features,
            'note_only_note': note_only_note_features, 'note_only_tab': note_only_tab_features, 'note_only': note_only_features,
            'mm_tab': mm_tab_features, 'mm_lab': mm_lab_features, 'mm_note': mm_note_features,
            'mm_tab_lab': mm_tab_lab_features, 'mm_tab_note': mm_tab_note_features, 'mm_lab_note': mm_lab_note_features,
            'mm_all': mm_all_features,
            'recon_total': recon_total_features,
            'total': total_features,
            'labels' : label,
        }

        

    # modality keys : 
    ## 1. unimodal : tabular, lab, note
    ## 2. bimodal : tab_only, lab_only, note_only, mm_tab, mm_lab, mm_note
    ## 3. trimodal : mm_tab_lab, mm_tab_note, mm_lab_note, mm_all
    ## 4. total : total


    # Unimodal
    total_scores.append([seed, 'Tabular', *get_score('tabular', total_outputs, proto)])
    total_scores.append([seed, 'Lab', *get_score('lab', total_outputs, proto)])
    total_scores.append([seed, 'Note', *get_score('note', total_outputs, proto)])
    # Bimodal
    total_scores.append([seed, 'Tab Only', *get_score('tab_only', total_outputs, proto)])
    total_scores.append([seed, 'Lab Only', *get_score('lab_only', total_outputs, proto)])
    total_scores.append([seed, 'Note Only', *get_score('note_only', total_outputs, proto)])
    total_scores.append([seed, 'MM Tab', *get_score('mm_tab', total_outputs, proto)])
    total_scores.append([seed, 'MM Lab', *get_score('mm_lab', total_outputs, proto)])
    total_scores.append([seed, 'MM Note', *get_score('mm_note', total_outputs, proto)])
    # Trimodal
    total_scores.append([seed, 'MM Tab & Lab', *get_score('mm_tab_lab', total_outputs, proto)])
    total_scores.append([seed, 'MM Tab & Note', *get_score('mm_tab_note', total_outputs, proto)])
    total_scores.append([seed, 'MM Lab & Note', *get_score('mm_lab_note', total_outputs, proto)])
    total_scores.append([seed, 'MM All', *get_score('mm_all', total_outputs, proto)])
    total_scores.append([seed, 'Recon Total', *get_score('recon_total', total_outputs, proto)])
    # Total
    total_scores.append([seed, 'Total', *get_score('total', total_outputs, proto)])
    break

total_score_df = pd.DataFrame(total_scores, columns=['Seed', 'Modality', 'ROC-AUC', 'PR-AUC', 'N', 'Number of Cases'])
total_score_df_balanced = total_score_df.copy()

Tabular keys: 32945, Lab keys: 27286, Note keys: 20273, Recon keys: 32945
1. tab total: 32945
2. lab total: 27286
3. note total: 20273
4. tab only: 4574
5. lab only: 8098
6. note only: 1085
7. lab & note: 19188
8. total: 32945
tabular - Features shape: (32945, 128), Probs shape: (32945, 2), Labels shape: (32945,)
lab - Features shape: (27286, 128), Probs shape: (27286, 2), Labels shape: (27286,)
note - Features shape: (20273, 128), Probs shape: (20273, 2), Labels shape: (20273,)
tab_only - Features shape: (4574, 128), Probs shape: (4574, 2), Labels shape: (4574,)
lab_only - Features shape: (8098, 128), Probs shape: (8098, 2), Labels shape: (8098,)
note_only - Features shape: (1085, 128), Probs shape: (1085, 2), Labels shape: (1085,)
mm_tab - Features shape: (19188, 128), Probs shape: (19188, 2), Labels shape: (19188,)
mm_lab - Features shape: (19188, 128), Probs shape: (19188, 2), Labels shape: (19188,)
mm_note - Features shape: (19188, 128), Probs shape: (19188, 2), Labels shape: (191

In [19]:
# avg
avg_scores = total_score_df_balanced.groupby('Modality').agg({'ROC-AUC': ['mean', 'std'], 'PR-AUC': ['mean', 'std'], 'N': 'mean', 'Number of Cases': 'mean'}).reset_index()
avg_scores.columns = ['Modality', 'ROC-AUC Mean', 'ROC-AUC Std', 'PR-AUC Mean', 'PR-AUC Std', 'N', 'Number of Cases']
avg_scores[['N', 'Number of Cases']] = avg_scores[['N', 'Number of Cases']].astype(int)
# avg_scores = avg_scores.sort_values(by='ROC-AUC Mean', ascending=False)
avg_scores

Unnamed: 0,Modality,ROC-AUC Mean,ROC-AUC Std,PR-AUC Mean,PR-AUC Std,N,Number of Cases
0,Lab,0.743799,,0.30473,,27286,2871
1,Lab Only,0.858032,,0.436915,,8098,687
2,MM All,0.852808,,0.533828,,19188,2184
3,MM Lab,0.733729,,0.313623,,19188,2184
4,MM Lab & Note,0.825667,,0.486805,,19188,2184
5,MM Note,0.816789,,0.470518,,19188,2184
6,MM Tab,0.822762,,0.455683,,19188,2184
7,MM Tab & Lab,0.840173,,0.480592,,19188,2184
8,MM Tab & Note,0.846754,,0.509744,,19188,2184
9,Note,0.821116,,0.475371,,20273,2246


In [14]:
avg_scores.to_csv(f'Results/Whole_{task.capitalize()}.csv', index=False)