In [14]:
import ipywidgets as widgets
from IPython.display import display

In [15]:
import pandas as pd
import numpy as np
import pickle as pkl
import json
import matplotlib.pyplot as plt
import seaborn as sns

In [16]:
config_list = ['configCC7-0', 'configCC7-1', 'configCC7-2', 'configCC7-3']

In [17]:
symptom_file = '../ProJect-Conv-Agent/XPs/cat_dqn/output/CC7/n_cc7_126_symptoms.json'
condition_file = '../ProJect-Conv-Agent/XPs/cat_dqn/output/CC7/n_cc7_20_conditions.json'

In [18]:
def clean(data):
    """Replaces commas and line breaks in the source string
       with a single space.
    Parameters
    ----------
    data : str
        data string to be cleaned

    Returns
    -------
    str
        the resulting string
    """
    result = data.replace("\r\n", " ")
    result = result.replace("\r", " ")
    result = result.replace("\n", " ")
    result = result.replace(",", " ")
    return result

In [19]:
def load_and_check_data(data_filepath, provided_data, key_name):
    """load the authorized data and check if the
       provided data are compliant with those.
    Parameters
    ----------
    symptom_filepath :  str
        path to a json file containing the authorized symptom data.
        the minimum structure of the data should be:
        {
            key_data1: {
                key_name: data-name1,
                ...
            },
            key_data2: {
                key_name: data-name2,
                ...
            },
            ...
        }
    provided_data : list
        list of syptoms as provided by the data from Synthea patient
        generation
    key_name : str
        the key used to access the information of the same meaning
        as the one in `provided_data`
    Returns
    -------
    index_2_key: list
        a list containing all the keys of the authorized data
    name_2_index: dict
        a dict mapping the name associated to the authorized data to an index
    data: dict
        the authorized data
    """

    with open(data_filepath) as fp:
        data = json.load(fp)

    index_2_key = sorted(list(data.keys()))
    for k in index_2_key:
        data[k][key_name] = clean(data[k][key_name])
    name_2_index = {data[index_2_key[i]][key_name]: i for i in range(len(index_2_key))}

    data_names = [data[k][key_name] for k in index_2_key]
    is_present = [elem in data_names for elem in provided_data]

    has_all_data = all(is_present)

    if not has_all_data:
        index = is_present.index(False)
        raise ValueError(
            "The provided symptom samples are not compliant with "
            + "authorized symptoms in the json file: {} : {}".format(
                data_filepath, provided_data[index]
            )
        )

    return index_2_key, name_2_index, data

In [20]:
def load_data(config_list, symptom_file, condition_file, config_base_dir='../ProJect-Conv-Agent/XPs/cat_dqn/output/CC7'):
    symptom_infos = load_and_check_data(
        symptom_file, [], key_name="name"
    )
    pathology_infos = load_and_check_data(
        condition_file, [], key_name="condition_name"
    )
    result = {}
    for config in config_list:
        stats_file = f'{config_base_dir}/{config}/evaluation_stats.pkl'
        result_file = f'{config_base_dir}/{config}/metric_results.json'
        result[config] = {}
        with open(stats_file, 'rb') as f:
            result[config]['data'] = pkl.load(f)
        with open(result_file) as f:
            result[config]['result'] = json.load(f)
    return result, symptom_infos, pathology_infos

In [21]:
def print_confusion_matrix(confusion_matrix, class_names, normalize=False, figsize = (10,7), fontsize=14, ax=None):
    """Prints a confusion matrix, as returned by sklearn.metrics.confusion_matrix, as a heatmap.
    
    Arguments
    ---------
    confusion_matrix: numpy.ndarray
        The numpy.ndarray object returned from a call to sklearn.metrics.confusion_matrix. 
        Similarly constructed ndarrays can also be used.
    class_names: list
        An ordered list of class names, in the order they index the given confusion matrix.
    figsize: tuple
        A 2-long tuple, the first value determining the horizontal size of the ouputted figure,
        the second determining the vertical size. Defaults to (10,7).
    fontsize: int
        Font size for axes labels. Defaults to 14.
        
    Returns
    -------
    matplotlib.figure.Figure
        The resulting confusion matrix figure
    """
    if normalize:
        confusion_matrix = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]        
    df_cm = pd.DataFrame(
        confusion_matrix, index=class_names, columns=class_names, 
    )
    if ax is None:
        _, ax = plt.subplots(figsize=figsize)
    #fig = plt.figure(figsize=figsize)
    fmt = '.2f' if normalize else 'd'
    try:
        heatmap = sns.heatmap(df_cm, annot=True, fmt=fmt, ax=ax)
    except ValueError:
        raise ValueError("Confusion matrix values must be integers.")
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
    #plt.ylabel('True label')
    #plt.xlabel('Predicted label')
    heatmap.set_ylabel('True label')
    heatmap.set_xlabel('Predicted label')
    return heatmap

In [22]:
all_data, all_symptoms_info, all_patho_info = load_data(config_list, symptom_file, condition_file)
symptoms_reverse_map = {all_symptoms_info[0][i]: i for i in range(len(all_symptoms_info[0]))}
pathos_reverse_map = {all_patho_info[0][i]: i for i in range(len(all_patho_info[0]))}

In [27]:
data_info = {
    'pathos': [all_patho_info[0], all_patho_info[1]],
    'symptoms': [all_symptoms_info[0], all_symptoms_info[1]],
}
with open('./data_info.json', "w") as f:
    json.dump(data_info, f, indent=4)

In [25]:
dropdown_config = widgets.Dropdown(options=config_list, description='Configuration')
dropdown_patho = widgets.Dropdown(options=['All'] + all_patho_info[0], description='Pathology')

item_layout = widgets.Layout(margin='0 0 50px 0')

output_config = widgets.Output()

output_metric = widgets.Output()
output_turn_stats = widgets.Output()
output_reward_stats = widgets.Output()
output_simulated_symp_stats = widgets.Output()
output_relevant_symp_stats = widgets.Output()
output_simulated_ratio_stats = widgets.Output()
output_relevancy_ratio_stats = widgets.Output()

output_cm = widgets.Output()
output_first_symptom_dist = widgets.Output()
output_inquire_dist = widgets.Output()

input_widgets = widgets.HBox([dropdown_config, dropdown_patho], layout=item_layout)

metric1_widgets = widgets.HBox(
    [output_metric, output_turn_stats, output_reward_stats, output_simulated_symp_stats],
    layout=item_layout
)
metric2_widgets = widgets.HBox(
    #[output_relevant_symp_stats, output_simulated_ratio_stats, output_relevancy_ratio_stats],
    [output_relevant_symp_stats, output_relevancy_ratio_stats],
    layout=item_layout
)
metrics = widgets.VBox([metric1_widgets, metric2_widgets], layout=item_layout)
plots = widgets.Tab([output_cm, output_first_symptom_dist, output_inquire_dist])
plots.set_title(0, 'Confusion Matrix')
plots.set_title(1, 'First Symptom Distribution')
plots.set_title(2, 'Inquired Symptom Distribution')

accordion = widgets.Accordion(children=[output_config, metrics, plots])
accordion.set_title(0, 'Data')
accordion.set_title(1, 'Metrics')
accordion.set_title(2, 'Plots')


dashboard = widgets.VBox([input_widgets, accordion])

def generate_stats_keys(base):
    return [base + a for a in ['_Avg', '_Std', '_Median', '_Min', '_Max']]

def plot_distribution(selected_data, key, title, figsize=None, fontsize=14):
    x = list(selected_data[key].keys())
    v = sorted([int(a) for a in x])
    x = [str(b) for b in v]
    y = [selected_data[key][a] for a in x]
    symp = [all_symptoms_info[0][int(i)] for i in x]
    max_value = max(y)
    n_plots = len(symp) // 20
    if n_plots*20 != len(symp):
        n_plots = n_plots + 1
    n_graphs = n_plots // 2        
    if n_graphs*2 != n_plots:
        n_graphs = n_graphs + 1
    
    # [6.4, 4.8]
    if n_graphs==n_plots==1:
        fig, axes = plt.subplots(figsize=figsize)
    else:
        if figsize is None:
            figsize = [6.5 * 2, 5 * n_graphs]
        fig, axes = plt.subplots(n_graphs, 2, figsize=figsize)
    for i in range(n_graphs):            
        for j in range(2):
            beg = (2*i+j)*20
            if beg < len(symp):
                end = min((2*i+j+1) * 20, len(symp))
                ax = axes if n_graphs==n_plots==1 else axes[i, j]
                barplot = sns.barplot(
                    x=symp[beg:end], y=y[beg:end], order=symp[beg:end], 
                    orient='v', 
                    ax=ax
                )
                barplot.xaxis.set_ticklabels(
                    barplot.xaxis.get_ticklabels(), 
                    rotation=90, 
                    ha='center', fontsize=fontsize
                )
                barplot.set_ylim(top=max_value)
                
                #barplot = sns.barplot(
                #    y=symp[beg:end], x=y[beg:end], order=symp[beg:end], 
                #    orient='h', estimator=lambda x:x[0], 
                #    ax=ax
                #)
                #barplot.yaxis.set_ticklabels(
                #    barplot.yaxis.get_ticklabels(), 
                #    rotation=0, 
                #    ha='right', fontsize=fontsize
                #)
    if not (n_graphs==n_plots==1):
        plt.tight_layout()
    return fig
       

    
metric_keys = ['accuracy', 'balanced_accuracy', 'precision', 'recall', 'f1']
turn_keys = generate_stats_keys('turns')
reward_keys = generate_stats_keys('rewards')
turn_keys_out = generate_stats_keys('turn')
reward_keys_out = generate_stats_keys('reward')
relevant_symp_keys = generate_stats_keys('num_relevant_symptoms')
simulated_symp_keys = generate_stats_keys('num_simulated_symptoms')
relevant_ratio_keys = generate_stats_keys('relevancy_symptoms_ratio')
simulated_ration_keys = generate_stats_keys('simulated_symptoms_ratio')

def common_filtering(config, patho):
    output_config.clear_output()
    
    output_metric.clear_output()
    output_turn_stats.clear_output()
    output_reward_stats.clear_output()
    output_simulated_symp_stats.clear_output()
    output_relevant_symp_stats.clear_output()
    output_simulated_ratio_stats.clear_output()
    output_relevancy_ratio_stats.clear_output()
    
    output_cm.clear_output()
    output_first_symptom_dist.clear_output()
    output_inquire_dist.clear_output()
    
    if patho == 'All':
        selected_data = all_data[config]['result']['global']
    else:
        selected_data = all_data[config]['result']['per_patho'][str(pathos_reverse_map[patho])]
        
    with output_metric:
        display(
            pd.DataFrame(
                {'Value': [selected_data[a] for a in metric_keys]}, 
                index=metric_keys
            )
        )
    with output_turn_stats:
        display(
            pd.DataFrame(
                {'Value': [selected_data[a] for a in turn_keys]}, 
                index=turn_keys_out
            )
        )
    with output_reward_stats:
        display(
            pd.DataFrame(
                {'Value': [selected_data[a] for a in reward_keys]}, 
                index=reward_keys_out
            )
        )
    with output_relevant_symp_stats:
        display(
            pd.DataFrame(
                {'Value': [selected_data[a] for a in relevant_symp_keys]}, 
                index=relevant_symp_keys
            )
        )
    with output_simulated_symp_stats:
        display(
            pd.DataFrame(
                {'Value': [selected_data[a] for a in simulated_symp_keys]}, 
                index=simulated_symp_keys
            )
        )
    with output_relevancy_ratio_stats:
        display(
            pd.DataFrame(
                {'Value': [selected_data[a] for a in relevant_ratio_keys]}, 
                index=relevant_ratio_keys
            )
        )
    with output_simulated_ratio_stats:
        display(
            pd.DataFrame(
                {'Value': [selected_data[a] for a in simulated_ration_keys]}, 
                index=simulated_ration_keys
            )
        )
        
    if patho == 'All':        
        with output_cm:
            fig = print_confusion_matrix(
                np.array(selected_data['confusion_matrix']),
                [
                    all_patho_info[0][i] if i !=-1 else "N/A" 
                    for i in selected_data['confusion_matrix_support']
                ],
            )
            plt.show()
    with output_first_symptom_dist:
        fig = plot_distribution(selected_data, 'first_symptoms_count', 'First symptom distribution')
        plt.show()
        
    with output_inquire_dist:
        fig = plot_distribution(selected_data, 'inquired_symptoms_count', 'Inquired symptom distribution')
        plt.show()
        
    with output_config:
        display(selected_data)
        
    
    


def dropdown_config_eventhandler(change):
    common_filtering(change.new, dropdown_patho.value)
    
def dropdown_patho_eventhandler(change):
    common_filtering(dropdown_config.value, change.new)
    
dropdown_config.observe(dropdown_config_eventhandler, names='value')
dropdown_patho.observe(dropdown_patho_eventhandler, names='value')



In [26]:
display(dashboard)

VBox(children=(HBox(children=(Dropdown(description='Configuration', options=('configCC7-0', 'configCC7-1', 'co…