In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import shutil

In [2]:
batch_size=512
projection_dim=128


total_results = []

for task in ['mortality_90days', 'readmission_15days'] :
    for seed in range(2026, 2029) : 
        for modality in ['tabular', 'lab', 'note'] :
            print(f"Task : {task} | Seed : {seed} | Modality : {modality}")
            result_dir = Path(f'Results/Linear_{modality}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
            exp_dir = result_dir
            output_dir = exp_dir / 'outputs'
            model_dir = exp_dir / 'models'
            best_output_save_path = output_dir / 'best_epoch.pkl'
            best_score_save_path = output_dir / 'best_epoch_scores.pkl'
            
            if not best_score_save_path.exists()  :
                score_paths = list(output_dir.glob('epoch*_scores.pkl'))
                if len(score_paths) == 0 :
                    if best_output_save_path.exists() and best_score_save_path.exists() :
                        pass
                    else :
                        print(f"No output paths for {result_dir}, skipping...")
                        continue
                else : 
                    score_paths = sorted(score_paths, key=lambda x: int(x.stem.split('_')[0].replace('epoch', '')))
                    # last epoch is the best
                    best_epoch_path = score_paths[-1]
                    best_epoch = int(best_epoch_path.stem.split('_')[0].replace('epoch', ''))
                    best_epoch_score = pickle.load(open(best_epoch_path, 'rb'))
                    
                    shutil.copy(output_dir / f'epoch{best_epoch:03d}.pkl', best_output_save_path)
                    shutil.copy(output_dir / f'epoch{best_epoch:03d}_scores.pkl', best_score_save_path)
                    # model_path = model_paths[best_epoch]
                    model_path_dict = {int(p.stem.split('_')[1].replace('epoch', '')):p for p in model_dir.glob('linear_epoch*.pth')}
                    model_path = model_path_dict[best_epoch]
                    best_model_save_path = model_dir / 'best_model.pth'
                    shutil.copy(model_path, best_model_save_path)
                    print(f"Copied best epoch output to {best_output_save_path} & {best_score_save_path}")
            
            if best_score_save_path.exists() :
                best_epoch_output = pickle.load(open(best_score_save_path, 'rb'))
                
                total_results.append({
                    'seed' : seed,
                    'task' : task,
                    'modality' : modality,
                    # 'best_epoch' : best_epoch,
                    'valid_auroc' : best_epoch_output['valid']['auroc'],
                    'valid_auprc' : best_epoch_output['valid']['auprc'],
                    'test_auroc' : best_epoch_output['test']['auroc'],
                    'test_auprc' : best_epoch_output['test']['auprc'],
                })
total_results_df = pd.DataFrame(total_results)
total_results_df

Task : mortality_90days | Seed : 2026 | Modality : tabular
Task : mortality_90days | Seed : 2026 | Modality : lab
Task : mortality_90days | Seed : 2026 | Modality : note
Task : mortality_90days | Seed : 2027 | Modality : tabular
Task : mortality_90days | Seed : 2027 | Modality : lab
Task : mortality_90days | Seed : 2027 | Modality : note
Task : mortality_90days | Seed : 2028 | Modality : tabular
Task : mortality_90days | Seed : 2028 | Modality : lab
Task : mortality_90days | Seed : 2028 | Modality : note
Task : readmission_15days | Seed : 2026 | Modality : tabular
Task : readmission_15days | Seed : 2026 | Modality : lab
Task : readmission_15days | Seed : 2026 | Modality : note
Task : readmission_15days | Seed : 2027 | Modality : tabular
Task : readmission_15days | Seed : 2027 | Modality : lab
Task : readmission_15days | Seed : 2027 | Modality : note
Task : readmission_15days | Seed : 2028 | Modality : tabular
Task : readmission_15days | Seed : 2028 | Modality : lab
Task : readmission_1

Unnamed: 0,seed,task,modality,valid_auroc,valid_auprc,test_auroc,test_auprc
0,2026,mortality_90days,tabular,0.819079,0.421106,0.820207,0.420561
1,2026,mortality_90days,lab,0.75802,0.313863,0.748357,0.303633
2,2026,mortality_90days,note,0.817043,0.48489,0.821119,0.475583
3,2027,mortality_90days,tabular,0.817039,0.413502,0.817422,0.420754
4,2027,mortality_90days,lab,0.747232,0.28859,0.747946,0.287073
5,2027,mortality_90days,note,0.820925,0.487611,0.822796,0.48171
6,2028,mortality_90days,tabular,0.828941,0.437157,0.815528,0.420176
7,2028,mortality_90days,lab,0.75065,0.290956,0.745017,0.29336
8,2028,mortality_90days,note,0.824293,0.493744,0.821927,0.474583
9,2026,readmission_15days,tabular,0.670377,0.425669,0.674797,0.430504


In [3]:
batch_size=512
projection_dim=128


total_results = []

for task in ['mortality_90days', 'readmission_15days'] :
    for prefix in ['', 'E2E_'] :
        for fusion_method in ['Sum', 'WeightedFusion', 'AttnMaskedFusion'] : 
            for seed in range(2026, 2029) : 
                result_dir = Path(f'Results/Linear_MultiModal/{prefix}Fusion_{fusion_method}/Seed_{seed}/{task}_proj_{projection_dim}_batch_{batch_size}')
                exp_dir = result_dir
                output_dir = exp_dir / 'outputs'
                model_dir = exp_dir / 'models'
                
                best_output_save_path = output_dir / 'best_epoch.pkl'
                best_score_save_path = output_dir / 'best_epoch_scores.pkl'
                
                score_paths = list(output_dir.glob('epoch*_scores.pkl'))
                if len(score_paths) == 0 :
                    print(f"No output paths for {result_dir}, skipping...")
                    continue
                score_paths = sorted(score_paths, key=lambda x: int(x.stem.split('_')[0].replace('epoch', '')))
                # last epoch is the best
                best_epoch_path = score_paths[-1]
                best_epoch = int(best_epoch_path.stem.split('_')[0].replace('epoch', ''))
                best_epoch_score = pickle.load(open(best_epoch_path, 'rb'))
                
                model_paths = list(model_dir.glob('linear_epoch*.pth'))
                # {epoch : model_path}
                model_paths = {int(p.stem.split('_')[1].replace('epoch', '')) : p for p in model_paths}
                model_path = model_paths[best_epoch]
                
                shutil.copy(output_dir / f'epoch{best_epoch:03d}.pkl', best_output_save_path)
                shutil.copy(output_dir / f'epoch{best_epoch:03d}_scores.pkl', best_score_save_path)
                model_path = model_paths[best_epoch]
                best_model_save_path = model_dir / 'best_model.pth'
                shutil.copy(model_path, best_model_save_path)
                # print(f"Copied Best Epoch Output {best_epoch_path} \n\t\t\t-> {best_output_save_path}")
                total_results.append({
                    'seed' : seed,
                    'task' : task,
                    'fusion' : fusion_method,
                    'best_epoch' : best_epoch,
                    'valid_auroc' : best_epoch_score['valid']['auroc'],
                    'valid_auprc' : best_epoch_score['valid']['auprc'],
                    'test_auroc' : best_epoch_score['test']['auroc'],
                    'test_auprc' : best_epoch_score['test']['auprc'],
                })
total_results_df = pd.DataFrame(total_results)
total_results_df

Unnamed: 0,seed,task,fusion,best_epoch,valid_auroc,valid_auprc,test_auroc,test_auprc
0,2026,mortality_90days,Sum,15,0.857454,0.509091,0.85412,0.499081
1,2027,mortality_90days,Sum,22,0.850209,0.494792,0.854534,0.487895
2,2028,mortality_90days,Sum,67,0.858354,0.507784,0.848482,0.498859
3,2026,mortality_90days,WeightedFusion,18,0.846221,0.486732,0.845259,0.480934
4,2027,mortality_90days,WeightedFusion,98,0.841039,0.487442,0.844389,0.488295
5,2028,mortality_90days,WeightedFusion,18,0.856876,0.502152,0.841947,0.488933
6,2026,mortality_90days,AttnMaskedFusion,4,0.849313,0.492513,0.84737,0.477881
7,2027,mortality_90days,AttnMaskedFusion,99,0.847692,0.482475,0.850826,0.481795
8,2028,mortality_90days,AttnMaskedFusion,25,0.856904,0.494231,0.842898,0.479016
9,2026,mortality_90days,Sum,3,0.830917,0.464679,0.823889,0.456495
