<center><img width="450" height="120" src="../../../../assets/media/teaser-v6.png"></center>

<center><h1> Analyzation Platform for CaC </h1></center>

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os.path import dirname, join, exists, splitext, basename, isdir
from copy import deepcopy
from typing import List
import multiprocessing as mp
from glob import glob
import base64
from functools import partial

import torch
import numpy as np
import pandas as pd
from scipy.special import softmax
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import ListedColormap
import seaborn as sns
from tqdm import tqdm
from IPython.display import display, HTML, clear_output, Markdown, Audio
from ipywidgets import HBox, Label, VBox, Dropdown, Layout, Output, Image
import binascii
from io import BytesIO

from cac.config import Config, DATA_ROOT
from cac.utils.logger import set_logger, color
from cac.utils.metrics import PrecisionAtRecall
from cac.utils.widgets import define_text, define_button, define_dropdown, define_inttext
from cac.data.dataloader import get_dataloader
from cac.analysis.classification import ClassificationAnalyzer

In [None]:
import warnings
warnings.simplefilter('ignore')
np.set_printoptions(suppress=True)

In [None]:
BATCH_SIZE = 10
NUM_WORKERS = 10

In [None]:
ATTRIBUTES_TO_TRACK = ['enroll_facility', 'enroll_patient_age', 'enroll_patient_gender', 'enroll_state', 'enroll_travel_history',
                       'enroll_contact_with_confirmed_covid_case', 'enroll_health_worker', 'enroll_fever', 'enroll_days_with_fever', 'enroll_cough',
                       'enroll_days_with_cough', 'enroll_shortness_of_breath', 'enroll_days_with_shortness_of_breath', 'enroll_other_symptoms',
                       'enroll_comorbidities', 'enroll_habits', 'enroll_patient_temperature', 'enroll_data_collector_name', 
                       'testresult_covid_test_result', 'testresult_patient_tested_date']

In [None]:
EMBEDDING_ATTRIBUTES_TO_TRACK = ATTRIBUTES_TO_TRACK + ['audio_type', 'unique_id']

In [None]:
version_text = define_text(value='cough-clf/wiai/stable/lr-1e-2-adamw-v2.0.yml', description='Version', placeholder='Add version to analyze')
user_text = define_text(value='aman', description='Username', placeholder='Add username to which the version belongs')
epoch_text = define_inttext(value='11', description='Best epoch', placeholder='Add best epoch')
input_texts = [{
    'user': user_text,
    'epoch': epoch_text,
    'version': version_text
}]

dropdown_layout = Layout(width='300px')
dropdown_style = {'description_width': '110px'}
mode_dropdown = define_dropdown(['val', 'train', 'val-subset-1'], 'val-subset-1', desc='mode', layout=dropdown_layout, style=dropdown_style)
submit_button = define_button('Submit', style={'button_color': "lightgreen"}, layout=Layout(width='150px'))
add_button = define_button('Add another', style={'button_color': "lightblue"}, layout=Layout(width='150px'))
reload_button = define_button('Reload', style={'button_color': "yellow"}, layout=Layout(width='150px'))

In [None]:
configs = []
analyzers = []
epochwise_logs = []
modes = []
epochs = []
features = []
embeddings = []
attributes = []

In [None]:
def reset_input_fields():
    for index, text in enumerate(input_texts):
        for key in text:
            input_texts[index][key].value = ''
    
def check_input_fields():
    for index, text in enumerate(input_texts):
        for key in text:
            value = input_texts[index][key].value
            if isinstance(value, str):
                if value is None or not len(value):
                    return False
            
            elif isinstance(value, int):
                if value is None:
                    return False

    return True

In [None]:
def _get_html_plot():
    bio = BytesIO()
    plt.savefig(bio)
    bio.seek(0)
    data_uri = base64.b64encode(bio.read()).decode('ascii')
    html_out = '<html><head></head><body>'
    html_out += '<img src="data:image/png;base64,{0}" align="left">'.format(data_uri)
    html_out += '</body></html>'
    return html_out

In [None]:
def _check_predicted_labels(_predict_labels, _predict_probs, _thresholds, recall=0.9):
    if len(_predict_labels.columns) == 3:
        targets = torch.tensor(_predict_labels['targets'])
        epoch_columns = [col for col in _predict_probs.columns if 'epoch' in col]
        for epoch_column in tqdm(epoch_columns, desc='Creating predicted labels'):
            predict_proba = torch.from_numpy(np.stack(_predict_probs[epoch_column].values))
            # only for binary
            predict_proba = predict_proba[:, 1]
            _, _, threshold = PrecisionAtRecall(recall=recall)(targets, predict_proba)
            _predict_labels[epoch_column] = predict_proba.ge(threshold).int().tolist()
            _thresholds[epoch_column] = threshold


def get_experiment_data(config, epoch, mode):
    analyzer = ClassificationAnalyzer(config, checkpoint=epoch, load_best=False)
    logs = analyzer.load_epochwise_logs(mode=mode)
    _predict_labels = logs['predict_labels']
    _predict_probs = logs['predict_probs']
    _thresholds = logs['thresholds']
    
    _check_predicted_labels(_predict_labels, _predict_probs, _thresholds)
    logs['predict_labels'] = _predict_labels
    logs['thresholds'] = _thresholds
    return analyzer, logs


def get_modes_for_config(config):
    return [basename(f) for f in glob(join(config.output_dir, 'logs', '*')) if isdir(f)]


def plot_correct_prediction_matrix(predict_labels):
    epochs = [x for x in predict_labels.columns if 'epoch' in x]
    COLOR_CODES = {
        0: 'red',
        1: 'blue'
    }
    epoch_predictions = predict_labels[epochs]
    targets = predict_labels['targets']
    prediction_correctness = predict_labels[epochs].copy()

    for epoch in epochs:
        prediction_correctness[epoch] = 1 * (predict_labels[epoch] == targets)

    fig, ax = plt.subplots(figsize=(15, 10))

    for row_idx in prediction_correctness.index:
        row = prediction_correctness.loc[row_idx]
        right = [int(x.split('_')[-1]) for x in list(row[row == 1].index)]
        yarray = [row_idx for _ in range(len(right))]
        plt.scatter(right, yarray, c='blue', s=0.4)

        wrong = [int(x.split('_')[-1]) for x in list(row[row == 0].index)]
        yarray = [row_idx for _ in range(len(wrong))]
        plt.scatter(wrong, yarray, c='red', s=0.4)

    plt.title('Model prediction grid')
    plt.xlabel('Epochs')
    plt.ylabel('Samples')
    ax.set_xlim([0, prediction_correctness.shape[1]])
    ax.set_ylim([0, prediction_correctness.shape[0]])
    ax.invert_yaxis()
    plt.grid()
    html_out = _get_html_plot()
    plt.close()
    return html_out


def plot_confidence_scores(predict_probs):
    epochs = [x for x in predict_probs.columns if 'epoch' in x]
    
    prediction_matrix = []
    for epoch in epochs:
        prediction_matrix.append(np.stack(predict_probs[epoch].values)[:, 1])
    
    prediction_matrix = np.vstack(prediction_matrix).T

    fig, ax = plt.subplots(figsize=(15, 10))
    cmap = ListedColormap(sns.color_palette("coolwarm", 7))
    sns.heatmap(prediction_matrix, vmin=0, vmax=1, cbar=True, robust=True, ax=ax, cmap=cmap)
    ax.set_title('Confidence scores over epochs')
    html_out = _get_html_plot()
    plt.close()
    return html_out


def plot_instance_loss_matrix(instance_losses):
    epochs = [x for x in instance_losses.columns if 'epoch' in x]
    epoch_instance_losses = instance_losses[epochs]
    instance_loss_matrix = []
    for row_idx in epoch_instance_losses.index:
        row = epoch_instance_losses.loc[row_idx].values
        instance_loss_matrix.append(row)

    instance_loss_matrix = np.array(instance_loss_matrix)

    fig, ax = plt.subplots(figsize=(15, 10))
    cmap = ListedColormap(sns.color_palette("coolwarm", 7))
    sns.heatmap(instance_loss_matrix, vmin=0, vmax=1, cbar=True, robust=True, ax=ax, cmap=cmap)
    ax.set_title('Instance losses over epochs')
    html_out = _get_html_plot()
    plt.close()
    return html_out, instance_loss_matrix


def plot_loss_contribution_matrix(instance_loss_matrix, batch_losses):
    epochs = [x for x in batch_losses.columns if 'epoch' in x]
    epoch_batch_losses = batch_losses[epochs]
    batch_loss_matrix = []
    for row_idx in epoch_batch_losses.index:
        row = epoch_batch_losses.loc[row_idx].values
        batch_loss_matrix.append(row)

    batch_loss_matrix = np.array(batch_loss_matrix)
    loss_contribution_matrix = instance_loss_matrix / batch_loss_matrix

    fig, ax = plt.subplots(figsize=(15, 10))
    cmap = ListedColormap(sns.color_palette("coolwarm", 7))
    sns.heatmap(loss_contribution_matrix, vmin=0, vmax=1, cbar=True, robust=True, ax=ax, cmap=cmap)
    ax.set_title('Loss contribution per batch over epochs')
    html_out = _get_html_plot()
    plt.close()
    return html_out


def _check_config_equal(c1, c2):
    if c1.version != c2.version: return False
    if c1.user != c2.user: return False
    return True


def log_instance_level(index, text, ignore_existing):
    version, user, epoch = text['version'].value, text['user'].value, text['epoch'].value
    config = Config(version, user)
    mode_value = input_box.children[index].children[-1].value

    if not ignore_existing and len(configs) > index and _check_config_equal(configs[index], config) and modes[index] == mode_value and epochs[index] == epoch:
        return False

    outputs[index].children = outputs[index].children[:1]
    for output in outputs[index].children:
        with output:
            clear_output()
    
    with outputs[index].children[0]:
        display(HTML(f'<h4 style="color:salmon"> Instance-level analysis </br></h6>'))
        display(HTML(f'<h6 style="color:orange"> Processing </br></h6>'))
        display(HTML(f'<h6> version: {version} </br> user: {user} </br> mode: {mode_value}</h6>'))

    input_box.children[index].children[-1].options = get_modes_for_config(config)
    input_box.children[index].children[-1].disabled = False
    input_box.children[index].children[-1].value = mode_value
    
    with outputs[index].children[0]:
        analyzer, logs = get_experiment_data(config, epoch, mode_value)

    with outputs[index].children[0]:
        display(HTML(f'<h6 style="color:orange"> Plotting correct prediction matrix </h6>'))
    correct_prediction_html = plot_correct_prediction_matrix(logs['predict_labels'])
    with outputs[index].children[0]:
        display(HTML(correct_prediction_html))
        
    with outputs[index].children[0]:
        display(HTML(f'<h6 style="color:orange"> Plotting confidence scores </h6>'))
    confidence_score_html = plot_confidence_scores(logs['predict_probs'])
    with outputs[index].children[0]:
        display(HTML(confidence_score_html))

    with outputs[index].children[0]:
        display(HTML(f'<h6 style="color:orange"> Plotting instance loss matrix </h6>'))
    instance_loss_html, instance_loss_matrix = plot_instance_loss_matrix(logs['instance_loss'])
    with outputs[index].children[0]:
        display(HTML(instance_loss_html))

    loss_contribution_html = None
    if len(logs['batch_loss'].columns) == len(logs['instance_loss'].columns):
        with outputs[index].children[0]:
            display(HTML(f'<h6 style="color:orange"> Plotting loss contribution matrix </h6>'))
        loss_contribution_html = plot_loss_contribution_matrix(instance_loss_matrix, logs['batch_loss'])
        with outputs[index].children[0]:
            display(HTML(loss_contribution_html))
    else:
        with outputs[index].children[0]:
            display(HTML(f'<h6 style="color:red"> Ignoring loss contribution</h6>'))

    if len(configs) > index:
        configs[index] = config
        analyzers[index] = analyzer
        epochwise_logs[index] = logs
        modes[index] = mode_value
        epochs[index] = epoch
    else:
        configs.append(config)
        analyzers.append(analyzer)
        epochwise_logs.append(logs)
        modes.append(mode_value)
        epochs.append(epoch)
    
    return True

In [None]:
def plot_distributions(df, attribute):
    assert 'targets' in df.columns
    assert 'predictions' in df.columns

    fig, ax = plt.subplots(1, 2, figsize=(14, 6))

    sns.countplot(data=df, x=attribute, hue='targets', ax=ax[0])
    ax[0].grid()
    ax[0].set_ylim([0, df.shape[0]])
    ax[0].xaxis.set_tick_params(rotation=45)
    ax[0].set_title('Distribution of {} based on ground truth labels'.format(attribute))
    
#     import ipdb; ipdb.set_trace()
    sns.countplot(data=df, x=attribute, hue='predictions', ax=ax[1])
    ax[1].grid()
    ax[1].set_ylim([0, df.shape[0]])
    ax[1].xaxis.set_tick_params(rotation=45)
    ax[1].set_title('Distribution of {} based on predicted labels'.format(attribute))
    plt.tight_layout()
    html_out = _get_html_plot()
    plt.close()
    return html_out


def on_plot_selector_dropdown_change(display_metrics, _output):
    def on_select_plot_(change):
        if change['type'] == 'change' and change['name'] == 'value':
            if change['new'] == 'confusion_matrix':
                with _output:
                    clear_output()
                    sns.heatmap(
                        display_metrics['confusion_matrix'], 
                        annot=True, annot_kws={'fontsize': 13}, 
                        cmap='GnBu', cbar=False)
                    plt.show()
            else:
                with _output:
                    clear_output()
                    display(display_metrics[change['new']])
    
    return on_select_plot_


def map_to_int(values, return_map=False):
    sorted_unique = sorted(np.unique(values).tolist())
    new_values = [sorted_unique.index(v) for v in values]
    
    if return_map:
        return new_values, sorted_unique
    
    return new_values


def get_correlation_coefficient(x1, x2):
    assert len(x1) == len(x2)
    assert len(x1)
    
    if not isinstance(x1[0], int):
        x1 = map_to_int(x1)
    
    if not isinstance(x2[0], int):
        x2 = map_to_int(x2)
    
    return pearsonr(x1, x2)


def attribute_based_analysis(index, text, attribute):
    logs = epochwise_logs[index]
    predicted_labels = logs['predict_labels']['epoch_{}'.format(epochs[index])]
    predicted_proba = logs['predict_probs']['epoch_{}'.format(epochs[index])]
    targets = logs['predict_probs']['targets']
    attribute_col = logs['attributes'][attribute]
    threshold = logs['thresholds']['epoch_{}'.format(epochs[index])]

    df = pd.concat([attribute_col, predicted_labels, predicted_proba, targets], axis=1)
    df.columns = [attribute, 'predictions', 'predicted_probs', 'targets']

    with outputs[index].children[3]:
        clear_output()
        display(HTML(f'<h6 style="color:orange"> Plotting </h6>'))
        attributes_html = plot_distributions(df, attribute)
        corr_coef, p_value = get_correlation_coefficient(df[attribute].values, targets)
        display(HTML(attributes_html))
        display(HTML('<h6 style="color:orange"> Correlation coefficient: </h6>{} </br> <h6 style="color:orange"> p-value: </h6>{}'.format(
            corr_coef, p_value)))
        
    grouping = df.groupby(attribute)
    
    if len(outputs[index].children[4].children) > len(grouping.groups):
        outputs[index].children[4].children = outputs[index].children[4].children[:len(grouping.groups)]
    
    for _index, key in enumerate(grouping.groups):            
        group_df = grouping.get_group(key)
        # sub_df.predicted_outputs: need to convert Series(list) into np.ndarray
        metrics, display_metrics = analyzers[index].compute_metrics(
            group_df.predicted_probs.apply(pd.Series).values,
            group_df.targets.values,
            threshold=threshold
        )

        metrics_df = pd.DataFrame(metrics.items(), columns=['Metric', 'Value'])
        metrics_df.Metric = metrics_df.Metric.apply(lambda x: x.upper())
        metrics_to_show = metrics_df.set_index('Metric')

        if len(outputs[index].children[4].children) == _index:
            plot_selector = define_dropdown([], default=None, desc='Select plot')
            display_plots = VBox([Output(), HBox([plot_selector, Output()])])
            _children = list(outputs[index].children[4].children)
            _children.append(display_plots)
            outputs[index].children[4].children = _children
        
        outputs[index].children[4].children[_index].children[1].children[0].options = [''] + list(display_metrics.keys())
        outputs[index].children[4].children[_index].children[1].children[0].default = None
        outputs[index].children[4].children[_index].children[1].children[0].observe(
            on_plot_selector_dropdown_change(display_metrics, 
                                             outputs[index].children[4].children[_index].children[1].children[1]))
        
        with outputs[index].children[4].children[_index].children[0]:
            clear_output()
            display(HTML('<h4> Value: {} </h4>'.format(key)))
            display(metrics_to_show.T)


def on_attribute_dropdown_change(change, index, text):
    if change['type'] == 'change' and change['name'] == 'value':
        attribute_based_analysis(index, text, change['new'])

In [None]:
def compute_values(index, text, method):
    with outputs[index].children[7]:
        clear_output()
        display(HTML(f'<h6 style="color:orange"> Computing features and embeddings </h6>'))
        mode = modes[index]
        dataloader, _ = get_dataloader(
            configs[index].data, modes[index],
            BATCH_SIZE,
            use_augmentation=False,
            num_workers=NUM_WORKERS,
            shuffle=False,
            drop_last=False)

        results = analyzers[index].compute_features(dataloader, last_layer_index=-1)
        _features = results['features']
        _attributes = pd.DataFrame(results['attributes'])

        embedding_method_cfg = {
            'name': method,
            'params': {'n_components': 2, 'random_state': 0}
        }
        _embeddings = analyzers[index].compute_embeddings(embedding_method_cfg, _features)

        if len(features) > index:
            features[index] = _features
            attributes[index] = _attributes
            embeddings[index] = _embeddings
        else:
            features.append(_features)
            attributes.append(_attributes)
            embeddings.append(_embeddings)

        
def scatter2d(x1, x2, row_values_: pd.DataFrame, label: str, legend: bool = True,
              title=None):

    row_values = row_values_.copy()
    
    # check if the label columns exists
    assert label in row_values.columns
    assert len(x1) == len(x2)
    assert len(x1) == len(row_values)
    
    # drop where label column is NaN
    row_values.dropna(subset=[label], inplace=True)
    
    # retaining only relevant indices in latent embeddings
    keep_indices = list(row_values.index)
    x1 = x1[keep_indices]
    x2 = x2[keep_indices]

    labels = row_values[label].values
    unique_labels = np.unique(labels)

    colors = cm.plasma(np.linspace(0, 1, len(unique_labels)))

    f, ax = plt.subplots(1, figsize=(10, 10))

    for (i, label), color in zip(enumerate(unique_labels), colors):
        indices = np.where(labels == label)
        num = len(indices[0])
        ax.scatter(x1[indices], x2[indices], label='{} : {}'.format(label, num), color=color)

    ax.set_ylabel('Component 2')
    ax.set_xlabel('Component 1')
    
    if title is not None:
        ax.set_title(title)

    ax.grid()

    if legend:
        ax.legend(loc='best')
    
    html_out = _get_html_plot()
    plt.close()
    return html_out


def embedding_based_analysis(index, text, attribute):    
    _embeddings = embeddings[index]
    _attributes = attributes[index]
    
    with outputs[index].children[8]:
        clear_output()
        display(HTML(f'<h6 style="color:orange"> Plotting </h6>'))
        embedding_html = scatter2d(_embeddings[:, 0], _embeddings[:, 1], _attributes, label=attribute, 
                                   title='Labelled by {}'.format(attribute))
        display(HTML(embedding_html))

        
def on_embedding_dropdown_change(change, index, text):
    if change['type'] == 'change' and change['name'] == 'value':
        embedding_based_analysis(index, text, change['new'])


def on_dim_red_dropdown_change(change, index, text):
    if change['type'] == 'change' and change['name'] == 'value':
        compute_values(index, text, change['new'])
        embedding_based_analysis(index, text, outputs[index].children[6].children[1].value)

In [None]:
def log_prediction_consistency(index, text):
    logs = epochwise_logs[index]
    _predict_labels = np.stack(logs['predict_labels']['epoch_{}'.format(epochs[index])].values)
    _ids = logs['predict_labels']['unique_id']
    _targets = logs['predict_labels']['targets'].values
    
    results_df = pd.DataFrame({
        'prediction': _predict_labels, 
        'target': _targets, 
        'user': _ids}
    )
    user_grouping = results_df.groupby('user')
    
    user_rows = []
    valid = True
    for user in user_grouping.groups:
        user_df = user_grouping.get_group(user)
        user_target = user_df['target'].values[0]
        user_preds = user_df['prediction'].values.tolist()
        if len(user_preds) != 3:
            valid = False
            break
        user_rows.append([user, *user_preds, user_target])
    
    if not valid:
        with outputs[index].children[10]:
            display(HTML('<h6 style="color:orange"> Ignoring prediction consistency </br> </h6>'))
        return
    
    user_df = pd.DataFrame(user_rows, columns=['user', 'cough_1_pred', 'cough_2_pred', 'cough_3_pred', 'target'])
    target_grouping = user_df.groupby('target')

    class_index_to_label = {
        1: 'covid',
        0: 'non-covid'
    }
    for target in target_grouping.groups:
        with outputs[index].children[10]:
            display(HTML('<h6 style="color:orange"> Target: {} </br> </h6>'.format(
                class_index_to_label[int(target)])))
        target_df = target_grouping.get_group(target);
        all_consistency_without_target = []
        all_consistency_with_target = []
        max_consistency_without_target = []
        max_consistency_with_target = []

        with outputs[index].children[10]:
            display(HTML('<h6 style="color:DodgerBlue"> Length: {} </br> </h6>'.format(len(target_df))))

        for _index in target_df.index:
            row = target_df.loc[_index]

            all_consistency_without_target.append(
                (row['cough_1_pred'] == row['cough_2_pred']) &
                (row['cough_2_pred'] == row['cough_3_pred']))
            all_consistency_with_target.append(
                (row['cough_1_pred'] == row['cough_2_pred']) & 
                (row['cough_2_pred'] == row['cough_3_pred']) & 
                (row['cough_2_pred'] == row['target']))
            max_consistency_without_target.append(
                (row['cough_1_pred'] == row['cough_2_pred']) | 
                (row['cough_2_pred'] == row['cough_3_pred']))
            max_consistency_with_target.append((
                (row['cough_1_pred'] == row['cough_2_pred']) | 
                (row['cough_2_pred'] == row['cough_3_pred'])) & row['cough_2_pred'] == row['target'])

        with outputs[index].children[10]:
            display(HTML('<h6> All 3 predictions equal: {} </br> </h6>'.format(np.mean(all_consistency_without_target))))
            display(HTML('<h6> All 3 predictions equal + equal to target: {} </br> </h6>'.format(np.mean(all_consistency_with_target))))
            display(HTML('<h6> Atleast 2 predictions equal: {} </br> </h6>'.format(np.mean(max_consistency_without_target))))
            display(HTML('<h6> Atleast 2 predictions equal + equal to target: {} </br> </h6>'.format(np.mean(max_consistency_with_target))))

In [None]:
def log_everything(ignore_existing=False):
    for index, text in enumerate(input_texts):
        # instance level
        change = log_instance_level(index, text, ignore_existing)
        
        if change:
            # attribute-level
            if len(outputs[index].children) == 1:
                attributes_dropdown = define_dropdown(ATTRIBUTES_TO_TRACK, 'enroll_facility', desc='attribute', layout=dropdown_layout, style=dropdown_style)
                attributes_dropdown.observe(partial(on_attribute_dropdown_change, index=index, text=text))
                children_list = list(outputs[index].children)
                children_list.append(Output())
                children_list.append(attributes_dropdown)
                children_list.append(Output())
                children_list.append(VBox([]))
                outputs[index].children = children_list

            with outputs[index].children[1]:
                clear_output()
                display(HTML(f'<h4 style="color:salmon"> Attribute based analysis </h4>'))

            attribute_based_analysis(index, text, outputs[index].children[2].value)
            
            # embeddings
            if len(outputs[index].children) == 5:
                dim_red_method_dropdown = define_dropdown(['TSNE', 'PCA'], 'TSNE', desc='method', layout=dropdown_layout, style=dropdown_style)
                attributes_dropdown = define_dropdown(ATTRIBUTES_TO_TRACK, 'enroll_facility', desc='attribute', layout=dropdown_layout, style=dropdown_style)
                attributes_dropdown.observe(partial(on_embedding_dropdown_change, index=index, text=text))
                dim_red_method_dropdown.observe(partial(on_dim_red_dropdown_change, index=index, text=text))
                children_list = list(outputs[index].children)
                children_list.append(Output())
                children_list.append(HBox([dim_red_method_dropdown, attributes_dropdown]))
                children_list.append(Output())
                children_list.append(Output())
                outputs[index].children = children_list

            with outputs[index].children[5]:
                clear_output()
                display(HTML(f'<h4 style="color:salmon"> Embedding-level analysis </h4>'))
            
            dim_red_method_name = outputs[index].children[6].children[0].value
            attribute_to_label = outputs[index].children[6].children[1].value
            
            compute_values(index, text, dim_red_method_name)
            embedding_based_analysis(index, text, attribute_to_label)

            # prediction consistency analysis
            if len(outputs[index].children) == 9:
                children_list = list(outputs[index].children)
                children_list.append(Output())
                children_list.append(Output())
                outputs[index].children = children_list
            
            with outputs[index].children[9]:
                clear_output()
                display(HTML(f'<h4 style="color:salmon"> Prediction consistency analysis </h4>'))
    
            log_prediction_consistency(index, text)

In [None]:
def on_click_submit(change):
    global feedback_output
    
    if not check_input_fields():
        with feedback_output:
            clear_output()
            display(HTML('<h6 style="color:red"> ERROR: Certains fields are empty</h6>'))
    else:
        log_everything()

        
def on_click_reload(change):
    global feedback_output
    
    if not check_input_fields():
        with feedback_output:
            clear_output()
            display(HTML('<h6 style="color:red"> ERROR: Certains fields are empty</h6>'))
    else:
        log_everything(ignore_existing=True)


def on_click_add(change):
    global feedback_output, dropdown_layout, input_box, configs, input_texts, dropdown_style, outputs

    if len(configs) < len(input_box.children):
        with feedback_output:
            clear_output()
            display(HTML('<h6 style="color:red"> ERROR: Empty inputs already exist </h6>'))
    else:
        outputs.append(VBox([Output()]))
        outputs_box.children = outputs

        version_text = define_text(value='cough-clf/wiai/stable/lr-1e-2-adamw-v2.0.yml', description='Version', placeholder='Add version to analyze')
        user_text = define_text(value='aman', description='Username', placeholder='Add username to which the version belongs')
        epoch_text = define_inttext(value='11', description='Best epoch', placeholder='Add best epoch')
        input_text = {
            'user': define_text(description='Username', placeholder='Add username to which the version belongs'),
            'epoch': define_inttext(description='Best epoch', placeholder='Add best epoch'),
            'version': define_text(description='Version', placeholder='Add version to analyze')
        }
        new_mode_dropdown = define_dropdown(['val', 'train'], 'val', desc='mode', layout=dropdown_layout, style=dropdown_style)
        input_texts.append(input_text)
        input_children = list(input_box.children)
        input_children.append(HBox([input_text['version'], input_text['user'], input_text['epoch'], new_mode_dropdown], 
                                    layout=Layout(padding='0px 0px 0px 0px')))
        input_box.children = input_children

In [None]:
outputs = [VBox([Output()])]
outputs_box = HBox(outputs, layout=Layout(padding='0px 0px 0px 50px'))

In [None]:
feedback_output = Output()

In [None]:
input_box = VBox([HBox([version_text, user_text, epoch_text, mode_dropdown], layout=Layout(padding='0px 0px 0px 0px'))])

<h4 style="color:salmon; padding:0px 0px 0px 50px">  Choose the config to analyze </h4>

In [None]:
display(input_box)

display(HBox([submit_button, add_button, reload_button, feedback_output], layout=Layout(margin='50px 50px 50px 50px')))

display(outputs_box)

In [None]:
submit_button.on_click(on_click_submit)
add_button.on_click(on_click_add)

In [None]:
# log_everything()