In [1]:
import os
import numpy as np
import plotly.graph_objects as go

import gzip
import pickle as pkl

import collections

def loadall_results2(path, n_folds):
    regret = []
    history_actions = []
    history_outcomes = []
    perf = []
    with gzip.open(  path ,'rb') as f:
        for i in range(n_folds+1):
            try:
                data = pkl.load(f)
            except EOFError:
                break

            if type(data) == dict:
                regret.append( data['regret'] ) 
                history_actions.append( data['action_history'] )
                history_outcomes.append( data['outcome_history'] ) 
                perf.append( data['pred'] ) 

    
    return regret, history_actions, history_outcomes, perf

In [24]:
n_folds = 25
horizon = 9999
context = 'MNISTbinary'
model = 'MLP'
case = 'case1'
agent_name = 'EEneuralcbpside_v6' #ineural6, neuronal6

direct = './results/'
path = os.path.join(direct, '{}_{}_{}_{}_{}_{}.pkl.gz'.format(case,model,context,horizon,n_folds,agent_name) )
regret, action_history, outcome_history, perf = loadall_results2(path, n_folds)


In [2]:
def action_counter(action_history, n_folds):
    total = []
    for i in range(n_folds):
        counter = collections.Counter(action_history[i])
        array_counter = np.array( list(counter.values()) )
        total.append(array_counter)
    return np.mean(total,0), np.std(total,0)

n_folds = 25
horizon = 9999
model = 'MLP'

material = {
    #'EEneuralcbpside_v5': {'color': [255, 255, 0], 'label': 'EEneuralcbpside_v5'},  # Red
    'EEneuralcbpside_v6': {'color': [255, 0, 0], 'label': 'Neural-CBP'},  # Red
    # 'ineural6': {'color': [51, 255, 255], 'label':'IneurAL (official)'},                    # Yellow
    # 'ineural3': {'color': [0, 0, 255], 'label':'IneurAL (tuned)'},                    # Cyan
    # 'neuronal6': {'color': [255, 0, 255], 'label':'Neuronal (official)'},                  # Magenta
    # 'neuronal3': {'color': [160, 160, 160], 'label':'Neuronal (tuned)'},                   # Orange
    # 'margin': {'color': [160, 160, 160], 'label':'Margin'},
    # 'cesa': {'color': [0, 0, 255], 'label':'Cesa'},
}


fig = go.Figure( )

data_cases = {}
meanaction_cases = {}
stdaction_cases = {}
for case in ['case1', 'case1b']:

    n_folds = 25
    datasets = ['MNISTbinary', 'MagicTelescope', 'adult', ]

    data_regrets = {} 
    data_meanactions = {} 
    data_stdactions = {} 
    for data in datasets: 
            
        final_regrets = {}
        final_meanactions = {}
        final_stdactions = {}
        for agent_name in material.keys():

            color, l_label = material[agent_name]['color'], material[agent_name]['label']

            r,g,b = color

            direct = './results/'
            path = os.path.join(direct, '{}_{}_{}_{}_{}_{}.pkl.gz'.format(case, model, data, horizon,n_folds,agent_name) )
            print(path)
            result, history_actions, history_outcomes, perf = loadall_results2(path, n_folds)
            result = np.array(result)
            result = result.astype(np.float32)
            mean, std = action_counter(history_actions, n_folds) 
            
            final_regrets[l_label] = result[:,-1]
            final_meanactions[l_label] = mean
            final_stdactions[l_label] = 2.576 * std / n_folds

        if data == 'MNISTbinary':
            data = 'MNISTbinary'
        if data == 'FASHION':
            data = 'Fashion'
            
        data_regrets[data] = final_regrets
        data_meanactions[data] = final_meanactions
        data_stdactions[data] = final_stdactions
        
    data_cases[case] = data_regrets
    meanaction_cases[case] = data_meanactions
    stdaction_cases[case] = data_stdactions
    


# fig.show()
# fig.write_image("./figures/case1_{}_{}.pdf".format(model, context) )

./results/case1_MLP_MNISTbinary_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1_MLP_MagicTelescope_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1_MLP_adult_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1b_MLP_MNISTbinary_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1b_MLP_MagicTelescope_9999_25_EEneuralcbpside_v6.pkl.gz
./results/case1b_MLP_adult_9999_25_EEneuralcbpside_v6.pkl.gz


In [125]:
stdaction_cases

{'case1': {'MNISTbinary': {'Neural-CBP': array([10.04359248, 21.55888602, 20.90077949])},
  'MagicTelescope': {'Neural-CBP': array([ 27.98088156, 310.19202685, 322.39539062])},
  'adult': {'Neural-CBP': array([  7.72064466, 478.81092516, 480.26885831])}},
 'case1b': {'MNISTbinary': {'Neural-CBP': array([12.2280419 , 19.50291037, 20.38600858])},
  'MagicTelescope': {'Neural-CBP': array([ 48.73379282, 273.87212673, 301.56533424])},
  'adult': {'Neural-CBP': array([ 16.55349706, 441.90683869, 449.162998  ])}}}

In [6]:
import plotly.graph_objects as go

# Create a figure
fig = go.Figure()

# Define the order of categories
categories_order = ['MNISTbinary', 'adult', 'MagicTelescope']


action = [
    {'name': 'Ask label', 'color': '#aacaff'},
    {'name': 'Class 1', 'color': '#008080'},
    {'name': 'Class 2', 'color': '#b0e0e6'}
]


# Add traces in the specified order
for category in categories_order:
    for case in ['case1', 'case1b']:
        values = meanaction_cases[case][category]['Neural-CBP']
        for i, value in enumerate(values):
            if case == 'case1' and category =='MNISTbinary':
                show=True
            else:
                show = False

            fig.add_trace(go.Bar(
                x=[f'{case}-{category}'],
                y=[value],
                name=action[i]['name'],
                offsetgroup=i,
                marker=dict(color=action[i]['color']),
                error_y=dict(type='data', array=stdaction_cases[case][category]['Neural-CBP'], visible=True,thickness=1 ),
                showlegend = show
                # base=i * 0.3  # This creates the side-by-side bars within each group
            ))

datasets = [ 'case1-MNISTbinary', 'case1b-MNISTbinary', 'case1-adult', 'case1b-adult', 'case1-MagicTelescope', 'case1b-MagicTelescope' ]
print(datasets)
for i, dataset in enumerate(datasets[:-1]):
    if i in [1,3]:
        fig.add_vline(x=i + 0.5, line_width=2, line_color="black")


fig.add_annotation(
    x=0.2,       # X-coordinate on the plot (adjust as necessary)
    y=6500,      # Y-coordinate on the plot (adjust as necessary)
    text='<span style="text-decoration: underline;">MNISTbinary</span>',  # Simulated underline with HTML
    # text='MNISTbinary',
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)

# Adding text annotation for each bar
fig.add_annotation(
    x=1.9,         # X-coordinate on the plot (adjust as necessary)
    y=6500,      # Y-coordinate on the plot (adjust as necessary)
    # text='adult',
    text='<span style="text-decoration: underline;">Adult</span>',  # Simulated underline with HTML
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)


# Adding text annotation for each bar
fig.add_annotation(
    x=4.25,         # X-coordinate on the plot (adjust as necessary)
    y=6500,      # Y-coordinate on the plot (adjust as necessary)
    text='<span style="text-decoration: underline;">MagicTelescope</span>',  # Simulated underline with HTML
    # text='MagicTelescope',
    showarrow=False,
    arrowhead=0,  # Ensuring no arrow head
    ax=0,         # Ensuring no arrow x-position
    ay=0          # Ensuring no arrow y-position
)


siz = 12
# Update layout
fig.update_layout(
    width=440,
    height=180,
    plot_bgcolor='white',  # Sets the plot background color
    paper_bgcolor='white',  # Sets the overall figure background color
    margin=dict(l=0, r=0, t=0, b=0),  # Small margins
    barmode='group',
    xaxis_title='Case and dataset',
    yaxis_title='Action count',
    # yaxis_range=[0,4],
    showlegend=True,

    xaxis=dict( 
        tickvals= datasets,
        ticktext = ['unit', 'cost', 'unit', 'cost', 'unit', 'cost'],
        tickfont=dict(size=siz-2), 
        title_standoff=0,
        title_font=dict(size=siz),   ),
        
    yaxis=dict(
        # type="log",
        gridcolor='lightgrey',
        title_standoff=5,
        title_font=dict(size=siz),
        tickfont=dict(size=siz)  ),

    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.37,  # Adjust this value to position the legend
        xanchor="center",
        x=0.5,
        font=dict(size=siz)  # Increase legend font size
    ),
)

fig.show()
fig.write_image("./figures/cost_action_count.pdf" )
fig.write_image("./figures/cost_action_count.pdf" )

['case1-MNISTbinary', 'case1b-MNISTbinary', 'case1-adult', 'case1b-adult', 'case1-MagicTelescope', 'case1b-MagicTelescope']
