In [2]:
import io
import warnings

import dill
import ipywidgets as widgets
import matplotlib.pyplot as plt
import pandas as pd
from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE
from aepsych.plotting import plot_strat
from aepsych.server import AEPsychServer
from IPython.display import FileLink, clear_output, display

plt.rcParams["figure.figsize"] = (10, 10)
warnings.filterwarnings("ignore")


server = AEPsychServer()
par_precision = 3
dim = 0
inducing_scale = 50
acq_dict = {
    "Exploration": "MonotonicMCPosteriorVariance",
    "Optimization": "qNoisyExpectedImprovement",
    "Threshold Finding": "MonotonicMCLSE",
}
style = {"description_width": "initial"}
csv_file_name = "aepsych_data.csv"
strat_file_name = "aepsych_server.pkl"


def add_param(b):
    global dim
    dim += 1
    hb = widgets.HBox(
        [
            widgets.Text(f"par{dim}", description="Name", style=style),
            widgets.FloatText(
                0.0, description="Lower Bound:", step=10 ** -par_precision, style=style
            ),
            widgets.FloatText(
                1.0, description="Upper Bound:", step=10 ** -par_precision, style=style
            ),
            widgets.Checkbox(value=False, description="Monotonic"),
        ]
    )
    params_boxes.children = tuple(list(params_boxes.children) + [hb])
    pars = [child.children[0].value for child in params_boxes.children]
    lbs = [child.children[1].value for child in params_boxes.children]
    ubs = [child.children[2].value for child in params_boxes.children]


def rem_param(b):
    global dim
    if dim > 1:
        dim -= 1
        params_boxes.children = tuple(list(params_boxes.children[:-1]))


def start_server(b):
    config = make_config()
    server.configure(config_str=config)
    server.one_outcome = one_outcome.value
    server.zero_outcome = zero_outcome.value

    with data_output:
        clear_output()

    with plot_output:
        clear_output()
        
    tell_boxes.children = [
        widgets.BoundedFloatText(
            lb, description=par, min=lb, max=ub, step=10 ** -par_precision, style=style
        )
        for par, lb, ub in zip(server.parnames, server.strat.lb, server.strat.ub)
    ]
    outcome_box.options = [('', None), (zero_outcome.value, 0), (one_outcome.value, 1)]
    clear_output()
    display(server_download, params_cont, plot_data_cont)
    get_next(None)


def resume_server(change):
    global server
    for name, csv in server_uploader.value.items():
        with io.BytesIO(csv["content"]) as f:
            server = dill.load(f)
            # When the server is pickled, it deletes these attributes.
            # This is an ugly hack around that.
            server.socket = None
            server.db = None

            tell_boxes.children = [
                widgets.BoundedFloatText(
                    lb,
                    description=par,
                    min=lb,
                    max=ub,
                    step=10 ** -par_precision,
                    style=style,
                )
                for par, lb, ub in zip(
                    server.parnames, server.strat.lb, server.strat.ub
                )
            ]
            zero_outcome.value = server.zero_outcome
            one_outcome.value = server.one_outcome
            outcome_box.options = [('', None), (zero_outcome.value, 0), (one_outcome.value, 1)]
            
        clear_output()
        display(server_download, params_cont, plot_data_cont)
        display_data()
        display_plot()
        get_next(None)

    server_uploader.value.clear()


def make_config():
    dim = len(params_boxes.children)
    pars = [child.children[0].value for child in params_boxes.children]
    parnames = f"[{','.join(par for par in pars)}]"
    lbs = [child.children[1].value for child in params_boxes.children]
    ubs = [child.children[2].value for child in params_boxes.children]
    monotonic = [
        i for i, child in enumerate(params_boxes.children) if child.children[3]
    ]
    target = threshold_box.value
    n_sobol = n_sobol_box.value
    acq = acq_dict[strategy_btns.value]
    model = "GPClassificationModel" if acq == "qNoisyExpectedImprovement" else "MonotonicRejectionGP"
    generator = "OptimizeAcqfGenerator" if acq == "qNoisyExpectedImprovement" else "MonotonicRejectionGenerator"

    config = f"""
        [common]
        parnames = {parnames}
        lb = {lbs}
        ub = {ubs}
        outcome_type = single_probit
        target = {target}
        strategy_names = [init_strat, opt_strat]

        [init_strat]
        n_trials = {n_sobol}
        generator = SobolGenerator

        [opt_strat]
        n_trials = -1
        refit_every = 1
        generator = {generator}

        [experiment]
        acqf = {acq}
        model = {model}
        
        [SobolGenerator]
        n_points = {n_sobol}
        
        [GPClassificationModel]
        inducing_size = {inducing_scale*dim} #TODO: find a better way to scale this

        [MonotonicRejectionGP]
        inducing_size = {inducing_scale*dim} #TODO: find a better way to scale this
        mean_covar_factory = monotonic_mean_covar_factory
        monotonic_idxs = {monotonic}
        """
    return config


def tell_model(b):
    if outcome_box.value is not None:
        with upload_output:
            clear_output()
        params = {child.description: child.value for child in tell_boxes.children}
        outcome = outcome_box.value
        server.tell(outcome, params)
        for child in tell_boxes.children:
            child.value = child.min
        outcome_box.value = None
        get_next(None)
        display_data()
        display_plot()
    else:
        with upload_output:
            clear_output()
            print("Select an outcome for this set of parameters!")


def get_next(b):
    tell_btn.disabled = True
    ask_btn.disabled = True
    uploader.disabled = True
    outcome_box.disabled = True
    for child in tell_boxes.children:
        child.disabled = True

    if server.strat.x is None and server.strat._count >= n_sobol_box.value:
        n_sobol_box.value = 1
        config = make_config()
        server.configure(config_str=config)
        next_pars = server.ask()

    else:
        next_pars = server.ask()

    for child, value in zip(tell_boxes.children, next_pars.values()):
        child.value = round(value[0], par_precision)

    tell_btn.disabled = False
    ask_btn.disabled = False
    uploader.disabled = False
    outcome_box.disabled = False
    for child in tell_boxes.children:
        child.disabled = False
    write_server()


def write_server():
    server_download.disabled = True
    with open(strat_file_name, "wb") as f:
        dill.dump(server, f)
    server_download.disabled = False


def display_data():
    if server.strat.x is not None:
        data = {par: server.strat.x[:, i] for i, par in enumerate(server.parnames)}
        data["outcome"] = server.strat.y
        data = pd.DataFrame(data)
        data.to_csv(csv_file_name, index=False)
        with data_output:
            clear_output()
            display(FileLink(csv_file_name), data)


def display_plot():
    with plot_output:
        clear_output()
        if server.strat.dim <= 2:
            if server.strat._strat_idx > 0:
                xlabel = server.parnames[0]
                ylabel = server.parnames[1] if server.strat.dim == 2 else None
                yes_label = one_outcome.value
                no_label = zero_outcome.value
                acqf = server.strat._strat.generator.acqf
                thresh = (
                    threshold_box.value
                    if acqf == MonotonicMCLSE
                    else None
                )
                plot_strat(
                    server.strat, xlabel=xlabel, ylabel=ylabel, target_level=thresh,
                    yes_label=yes_label, no_label=no_label
                )
            else:
                print(
                    "\n\n\n\n\n Initializing model. Collect more data to plot posterior."
                )
        else:
            print("Plotting currently only works for <=2D")


def mass_tell(change):
    for name, csv in uploader.value.items():
        with io.BytesIO(csv["content"]) as f:
            try:
                data = pd.read_csv(f)
                for i, row in data.iterrows():
                    server.tell(
                        row["outcome"], {par: row[par] for par in server.parnames}
                    )
                    idx = server.strat._strat_idx
                    server.strat.strat_list[idx]._count += 1
                with upload_output:
                    clear_output()
                get_next(None)
                display_data()
                display_plot()
            except:
                with upload_output:
                    clear_output()
                    print("Data is improperly formatted!")
    uploader.value.clear()
    write_server()


server_uploader = widgets.FileUpload(
    description="Resume Session", accept=".pkl", multiple=False, style=style
)
server_uploader.observe(resume_server, names="_counter")

outcome_label = widgets.Label(value='Outcome Labels:')
zero_outcome = widgets.Text("No Trial", description="0: ", style=style)
one_outcome = widgets.Text("Yes Trial", description="1: ", style=style)
outcomes_labels = widgets.VBox([outcome_label, zero_outcome, one_outcome])

params_label = widgets.Label(value="Parameters:")
params_boxes = widgets.VBox([])
add_param(None)

add_param_btn = widgets.Button(description="Add Parameter")
add_param_btn.on_click(add_param)

rem_param_btn = widgets.Button(description="Remove Parameter")
rem_param_btn.on_click(rem_param)

btns = widgets.HBox([add_param_btn, rem_param_btn])

strategy_btns = widgets.RadioButtons(
    options=["Threshold Finding", "Exploration", "Optimization"],
    value="Threshold Finding",
    description="Strategy:",
)

threshold_box = widgets.BoundedFloatText(
    value=0.75, min=0, max=1.0, step=0.05, description="Threshold:"
)

n_sobol_box = widgets.BoundedIntText(
    value=10, min=0, description="Initialization Trials:", style=style
)

start_server_btn = widgets.Button(description="Start AEPsych")
start_server_btn.on_click(start_server)

strat_settings = widgets.HBox([strategy_btns, threshold_box, n_sobol_box])

config = make_config()
server.configure(config_str=config)

tell_boxes = widgets.VBox()
outcome_box = widgets.Dropdown(
    options=[('No Trial', 0), ('Yes Trial', 1), ('', None)],
    value=None,
    description='Outcome:',
)

ask_btn = widgets.Button(description="Next Parameters")
ask_btn.on_click(get_next)

tell_btn = widgets.Button(description="Update Model")
tell_btn.on_click(tell_model)

uploader = widgets.FileUpload(description="Upload Data", accept=".csv", multiple=False)
uploader.observe(mass_tell, names="_counter")

server_download = FileLink(strat_file_name)

upload_output = widgets.Output()

ask_tell_cont = widgets.HBox([ask_btn, tell_btn, uploader, upload_output])

params_cont = widgets.VBox([ask_tell_cont, widgets.HBox([tell_boxes, outcome_box])])

data_output = widgets.Output()
plot_output = widgets.Output()
plot_data_cont = widgets.HBox([data_output, plot_output])

server_btns = widgets.HBox([start_server_btn, server_uploader])

display(
    server_btns, strat_settings, outcomes_labels, btns, params_boxes,
)


2022-07-18 13:27:44,686 [INFO   ] Found DB at ./databases/default.db, appending!


HBox(children=(Button(description='Start AEPsych', style=ButtonStyle()), FileUpload(value={}, accept='.pkl', d…

HBox(children=(RadioButtons(description='Strategy:', options=('Threshold Finding', 'Exploration', 'Optimizatio…

VBox(children=(Label(value='Outcome Labels:'), Text(value='No Trial', description='0: ', style=DescriptionStyl…

HBox(children=(Button(description='Add Parameter', style=ButtonStyle()), Button(description='Remove Parameter'…

VBox(children=(HBox(children=(Text(value='par1', description='Name', style=DescriptionStyle(description_width=…

In [1]:
from aepsych.server import AEPsychServer
from aepsych.plotting import plot_strat
import ipywidgets as widgets
from IPython.display import display, HTML
from ipywidgets import interact, interactive, fixed, interact_manual
from ipywidgets import Layout, Label, Text
# !pip install voila
database_path = "/Users/ecortez/Work/Notes/webapp-proj/jupyter_dashboard/data_collection_analysis_tutorial.db"

# This should run whenever a new db is uploaded
serv = AEPsychServer(database_path=database_path)
strat = serv.get_strat_from_replay()

#---------- Style ----------
display(HTML('<h1 style="text-align: center;">AEPsych Visualization tool</h1>'))
input_style = {'description_width': 'initial', 
         'padding': '10px'}
btn_style = Layout(margin='20px 10px')

btn_box_layout = Layout(display='flex',
                        justify_content='flex-end')
file_output = widgets.Output(layout={
    'border': '1px solid black',
    'margin': '10px',
    'overflow': 'scroll',
    'height': '40px',
    'padding': '5px'})
#---------- Inputs -----------

def on_value_change(change):
    file_output.clear_output()
    with file_output:
        print(change['new'])

uploader = widgets.FileUpload(
    description="Resume Session",
    accept=".db",
    multiple=False
    )

input_zero = widgets.Dropdown(
    options=[('Yes Trial', "Detected Trial"), ('No Trial', "Undetected Trial")],
    value="Detected Trial",
    description='0:',
    style=input_style
)
input_one = widgets.Dropdown(
    options=[ ('No Trial', "Undetected Trial"), ('Yes Trial', "Detected Trial")],
    value="Undetected Trial",
    description='1:',
    style=input_style
)

target_level = widgets.BoundedFloatText(
    value=0.75,
    min=0,
    max=1,
    step=0.1,
    description='target_level:',
    disabled=False,
    style=input_style
)

cred_level = widgets.BoundedFloatText(
    value=0.95,
    min=0,
    max=1,
    step=0.1,
    description='cred_level:',
    disabled=False,
    style=input_style
)

x_axis = Text(
    value='Angle (degrees)',
    placeholder='x axis label',
    description='x_axis:',
    disabled=False,
    style=input_style
)

y_axis = Text(
    value='Detection Probability',
    placeholder='y axis label',
    description='y_axis:',
    disabled=False,
    style=input_style
)

# Observes input changes
uploader.observe(on_value_change, names='value')
# display(uploader, file_output)

inputs_output = widgets.Output(layout={
    'border': '1px solid black',
    'padding': '20px',
    'margin': '20px'})
input_zero.observe(on_value_change, names='value')
input_one.observe(on_value_change, names='value')
target_level.observe(on_value_change, names='value')
cred_level.observe(on_value_change, names='value')
x_axis.observe(on_value_change, names='value')
y_axis.observe(on_value_change, names='value')
#------------ Buttons ------------

button_upload = widgets.Button(
    description='Upload',
    disabled=False,
    button_style='warning',
    tooltip='Click to Upload',
    layout=btn_style)
button_submit = widgets.Button(
    description='Submit',
    disabled=False,
    button_style='warning',
    tooltip='submit',
    layout=btn_style)
button_reset = widgets.Button(
    description='Reset',
    disabled=False,
    button_style='danger',
    tooltip='reset',
    layout=btn_style)

#--------- Accordion -----------
accordion = widgets.Accordion(
                   children=[
                       widgets.VBox([input_zero,input_one]),
                       widgets.VBox([target_level, cred_level,]),
                       widgets.VBox([ x_axis, y_axis]),
                   ])    
accordion.set_title(0, 'Outcome Labels')
accordion.set_title(1, 'Parameters')
accordion.set_title(2, 'Axis Labels')

#------------ Tabs -------------
tab = widgets.Tab()    

children = [
     widgets.VBox([Label("To resume an experiment upload a file from a previous session."),
                   uploader,
                   file_output,
                   widgets.Box([button_upload],layout=btn_box_layout),
                  ]),
     widgets.VBox([accordion,
                   widgets.Box([button_submit, button_reset],layout=btn_box_layout),
                   Label("Plot output: "),
                   inputs_output,
                   ]),
]
tab.children = children
tab.set_title(0, "Upload")
tab.set_title(1, "Plot")
display(tab)


def upload():
    with file_output:
        print('\n -----Now this is how your file looks like:----- \n')
        file_output.clear_output()
        if uploader.value == {}:
            print("No file uploaded")
        else: 
#Maybe add a function to check file format 
            print('File uploaded successfuly...')
        

def submit():    
    with inputs_output:
        inputs_output.clear_output()
        plot_strat(strat, 
                   xlabel=x_axis.value,
                   ylabel=y_axis.value, 
                   yes_label=input_one.value, 
                   no_label=input_zero.value, 
                   cred_level=cred_level.value, 
                   target_level=target_level.value)

        
def reset():
    with inputs_output:
        inputs_output.clear_output()
#         plotting options
        x_axis.value = "Angle (degrees)"
        y_axis.value = "Detection Probability"
        input_one.value = "Detected Trial"
        input_zero.value = "Undetected Trial"
        cred_level.value = 0.75
        target_level.value = 0.95
        
        
# Onclick handler        
def upload_clicked(b):
    upload()


def submit_clicked(b):
    submit()

    
def reset_clicked(b):
    reset()
    
button_upload.on_click(upload_clicked)
button_submit.on_click(submit_clicked)
button_reset.on_click(reset_clicked)

2022-07-18 13:27:33,674 [INFO   ] Found DB at /Users/ecortez/Work/Notes/webapp-proj/jupyter_dashboard/data_collection_analysis_tutorial.db, appending!


Tab(children=(VBox(children=(Label(value='To resume an experiment upload a file from a previous session.'), Fi…