In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import shutil
import torch

from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from sklearn.metrics import log_loss, brier_score_loss

In [None]:
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / e_x.sum(axis=1, keepdims=True)

def softmax_by_sample(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()
    

def load_pickle(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def load_modality_output(seed, projection_dim, batch_size, modality, task) : 
    result_dir = Path(f'Results/Linear_{modality}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
    output = load_pickle(result_dir / 'outputs' / 'best_epoch.pkl')
    etf = Path(f'Results/ETF/ETF_{projection_dim}_IN_2_OUT_Seed_{seed}.pt')
    etf = torch.load(etf)
    return output, etf

def load_fusion_ouptut(seed, projection_dim, batch_size, fusion_method, task) : 
    result_dir = Path(f'Results/Linear_MultiModal/Fusion_{fusion_method}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
    output = load_pickle(result_dir / 'outputs' / 'best_epoch.pkl')    
    return output

def load_e2e_mm_output(seed, projection_dim, batch_size, fusion_method, task) : 
    result_dir = Path(f'Results/Linear_MultiModal/E2E_Fusion_{fusion_method}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
    output = load_pickle(result_dir / 'outputs' / 'best_epoch.pkl')    
    return output

def unimodal_process(output, key):
    data = output.get(key)
    features = data['features']
    return features

def bimodal_process(output1, output2, key, weight1=0.5, weight2=0.5) : 
    data1 = output1.get(key)
    data2 = output2.get(key)
    features1 = data1['features']
    features2 = data2['features']
    # combined_features = (features1 + features2) / 2
    combined_features = (features1 * weight1 + features2 * weight2) / (weight1 + weight2)
    return features1, features2, combined_features

def trimodal_process(output1, output2, output3, key, weight1=0.5, weight2=0.5, weight3=0.5) : 
    data1 = output1.get(key)
    data2 = output2.get(key)
    data3 = output3.get(key)
    features1 = data1['features']
    features2 = data2['features']
    features3 = data3['features']
    feature_12 = (features1 * weight1 + features2 * weight2) / (weight1 + weight2)
    feature_13 = (features1 * weight1 + features3 * weight3) / (weight1 + weight3)
    feature_23 = (features2 * weight2 + features3 * weight3) / (weight2 + weight3)
    combined_features = (features1 * weight1 + features2 * weight2 + features3 * weight3) / (weight1 + weight2 + weight3)
    return features1, features2, features3, feature_12, feature_13, feature_23, combined_features

def whole_process(output1, output2, output3, key) : 
    data1 = output1.get(key, None)
    data2 = output2.get(key, None)
    data3 = output3.get(key, None)
    # ignore none and agg the rest
    features_list = [data['features'] for data in [data1, data2, data3] if data is not None]
    features = np.mean(features_list, axis=0) if features_list else None
    return features

def get_output(modality, total_outputs, etf) : 
    labels, features = [], []
    labels = [values['labels'] for values in total_outputs.values() if values[modality] is not None]
    features = [[values[modality]] for values in total_outputs.values() if values[modality] is not None]
    labels = np.array(labels)
    features = np.concatenate(features, axis=0)
    probs = features @ etf.numpy()
    probs = softmax(probs)
    print(f"{modality} - Features shape: {features.shape}, Probs shape: {probs.shape}, Labels shape: {labels.shape}")
    return labels, probs

def get_performance(labels, probs):
    roc_auc = roc_auc_score(labels, probs)
    # pr_auc = average_precision_score(labels, probs)
    precision, recall, _ = precision_recall_curve(labels, probs)
    pr_auc = auc(recall, precision)
    return roc_auc, pr_auc, len(labels)

def get_score(modality, total_outputs, etf) : 
    labels, probs = get_output(modality, total_outputs, etf)
    roc_auc, pr_auc, n = get_performance(labels, probs[:,1])
    number_of_cases = (labels == 1).sum()
    nll_loss = log_loss(labels, probs)
    brier_loss = brier_score_loss(labels, probs[:,1])
    return roc_auc, pr_auc, n, number_of_cases, nll_loss, brier_loss

def get_fusion_score(fusion_method, total_outputs) : 
    labels, probs = [], []
    labels = [values['labels'] for values in total_outputs.values()]
    probs = [values[fusion_method + '_prob'] for values in total_outputs.values()]
    labels = np.array(labels)
    probs = np.stack(probs, axis=0)
    print(f"{fusion_method} - Probs shape: {probs.shape}, Labels shape: {labels.shape}")
    roc_auc, pr_auc, n = get_performance(labels, probs[:,1])
    number_of_cases = (labels == 1).sum()
    nll_loss = log_loss(labels, probs)
    brier_loss = brier_score_loss(labels, probs[:,1])
    return roc_auc, pr_auc, n, number_of_cases, nll_loss, brier_loss

In [None]:
projection_dim=128
batch_size=512
seed=2026
task='mortality_90days'
# task='readmission_15days'

split_mode='test'


total_scores = []
for seed in range(2026, 2029) : 

    # Unimodal
    tabular_outputs, proto = load_modality_output(seed, projection_dim, batch_size, 'tabular', task)
    lab_outputs, _ = load_modality_output(seed, projection_dim, batch_size, 'lab', task)
    note_outputs, _ = load_modality_output(seed, projection_dim, batch_size, 'note', task)
    tab_output = tabular_outputs[split_mode]['outputs']
    lab_output = lab_outputs[split_mode]['outputs']
    note_output = note_outputs[split_mode]['outputs']
    
    # MM - FUSION / Extra Training
    sum_outputs = load_fusion_ouptut(seed, projection_dim, batch_size, 'Sum', task)
    weighted_sum_outputs = load_fusion_ouptut(seed, projection_dim, batch_size, 'WeightedFusion', task)
    attn_masked_outputs = load_fusion_ouptut(seed, projection_dim, batch_size, 'AttnMaskedFusion', task)
    sum_output = sum_outputs[split_mode]['outputs']
    weighted_sum_output = weighted_sum_outputs[split_mode]['outputs']
    attn_masked_output = attn_masked_outputs[split_mode]['outputs']
    
    # # E2E FUSION / End-to-end trainig
    e2e_sum_outputs = load_e2e_mm_output(seed, projection_dim, batch_size, 'Sum', task)
    e2e_sum_output = e2e_sum_outputs[split_mode]['outputs']
    e2e_weighted_sum_outputs = load_e2e_mm_output(seed, projection_dim, batch_size, 'WeightedFusion', task)
    e2e_weighted_sum_output = e2e_weighted_sum_outputs[split_mode]['outputs']
    e2e_attn_masked_outputs = load_e2e_mm_output(seed, projection_dim, batch_size, 'AttnMaskedFusion', task)
    e2e_attn_masked_output = e2e_attn_masked_outputs[split_mode]['outputs']
    
    ## PERFORMANCE EVALATION
    split_mode = 'test'

    tab_keys = set(tab_output.keys())
    lab_keys = set(lab_output.keys())
    note_keys = set(note_output.keys())

    # 3개의 조합이면, 나올 수 있는 경우의 수
    ## Tab > Lab & Tab > Note 포함관계
    # 1.tab_total
    tab_keys
    # 2.lab_total
    lab_keys
    # 3.note_total
    note_keys
    # 4. tab - lab - note
    tab_only_keys = tab_keys - lab_keys - note_keys
    # 5. lab - note
    lab_only_keys = lab_keys - note_keys
    # 6. note - lab
    note_only_keys = note_keys - lab_keys
    # 7. lab ∩ note
    lab_note_keys = lab_keys.intersection(note_keys)
    # 8. total
    total_keys = tab_keys.union(lab_keys).union(note_keys)

    print(f"1. tab total: {len(tab_keys)}")
    print(f"2. lab total: {len(lab_keys)}")
    print(f"3. note total: {len(note_keys)}")
    print(f"4. tab only: {len(tab_only_keys)}")
    print(f"5. lab only: {len(lab_only_keys)}")
    print(f"6. note only: {len(note_only_keys)}")
    print(f"7. lab & note: {len(lab_note_keys)}")
    print(f"8. total: {len(total_keys)}")

    total_outputs = {}
    for key in total_keys:
        ### Unimodal feature extraction
        # 1. tab total
        tab_features = unimodal_process(tab_output, key) if key in tab_keys else None
        # 2. lab total
        lab_features = unimodal_process(lab_output, key) if key in lab_keys else None
        # 3. note total
        note_features = unimodal_process(note_output, key) if key in note_keys else None    
        # 4. tab only
        tab_only_features = unimodal_process(tab_output, key) if key in tab_only_keys else None
        
        ### Bimodal feature extraction
        # 5. lab only : lab_only -> lab_only를 tab이 본 관점
        lab_only_lab_features, lab_only_tab_features, lab_only_features = bimodal_process(lab_output, tab_output, key, weight1=0.5, weight2=0.5) if key in lab_only_keys else (None, None, None)
        
        # 6. note only : note_only -> note_only를 tab이 본 관점
        note_only_note_features, note_only_tab_features, note_only_features = bimodal_process(note_output, tab_output, key, weight1=0.5, weight2=0.5) if key in note_only_keys else (None, None, None)
        
        ### Trimodal feature extraction
        # 7. lab ∩ note
        mm_tab_features, mm_lab_features, mm_note_features, mm_tab_lab_features, mm_tab_note_features, mm_lab_note_features, mm_all_features = trimodal_process(tab_output, lab_output, note_output, key, weight1=0.5, weight2=0.5, weight3=0.5) if key in lab_note_keys else (None, None, None, None, None, None, None)
        
        # 8. total
        total_features = whole_process(tab_output, lab_output, note_output, key)
        
        # 9. Fusion - Sum
        fusion_sum_prob = sum_output.get(int(key)).get('probs')
        fusion_sum_prob = softmax_by_sample(fusion_sum_prob)
        
        # 10. Fusion - Weighted Sum
        fusion_weighted_sum_prob = weighted_sum_output.get(int(key)).get('probs')
        fusion_weighted_sum_prob = softmax_by_sample(fusion_weighted_sum_prob)
        
        # 11. Fusion - Attn Masked
        fusion_attn_masked_prob = attn_masked_output.get(int(key)).get('probs')
        fusion_attn_masked_prob = softmax_by_sample(fusion_attn_masked_prob)
        
        # # 12. E2E Fusion - Sum
        e2e_fusion_sum_prob = e2e_sum_output.get(key).get('probs')
        e2e_fusion_sum_prob = softmax_by_sample(e2e_fusion_sum_prob)
        
        # # 13. E2E Fusion - Weighted Sum
        e2e_fusion_weighted_sum_prob = e2e_weighted_sum_output.get(key).get('probs')
        e2e_fusion_weighted_sum_prob = softmax_by_sample(e2e_fusion_weighted_sum_prob)
        
        # # # 14. E2E Fusion - Attn Masked
        e2e_fusion_attn_masked_prob = e2e_attn_masked_output.get(key).get('probs')
        e2e_fusion_attn_masked_prob = softmax_by_sample(e2e_fusion_attn_masked_prob)
        
        # Label
        label = tab_output.get(key)['labels']
        
        # Store results
        total_outputs[key] = {
            'tabular': tab_features,
            'lab': lab_features,
            'note': note_features,
            'tab_only': tab_only_features,
            'lab_only_lab': lab_only_lab_features, 'lab_only_tab': lab_only_tab_features, 'lab_only': lab_only_features,
            'note_only_note': note_only_note_features, 'note_only_tab': note_only_tab_features, 'note_only': note_only_features,
            'mm_tab': mm_tab_features, 'mm_lab': mm_lab_features, 'mm_note': mm_note_features,
            'mm_tab_lab': mm_tab_lab_features, 'mm_tab_note': mm_tab_note_features, 'mm_lab_note': mm_lab_note_features,
            'mm_all': mm_all_features,
            'total': total_features,
            'fusion_sum_prob': fusion_sum_prob,
            'fusion_weighted_sum_prob': fusion_weighted_sum_prob,
            'fusion_attn_masked_prob': fusion_attn_masked_prob,
            'E2E_fusion_sum_prob' : e2e_fusion_sum_prob,
            'E2E_fusion_weighted_sum_prob': e2e_fusion_weighted_sum_prob,
            'E2E_fusion_attn_masked_prob': e2e_fusion_attn_masked_prob,
            'labels' : label,
        }
    # Unimodal
    total_scores.append([seed, 'Tabular', *get_score('tabular', total_outputs, proto)])
    total_scores.append([seed, 'Lab', *get_score('lab', total_outputs, proto)])
    total_scores.append([seed, 'Note', *get_score('note', total_outputs, proto)])
    # Bimodal
    total_scores.append([seed, 'Tab Only', *get_score('tab_only', total_outputs, proto)])
    total_scores.append([seed, 'Lab Only', *get_score('lab_only', total_outputs, proto)])
    total_scores.append([seed, 'Note Only', *get_score('note_only', total_outputs, proto)])
    total_scores.append([seed, 'MM Tab', *get_score('mm_tab', total_outputs, proto)])
    total_scores.append([seed, 'MM Lab', *get_score('mm_lab', total_outputs, proto)])
    total_scores.append([seed, 'MM Note', *get_score('mm_note', total_outputs, proto)])
    # Trimodal
    total_scores.append([seed, 'MM Tab & Lab', *get_score('mm_tab_lab', total_outputs, proto)])
    total_scores.append([seed, 'MM Tab & Note', *get_score('mm_tab_note', total_outputs, proto)])
    total_scores.append([seed, 'MM Lab & Note', *get_score('mm_lab_note', total_outputs, proto)])
    total_scores.append([seed, 'MM All', *get_score('mm_all', total_outputs, proto)])
    # Total
    total_scores.append([seed, 'Simple Average', *get_score('total', total_outputs, proto)])
    total_scores.append([seed, 'Fusion - Sum', *get_fusion_score('fusion_sum', total_outputs)])
    total_scores.append([seed, 'Fusion - Weighted Sum', *get_fusion_score('fusion_weighted_sum', total_outputs)])
    total_scores.append([seed, 'Fusion - Attn Masked', *get_fusion_score('fusion_attn_masked', total_outputs)])
    total_scores.append([seed, 'E2E Fusion - Sum', *get_fusion_score('E2E_fusion_sum', total_outputs)])
    total_scores.append([seed, 'E2E Fusion - Weighted Sum', *get_fusion_score('E2E_fusion_weighted_sum', total_outputs)])
    total_scores.append([seed, 'E2E Fusion - Attn Masked', *get_fusion_score('E2E_fusion_attn_masked', total_outputs)])

total_score_df = pd.DataFrame(total_scores, columns=['Seed', 'Modality', 'ROC-AUC', 'PR-AUC', 'N', 'Number of Cases', 'NLL Loss', 'Brier Loss'])
total_score_df_balanced = total_score_df.copy()

In [None]:
avg_scores = total_score_df_balanced.groupby('Modality').mean().reset_index().drop(columns=['Seed'])
avg_scores.columns = ['Modality', 'ROC-AUC', 'PR-AUC', 'N', 'Number of Cases', 'NLL Loss', 'Brier Loss']
avg_scores[['N', 'Number of Cases']] = avg_scores[['N', 'Number of Cases']].astype(int)
avg_scores[['Brier Loss', 'NLL Loss']] = avg_scores[['Brier Loss', 'NLL Loss']].apply(lambda x : round(x, 4))
avg_scores['case_ratio'] = avg_scores['Number of Cases'] / avg_scores['N']
# avg_scores = avg_scores.sort_values(by='ROC-AUC Mean', ascending=False),
avg_scores = avg_scores[['Modality', 'Number of Cases', 'N', 'case_ratio', 'ROC-AUC', 'PR-AUC', 'Brier Loss', 'NLL Loss']]

target_modality_and_order = [
                            'Tabular', 'Lab', 'Note', # Unimodal
                            'E2E Fusion - Sum',
                            'E2E Fusion - Weighted Sum',
                            'E2E Fusion - Attn Masked',
                            # 'MM All', # Trimodal
                            # 'Tab Only', 'Lab Only', 'Note Only', # Bimodal - only
                            'Simple Average', # Total
                            'Fusion - Sum', 
                            'Fusion - Weighted Sum', 
                            'Fusion - Attn Masked', # Fusion,
                             ]
avg_scores = avg_scores.set_index('Modality').loc[target_modality_and_order].reset_index()
avg_scores[['ROC-AUC', 'PR-AUC']] = avg_scores[['ROC-AUC', 'PR-AUC']].apply(lambda x : round(x * 100, 2))
avg_scores.to_csv(f'Results/Whole_{task.capitalize()}.csv', index=False)
avg_scores

In [None]:
# avg
avg_scores = total_score_df_balanced.groupby('Modality').mean().reset_index().drop(columns=['Seed'])
avg_scores.columns = ['Modality', 'ROC-AUC', 'PR-AUC', 'N', 'Number of Cases', 'NLL Loss', 'Brier Loss']

target_modality_and_order = ['MM Tab', 'MM Lab', 'MM Note', # Bimodal - cross
                             'MM Tab & Lab', 'MM Tab & Note', 'MM Lab & Note', # Trimodal - pair
                             'MM All', # Trimodal - all
                             ]
avg_scores = avg_scores.set_index('Modality').loc[target_modality_and_order].reset_index()
avg_scores['AVG'] = avg_scores[['ROC-AUC', 'PR-AUC']].mean(axis=1)
avg_scores[['ROC-AUC', 'PR-AUC', 'AVG']] = avg_scores[['ROC-AUC', 'PR-AUC', 'AVG']].apply(lambda x : round(x * 100, 2))
avg_scores[['NLL Loss', 'Brier Loss']] = avg_scores[['NLL Loss', 'Brier Loss']].apply(lambda x : round(x,4))
avg_scores = avg_scores[['Modality', 'ROC-AUC', 'PR-AUC', 'AVG', 'NLL Loss', 'Brier Loss']]
avg_scores.to_csv(f'Results/MM_{task.capitalize()}.csv', index=False)
avg_scores

In [None]:
# avg
avg_scores = total_score_df_balanced.groupby('Modality').mean().reset_index().drop(columns=['Seed'])
avg_scores.columns = ['Modality', 'ROC-AUC', 'PR-AUC', 'N', 'Number of Cases', 'NLL Loss', 'Brier Loss']

target_modality_and_order = ['Tabular', 'Lab', 'Note', # Unimodal
                             'Tab Only', 'Lab Only', 'Note Only', # Unimodal - no intersection
                             ]
avg_scores = avg_scores.set_index('Modality').loc[target_modality_and_order].reset_index()
avg_scores['AVG'] = avg_scores[['ROC-AUC', 'PR-AUC']].mean(axis=1)
avg_scores[['ROC-AUC', 'PR-AUC', 'AVG']] = avg_scores[['ROC-AUC', 'PR-AUC', 'AVG']].apply(lambda x : round(x * 100, 2))
avg_scores[['NLL Loss', 'Brier Loss']] = avg_scores[['NLL Loss', 'Brier Loss']].apply(lambda x : round(x,4))
avg_scores = avg_scores[['Modality', 'ROC-AUC', 'PR-AUC', 'AVG', 'NLL Loss', 'Brier Loss']]
avg_scores.to_csv(f'Results/Whole_{task.capitalize()}.csv', index=False)
avg_scores

In [None]:
sample = 'admission_ids_seed_2026.pkl'
sample = pickle.load(open(Path('/home/data/2025_MIMICIV_processed/mimic4/task:mortality_90days') / sample, 'rb'))
label = pickle.load(open(Path('/home/data/2025_MIMICIV_processed/mimic4/task:mortality_90days') / 'admission_ids_seed_2026_noise_0.01.pkl', 'rb'))

In [None]:
train_code_ids = sample['train_code_ids']
train_lab_ids = sample['train_lab_ids']
train_note_ids = sample['train_discharge_ids']

valid_code_ids = sample['val_code_ids']
valid_lab_ids = sample['val_lab_ids']
valid_note_ids = sample['val_discharge_ids']

test_code_ids = sample['test_code_ids']
test_lab_ids = sample['test_lab_ids']
test_note_ids = sample['test_discharge_ids']

In [None]:
train_code_labels = label['train_code_ids']
train_lab_labels = label['train_lab_ids']
train_note_labels = label['train_discharge_ids']
valid_code_labels = label['val_code_ids']
valid_lab_labels = label['val_lab_ids']
valid_note_labels = label['val_discharge_ids']
test_code_labels = label['test_code_ids']
test_lab_labels = label['test_lab_ids']
test_note_labels = label['test_discharge_ids']

In [None]:
code_ids = train_code_ids + valid_code_ids + test_code_ids
lab_ids = train_lab_ids + valid_lab_ids + test_lab_ids
note_ids = train_note_ids + valid_note_ids + test_note_ids

In [None]:
admission_dict = pickle.load(open(Path('/home/data/2025_MIMICIV_processed/mimic4/hosp_adm_dict_90days.pkl'), 'rb'))

In [None]:
admission_dict['23196014'].mortality

In [None]:
new_dict = {key : admission_dict[key].mortality for key in np.array(list(admission_dict.keys()))}
new_dict = {key : value for key, value in new_dict.items() if value is not None}

In [None]:
# Tabular -> code_ids
# Time -> lab_ids
# Note -> note_ids

# Tabular only -> code_ids - lab_ids - note_ids
# Lab only -> lab_ids - note_ids
# Note only -> note_ids - lab_ids
# Multi -> code_ids ∩ lab_ids ∩ note_ids

# get sample size, label size, positive ratio for each modality
tabular = {key : new_dict[key] for key in code_ids}
lab = {key : new_dict[key] for key in lab_ids}
note = {key : new_dict[key] for key in note_ids}    
tabular_only = {key : new_dict[key] for key in set(code_ids) - set(lab_ids) - set(note_ids)}
lab_only = {key : new_dict[key] for key in set(lab_ids) - set(note_ids)}
note_only = {key : new_dict[key] for key in set(note_ids) - set(lab_ids)}
multimodal = {key : new_dict[key] for key in set(code_ids).intersection(set(lab_ids)).intersection(set(note_ids))}

In [None]:
new_dfs= [] 
for name, dist_dict in [
    ['Tabular', tabular],
    ['Lab', lab],
    ['Note', note],
    ['Tabular Only', tabular_only],
    ['Lab Only', lab_only],
    ['Note Only', note_only],
    ['Multimodal', multimodal]
] : 
    print(f"{name} - Sample Size: {len(dist_dict)}, Positive Cases: {sum(dist_dict.values())}, Positive Ratio: {sum(dist_dict.values()) / len(dist_dict):.4f}")
    new_dfs.append(pd.DataFrame({
        'Modality' : name,
        'Positive Cases' : sum(dist_dict.values()),
        'Sample Size' : len(dist_dict),
        'Positive Ratio' : round(sum(dist_dict.values()) / len(dist_dict), 4) * 100
    }, index=[0]))  
final_df = pd.concat(new_dfs, axis=0).reset_index(drop=True)
final_df