## Dashboard

In this Notebook, we interact with the Experiment Manager to configure, setup and run experiments.

In [None]:
# Source Code
from __future__ import print_function
from ipywidgets import Layout, Box, VBox, HTML, HBox, GridBox
import ipywidgets as widgets


import io
import os
import subprocess
import yaml
import sys

sys.path.append('../..')
import runner.run as ibmfl_runner
import runner.postprocess as ibmfl_postproc
import pprint as pp
from string import Template

import json
from json import JSONDecodeError
import pandas as pd

from IPython.display import clear_output
from ipywidgets import GridspecLayout

%config Completer.use_jedi = False  # to avoid autocompletion errors

## Store all supported datasets, models and algorithms in a pandas dataframe 
df = pd.read_csv(filepath_or_buffer='supported_models.csv', header=0, names=['fusion_identifier', 'fusion_algo', 'dataset', 'model_spec_name', 'fl_model', 'model_ui'], skipinitialspace=True)
# df.head()

df_hyperparams = pd.read_json(path_or_buf='hyperparams_to_models_map.json')
# df_hyperparams.head()

ui_model_choices = df.model_ui.unique()
uimodel_modelid_dict = {
    'Keras': 'keras',
    'PyTorch': 'pytorch',
    'TensorFlow': 'tf',
    'Scikit-learn': 'sklearn',
    'None': 'none'
}

## dictionary object to store config parameters for run; TODO: replace with class later
nb_config = {}
### set defaults
nb_config['split'] = {}
nb_config['split']['ppp'] = 100
nb_config['split']['method'] = 'Uniform Random Sampling'
nb_config['parties'] = 5
nb_config['quorum'] = 1

exp_runner = ibmfl_runner.Runner()

model_header = HTML(
    value='<{size}>Model details'.format(size='h4'),
    layout=Layout(width='auto', grid_area='model_header'))

# Model Selection:
model_dropdown = widgets.Dropdown(
    options=['Choose your model'] + list(ui_model_choices),
    description='Model:',
    disabled=False,
    layout=Layout(width='40%', grid_area='model_dr')
)


def model_dropdown_eventhandler(change):
    model_chosen = change.new
    global nb_config
    nb_config['model'] = model_chosen


model_dropdown.observe(model_dropdown_eventhandler, names='value')


dataset_header = HTML(value='<{size}>Dataset details'.format(size='h4'),
               layout=Layout(width='auto', grid_area='dataset_header'))


dataset_dropdown = widgets.Dropdown(
    options=['Choose your dataset'],# + determine_allowed_datasets(),
    description='Dataset:',
    disabled=False,
    layout=Layout(width='80%',grid_area='dataset')
)


def update_supported_datasets(change):
    model_chosen = change.new
    rows_for_model = df[df.model_ui==model_chosen]
    dataset_dropdown.options = list(rows_for_model['dataset'].unique())


model_dropdown.observe(update_supported_datasets, 'value')


def dataset_dropdown_eventhandler(change):
    dataset_chosen = change.new
    global nb_config
    nb_config['dataset'] = dataset_chosen

dataset_dropdown.observe(dataset_dropdown_eventhandler, names='value')


# Data Splitting Strategy:
splitting_dropdown = widgets.Box([
    widgets.Label(
        value='Data Split:',
        layout=Layout(width='auto')
    ),
    widgets.Dropdown(
        options=['Uniform Random Sampling', 'Stratified Sampling (per source class)'],
        disabled=False,
        layout=Layout(width = 'auto'),
        value='Uniform Random Sampling'
    )
], grid_area='dataset_spl')


def splitting_dropdown_eventhandler(change):
    split_chosen = change.new
    global nb_config
    stderr.clear_output()
    if split_chosen == '':
        with stderr:
            display('Please choose a how you\'d like to split the data from the dropdown')
            nb_config.pop('split', None)
    else:
        nb_config['split']['method'] = split_chosen

splitting_dropdown.children[1].observe(splitting_dropdown_eventhandler, names='value')


# Points per party when splitting data:
points_slider = widgets.Box([
    widgets.Label(
        value='Points from each party:',
        layout=Layout(width = 'auto')
    ),
    widgets.IntSlider(
        min=100,
        max=1000,
        layout=Layout(width='50%'),
        value=100
    )
], grid_area='ppp')


def points_slider_eventhandler(change):
    # print(change)
    ppp = change.new
    global nb_config
    stderr.clear_output()
    nb_config['split']['ppp'] = ppp
    

points_slider.children[1].observe(points_slider_eventhandler, names='value')


fusion_dropdown = widgets.Box([
    HTML(value='<{size}>Fusion Algorithm'.format(size='h4'),
         layout=Layout(width='auto')),
    widgets.Dropdown(
        options=['Choose your Fusion Algorithm'], disabled=False,
        layout=Layout(width='auto'))
], grid_area='fusion_dr')


def update_potential_fusion_algorithm(change):
    model_chosen = nb_config['model']
    dataset_chosen = nb_config['dataset']
    potential_algo = list(df[(df.model_ui == model_chosen) & (df.dataset == dataset_chosen)]['fusion_algo'].unique())
    fusion_dropdown.children[1].options = potential_algo


model_dropdown.observe(update_potential_fusion_algorithm, 'value')
dataset_dropdown.observe(update_potential_fusion_algorithm, 'value')


def fusion_dropdown_eventhandler(change):
    fusion_algo_chosen = change.new
    global nb_config
    nb_config['fusion'] = fusion_algo_chosen


fusion_dropdown.children[1].observe(fusion_dropdown_eventhandler, names='value')


header_parties = HTML(value='<{size}>Participants'.format(size='h4'), layout=Layout(width='auto', grid_area='header_parties'))


num_parties = widgets.Box([
    widgets.Label(
        value='Number of parties:',
        layout=Layout(width='auto')
    ),
    widgets.IntSlider(
        min=2,
        max=100,
        value=5,
        layout=Layout(width='50%')
    )
], grid_area='parties')


def num_parties_eventhandler(change):
    # print(change)
    parties = change.new
    global nb_config
    nb_config['parties'] = parties


num_parties.children[1].observe(num_parties_eventhandler, names='value')


parties_in_quorum = widgets.Box([
    widgets.Label(
        value='Number of parties in quorum',
        layout=Layout(width = 'auto')
    ),
    widgets.IntSlider(
        min=2,
        max=100,
        value=5,
        layout=Layout(width='50%')
    )
], grid_area='parties')


# quorum can have atmost all parties
def update_quorum_range(*args):
    parties_in_quorum.children[1].max = num_parties.children[1].value
    parties_in_quorum.children[1].value = num_parties.children[1].value


num_parties.children[1].observe(update_quorum_range, 'value')


def parties_in_quorum_eventhandler(change):
    # print(change)
    quorum = change.new
    global nb_config
    nb_config['quorum'] = round(quorum/float(nb_config['parties']),2)
    

parties_in_quorum.children[1].observe(parties_in_quorum_eventhandler, names='value')


header_postproc = HTML(value='<{size}>Postprocessing Details'.format(size='h4'), 
                       layout=Layout(width='auto', grid_area='header_postproc'))


hyperparams_dict = {}


params_widgets = []


gen_hyperparams = widgets.Box([
    HTML(value='<{size}>Hyperparameters'.format(size='h4'), layout=Layout(width='auto')),
    widgets.Button(
        description='Get Hyperparameters',
        disabled=False,
        button_style='warning',
        tooltip='Show available hyperparameters for the choices made',
        layout=Layout(width='auto', height='40px')
    )
], grid_area='gen_hyper')


confirmation_box = widgets.Box()


hyperparams_text = widgets.Box()

def populate_hyperparams(b):
    confirm_butn=widgets.Button(
        description='Confirm Hyperparameters',
        disabled=False,
        button_style='warning',
        tooltip='Saves the hyperparameter changes',
        layout=Layout(width='auto', height='40px'))
    determine_hyperparams()
    global params_widgets
    params_widgets.clear()
    generate_hyperparam_UI(hyperparams_dict)
    global hyperparams_text
    hyperparams_text.children = params_widgets
    confirmation_box.children = (confirm_butn,)
    [confirmation_box.children[i].on_click(confirmation_button_handler) for i in range(len(confirmation_box.children))]


gen_hyperparams.children[1].on_click(populate_hyperparams)


def determine_hyperparams():
    exp_df = df[(df.model_ui == nb_config['model']) & (df.dataset == nb_config['dataset']) & (df.fusion_algo == nb_config['fusion'])]
    if len(exp_df) != 1:
        # pick the first matching fusion algorithm
        # print('Found multiple matches, picking the first one')
        firstMatch = exp_df.iloc[0]
        # print(firstMatch)
        nb_config['fusion_identifier'] = firstMatch[0]
    else:
        # print(exp_df)
        nb_config['fusion_identifier'] = list(exp_df.fusion_identifier)[0]
    # print('fusion_id:', nb_config['fusion_identifier'])
    global hyperparams_dict
    model_hyperparams_key = nb_config['fusion_identifier'] + '_' + uimodel_modelid_dict[nb_config['model']] # to get hyperparams from df
    hyperparams_dict = df_hyperparams[df_hyperparams['model_identifier'] == model_hyperparams_key].hyperparams.values[0]


# every model has at most two keys: global and local:
def generate_hyperparam_UI(parameter_dict):
    # print(parameter_dict)
    for key in parameter_dict:
        if type(parameter_dict[key]) == 'dict':
            generate_hyperparam_UI(parameter_dict[key])
        else:
            global params_widgets
            params_widgets.append(widgets.Textarea(description=key, value=str(parameter_dict[key]), layout=Layout(width='400px', height='100px'), grid_area='hyperparams'))


# Add fields for IP addresses
local_or_remote = widgets.Box([
                HTML(value = '<{size}>Run this experiment locally or on remote machines?'.format(size='h4'),
                     layout = Layout(width='auto')),
                widgets.Dropdown(
                    options=['Choose your option','Run Locally', 'Run on Remote Machines'],
                    description='',
                    disabled=False,
                    layout=Layout(width='200px')
                )
])


# dictionary for details of the run, which will get populated as fields get filled
run_details = {}


def network_details_tracker(change): 
    value = change.new
    subkey = change.owner.description.split(':')[0].replace(' ', '_').lower()
    machine_key = change.owner.placeholder.split(' ')[-1]
    # update the run_details dict, depending on whether it already has some details:
    if len(run_details['machines'][machine_key].keys()) == 0:
        temp_dict = {}
        temp_dict[subkey] = value
        run_details['machines'][machine_key] = temp_dict
    else:
        temp_dict = run_details['machines'][machine_key]
        temp_dict[subkey] = value
        run_details['machines'][machine_key] = temp_dict


def get_IPaddr_port(party_index=None):
    placeholder_suffix = ' for machine' + str(party_index)

    ip_addr = widgets.Text(value='', placeholder='IP Address' + placeholder_suffix, description='IP Address:')
    port_num = widgets.Text(value='', placeholder='Port Number' + placeholder_suffix, description='Port Number:')
    ssh_user = widgets.Text(value='', placeholder='ssh username' + placeholder_suffix, description='SSH user:')
    
    machine_detail_vbox = widgets.VBox(children=[ip_addr, port_num, ssh_user])
    [machine_detail_vbox.children[i].observe(network_details_tracker, 'value') for i in range(len(machine_detail_vbox.children))]
    return machine_detail_vbox
    

def path_details_tracker(change):
    value = change.new
    subkey = change.owner.description.split(':')[0].replace(' ', '_').lower()
    if 'local' in change.owner.placeholder:
        # this is a local path, put within `experiments` key
        local_subkey = 'local_' + subkey
        run_details['experiments'][0][local_subkey] = value  # there's only one trial for now
    else:
        # this is a machine path
        # update the run_details dict, depending on whether it already has some details:
        machine_key = change.owner.placeholder.split(' ')[-1]  # to figure which machine is this for
        if len(run_details['machines'][machine_key].keys())==0:
            temp_dict = {}
            temp_dict[subkey] = value
            run_details['machines'][machine_key] = temp_dict
        else:
            temp_dict = run_details['machines'][machine_key]
            temp_dict[subkey] = value
            run_details['machines'][machine_key] = temp_dict


def get_paths(party_index=None):
    if party_index is None:
        placeholder_suffix = ' for local machine'
    else:
        placeholder_suffix = ' for machine' + str(party_index)

    config_path = widgets.Text(value='', placeholder='Staging Dir' + placeholder_suffix, description='Staging Dir:')
    code_path = widgets.Text(value='', placeholder='IBMFL Dir' + placeholder_suffix, description='IBMFL Dir:')
    
    machine_detail_vbox = widgets.VBox(children=[config_path, code_path])
    [machine_detail_vbox.children[i].observe(path_details_tracker, 'value') for i in range(len(machine_detail_vbox.children))]
    return machine_detail_vbox


networking_deets_box = widgets.VBox()


def venv_box_isConda_handler(change):
    if change.new == 'Yes':
        run_details['machines']['venv_uses_conda'] = True
    else:
        run_details['machines']['venv_uses_conda'] = False


def venv_box_venvPath_handler(change):
    run_details['machines']['venv_dir'] = change.new


def display_conda_venv_fields():
    venv_box = widgets.HBox([
                widgets.RadioButtons(
                    options=['Yes', 'No'],
                    description='Use conda?'
                ),
                widgets.Text(
                    value='',
                    placeholder='.venv or conda env name',
                    description='virtual env:',
                    layout=Layout(width='300px', height='auto')
                )
    ])
    venv_box.children[0].observe(venv_box_isConda_handler, 'value')
    venv_box.children[1].observe(venv_box_venvPath_handler, 'value')
    return venv_box


def run_details_text_handler(change):
    # print(change.new)
    global run_details
    try:
        run_details = json.loads(change.new)
    except JSONDecodeError:
        if change.new == '':
            pass
        else:
            display('Incorrect JSON passed for remote details, check and retry!')
            ## Todo: use an Output widget here so the message goes away once the input JSON is changed


def machines_dropdown_eventhandler(change):
    # print(change.new)
    agg_machine = change.new.lower()
    run_details['experiments'][0]['agg_machine'] = agg_machine  # there is only one trial for now
    party_machines = []
    for machine in run_details['machines']:
        party_machines.append(machine)
        
    # now remove the agg machine from the dict
    party_machines.remove(agg_machine)
    run_details['experiments'][0]['party_machines'] = party_machines  # there is only one trial for now


def display_run_details(change):
    change.owner.disabled = True
    run_details['machines'] = {}
    run_details['machines']['venv_uses_conda'] = True
    run_details['machines']['venv_dir'] = '.venv'
    run_details['experiments'] = []
    
    temp_exp_dict = {}
    temp_exp_dict['local_staging_dir'] = ''
    temp_exp_dict['local_ibmfl_dir'] = ''
    conda_fields = display_conda_venv_fields()
    
    if 'Remote' in change.new:
        ## remote execution
        ## initialise the run_details dictionary
        run_details['isLocalRun'] = False

        temp_exp_dict['agg_machine'] = ''
        temp_exp_dict['party_machines'] = []

        for eachMachine in range(nb_config['parties'] + 1):
            run_details['machines']['machine' + str(eachMachine + 1)] = {}
            run_details['machines']['machine' + str(eachMachine + 1)]['ip_address'] = ''
            run_details['machines']['machine' + str(eachMachine + 1)]['port_number'] = ''
            run_details['machines']['machine' + str(eachMachine + 1)]['ssh_username'] = ''
            run_details['machines']['machine' + str(eachMachine + 1)]['staging_dir'] = ''
            run_details['machines']['machine' + str(eachMachine + 1)]['ibmfl_dir'] = ''

        networking_header_1 = HTML(value='<{size}>Details for remote execution: Fill details into the textbox on the left or in individual fields on the right'.format(size='h4'), layout=Layout(width='auto'))

        run_details_box = widgets.VBox([
                                widgets.Label(value='Machine details:', layout=Layout(width='auto')),
                                widgets.Textarea(value=json.dumps(run_details, indent=4), layout=Layout(width='300px', height='700px'))
                            ])
        run_details_box.children[1].observe(run_details_text_handler, 'value')

        networking_header_2=HTML(value = '<center><{size}>OR'.format(size='h3'), layout=Layout(width='auto', margin='5px 15px 5px 15px'))

        all_machines_tuple = ()
        for eachMachine in range(nb_config['parties'] + 1):
            machine_header = HTML(value='<{size}>Machine{id}'.format(size='h4', id=str(eachMachine+1)))
            temp_machine_box = widgets.VBox()
            machine_IP = get_IPaddr_port(eachMachine+1)
            machine_paths = get_paths(eachMachine+1)
            temp_machine_box.children = (machine_header, widgets.HBox(children=[machine_IP, machine_paths]))
            all_machines_tuple = all_machines_tuple + (temp_machine_box,)
        
        machines_dropdown = widgets.Box([
            widgets.Label(
                value='Pick machine for running Aggregator:',
                layout=Layout(width='auto')
            ),
            widgets.Dropdown(
                options=[''] + ['Machine{id}'.format(id=i+1) for i in range(nb_config['parties']+1)],
                layout=Layout(width='auto')
            )])
        
        machines_dropdown.children[1].observe(machines_dropdown_eventhandler, 'value')

        temp_local_vbox = widgets.VBox()
        local_header = HTML(value='<{size}>Local Directories'.format(size='h4'))
        local_path_fields = get_paths()
        temp_local_vbox.children = (local_header, local_path_fields)

        networking_fields_vbox = widgets.VBox(layout=Layout(width='auto', border='0.5px solid black'))
        networking_fields_vbox.children=(conda_fields,) + all_machines_tuple + (machines_dropdown, temp_local_vbox,)
        networking_deets_hbox = widgets.HBox(children=[run_details_box, networking_header_2, networking_fields_vbox])
        save_generate_butn.layout = Layout(width='185px', height='40px', margin='5px 50px 5px 400px')
        networking_deets_box.children=(networking_header_1, networking_deets_hbox, save_generate_butn,)
        run_details['experiments'].append(temp_exp_dict)
        
    else:
        ## local execution
        run_details['isLocalRun'] = True
        temp_exp_dict['agg_machine'] = 'local0'
        temp_exp_dict['party_machines'] = ['local{id}'.format(id=i+1) for i in range(nb_config['parties'])]
        
        ## setup dicts to populate IP addr and port number from generated configs later
        run_details['machines']['local0']={}
        for party in temp_exp_dict['party_machines']:
            run_details['machines'][party] = {}
        
        networking_header = HTML(value = '<{size}>Details for local execution'.format(size='h4'), layout=Layout(width='auto'))

        local_paths = get_paths()
        save_generate_butn.layout = Layout(width='185px', height='40px', margin='5px 50px 5px 50px')
        networking_deets_box.children=(networking_header, conda_fields, local_paths, save_generate_butn)
        
        run_details['experiments'].append(temp_exp_dict)

    
    with input_ui:
        display(partyDetails_grid)
    

local_or_remote.children[1].observe(display_run_details, 'value')


def display_configs_before_run(b):
    input_ui.clear_output()
    agg_conf_path, party_conf_path = generate_update_configs()
    if agg_conf_path is None or party_conf_path is None:
        with input_ui:
            display('Error generating configs. Exiting...')
    else:
        display_configs(agg_conf_path, party_conf_path)
        with input_ui:
            display(display_grid_1)


save_generate_butn = widgets.Button(
        description='Proceed to generate configs',
        disabled=False,
        button_style='warning', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Generates config files from the above details',
        layout=Layout(width='185px', height='40px', margin='10px')
    )


save_generate_butn.on_click(display_configs_before_run)


def confirmation_button_handler(b):
    b.disabled = True
    b.description = 'Confirming hyperparams...'
    global params_widgets, nb_config
    for widget in params_widgets:
        nb_config[widget.description] = widget.value
    
    input_ui.clear_output()

    with input_ui:
        display(local_or_remote)


def generate_update_configs():
    # Get timestamp and add it to the given local staging directory:
    nb_config['timestamp_str'] = ibmfl_runner.Runner().generate_timestamp()
    trial_dir = run_details['experiments'][0]['local_staging_dir'] + '/' + nb_config['timestamp_str']
    
    # Create the staging_directory:
    mkdir_cmd = 'mkdir -p ' + trial_dir
    process = subprocess.run(mkdir_cmd, shell=True,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE)
    if process.returncode!=0:
        with input_ui:
            display('Erred: ', process.stderr)
        return None, None
    
    # Generate Data
    with input_ui:
            display('Generating Data...')

    cmd_to_run = 'cd ../../; python3 examples/generate_data.py --num_parties ' + str(nb_config['parties']) + ' -d ' + nb_config['dataset'] + ' -pp ' + str(nb_config['split']['ppp']) + ' -p ' + trial_dir  # there's only one trial for now
    if 'Stratified' in nb_config['split']['method']:
        cmd_to_run = cmd_to_run + ' --stratify'

    # with input_ui:
    #    display('Executing {}'.format(cmd_to_run))
    process = subprocess.run(cmd_to_run, shell=True,
                             stdout=subprocess.PIPE, 
                             stderr=subprocess.PIPE)
    if process.returncode!=0:
        with input_ui:
            display('Erred: ', process.stderr)
        return None, None

    # path to get datasets from
    data_path = str(process.stdout).split('Data saved in')[-1].strip().replace('\\n\'', '')
    with input_ui:
        display('Datasets saved to: {}'.format(data_path))

    # Generate Configs:
    with input_ui:
        display('Generating Configs...')
    if 'crypto' in nb_config['fusion_identifier']:
        # it is either of crypto keras or crypto_multiclass_keras, so need -crypto flags:
        # Todo: Need to let user pick one of {Paillier, ThresholdPaillier}
        cmd_to_run = 'cd ../../; python3 examples/generate_configs.py --num_parties ' + str(nb_config['parties']) + ' -f ' + nb_config['fusion_identifier'] + ' -m ' + uimodel_modelid_dict[nb_config['model']] + ' -crypto Paillier -d ' + nb_config['dataset'] + ' -p ' + data_path + ' --config_path ' + trial_dir  # there's only one trial for now
    else:
        cmd_to_run = 'cd ../../; python3 examples/generate_configs.py --num_parties ' + str(nb_config['parties']) + ' -f ' + nb_config['fusion_identifier'] + ' -m ' + uimodel_modelid_dict[nb_config['model']] + ' -d ' + nb_config['dataset'] + ' -p ' + data_path + ' --config_path ' + trial_dir  # there's only one trial for now

    # print('Executing {}'.format(cmd_to_run))
    process = subprocess.run(cmd_to_run, shell=True,
                             stdout=subprocess.PIPE, 
                             stderr=subprocess.PIPE,
                             universal_newlines=True)
    if process.returncode==0:
        # save agg and party configs path
        configs_path = os.path.dirname(process.stdout.split('\n')[0].split(':')[1].strip())
        path_to_save_agg_configs = configs_path + '/config_agg.yml'
        print('Aggregator configs saved to: {}'.format(path_to_save_agg_configs))
        path_to_save_party_configs = configs_path + '/config_party*.yml'
        print('Party configs saved to: {}'.format(path_to_save_party_configs))
    else:
        with input_ui:
            display('Erred: ', process.stderr)
        return None, None

    # modify hyperparameter text to fix quotes
    hyp_text = nb_config['global']
    hyp_text = json.loads(hyp_text.replace('\'', '"'))
    nb_config['global'] = hyp_text
    hyp_text = nb_config['local']
    hyp_text = json.loads(hyp_text.replace('\'', '"'))
    if nb_config['fusion_identifier'] == 'fedplus':
        alpha = hyp_text['training'].pop('alpha')
    nb_config['local'] = hyp_text
    
    # add num_parties as a key under global, to match the structure in the agg yaml configs
    val = nb_config.pop('parties')
    nb_config['global']['num_parties'] = val
    val = nb_config.pop('quorum')
    nb_config['global']['perc_quorum'] = val
    
    # Load Aggregator Config
    with open(path_to_save_agg_configs, 'r') as stream:
        try:
            agg_config = yaml.safe_load(stream)
        except yaml.YAMLError as e:
            print(e)
            return None, None

    # for local runs, update the dirs to all the "machines" (they're all local)
    if run_details['isLocalRun']:
        run_details['machines']['ibmfl_dir'] = run_details['experiments'][0]['local_ibmfl_dir']
        run_details['machines']['staging_dir'] = run_details['experiments'][0]['local_staging_dir']

    # Modify aggregator config with values captured from the UI:
    # - update the hyperparameters object with newer global and local objects as updated above
    # - update ip and port from the run_details object
    # - TODO: Update model spec when uploading model file is supported
    agg_config['hyperparams']['global'] = nb_config['global']
    agg_config['hyperparams']['local'] = nb_config['local']    
    agg_machine = run_details['experiments'][0]['agg_machine']  # there's only one trial for now

    if not run_details['isLocalRun']:
        agg_config['connection']['info']['ip'] = run_details['machines'][agg_machine]['ip_address']
        agg_config['connection']['info']['port'] = int(run_details['machines'][agg_machine]['port_number'])
    else:
        run_details['machines'][agg_machine]['ip_address'] = agg_config['connection']['info']['ip']
        run_details['machines'][agg_machine]['port_number'] = agg_config['connection']['info']['port']
        run_details['machines'][agg_machine]['ssh_username'] = os.getenv('USER')

    
    # Write this updated yaml to file
    with open(path_to_save_agg_configs, 'w') as out:
        yaml.safe_dump(agg_config, out, default_flow_style=False)
    with input_ui:
        display('Updated Aggregator config at {}'.format(path_to_save_agg_configs))


    # Modify party config with values accepted from the UI
    # - update IP address, port for agg and party as received from the UI (only remote runs)
    # - add metrics section (both remote and local run)
    # - add alpha, if model chosen is Fed+
    if not run_details['isLocalRun']:
        currParty = 0
        for eachMachine in run_details['experiments'][0]['party_machines']: # there's only one trial for now
            # Load
            with open(path_to_save_party_configs.replace('*', str(currParty))) as stream:
                try:
                    party_config = yaml.safe_load(stream)
                except yaml.YAMLError as e:
                    print(e)
                    return None, None

            agg_machine = run_details['experiments'][0]['agg_machine']  # there's only one trial for now
            # Modify
            party_config['aggregator']['ip'] = run_details['machines'][agg_machine]['ip_address']
            party_config['aggregator']['port'] = run_details['machines'][agg_machine]['port_number']
            
            party_config['connection']['info']['ip'] = run_details['machines'][eachMachine]['ip_address']
            party_config['connection']['info']['port'] = int(run_details['machines'][eachMachine]['port_number'])
            party_config['connection']['info']['port'] = int(run_details['machines'][eachMachine]['port_number'])

            # Metrics section to add to each party config
            party_config['metrics_recorder']={}
            party_config['metrics_recorder']['name'] = 'MetricsRecorder'
            party_config['metrics_recorder']['path'] = 'ibmfl.party.metrics.metrics_recorder'
            party_config['metrics_recorder']['output_file'] = '${config_dir}/metrics_party${id}'.replace('${config_dir}', run_details['machines'][eachMachine]['staging_dir']).replace('${id}', str(currParty))
            party_config['metrics_recorder']['output_type'] = 'json'
            party_config['metrics_recorder']['compute_pre_train_eval'] = False
            party_config['metrics_recorder']['compute_post_train_eval'] = True

            if nb_config['fusion_identifier'] == 'fedplus':
                party_config['local_training']['info']['alpha'] = alpha

            # Finally, write updated agg config to file
            with open(path_to_save_party_configs.replace('*', str(currParty)), 'w') as out:
                yaml.safe_dump(party_config, out, default_flow_style=False)
            currParty += 1
    else:
        currParty = 0
        for eachMachine in run_details['experiments'][0]['party_machines']:  # there's only one trial for now
            # Load
            with open(path_to_save_party_configs.replace('*', str(currParty))) as stream:
                try:
                    party_config = yaml.safe_load(stream)
                except yaml.YAMLError as e:
                    print(e)
                    return None, None


            # save IP addr and port number from the party config, into `run_details` dict, for runner's use
            run_details['machines'][eachMachine]['ip_address'] = party_config['connection']['info']['ip']
            run_details['machines'][eachMachine]['port_number'] = party_config['connection']['info']['port']
            run_details['machines'][eachMachine]['ssh_username'] = os.getenv('USER')

            # Metrics section to add to each party config
            party_config['metrics_recorder']={}
            party_config['metrics_recorder']['name'] = 'MetricsRecorder'
            party_config['metrics_recorder']['path'] = 'ibmfl.party.metrics.metrics_recorder'
            party_config['metrics_recorder']['output_file'] = '${config_dir}/metrics_party${id}'.replace('${config_dir}', trial_dir).replace('${id}', str(currParty))
            party_config['metrics_recorder']['output_type'] = 'json'
            party_config['metrics_recorder']['compute_pre_train_eval'] = False
            party_config['metrics_recorder']['compute_post_train_eval'] = True

            if nb_config['fusion_identifier'] == 'fedplus':
                party_config['local_training']['info']['alpha'] = alpha
                
            # Finally, write updated party config to file
            with open(path_to_save_party_configs.replace('*', str(currParty)), 'w') as out:
                yaml.safe_dump(party_config, out, default_flow_style=False)

            currParty += 1

    with input_ui:
        display('Updated Party configs at {}'.format(path_to_save_party_configs))

    nb_config['local_conf_dir'] = str(os.path.dirname(path_to_save_agg_configs))

    return path_to_save_agg_configs, path_to_save_party_configs

config_box = widgets.VBox(layout=Layout(width='auto'))


def display_configs(agg_conf_path, party_conf_path):
    # Display aggregator and party* configs
    display_header = HTML(value='<{size}>Configs Generated:'.format(size='h4'), layout=Layout(width='auto'))

    agg_conf_header = HTML(value='<{size}>Aggregator Config'.format(size='h4'), layout=Layout(width='auto'))
    agg_conf = widgets.Output(layout={'border': '0.5px solid black'})

    # read agg config from filesystem:
    with open(agg_conf_path) as stream:
        try:
            agg_config = yaml.safe_load(stream)
        except yaml.YAMLError as e:
            print(e)

    with agg_conf:
        display(agg_config)

    party_conf_header = HTML(value='<{size}>Party0 Config'.format(size='h4'), layout=Layout(width='auto'))
    party_conf = widgets.Output(layout={'border': '0.5px solid black'})    

    # read party0 from filesystem:
    with open(party_conf_path.replace('*', '0')) as stream:
        try:
            party_config = yaml.safe_load(stream)
        except yaml.YAMLError as e:
            print(e)

    ## display
    with party_conf:
        display(party_config)

    agg_box = widgets.HBox(children=[agg_conf_header, agg_conf], layout=Layout(width='auto', padding='20px'))
    party_box = widgets.HBox(children=[party_conf_header, party_conf], layout=Layout(width='auto', padding='10px'))
    party_disclmr_1 = HTML(value='<strong><center>Other parties follow config similar to Party0, except connection.info.[ip,port] and paths', 
                         layout=Layout(width='auto'))
    party_disclmr_2 = HTML(value='<strong><center>Also, each party gets a separate dataset file, split from the chosen dataset', 
                         layout=Layout(width='auto'))
    config_box.children=[display_header, agg_box, party_box, party_disclmr_1, party_disclmr_2, run_butn]


run_butn = widgets.Button(
        description='Run Experiment',
        disabled=False,
        button_style='warning',
        tooltip='Runs the experiment with above config',
        layout=Layout(width='125px', height='40px', margin='5px 50px 5px 400px')  # margin to position button centrally
    )


monitoring_box = widgets.VBox()

plot_button = widgets.Button(
        description='Show Charts',
        disabled=False,
        button_style='warning', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Displays the various plots for the experiment that ran',
        layout = Layout(width='120px', height='40px', margin='5px 50px 5px 400px') ## margin to position button centrally
    )


def invoke_runner(b):
    b.disabled = True
    input_ui.clear_output()
    monitoring_out = widgets.Output(layout={'border': '0.5px solid black'})
    monitoring_box.children = [monitoring_out]
    display(display_grid_2)

    # some values needed by the Runner; there's only one trial for now
    run_details['experiments'][0]['shuffle_party_machines'] = False
    run_details['experiments'][0]['n_trials'] = 1
    run_details['experiments'][0]['n_parties'] = nb_config['global']['num_parties']
    run_details['experiments'][0]['n_rounds'] = nb_config['global']['rounds']

    # values for postprocessing and showing default metrics
    run_details['experiments'][0]['postproc_fn'] = {}
    run_details['experiments'][0]['postproc_fn'] = 'gen_reward_vs_time_plots'
    run_details['experiments'][0]['postproc_x_key'] = 'post_train:ts'
    run_details['experiments'][0]['postproc_y_keys'] = ['post_train:eval:loss', 'post_train:eval:acc']#, 'post_train:eval:precision weighted', 'post_train:eval:recall weighted']

    exp_machines = exp_runner.convert_machine_dict_from_nb_to_cli(run_details['machines'])

    for exp_info in run_details['experiments']:
        with open('{}/config_agg.yml'.format(nb_config['local_conf_dir']), 'r') as config_agg_file:
            config_agg = config_agg_file.read()
        config_parties = []
        for pi in range(exp_info['n_parties']):
            with open('{}/config_party{}.yml'.format(nb_config['local_conf_dir'], pi), 'r') as config_party_file:
                config_parties += [config_party_file.read()]
        with monitoring_out:
            display(exp_runner.run_experiment(exp_info, run_details['machines'],
                                              config_agg, config_parties, ui_mode='nb', ts=nb_config['timestamp_str']) \
                    or 'Finished!')

    if 'Keras' in nb_config['model']:
        monitoring_box.children = monitoring_box.children + (plot_button,)
    else:
        with monitoring_out:
            display('No plots to show for the chosen model')


plots_box = widgets.VBox()


def get_plots(b):
    b.disabled = True
    no_plots_for_these = ['Federated Averaging', 'Gradient Averaging', 'Probabilistic Federated Neural Matching']
    plots_out = widgets.Output(layout={'border': '0.5px solid black'})
    plots_box.children = [plots_out]
    display(display_grid_3)
    if nb_config['fusion'] in no_plots_for_these:
        with plots_out:
            display('Plots for chosen fusion algorithm are not supported yet') # metrics processing not in place
    else:
        # generate the plot
        with plots_out:
            display(exp_info = exp_runner.call_postproc_fn())


run_butn.on_click(invoke_runner)


plot_button.on_click(get_plots)


# GridBox layout for UI
grid = GridspecLayout(2,3)

grid[0,:] = GridBox(children=[model_header, model_dropdown, #upload_model_file, 
                  dataset_header, dataset_dropdown, splitting_dropdown, points_slider,
                  fusion_dropdown,
                  header_parties, num_parties, parties_in_quorum,
#                   header_postproc, postproc_func, postproc_xkey, postproc_ykeys,
                  gen_hyperparams
                 ],
       layout = Layout(
           width='100%',
           grid_template_rows='auto auto',
           grid_template_columns='48% 48%',
           grid_template_areas='''
           "model_header model_header"
           "model_dr model_dr"
           "dataset_header dataset_header"
           "dataset dataset_spl"
           "fusion_dr fusion_dr"
           "header_parties header_parties"
           "parties parties"
           "gen_hyper gen_hyper"
            ''')
       )
# Nested grid to vary spacing across various widgets
sub_grid_hyperparams = GridspecLayout(2,3)
sub_grid_hyperparams[0,:] = hyperparams_text
sub_grid_hyperparams[1,1] = confirmation_box

grid[1, :] = sub_grid_hyperparams

input_ui = widgets.Output()

with input_ui:
    display(grid)

# grid for displaying networking fields -- IP addr, port, ssh user, paths
partyDetails_grid = GridspecLayout(1,3)
partyDetails_grid[0, :] = networking_deets_box

# grid for displaying generated configuration
display_grid_1 = GridspecLayout(1,3)
display_grid_1[0, :] = config_box

# grid for displaying progress of running experiment
display_grid_2 = GridspecLayout(1,1)
display_grid_2[0, :] = monitoring_box

# grid for displaying charts from collected metrics
display_grid_3 = GridspecLayout(1,1)
display_grid_3[0, :] = plots_box

input_ui