In [36]:
import numpy as np
import pandas as pd
np.set_printoptions(precision=3)

file = 'sensitivity_10_1e5.csv'

def mad(df):
    return (df - df.median()).abs().median()

In [37]:
sens = pd.DataFrame(np.loadtxt(f'../out/{file}'), columns=['r1', 'r2', 'r3', 't1', 't2', 't3', 'det', 'det_true'])
sens = sens.groupby(['r1', 'r2', 'r3', 't1', 't2', 't3']).agg([mad, np.median]).reset_index()
opt = sens[(sens.t1 == sens.r1) & (sens.t2 == sens.r2) & (sens.t3 == sens.r3)]\
            .filter([('t1', ''), ('t2', ''), ('t3', ''), ('det_true', 'median')])
opt = opt.rename(columns={'det_true': 'det_opt'})

sens = sens.merge(opt, on=[('t1', ''), ('t2', ''), ('t3', '')])
sens['ratio'] = sens[('det_true', 'mad')] / sens[('det_opt', 'median')]
sens

Unnamed: 0_level_0,r1,r2,r3,t1,t2,t3,det,det,det_true,det_true,det_opt,ratio
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,mad,median,mad,median,median,Unnamed: 12_level_1
0,0.1,0.1,0.1,0.1,0.1,0.1,0.022266,14.312590,0.022266,14.312590,14.312590,0.001556
1,0.1,0.1,1.0,0.1,0.1,0.1,0.002636,10.469039,0.001892,14.341682,14.312590,0.000132
2,0.1,0.1,10.0,0.1,0.1,0.1,0.013415,5.379578,0.057451,14.037036,14.312590,0.004014
3,0.1,1.0,0.1,0.1,0.1,0.1,0.001512,9.070700,0.011350,14.111499,14.312590,0.000793
4,0.1,1.0,1.0,0.1,0.1,0.1,0.001117,7.873545,0.004329,14.155475,14.312590,0.000302
...,...,...,...,...,...,...,...,...,...,...,...,...
724,10.0,1.0,1.0,10.0,10.0,10.0,0.002845,1.001078,0.003379,0.635604,0.644773,0.005240
725,10.0,1.0,10.0,10.0,10.0,10.0,0.005069,0.793223,0.005613,0.641022,0.644773,0.008705
726,10.0,10.0,0.1,10.0,10.0,10.0,0.002610,0.737011,0.004818,0.641343,0.644773,0.007472
727,10.0,10.0,1.0,10.0,10.0,10.0,0.003447,0.734944,0.007786,0.640197,0.644773,0.012075


In [38]:
import plotly.graph_objects as go
from plotly.colors import DEFAULT_PLOTLY_COLORS
from plotly.subplots import make_subplots
import re

fig = make_subplots(3, 3, subplot_titles=[
    f'$r_4 = {r3} \quad \eta = {eta}$' for r3 in [0.1, 1, 10] for eta in [0.1, 1, 10]
], horizontal_spacing=0.05, vertical_spacing=0.1)
for j, x in enumerate([0.1, 1, 10]):
    sens1 = sens[(sens.t1 == sens.t2) & (sens.t2 == sens.t3) & (sens.t1 == x)]
    for i, (r3, d1) in enumerate(sens1.groupby('r3')):
        z = d1.sort_values(['r2', 'r1']).ratio.to_numpy().reshape(3, 3)
        z = np.flipud(z)
        fig.add_trace(go.Heatmap(
            z=z*100,
            x=['0.1', '1', '10'],
            y=['10', '1', '0.1'],
            coloraxis='coloraxis',
            text=z*100,
            texttemplate='%{text:.2f}%',
            textfont={'size': 8}
        ), row=i+1, col=j+1)
fig.update_layout(
    coloraxis={'colorscale': 'viridis'},
    height=750,
)
fig.update_xaxes(title='$r_2$', row=3)
fig.update_yaxes(title='$r_3$', col=1)
fig.show()
fig.write_image('../figures/sensitivity_mad.png')

In [39]:
import plotly.express as px
from functools import reduce
from plotly.colors import DEFAULT_PLOTLY_COLORS

fig = go.Figure()
for i, r in enumerate([1, 2, 3]):
    cond = reduce(lambda x, y: x & y, [sens[f'r{rx}'] == sens.t1 for rx in [1, 2, 3] if rx != r])
    sens_temp = sens[(sens.t1 == sens.t2) & (sens.t2 == sens.t3) & cond]
    color = DEFAULT_PLOTLY_COLORS[i].split('(')[1].split(')')[0].split(', ')

    for j, eta in enumerate([0.1, 1, 10]):
        sens_eta = sens_temp[sens_temp.t1 == eta].sort_values(f'r{r}')
        fig.add_trace(go.Scatter(
            x=sens_eta[f'r{r}'], y=sens_eta.ratio, legendgroup=eta, legendgrouptitle_text=f'$\eta = {eta}$',
            name=f'$r_{r}$', line_color=f'rgba({", ".join(color)}, {(j+1)/3})'
        ))

fig.update_xaxes(type='log')
fig.update_layout(legend=dict(groupclick="toggleitem"))
fig.show()
# fig.write_image('r2.png')