In [21]:
import glob
import pandas as pd
import altair as alt

alt.data_transformers.disable_max_rows()

# load in csv files from abstract simulations
#abstract_results = glob.glob('results/abstract/*')
#abstract_results = glob.glob('csv/abstract/*')
abstract_results = glob.glob('csv_test/abstract/*')
results = sorted(abstract_results)

# create a list to store each the dataframe loaded from each csv
results_lst = []
for filename in results:
    df = pd.read_csv(filename, index_col=0, header=0)
    max_epochs = filename.split('_')[-1][:-4]  # extract max epochs
    df['Max Epochs'] = max_epochs
    results_lst.append(df)

# combine csv files using pandas
abstract_results_df = pd.concat(results_lst, axis=0, ignore_index=True)
df.to_csv('csv_test/hm.csv')

# create the base plotting object
# mark_point() specifies to plot each data point
# mark_line() specifies to plot lines between points
# encode() and its arguments specify which columns from the data frame to read from and plot
base = alt.Chart(abstract_results_df).mark_point(filled=True).mark_line().encode(
    x=alt.X('Epoch'),
    y=alt.Y('Probability Correct', scale=alt.Scale(domain=(0, 1))),
    color='Type:N'
)

# create filters for each variable
#models = ['resnet18', 'resnet152']
models = ['mlp', 'alcove']
model_dropdown = alt.binding_radio(options=models)
model_select = alt.selection_single(fields=['Model'], bind=model_dropdown, name='model')

nets = ['resnet18','resnet152','vgg11']
net_dropdown = alt.binding_radio(options=nets)
net_select = alt.selection_single(fields=['Net'], bind=net_dropdown, name='net')

types = [1, 2, 3, 4, 5, 6]
type_dropdown = alt.binding_radio(options=types)
type_select = alt.selection_single(fields=['Type'], bind=type_dropdown, name="type")

losses = ['ll', 'hinge','mse','humble']
loss_dropdown = alt.binding_radio(options=losses)
loss_select = alt.selection_single(fields=['Loss Type'], bind=loss_dropdown, name='loss')

epochs = ['16', '32', '64', '128']
epochs_dropdown = alt.binding_radio(options=epochs)
epochs_select = alt.selection_single(fields=['Max Epochs'], bind=epochs_dropdown, name='epochs')

#lr_assoc_slider = alt.binding_range(min=0.005, max=0.015, step=0.005, name='lr_assoc_s')
lr_assoc_slider = alt.binding_range(min=0.03, max=0.035, step=0.005, name='lr_assoc_s')
lr_assoc_select = alt.selection_single(fields=['LR-Association'], bind=lr_assoc_slider, name='lr_assoc')

#lr_attn_slider = alt.binding_range(min=0.001, max=0.003, step=0.001, name='lr_attn_s')
lr_attn_slider = alt.binding_range(min=0.0033, max=0.0043, step=0.001, name='lr_attn_s')
lr_attn_select = alt.selection_single(fields=['LR-Attention'], bind=lr_attn_slider, name='lr_attn')

#c_slider = alt.binding_range(min=0.5, max=1.5, step=0.5, name='c_s')
c_slider = alt.binding_range(min=6.5, max=7, step=0.5, name='c_s')
c_select = alt.selection_single(fields=['c'], bind=c_slider, name='c')

#phi_slider = alt.binding_range(min=0.25, max=0.75, step=0.25, name='phi_s')
phi_slider = alt.binding_range(min=2.5, max=2.75, step=0.25, name='phi_s')
phi_select = alt.selection_single(fields=['phi'], bind=phi_slider, name='phi')

#cs =

#phis =

# add any additional _selects necessary inside add_selection
# and then add a separate transform_filter() for each select object
interactive_plot = base.add_selection(
    loss_select, epochs_select, model_select, lr_assoc_select, lr_attn_select, c_select, phi_select
).transform_filter(
    loss_select
).transform_filter(
    epochs_select  
).transform_filter(
    model_select
).transform_filter(
    lr_assoc_select
).transform_filter(
    lr_attn_select
).transform_filter(
    c_select
).transform_filter(
    phi_select
).properties(title='Abstract Stimuli')

# display the plot in the notebook
# you may need to select all of the options once the plot is created to see something sensible
# it would be good to find a way to set a default value to plot initially
interactive_plot