In [None]:
import numpy as np
import pandas as pd
import ipywidgets as widgets
import sys
sys.path.append('../')
sys.path.append('../src')
import backend
import workload_builder as builder
from mbi import Domain, Dataset
import plots
import altair as alt
from IPython.display import display, clear_output


from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


alt.data_transformers.disable_max_rows()

In [None]:
domain = Domain(attrs=('Incident Month','Incident Year','Operator','Species Name','Species Quantity'), shape=(12,15,6,7,5))
questions = ["In the past few years, the FAC has been observing a large increase in the number of incidents reported, they would like you to investigate. Type 'next' to start this activity.",
            "Question 1/5: When did this increase begin? Identify the spike year.",
            "Question 2/5: What was the number of incidents in the year prior to the spike year?",
            "Question 3/5: What was the number of incidents in the spike year?",
            "Question 4/5: Between the two years identified previously, which airline experienced the largest decrease in incidents? We will perform further analysis on this Airline in the next step.",
            "Question 5/5: We’ve been told by the authorities to investigate incidents during Spring months (MAR,APR,MAY) due a spike in incidents during that time period. How many incidents occur in that range involving the airline you identified previously?"]

yr_operator = {'Incident Year':1, 'Operator':1}
airline_month = {'Incident Month':1, 'Operator':1}

visualizations = [None, yr_operator, yr_operator, yr_operator, yr_operator, airline_month]
prev_spec = None
curr_spec = None
epsilon_increments = [None,0.05,0.05,0.05,0.05,0.05]
budget = 1.0
max_tries = 10

index = 0
tries = 0
back_end = None
seed=10


# Logging

In [None]:
import logging

class OutputWidgetHandler(logging.Handler):
    """ Custom logging handler sending logs to an output widget """

    def __init__(self, *args, **kwargs):
        super(OutputWidgetHandler, self).__init__(*args, **kwargs)
        layout = {
            'width': '100%',
            'height': '160px',
            'border': '1px solid black'
        }
        self.out = widgets.Output(layout=layout)

    def emit(self, record):
        """ Overload of logging.Handler method """
        formatted_record = self.format(record)
        new_output = {
            'name': 'stdout',
            'output_type': 'stream',
            'text': formatted_record+'\n'
        }
        self.out.outputs = (new_output, ) + self.out.outputs

    def show_logs(self):
        """ Show the logs """
        display(self.out)

    def clear_logs(self):
        """ Clear the current logs """
        self.out.clear_output()


logger = logging.getLogger(__name__)
handler = OutputWidgetHandler()
handler.setFormatter(logging.Formatter('%(asctime)s  - [%(levelname)s] %(message)s'))
logger.addHandler(handler)
logger.setLevel(logging.INFO)

In [None]:
plot_output = widgets.Output()
back_end = backend.initialize_backend_wildlife(domain, 'discretized.csv', seed=seed, budget=2.0)

## Buttons

In [None]:
submit_btn = widgets.Button(description='Submit')
make_it_better = widgets.Button(description='Remeasure')

def on_click_submit(obj):
    global index, survey_answers, visualizations, epsilon_increments, curr_spec, back_end
    logger.info('Clicked submit')
    if answer.value == '':
        return
    val = answer.value
    answer.value = answer.placeholder
    survey_answers[index] = val
    
    if index+1 < len(questions):
        index += 1
        prompt.value = questions[index]
    else:
        with plot_output:
            clear_output()
        submit_btn.close()
        interface.close()
        bar_label.close()
        prompt.value = 'Thanks! Let\'s move on to the next activity!'
        display(prompt)
        return
    
    #back_end = backend.initialize_backend_wildlife(domain, 'discretized.csv', budget=2.0)
    
    hist = builder.histogram_workload(domain.config, bin_widths=visualizations[index])
    if visualizations[index] != visualizations[index-1]: 
        back_end.measure_hdmm(workload=hist, eps=0.01, restarts=20)
    
    column_names = list(visualizations[index].keys())
    curr_spec = back_end.display(hist)
    
    with plot_output:
        plot_output.clear_output()
        plot = plots.linked_hist(column_names[0], column_names[1], data=curr_spec.reset_index(column_names), display_true=False)
        display(plot) 

def on_click_make_it_better(obj):
    global index, visualizations, tries
    logger.info('Clicked remeasure')

    key_val = [x for x in visualizations[index].items()]
    measure_dict = {'left': key_val[0], 'right': key_val[1]}
    binning(measure_dict, epsilon = epsilon_increments[index])
    
submit_btn.on_click(on_click_submit)
make_it_better.on_click(on_click_make_it_better)

## Text Boxes

In [None]:
global index
prompt = widgets.Textarea(
    value=questions[0],
    placeholder='',
    description='',
    disabled=True,
    layout=widgets.Layout(width='1000px', height='50px')
)

answer = widgets.Textarea(
    value='',
    placeholder='',
    description='Answer:',
    layout=widgets.Layout(width='300px', height='40px')
)


In [None]:
back_end = backend.initialize_backend_wildlife(domain, 'discretized.csv', seed=seed, budget=budget)

progress_bar = widgets.FloatProgress(min=0.0, max=max_tries) # instantiate the bar
progress_bar.style.bar_color = 'red'
progress_bar.description = str(tries) + '/' + str(max_tries)
budget_spent = widgets.Label(value='Tries')
num_answers = len(questions)
survey_answers = [0]*num_answers

def binning(measure_dict, epsilon=None, group_income=None):
    global tries, curr_spec, back_end
    
    left_col = measure_dict['left'] 
    right_col = measure_dict['right']
    
    widths = {left_col[0]:left_col[1], right_col[0]:right_col[1]}
    
    hist = builder.histogram_workload(domain.config, bin_widths=widths)
    
    if epsilon is not None:
        if tries+1 >= max_tries:
            make_it_better.close()
        if tries+1 > max_tries:
            return
        tries += 1
        progress_bar.value = tries
        progress_bar.description = str(tries) + '/' + str(max_tries)
        back_end.measure_hdmm(workload=hist, eps=epsilon, restarts=20)
    
    prev_spec = curr_spec
    prev_spec.rename(columns={'error': 'error_prev', 'plus_error': 'plus_error_prev', 'minus_error': 'minus_error_prev', 'true_count':'true_count_prev', 'noisy_count':'noisy_count_prev'}, inplace=True)
    curr_spec = back_end.display(hist)
    spec = curr_spec.join(prev_spec, on=[left_col[0], right_col[0]]).reset_index([left_col[0], right_col[0]])
    spec = spec.round(0)
    
    with plot_output:
        plot_output.clear_output()
        plot = plots.linked_hist_test(left_col[0], right_col[0], data=spec, projection=True,label=False)
        display(plot) 
    
box_layout = widgets.Layout(display='flex',
    flex_flow='column',
    align_items='flex-start',
    color='black',
    width='50%')

bar_label = widgets.HBox([budget_spent, progress_bar, make_it_better])
prompt_answer = widgets.VBox([answer, submit_btn], layout=box_layout)

In [None]:
display(bar_label)
    
interface = widgets.VBox([prompt, plot_output, prompt_answer])
display(interface)

In [None]:
survey_answers

In [None]:
handler.show_logs()