In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np

import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from fig_utils import *

In [5]:
base_dir = '../simulated_SL/sl_results/23Aug22/1/'
bulk_dir = '../simulated_SL/sl_results/28Aug22/1/'
sfai_dir = '../simulated_SL/sl_results/28Aug22/1/'

# plot type = '' (main text), _bulk, _sfai
plot_types = ['']
plot_type = ''
fs = []

for pt in plot_types:

    if pt == '':
        path_zt = base_dir+'starrydata_zt/starrydata_zt-1-496-55-90-100.csv'
        path_tc = base_dir+'starrydata_tc/starrydata_tc-1-496-55-0-10.csv'
        path_sigma = base_dir+'starrydata_sigma/starrydata_sigma-1-559-62-90-100.csv'
        
    if pt == '_bulk':
        path_zt = bulk_dir+'starrydata_zt_bulk/starrydata_zt_bulk-1-218-24-90-100.csv'
        path_tc = bulk_dir+'starrydata_tc_bulk/starrydata_tc_bulk-1-211-23-0-10.csv'
        path_sigma = bulk_dir+'starrydata_sigma_bulk/starrydata_sigma_bulk-1-234-26-90-100.csv'

    if pt == '_sfai':
        path_zt = sfai_dir+'starrydata_zt_sfai/starrydata_zt_sfai-1-524-58-90-100.csv'
        path_tc = sfai_dir+'starrydata_tc_sfai/starrydata_tc_sfai-1-518-57-0-10.csv'
        path_sigma = sfai_dir+'starrydata_sigma_sfai/starrydata_sigma_sfai-1-587-65-90-100.csv'

        
    fs.extend([path_zt, path_tc, path_sigma])
    

In [6]:
print(fs)

['../simulated_SL/sl_results/23Aug22/1/starrydata_zt/starrydata_zt-1-496-55-90-100.csv', '../simulated_SL/sl_results/23Aug22/1/starrydata_tc/starrydata_tc-1-496-55-0-10.csv', '../simulated_SL/sl_results/23Aug22/1/starrydata_sigma/starrydata_sigma-1-559-62-90-100.csv']


In [7]:
fig = make_subplots(
    rows=3, cols=3,
#     specs=[[{}, {'rowspan':2}],
#            [{}, None]],
    print_grid=True,
    shared_yaxes=True,
    vertical_spacing=0.2,
    subplot_titles=['(a) ZT', r'$\text{(b)}   \kappa_{total} (W/mK)$', r'$\text{(c)  log(}\sigma_{E0}\text{)} (S/m)$',
                   '(d)', '(e)', '(f)',
                    '(g)', '(h)', '(i)',
                   ]
)

count = 1
for f in fs:

    # this is just to only show one set of legends
    if 'zt' in f:
        dp_traces = get_discovery_probability_traces([f], x_metric='NDME', legend=False)
        i1_traces = get_avg_and_std_trace(f, 'fraction_of_targets_found', legend=True)
        i2_traces = get_avg_and_std_trace(f, 'NDME', legend=False)
    else:
        dp_traces = get_discovery_probability_traces([f], x_metric='NDME', legend=False)
        i1_traces = get_avg_and_std_trace(f, 'fraction_of_targets_found', legend=False)
        i2_traces = get_avg_and_std_trace(f, 'NDME', legend=False)
        
    for it in i1_traces:
        fig.add_trace(it, row=1, col=count)

    for it in i2_traces:
        fig.add_trace(it, row=2, col=count)
    
    for ft in dp_traces:
        fig.add_trace(ft, row=3, col=count)
        
    count += 1
    
    # a perfect selection strategy would select a candidate at every iteration
    ## for Eg dataset n_targets = 215
    total_targets = get_total_targets(pd.read_csv(f))
    benchmark_x_vals = np.linspace(0,total_targets)
    slope = 1/total_targets
    print('Total tagets: {} {}'.format(total_targets, slope))
    benchmark_y_vals = slope * benchmark_x_vals
    benchmark_trace = go.Scatter(x=benchmark_x_vals, y=benchmark_y_vals, name='perfect SL strategy', 
                                 line = dict(color='forestgreen', width=4, dash='dash'),
                                showlegend=False
    )
    
    if 'zt' in f:
        fig.add_trace(benchmark_trace, row=1, col=1)
    elif 'tc' in f:
        fig.add_trace(benchmark_trace, row=1, col=2)
    elif 'sigma' in f:
        fig.add_trace(benchmark_trace, row=1, col=3)

fig.update_yaxes(title_text='$DY_n$', row=1, col=1)
fig.update_yaxes(title_text='NDME', row=2, col=1)
fig.update_yaxes(title_text='$DP_n$', row=3, col=1)

fig.update_yaxes(range=[0, 0.5], row=3, col=1)
fig.update_yaxes(range=[0, 0.5], row=3, col=2)
fig.update_yaxes(range=[0, 0.5], row=3, col=3)


for i in [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]:
    fig.update_xaxes(title_text='Iteration', row=i[0], col=i[1])

for col in [1,2,3]:
    fig.update_xaxes(title_text='NDME', range=[1.1, 0.5], row=3, col=col)
    fig.update_yaxes(range=[0, 0.75], row=3, col=col)


fig.update_layout(legend=dict(
    yanchor="bottom",
    y=0.04,
    xanchor="right",
    x=1.01
))
fig.layout.legend.tracegroupgap = 0

fig.update_annotations(font=dict(family="Barlow Semi Condensed", size=24))
fig.update_layout(font=dict(family="Barlow Semi Condensed", size=20), width=1000, height=900,  showlegend=True)
fig.write_image("static/fig6-sl_metrics-starrydata{}.png".format(plot_type), scale=2)
fig.show()

This is the format of your plot grid:
[ (1,1) x,y   ]  [ (1,2) x2,y2 ]  [ (1,3) x3,y3 ]
[ (2,1) x4,y4 ]  [ (2,2) x5,y5 ]  [ (2,3) x6,y6 ]
[ (3,1) x7,y7 ]  [ (3,2) x8,y8 ]  [ (3,3) x9,y9 ]

File:  ../simulated_SL/sl_results/23Aug22/1/starrydata_zt/starrydata_zt-1-496-55-90-100.csv
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: iterations > # of targets 49 49
arrary sliced: