In [1]:
import os
import numpy as np
import plotly.graph_objects as go

import gzip
import pickle as pkl

import collections

def loadall_results2(path, n_folds):
    regret = []
    history_actions = []
    history_outcomes = []
    perf = []
    cm = []
    with gzip.open(  path ,'rb') as f:
        for i in range(n_folds+1):
            try:
                data = pkl.load(f)
            except EOFError:
                break

            if type(data) == dict:
                regret.append( data['regret'] ) 
                history_actions.append( data['action_history'] )
                history_outcomes.append( data['outcome_history'] ) 
                perf.append( data['pred'] ) 
                cm.append( data['cm'] ) 

    return regret, history_actions, history_outcomes, perf, cm

In [24]:
n_folds = 25
horizon = 9999
context = 'MNISTbinary'
model = 'MLP'
case = 'case1'
agent_name = 'EEneuralcbpside_v6' #ineural6, neuronal6

direct = './results/'
path = os.path.join(direct, '{}_{}_{}_{}_{}_{}.pkl.gz'.format(case,model,context,horizon,n_folds,agent_name) )
regret, action_history, outcome_history, perf = loadall_results2(path, n_folds)


In [2]:
def process_cm(cm):
    vals = []
    for i in range(n_folds):        
        tn, fp, fn, tp = cm[i].ravel()
        vals.append( [ tp / cm[i].sum(), tn / cm[i].sum(), fp / cm[i].sum() , fn / cm[i].sum() ,  ] )
    return vals

n_folds = 25
horizon = 9999
data = 'MNISTbinary'
model = 'MLP'
case = 'case1'
agent_name = 'EEneuralcbpside_v6'

final = {}

for data in ['MNISTbinary', 'MagicTelescope', 'adult']:

    final_data = {}
    for case in ['case1', 'case1b']:
        direct = './results/'
        path = os.path.join(direct, '{}_{}_{}_{}_{}_{}.pkl.gz'.format(case, model, data, horizon,n_folds,agent_name) )
        print(path)
        result, history_actions, history_outcomes, perf, cm = loadall_results2(path, n_folds)
        vals = process_cm(cm)
        # values = [ tns, fps, fns, tps ]
        final_data[case] = np.array(vals)
    final[data] = final_data

# import pandas as pd
# df = pd.DataFrame(final_data)
# df.columns = [ 'case', 'dataset', 'tn_mean', 'fp_mean', 'fn_mean', 'tp_mean', 'tn_std', 'fp_std', 'fn_std', 'tp_std' ]

./results/case1_MLP_MNISTbinary_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1b_MLP_MNISTbinary_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1_MLP_MagicTelescope_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1b_MLP_MagicTelescope_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1_MLP_adult_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1b_MLP_adult_9999_25_EEneuralcbpside_v6.pkl.gz


In [334]:
final

{'MNISTbinary': {'case1': array([[0.4706, 0.022 , 0.021 , 0.4864],
         [0.4541, 0.0385, 0.02  , 0.4874],
         [0.4538, 0.0388, 0.011 , 0.4964],
         [0.4696, 0.023 , 0.0179, 0.4895],
         [0.4746, 0.018 , 0.021 , 0.4864],
         [0.4708, 0.0218, 0.0372, 0.4702],
         [0.4756, 0.017 , 0.026 , 0.4814],
         [0.465 , 0.0276, 0.0238, 0.4836],
         [0.4741, 0.0185, 0.0291, 0.4783],
         [0.4762, 0.0164, 0.0218, 0.4856],
         [0.4784, 0.0142, 0.0235, 0.4839],
         [0.4634, 0.0292, 0.0151, 0.4923],
         [0.4633, 0.0293, 0.0147, 0.4927],
         [0.4891, 0.0035, 0.1003, 0.4071],
         [0.4829, 0.0097, 0.0342, 0.4732],
         [0.4613, 0.0313, 0.0177, 0.4897],
         [0.4856, 0.007 , 0.0562, 0.4512],
         [0.4692, 0.0234, 0.0283, 0.4791],
         [0.4612, 0.0314, 0.0199, 0.4875],
         [0.4747, 0.0179, 0.0222, 0.4852],
         [0.4624, 0.0302, 0.0195, 0.4879],
         [0.4824, 0.0102, 0.0482, 0.4592],
         [0.4705, 0.0221, 0.01

In [348]:
import plotly.graph_objects as go
import numpy as np

# Assuming 'total_cases' is your dictionary containing the datasets

fig = go.Figure()

# Define the order of datasets and initialize a counter for x-axis positioning
datasets = ['MNISTbinary', 'MagicTelescope', 'adult']
x_position = 0  # Counter for x-axis position

# Iterate through each dataset, then each column, and finally each case
names = []
for dataset_name in datasets:
    for i in range(4):  # Assuming each dataset has 5 columns
        for case_name in ['case1', 'case1b']:
            case_data = final[dataset_name][case_name]
            name = f"{case_name}_{dataset_name}_Col{i+1}"
            names.append(name)
            bool = True if name in ['case1_MNISTbinary_Col1', 'case1b_MNISTbinary_Col1'] else False
            col = '#6495ed' if case_name == 'case1' else '#d891ef'
            legend_name = 'unit' if case_name=='case1' else 'cost'
            fig.add_trace(go.Box(
                y=case_data[:, i],
                name=legend_name,
                marker_color=col,
                boxpoints=False,  # Hide outliers
                boxmean=True,     # Show mean
                width=0.7,        # Adjust the width of the boxes
                x=[x_position] * len(case_data[:, i]),  # Set x-axis positions
                showlegend=bool
            ))
            x_position += 1

        # Introduce space after each column
        x_position += 1

# Adjust x-axis range to account for the extra space introduced
fig.update_xaxes(range=[-0.5, x_position - 0.5])

# Vertical lines to separate datasets
# for i, dataset in enumerate(names):
#     if i in [14, 28]:
fig.add_vline(x=11, line_width=1, line_color="black")
fig.add_vline(x=23, line_width=1, line_color="black")

# Add annotations for each dataset (adjust positions as necessary)
# annotations = ['MNISTbinary', 'adult', 'MagicTelescope']

fig.add_annotation(
        x=3,  # Adjust x position for each dataset
        y=1,  # Y-coordinate on the plot
        text=f'<span style="text-decoration: underline;">MNIST binary</span>',
        showarrow=False,
        arrowhead=0,
        ax=0,
        ay=0
    )

fig.add_annotation(
        x=15,  # Adjust x position for each dataset
        y=1,  # Y-coordinate on the plot
        text=f'<span style="text-decoration: underline;">MagicTelescope</span>',
        showarrow=False,
        arrowhead=0,
        ax=0,
        ay=0
    )

fig.add_annotation(
        x=25,  # Adjust x position for each dataset
        y=1,  # Y-coordinate on the plot
        text=f'<span style="text-decoration: underline;">Adult</span>',
        showarrow=False,
        arrowhead=0,
        ax=0,
        ay=0
    )

siz = 11
# Calculate midpoints for tick labels
tickvals = [x + 0.5 for x in range(0, x_position - 1, 2)]

# Update layout and display the plot
fig.update_layout(
    width=450,
    height=180,
    plot_bgcolor='white',
    paper_bgcolor='white',
    margin=dict(l=0, r=0, t=0, b=0),
    # title="Box Plot of Datasets by Cases and Columns",
    xaxis_title="Action",
    yaxis_title="Count",
    boxmode='group',
    showlegend=True,

    xaxis=dict(
        tickvals=[0.5, 3.5, 6.5, 9.5, 12.5, 15.5, 18.5, 21.5, 24.5, 27.5, 30.5, 33.5, 36.5, 39.5, 42.5],
        ticktext=['TP', 'TN', 'FP', 'FN'] * 3,  # Adjust ticktext as needed
        tickfont=dict(size=siz-2),
        title_standoff=5,
        title_font=dict(size=siz),
    ),

    yaxis=dict(
        gridcolor='lightgrey',
        title_standoff=5,
        title_font=dict(size=siz),
        tickfont=dict(size=siz)
    ),


    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.4,  # Adjust this value to position the legend
        xanchor="center",
        x=0.5,
        bgcolor='rgba(0,0,0,0)',
        font=dict(size=siz)  # Increase legend font size
    ),

)
fig.show()

# fig.write_image("./figures/test.pdf" )
# fig.write_image("./figures/cost_action_count.pdf" )

In [122]:
import plotly.graph_objects as go

# Prepare data for the bar plot
datasets = df['dataset'].unique()
metrics = [  'tp_mean', 'tn_mean', 'fp_mean', 'fn_mean',]
std_metrics = [  'tp_std', 'tn_std', 'fp_std', 'fn_std',]
cases = ['case1', 'case1b']

# Create subplots
fig = go.Figure()

# Add bars for each dataset
bar_width = 0.35  # Adjust the bar width here
for d_idx, dataset in enumerate(datasets):
    for m_idx, metric in enumerate(metrics):
        for c_idx, case in enumerate(cases):
            # Filter data for the specific dataset and case
            subset_df = df[(df['dataset'] == dataset) & (df['case'] == case)]

            # Calculate the x position for the bar
            x_position = d_idx * (len(metrics) + 1) + m_idx + c_idx * bar_width
            val = subset_df[ std_metrics[m_idx] ] #.tolist()[0]
            # print(val)
            val = 2.575 * val / np.sqrt(n_folds)
            # Add bar for the metric and case
            fig.add_trace(go.Bar(
                x=[x_position],
                y=subset_df[metric],  # Assuming you want the mean value
                name=f"{case} - {metric}",
                width=bar_width,
                error_y=dict(type='data', array=val, visible=True, thickness=1 ),
            ))



# Customize x-axis
x_ticks = []
x_ticktext = []
for d_idx, dataset in enumerate(datasets):
    for m_idx, metric in enumerate(metrics):
        x_ticks.append(d_idx * (len(metrics) + 1) + m_idx)
        x_ticktext.append(f"   {metric.split('_')[0]}")

fig.add_annotation(
    x=0.35,       # X-coordinate on the plot (adjust as necessary)
    y=0.6,      # Y-coordinate on the plot (adjust as necessary)
    text='<span style="text-decoration: underline;">MNISTbinary</span>',  # Simulated underline with HTML
    # text='MNISTbinary',
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)

# Adding text annotation for each bar
fig.add_annotation(
    x=5.2,         # X-coordinate on the plot (adjust as necessary)
    y=0.6,      # Y-coordinate on the plot (adjust as necessary)
    # text='adult',
    text='<span style="text-decoration: underline;">Adult</span>',  # Simulated underline with HTML
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)


# Adding text annotation for each bar
fig.add_annotation(
    x=11.25,         # X-coordinate on the plot (adjust as necessary)
    y=0.6,      # Y-coordinate on the plot (adjust as necessary)
    text='<span style="text-decoration: underline;">MagicTelescope</span>',  # Simulated underline with HTML
    # text='MagicTelescope',
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)


# datasets = [ 'case1-MNISTbinary', 'case1b-MNISTbinary', 'case1-adult', 'case1b-adult', 'case1-MagicTelescope', 'case1b-MagicTelescope' ]
# print(datasets)
for i, dataset in enumerate(x_ticktext[:-1]):
    if i in [4,9]:
        fig.add_vline(x=i + 0.5, line_width=2, line_color="black")

fig.update_layout(

    width=440,
    height=180,
    plot_bgcolor='white',  # Sets the plot background color
    paper_bgcolor='white',  # Sets the overall figure background color
    margin=dict(l=0, r=5, t=0, b=0),  # Small margins

    barmode='group',
    # title='Mean and Standard Deviation of Metrics Across Different Datasets',
    yaxis_title='Value',
    legend_title='Metrics',

    # xaxis=dict( 
    #     tickvals= current_ticks,
    #     ticktext = ['tn', 'fp', 'fn', 'tp',]*4,
    #     title_standoff=10, ),
    
    xaxis=dict(
        tickvals=x_ticks,
        ticktext=x_ticktext,
    ),

    showlegend = False
)

fig.show()

In [3]:
def action_counter(history_actions, history_outcomes, n_folds):

    total = []
    for s in range(len(history_actions)):

        final_history = []
        for i,j in zip(history_actions[s], history_outcomes[s]):
            if i == 0:
                final_history.append(0)
            elif i == 1 and j == 0:
                final_history.append(1)
            elif i == 1 and j == 1:
                final_history.append(2)
            elif i == 2 and j == 0:
                final_history.append(3)
            elif i == 2 and j == 1:
                final_history.append(4)



        # print(final_history)
        counter = collections.Counter(final_history)
        elements_to_count = [0, 1, 4, 2, 3]

        # Counting specific elements
        counts = {element: counter[element] for element in elements_to_count}

        array_counter = np.array( list(counts.values()) )
        # print(array_counter)
        total.append(array_counter)


    total = np.array(total)
    # print(total.shape)
    mean = np.mean( total, 0 )
    std =  np.std( total, 0 )
    return mean, std

n_folds = 25
horizon = 9999
model = 'MLP'

material = {
    'EEneuralcbpside_v6': {'color': [255, 0, 0], 'label': 'Neural-CBP'},  # Red
}


fig = go.Figure( )

data_cases = {}
meanaction_cases = {}
stdaction_cases = {}
for case in ['case1', 'case1b']:

    n_folds = 25
    datasets = ['MNISTbinary', 'MagicTelescope', 'adult', ]

    data_regrets = {} 
    data_meanactions = {} 
    data_stdactions = {} 
    for data in datasets: 
            
        final_regrets = {}
        final_meanactions = {}
        final_stdactions = {}

        for agent_name in material.keys():

            color, l_label = material[agent_name]['color'], material[agent_name]['label']

            r,g,b = color

            direct = './results/'
            path = os.path.join(direct, '{}_{}_{}_{}_{}_{}.pkl.gz'.format(case, model, data, horizon,n_folds,agent_name) )
            print(path)
            result, history_actions, history_outcomes, perf, cm = loadall_results2(path, n_folds)
            result = np.array(result)
            result = result.astype(np.float32)
            
            mean, std = action_counter(history_actions,history_outcomes, n_folds) 
            print('hey')
            final_regrets[l_label] = result[:,-1]
            final_meanactions[l_label] = mean
            final_stdactions[l_label] = 2.576 * std / np.sqrt(n_folds)

        if data == 'MNISTbinary':
            data = 'MNISTbinary'
        if data == 'FASHION':
            data = 'Fashion'
            
        data_regrets[data] = final_regrets
        data_meanactions[data] = final_meanactions
        data_stdactions[data] = final_stdactions
        
    data_cases[case] = data_regrets
    meanaction_cases[case] = data_meanactions
    stdaction_cases[case] = data_stdactions
    

./results/case1_MLP_MNISTbinary_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1_MLP_MagicTelescope_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1_MLP_adult_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1b_MLP_MNISTbinary_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1b_MLP_MagicTelescope_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1b_MLP_adult_9999_25_EEneuralcbpside_v6.pkl.gz
hey


In [185]:
import plotly.graph_objects as go

# Create a figure
fig = go.Figure()

# Define the order of categories
categories_order = ['MNISTbinary', 'adult', 'MagicTelescope']


action = [
    {'name': 'Ask label', 'color': '#aacaff'},
    {'name': 'TP', 'color': '#ffc1cc'},
    {'name': 'TN', 'color': '#bf94e4'},
    {'name': 'FP', 'color': '#6495ed'},
    {'name': 'FN', 'color': '#c6e2ff'}
]


# Add traces in the specified order
for category in categories_order:
    for case in ['case1', 'case1b']:
        values = meanaction_cases[case][category]['Neural-CBP']
        for i, value in enumerate(values):
            if case == 'case1' and category =='MNISTbinary':
                show=True
                # print(values)
            else:
                show = False

            fig.add_trace(go.Bar(
                x=[f'{case}-{category}'],
                y=[value],
                name=action[i]['name'],
                offsetgroup=i,
                marker=dict(color=action[i]['color']),
                error_y=dict(type='data', array=[ stdaction_cases[case][category]['Neural-CBP'][i] ], visible=True,thickness=1 ),
                showlegend = show            ))

datasets = [ 'case1-MNISTbinary', 'case1b-MNISTbinary', 'case1-adult', 'case1b-adult', 'case1-MagicTelescope', 'case1b-MagicTelescope' ]
print(datasets)
for i, dataset in enumerate(datasets[:-1]):
    if i in [1,3]:
        fig.add_vline(x=i + 0.5, line_width=2, line_color="black")


fig.add_annotation(
    x=0.2,       # X-coordinate on the plot (adjust as necessary)
    y=6500,      # Y-coordinate on the plot (adjust as necessary)
    text='<span style="text-decoration: underline;">MNISTbinary</span>',  # Simulated underline with HTML
    # text='MNISTbinary',
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)

# Adding text annotation for each bar
fig.add_annotation(
    x=1.9,         # X-coordinate on the plot (adjust as necessary)
    y=6500,      # Y-coordinate on the plot (adjust as necessary)
    # text='adult',
    text='<span style="text-decoration: underline;">Adult</span>',  # Simulated underline with HTML
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)


# Adding text annotation for each bar
fig.add_annotation(
    x=4.25,         # X-coordinate on the plot (adjust as necessary)
    y=6500,      # Y-coordinate on the plot (adjust as necessary)
    text='<span style="text-decoration: underline;">MagicTelescope</span>',  # Simulated underline with HTML
    # text='MagicTelescope',
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)


siz = 12
# Update layout
fig.update_layout(
    width=440,
    height=180,
    plot_bgcolor='white',  # Sets the plot background color
    paper_bgcolor='white',  # Sets the overall figure background color
    margin=dict(l=0, r=0, t=0, b=0),  # Small margins
    barmode='group',
    xaxis_title='Case and dataset',
    yaxis_title='Action count',
    # yaxis_range=[0,4],
    showlegend=True,

    xaxis=dict( 
        tickvals= datasets,
        ticktext = ['unit', 'cost', 'unit', 'cost', 'unit', 'cost'],
        tickfont=dict(size=siz-2), 
        title_standoff=0,
        title_font=dict(size=siz),   ),
        
    yaxis=dict(
        # type="log",
        gridcolor='lightgrey',
        title_standoff=5,
        title_font=dict(size=siz),
        tickfont=dict(size=siz)  ),

    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.37,  # Adjust this value to position the legend
        xanchor="center",
        x=0.5,
        font=dict(size=siz)  # Increase legend font size
    ),
)

fig.show()
# fig.write_image("./figures/cost_action_count.pdf" )

# fig.write_image("./figures/cost_action_count.pdf" )

['case1-MNISTbinary', 'case1b-MNISTbinary', 'case1-adult', 'case1b-adult', 'case1-MagicTelescope', 'case1b-MagicTelescope']


In [179]:
import plotly.graph_objects as go

# Create a figure
fig = go.Figure()

# Define the order of categories
categories_order = ['MNISTbinary', 'adult', 'MagicTelescope']
actions = ['Ask label', 'TP', 'TN', 'FP', 'FN']

# Add traces in the specified order
for category in categories_order:
    for i, action in enumerate(actions):
        for case in ['case1', 'case1b']:
            value = meanaction_cases[case][category]['Neural-CBP'][i]
            show_legend = (category == 'MNISTbinary' and action == 'Ask label' and case in ['case1', 'case1b'] )
            color = '#6495ed' if case == 'case1' else '#c6e2ff'
            fig.add_trace(go.Bar(
                x=[f'{category}-{action}'],
                y=[value],
                name=case,
                offsetgroup=case,
                marker=dict(color=color),
                error_y=dict(type='data', array=[stdaction_cases[case][category]['Neural-CBP'][i] ], visible=True, thickness=1),
                showlegend=show_legend
            ))

# New datasets list for the x-axis
datasets = []
for category in categories_order:
    for action in actions:
        datasets.append(f'{category}-{action}')

# Vertical lines to separate datasets
for i, dataset in enumerate(datasets):
    if dataset.split('-')[1] in ['FN']:  # Adjust the condition based on where you want the lines
        fig.add_vline(x=i + 0.5, line_width=2, line_color="black")

# Add annotations for each dataset (adjust positions as necessary)
annotations = ['MNISTbinary', 'adult', 'MagicTelescope']
for i, annotation in enumerate(annotations):
    fig.add_annotation(
        x=i * 5 + 2.2,  # Adjust x position for each dataset
        y=6500,  # Y-coordinate on the plot
        text=f'<span style="text-decoration: underline;">{annotation}</span>',
        showarrow=False,
        arrowhead=0,
        ax=0,
        ay=0
    )

# Update layout
siz = 12
fig.update_layout(
    width=440,
    height=180,
    plot_bgcolor='white',
    paper_bgcolor='white',
    margin=dict(l=0, r=0, t=0, b=0),
    barmode='group',
    xaxis_title='Dataset and action',
    yaxis_title='Action count',
    showlegend=True,
    xaxis=dict( 
        tickvals=datasets,
        ticktext=[ 'Expert', 'TP', 'TN', 'FP', 'FN'] * 3,  # Update ticktext with new dataset-action pairs
        tickfont=dict(size=siz-2),
        title_standoff=0,
        title_font=dict(size=siz),
    ),
    yaxis=dict(
        gridcolor='lightgrey',
        title_standoff=5,
        title_font=dict(size=siz),
        tickfont=dict(size=siz)
    ),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.37,
        xanchor="center",
        x=0.5,
        font=dict(size=siz)
    ),
)

fig.show()

In [4]:
def action_counter(history_actions, history_outcomes, n_folds):

    total = []
    for s in range(len(history_actions)):

        final_history = []
        for i,j in zip(history_actions[s], history_outcomes[s]):
            if i == 0:
                final_history.append(0)
            elif i == 1 and j == 0:
                final_history.append(1)
            elif i == 1 and j == 1:
                final_history.append(2)
            elif i == 2 and j == 0:
                final_history.append(3)
            elif i == 2 and j == 1:
                final_history.append(4)



        # print(final_history)
        counter = collections.Counter(final_history)
        elements_to_count = [0, 1, 4, 2, 3]

        # Counting specific elements
        counts = {element: counter[element] for element in elements_to_count}

        array_counter = np.array( list(counts.values()) )
        # print(array_counter)
        total.append(array_counter)


    total = np.array(total)
    # print(total.shape)
    mean = np.mean( total, 0 )
    std =  np.std( total, 0 )
    return mean, std, total

n_folds = 25
horizon = 9999
model = 'MLP'

material = {
    'EEneuralcbpside_v6': {'color': [255, 0, 0], 'label': 'Neural-CBP'},  # Red
}


fig = go.Figure( )

total_cases = {}
meanaction_cases = {}
stdaction_cases = {}
for case in ['case1', 'case1b']:

    n_folds = 25
    datasets = ['MNISTbinary', 'MagicTelescope', 'adult', ]

    data_total = {} 
    data_meanactions = {} 
    data_stdactions = {}

    for data in datasets: 
        
        final_total = {}
        final_meanactions = {}
        final_stdactions = {}

        for agent_name in material.keys():

            color, l_label = material[agent_name]['color'], material[agent_name]['label']

            r,g,b = color

            direct = './results/'
            path = os.path.join(direct, '{}_{}_{}_{}_{}_{}.pkl.gz'.format(case, model, data, horizon,n_folds,agent_name) )
            print(path)
            result, history_actions, history_outcomes, perf, cm = loadall_results2(path, n_folds)
            result = np.array(result)
            result = result.astype(np.float32)
            
            mean, std, total = action_counter(history_actions,history_outcomes, n_folds) 
            print('hey')

            final_meanactions[l_label] = mean
            final_stdactions[l_label] = 2.576 * std / np.sqrt(n_folds)
            final_total = total

        if data == 'MNISTbinary':
            data = 'MNISTbinary'
        if data == 'FASHION':
            data = 'Fashion'
            

        data_total[data] = final_total
        data_meanactions[data] = final_meanactions
        data_stdactions[data] = final_stdactions
        
    total_cases[case] = data_total
    meanaction_cases[case] = data_meanactions
    stdaction_cases[case] = data_stdactions
    

./results/case1_MLP_MNISTbinary_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1_MLP_MagicTelescope_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1_MLP_adult_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1b_MLP_MNISTbinary_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1b_MLP_MagicTelescope_9999_25_EEneuralcbpside_v6.pkl.gz
hey
./results/case1b_MLP_adult_9999_25_EEneuralcbpside_v6.pkl.gz
hey


In [90]:
import plotly.graph_objects as go
import numpy as np

# Assuming 'total_cases' is your dictionary containing the datasets

fig = go.Figure()

# Define the order of datasets and initialize a counter for x-axis positioning
datasets = [ 'MNISTbinary' ] #'MagicTelescope', 'adult' 'MNISTbinary',
x_position = 0  # Counter for x-axis position

# Iterate through each dataset, then each column, and finally each case
names = []
for dataset_name in datasets:
    for i in range(5):  # Assuming each dataset has 5 columns
        for case_name in ['case1', 'case1b']:
            case_data = total_cases[case_name][dataset_name]
            name = f"{case_name}_{dataset_name}_Col{i+1}"
            names.append(name)
            bool = True if name in ['case1_MNISTbinary_Col1', 'case1b_MNISTbinary_Col1'] else False
            col = '#6495ed' if case_name == 'case1' else '#d891ef'
            legend_name = 'unit cost case' if case_name=='case1' else 'varied cost case'
            fig.add_trace(go.Box(
                y=case_data[:, i],
                name=legend_name,
                marker_color=col,
                line=dict(width=2),
                # marker=dict(color=col, line=dict(color='black', width=2)), 
                # line=dict(color='black', width=2), 
                boxpoints=False,  # Hide outliers
                boxmean=True,     # Show mean
                width=0.7,        # Adjust the width of the boxes
                x=[x_position] * len(case_data[:, i]),  # Set x-axis positions
                showlegend=bool
            ))
            x_position += 1

        # Introduce space after each column
        x_position += 1

# Adjust x-axis range to account for the extra space introduced
fig.update_xaxes(range=[-0.5, x_position - 0.5])


siz = 11
# Calculate midpoints for tick labels
tickvals = [x + 0.5 for x in range(0, x_position - 1, 2)]

siz = 13

def format_tick(val):
    """ Custom function to format the tick labels. """
    return f'{int(val / 1000)}k' if val >= 1000 else str(val)

# fig.add_vline(x=2, line_width=1, line_color="lightgrey")
# fig.add_vline(x=5, line_width=1, line_color="lightgrey")
# fig.add_vline(x=8, line_width=1, line_color="lightgrey")
# fig.add_vline(x=11, line_width=1, line_color="lightgrey")

# Update layout and display the plot
fig.update_layout(
    width=220,  # Adjusted for two subplots
    height=180,
    plot_bgcolor='white',
    paper_bgcolor='white',
    margin=dict(l=0, r=0, t=0, b=0),
    # title="Box Plot of Datasets by Cases and Columns",
    xaxis_title="Action",
    yaxis_title="Count",
    boxmode='group',
    showlegend=True,

    xaxis=dict(
        tickvals=[0.5, 3.5, 6.5, 9.5, 12.5, ],
        ticktext=['Expl', 'TP', 'TN', 'FP', 'FN'] ,  # Adjust ticktext as needed
        tickfont=dict(size=siz-2),
        title_standoff=5,
        title_font=dict(size=siz),
    ),

        yaxis=dict(
        # type="log",
        # gridcolor='lightgrey',
        title="Count",
        title_standoff=5,
        title_font=dict(size=siz),
        tickfont=dict(size=siz-3) , # Increase Y-axis tick font size
        tickmode='array',
        tickvals=[1000, 2000, 3000, 4000, 5000, 7000, 9000 ],  # Specify the values where you want ticks
        ticktext=[format_tick(val) for val in [1000, 2000, 3000, 4000, 5000, 7000, 9000]]  # Format those values
    ),


    legend=dict(
        orientation="v",
        yanchor="middle",
        y=0.5,  # Adjust this value to position the legend
        xanchor="center",
        x=0.5,
        bgcolor='rgba(0,0,0,0)',
        font=dict(size=siz+3),  # Increase legend font size
        title=dict(
            text='<b>Legend</b>',  # Set the text for your legend title
            font=dict(size=siz+3)  # Optionally set the font size for the legend title
        )
    )

)
fig.show()

# fig.write_image("./figures/test.pdf" )
# fig.write_image("./figures/cost_action_count_{}2.pdf".format(dataset_name) )
fig.write_image("./figures/cost_action_legend.png".format(dataset_name), scale=10 )