In [1]:
from utils import get_runs_df
from survkit.configs import WandbConfig
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
groups = ['mnist_mixture', 'support2_mixture', 'sepsis_mixture']

In [3]:
wandb_config = WandbConfig()
entity_project = f'{wandb_config.entity}/{wandb_config.project}'
dfs = {}
for group in groups:
    df = get_runs_df(entity_project, group)
    dfs[group] = df

100%|██████████| 30/30 [00:09<00:00,  3.21it/s]
100%|██████████| 30/30 [00:09<00:00,  3.28it/s]
100%|██████████| 30/30 [00:09<00:00,  3.27it/s]


In [4]:
# combine all dataframes
df = pd.concat(dfs.values(), ignore_index=True)

In [5]:
# check that all methods have expected number of runs for each dataset
df.groupby(['dataset', 'model']).size()

dataset   model           
MNIST     FFNet               5
          FFNetMixture        5
          FFNetMixtureMTLR    5
          FFNetTimeWarpMoE    5
          SKSurvCox           5
          SKSurvRF            5
SUPPORT2  FFNet               5
          FFNetMixture        5
          FFNetMixtureMTLR    5
          FFNetTimeWarpMoE    5
          SKSurvCox           5
          SKSurvRF            5
Sepsis    FFNet               5
          FFNetMixture        5
          FFNetMixtureMTLR    5
          FFNetTimeWarpMoE    5
          SKSurvCox           5
          SKSurvRF            5
dtype: int64

In [6]:
def dataset_configuration(row):
    if row['dataset'] != 'MNIST':
        return row['dataset']
    if row['mnist_means'] == [1, 5, 9, 13, 17, 21, 25, 29, 33, 37] and row['mnist_stds'] == [1, 2, 3, 1, 2, 3, 1, 1, 1, 1]:
        return 'Survival MNIST'

In [7]:
df['dataset'] = df.apply(dataset_configuration, axis=1)

In [8]:
# rename model column
df.rename(columns={'model': 'Model', 'dataset': 'Dataset'}, inplace=True)
# rename models
model_map = {'FFNet': 'MTLR', 'FFNetMixture': 'Fixed MoE (ours)', 'FFNetTimeWarpMoE': 'Adjustable MoE (ours)', 'FFNetMixtureMTLR': 'Personalized MoE (ours)', 'SKSurvCox': 'CoxPH', 'SKSurvRF': 'RSF'}
df['Model'] = df['Model'].replace(model_map)

In [9]:
def compute_diffs_per_seed(group_df):
    reference_df = group_df[(group_df['Model'] == 'MTLR')][['test_ece_equal_mass', 'test_concordance', 'test_loss', 'test_brier@25th', 'test_brier@50th', 'test_brier@75th']]
    # take diff with all other rows
    diff_df = group_df.merge(reference_df, how='cross', suffixes=('', '_ref'))
    # compute the differences
    diff_df['ece_diff'] = diff_df['test_ece_equal_mass'] - diff_df['test_ece_equal_mass_ref']
    diff_df['concordance_diff'] = diff_df['test_concordance'] - diff_df['test_concordance_ref']
    diff_df['loss_diff'] = diff_df['test_loss'] - diff_df['test_loss_ref']
    diff_df['brier_25th_diff'] = diff_df['test_brier@25th'] - diff_df['test_brier@25th_ref']
    diff_df['brier_50th_diff'] = diff_df['test_brier@50th'] - diff_df['test_brier@50th_ref']
    diff_df['brier_75th_diff'] = diff_df['test_brier@75th'] - diff_df['test_brier@75th_ref']
    # keep only relevant columns
    diff_df = diff_df[['Model', 'ece_diff', 'concordance_diff', 'loss_diff', 'brier_25th_diff', 'brier_50th_diff', 'brier_75th_diff']]
    return diff_df

In [10]:
df_grouped = df.groupby(['Dataset', 'seed'], sort=False).apply(compute_diffs_per_seed).reset_index()
# drop the level_2 column
df_grouped.drop(columns=['level_2'], inplace=True)

  df_grouped = df.groupby(['Dataset', 'seed'], sort=False).apply(compute_diffs_per_seed).reset_index()


In [11]:
df_grouped

Unnamed: 0,Dataset,seed,Model,ece_diff,concordance_diff,loss_diff,brier_25th_diff,brier_50th_diff,brier_75th_diff
0,Survival MNIST,42,Personalized MoE (ours),-0.001800,-0.000886,-0.002267,-0.001321,-0.000377,-0.000837
1,Survival MNIST,42,Adjustable MoE (ours),0.002745,-0.001292,0.067735,0.001824,0.001957,0.000850
2,Survival MNIST,42,Fixed MoE (ours),0.001454,0.007326,0.015079,0.000668,0.001314,-0.000608
3,Survival MNIST,42,MTLR,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
4,Survival MNIST,42,CoxPH,0.024513,-0.136767,1.184422,0.082950,0.124600,0.058676
...,...,...,...,...,...,...,...,...,...
85,Sepsis,46,Personalized MoE (ours),-0.014842,0.009501,-0.049829,-0.003369,-0.003472,-0.002674
86,Sepsis,46,CoxPH,0.617852,-0.147653,16.737344,0.261544,0.511418,0.728602
87,Sepsis,46,RSF,0.584014,-0.060676,14.979344,0.232373,0.570550,0.769413
88,Sepsis,46,MTLR,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [12]:
# merge the diffs back into the original dataframe on Dataset, seed, and Model
df = df.merge(df_grouped, on=['Dataset', 'seed', 'Model'], how='left')

In [13]:
# group by model and get mean and std of various metrics
metrics_df = df.groupby(['Model', 'Dataset'], as_index=False).agg(**{
                                                      # ECE cols
                                                      'ECE': ('test_ece_equal_mass', 'mean'),
                                                      'ECE Std': ('test_ece_equal_mass', 'std'),
                                                      'ECE Diff': ('ece_diff', 'mean'),
                                                      # Concordance
                                                      'Concordance': ('test_concordance', 'mean'),
                                                      'Concordance Std': ('test_concordance', 'std'),
                                                      'Concordance Diff': ('concordance_diff', 'mean'),
                                                      # Brier Scores
                                                      'Brier (25th)': ('test_brier@25th', 'mean'),
                                                      'Brier (25th) Std': ('test_brier@25th', 'std'),
                                                      'Brier (25th) Diff': ('brier_25th_diff', 'mean'),
                                                      'Brier (50th)': ('test_brier@50th', 'mean'),
                                                      'Brier (50th) Std': ('test_brier@50th', 'std'),
                                                      'Brier (50th) Diff': ('brier_50th_diff', 'mean'),
                                                      'Brier (75th)': ('test_brier@75th', 'mean'),
                                                      'Brier (75th) Std': ('test_brier@75th', 'std'),
                                                      'Brier (75th) Diff': ('brier_75th_diff', 'mean'),
                                                      'Loss': ('test_loss', 'mean'),
                                                      'Loss Std': ('test_loss', 'std'),
                                                      'Loss Diff': ('loss_diff', 'mean'),
                                                      'Parameters': ('num_model_parameters', 'mean'),
                                                      })

In [14]:
# scale concordance to percentage
scale_cols = ['Concordance', 'Concordance Std', 'Concordance Diff']
for col in scale_cols:
    metrics_df[col] = metrics_df[col] * 100

In [15]:
# report everything to 3 significant digits
# combine each column with it's diff to get score (diff)
def format_scores(row, metrics, format_dict={}):
    final_row = {'Model': row['Model'], 'Dataset': row['Dataset']}
    for metric in metrics:
        format_str = format_dict.get(metric, '.3f')
        if metric == 'Parameters':
            if pd.isna(row[metric]):
                final_row[metric] = '-'
            else:
                final_row[metric] = f"{int(row[metric]):,}"
            continue
        final_row[metric] = f"{row[metric]:{format_str}} ({row[metric + ' Diff']:{format_str}})"
    return pd.Series(final_row)

In [16]:
format_dict = {k: '.2f' for k in scale_cols} # only 2 decimal places for percentage metrics
cols = ['ECE', 'Concordance', 'Brier (25th)', 'Brier (50th)', 'Brier (75th)',]
output_df = metrics_df.apply(lambda row: format_scores(row, cols, format_dict), axis=1)

In [17]:
reporting_cols = ['Dataset', 'Model', 'ECE', 'Concordance', 'Brier (25th)', 'Brier (50th)', 'Brier (75th)']

In [18]:
output_df = output_df[reporting_cols]

In [19]:
# order index as follows
dataset_order = ['Survival MNIST', 'SUPPORT2', 'Sepsis']
model_order = ['CoxPH', 'RSF', 'MTLR', 'Fixed MoE (ours)', 'Adjustable MoE (ours)', 'Personalized MoE (ours)']
output_df = output_df.set_index(['Dataset', 'Model'])
output_df = output_df.reindex(pd.MultiIndex.from_product([dataset_order, model_order], names=['Dataset', 'Model']), fill_value='-')

In [20]:
level0 = output_df.index.levels[0]
output_df.index = output_df.index.set_levels(level0, level=0)
output_df = output_df.loc[output_df.index.get_level_values(0).isin(['Survival MNIST', 'SUPPORT2', 'Sepsis'])]
output_df

Unnamed: 0_level_0,Unnamed: 1_level_0,ECE,Concordance,Brier (25th),Brier (50th),Brier (75th)
Dataset,Model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Survival MNIST,CoxPH,0.030 (0.024),79.16 (-13.61),0.112 (0.083),0.159 (0.125),0.069 (0.059)
Survival MNIST,RSF,0.057 (0.051),90.06 (-2.72),0.048 (0.019),0.073 (0.038),0.025 (0.015)
Survival MNIST,MTLR,0.006 (0.000),92.77 (0.00),0.029 (0.000),0.034 (0.000),0.010 (0.000)
Survival MNIST,Fixed MoE (ours),0.009 (0.003),93.24 (0.47),0.030 (0.001),0.035 (0.000),0.010 (-0.000)
Survival MNIST,Adjustable MoE (ours),0.008 (0.002),92.61 (-0.16),0.030 (0.001),0.036 (0.002),0.011 (0.001)
Survival MNIST,Personalized MoE (ours),0.005 (-0.001),92.61 (-0.16),0.029 (-0.000),0.036 (0.001),0.010 (0.000)
SUPPORT2,CoxPH,0.187 (0.130),78.89 (-1.01),0.212 (0.055),0.209 (0.060),0.236 (0.088)
SUPPORT2,RSF,0.186 (0.129),79.76 (-0.14),0.207 (0.050),0.203 (0.055),0.232 (0.085)
SUPPORT2,MTLR,0.058 (0.000),79.90 (0.00),0.156 (0.000),0.149 (0.000),0.148 (0.000)
SUPPORT2,Fixed MoE (ours),0.053 (-0.005),79.84 (-0.06),0.157 (0.001),0.146 (-0.003),0.145 (-0.003)


# Table with additional metrics

In [21]:
parameter_df = metrics_df[['Dataset', 'Model', 'Parameters']]
# drop the CoxPH and RSF rows from the parameter_df
parameter_df = parameter_df[~parameter_df['Model'].isin(['CoxPH', 'RSF'])]

In [22]:
parameter_df

Unnamed: 0,Dataset,Model,Parameters
0,SUPPORT2,Adjustable MoE (ours),69435.0
1,Sepsis,Adjustable MoE (ours),63579.0
2,Survival MNIST,Adjustable MoE (ours),194883.0
6,SUPPORT2,Fixed MoE (ours),69480.0
7,Sepsis,Fixed MoE (ours),63008.0
8,Survival MNIST,Fixed MoE (ours),209844.0
9,SUPPORT2,MTLR,68521.0
10,Sepsis,MTLR,62945.0
11,Survival MNIST,MTLR,187189.0
12,SUPPORT2,Personalized MoE (ours),62141.0


In [23]:
# order index as follows
dataset_order = ['Survival MNIST', 'SUPPORT2', 'Sepsis']
model_order = ['MTLR', 'Fixed MoE (ours)', 'Adjustable MoE (ours)', 'Personalized MoE (ours)']
parameter_df = parameter_df.set_index(['Dataset', 'Model'])
parameter_df = parameter_df.reindex(pd.MultiIndex.from_product([dataset_order, model_order], names=['Dataset', 'Model']), fill_value='-')
level0 = parameter_df.index.levels[0]
parameter_df.index = parameter_df.index.set_levels(level0, level=0)
parameter_df = parameter_df.loc[parameter_df.index.get_level_values(0).isin(['Survival MNIST', 'SUPPORT2', 'Sepsis'])]
parameter_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Parameters
Dataset,Model,Unnamed: 2_level_1
Survival MNIST,MTLR,187189.0
Survival MNIST,Fixed MoE (ours),209844.0
Survival MNIST,Adjustable MoE (ours),194883.0
Survival MNIST,Personalized MoE (ours),195891.0
SUPPORT2,MTLR,68521.0
SUPPORT2,Fixed MoE (ours),69480.0
SUPPORT2,Adjustable MoE (ours),69435.0
SUPPORT2,Personalized MoE (ours),62141.0
Sepsis,MTLR,62945.0
Sepsis,Fixed MoE (ours),63008.0
