In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os.path import dirname, join, exists
from copy import deepcopy
from typing import List
import multiprocessing as mp
import torch
import numpy as np
import pandas as pd
from scipy.special import softmax
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
from tqdm import tqdm
from IPython.display import display, HTML, clear_output, Markdown, Audio
from ipywidgets import HBox, Label, VBox, Dropdown, Layout, Output, Image

from cac.config import Config, DATA_ROOT
from cac.utils.logger import set_logger, color
from cac.data.dataloader import get_dataloader
from cac.analysis.classification import ClassificationAnalyzer

In [None]:
import warnings
warnings.simplefilter('ignore')

### Define inputs

In [None]:
VERSION = 'experiments/covid-detection/v9_4_cough_adam_1e-4.yml'
USER = 'piyush'
BEST_EPOCH = 99

In [None]:
BATCH_SIZE = 10
NUM_WORKERS = 10

### Define config

In [None]:
config = Config(VERSION, USER)

### Load data

In [None]:
val_dataloader, _ = get_dataloader(
    config.data, 'val',
    BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
    drop_last=False
)

### Initialize the analyzer module

In [None]:
analyzer = ClassificationAnalyzer(config, checkpoint=BEST_EPOCH, load_best=False, debug=True)

## Data Summary

In [None]:
val_logs = analyzer.load_epochwise_logs(mode='val', get_metrics=False)

In [None]:
val_logs['attributes'].shape

In [None]:
def plot_distributions(df, attribute):
    
    assert 'targets' in df.columns
    assert 'predictions' in df.columns

    fig, ax = plt.subplots(1, 2, figsize=(14, 6))

    sns.countplot(data=df, x=attribute, hue='targets', ax=ax[0])
    ax[0].grid()
    ax[0].set_ylim([0, df.shape[0]])
    ax[0].set_title('Distribution of {} based on ground truth labels'.format(attribute))

    sns.countplot(data=df, x=attribute, hue='predictions', ax=ax[1])
    ax[1].grid()
    ax[1].set_ylim([0, df.shape[0]])
    ax[1].set_title('Distribution of {} based on predicted labels'.format(attribute))

    plt.show()

In [None]:
def define_dropdown(options, default=None, desc='Dropdown', layout=Layout(), style={}):
    dropdown = Dropdown(
        options=options,
        value=default,
        description=desc,
        disabled=False,
        layout=layout,
        style=style
    )
    return dropdown

def on_select_plot(change):
    global output
    global display_metrics

    if change['type'] == 'change' and change['name'] == 'value':

        if change['new'] == 'confusion_matrix':
            with output:
                clear_output()
                sns.heatmap(display_metrics['confusion_matrix'], annot=True, annot_kws={'fontsize': 13}, cmap='GnBu', cbar=False)
                plt.show()
        else:
            with output:
                clear_output()
                display(display_metrics[change['new']])

def on_select_plot_wrapper(display_metrics, output):

    def on_select_plot_(change):

        if change['type'] == 'change' and change['name'] == 'value':

            if change['new'] == 'confusion_matrix':
                with output:
                    clear_output()
                    sns.heatmap(display_metrics['confusion_matrix'], annot=True, annot_kws={'fontsize': 13}, cmap='GnBu', cbar=False)
                    plt.show()
            else:
                with output:
                    clear_output()
                    display(display_metrics[change['new']])
    
    return on_select_plot_

In [None]:
def attribute_summary(attribute, epoch, threshold=None, recall=0.9):
    
    predicted_labels = val_logs['predict_labels'][['epoch_{}'.format(epoch)]]
    predicted_proba = val_logs['predict_probs'][['epoch_{}'.format(epoch)]]
    targets = val_logs['predict_probs'][['targets']]
    attribute_col = val_logs['attributes'][[attribute]]

    df = pd.concat([attribute_col, predicted_labels, predicted_proba, targets], axis=1)
    df.columns = [attribute, 'predictions', 'predicted_outputs', 'targets']
    plot_distributions(df, attribute)
    
    group_df = df.groupby(attribute)
    groups = group_df.groups
    
    for key in groups.keys():
        display(Markdown('### {}'.format(key)))
        display(Markdown('---'))

        sub_df = df.loc[groups[key]]

        # sub_df.predicted_outputs: need to convert Series(list) into np.ndarray
        metrics, display_metrics = analyzer.compute_metrics(
            sub_df.predicted_outputs.apply(pd.Series).values,
            sub_df.targets.values,
            threshold=threshold,
            recall=recall,
        )
        
        metrics_df = pd.DataFrame(metrics.items(), columns=['Metric', 'Value'])
        metrics_df.Metric = metrics_df.Metric.apply(lambda x: x.upper())
        
        plot_selector = define_dropdown(display_metrics.keys(), desc='Select plot')
        metrics_to_show = metrics_df.set_index('Metric')
        
        output = Output()
        display_plots = VBox([plot_selector, output])
        display(metrics_to_show.T, display_plots)
        
        plot_selector.observe(on_select_plot_wrapper(display_metrics, output))

In [None]:
attribute_summary(attribute='enroll_fever', epoch='99')