# Figure 5

Requires training the models first. See ```cv_scm.sh``` for details.

In [1]:
import os
os.chdir("../")

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import json
from src import utils
from IPython.display import clear_output
from matplotlib.ticker import PercentFormatter

sns.set_theme(context='paper', style='ticks', font_scale=1)

In [3]:
name="mimic_transitions_cv"
distribution="multigaussian"
width_pt = 397

In [4]:
files = glob.glob("./outputs/cv_reports/{name}_*_{distribution}.json".format(name=name, distribution=distribution))

dicts = []
for fl_id, fl in enumerate(files):
    clear_output(wait=True)
    print('Reading file ' + str(fl_id+1)+'/'+str(len(files)))
    with open(fl,"r") as f:
        js = json.load(f)
        res = {}
        res['lipschitz_loc'] = js['lipschitz_loc']
        res['lipschitz_scale'] = js['lipschitz_scale']
        res['crossval_last_loss'] = js['crossval_last_loss']

        dicts.append(res)

raw_df = pd.DataFrame(dicts)

Reading file 101/101


In [5]:
# drop the rows that contain 'none'
input_df = raw_df[raw_df['lipschitz_loc'] != 'none']
# convert all columns to float
input_df = input_df.astype(float)
# prepare shape for heatmap
input_df = input_df.pivot("lipschitz_scale", "lipschitz_loc", "crossval_last_loss")

In [6]:
# get value of baseline
baseline_df = raw_df[raw_df['lipschitz_loc'] == 'none']
baseline = baseline_df['crossval_last_loss'].values[0]
# for each value in input_df, get the percent difference from baseline
input_df = input_df.apply(lambda x: (x-baseline)/baseline*100)

In [7]:
utils.latexify() # Computer Modern, with TeX

fig_width, fig_height = utils.get_fig_dim(width_pt, fraction=0.6)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(fig_width, fig_height))

# Draw Axis 1
sns.heatmap(data=input_df, ax=ax, cmap=sns.color_palette("rocket", as_cmap=True))
sns.despine(ax=ax)

ax.set_xlabel(r"$L_h$")
ax.set_ylabel(r"$L_\phi$")
for tick in ax.get_xticklabels():
    tick.set_rotation(40)

# Get the colorbar object
cbar = ax.collections[0].colorbar
# Set the formatter of the colorbar to a `PercentFormatter` object
cbar.formatter = PercentFormatter(decimals=0)
# Update the colorbar
cbar.update_ticks()

fig.tight_layout()
fig.savefig('figures/loss_{distribution}.pdf'.format(distribution=distribution), dpi=300)