In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os.path import dirname, join, exists
from copy import deepcopy
from typing import List
import multiprocessing as mp
import torch
import numpy as np
import pandas as pd
from scipy.special import softmax
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
from tqdm import tqdm
from IPython.display import display, HTML, clear_output, Markdown, Audio
from ipywidgets import HBox, Label, VBox, Dropdown, Layout, Output, Image

from cac.config import Config, DATA_ROOT
from cac.utils.logger import set_logger, color
from cac.data.dataloader import get_dataloader
from cac.analysis.classification import ClassificationAnalyzer

In [None]:
import warnings
warnings.simplefilter('ignore')

### Define inputs

In [None]:
VERSION = 'experiments/covid-detection/v9_4_cough_adam_1e-4.yml'
USER = 'piyush'
BEST_EPOCH = 99

In [None]:
BATCH_SIZE = 10
NUM_WORKERS = 10

### Define config

In [None]:
config = Config(VERSION, USER)

### Load data

In [None]:
val_dataloader, _ = get_dataloader(
    config.data, 'val',
    BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
    drop_last=False
)

### Initialize the analyzer module

In [None]:
analyzer = ClassificationAnalyzer(config, checkpoint=BEST_EPOCH, load_best=False, debug=True)

### Compute model embeddings (penultimate layer)

In [None]:
results = analyzer.compute_features(val_dataloader, last_layer_index=-2)

In [None]:
results['features'].shape

### Apply dimensionality reduction on the embeddings

In [None]:
X = results['features']

In [None]:
method_cfg = {
    'name': 'TSNE',
    'params': {'n_components': 2, 'random_state': 0}
}

In [None]:
Z = analyzer.compute_embeddings(method_cfg, X)

In [None]:
Z.shape

### Converting attributes to DataFrame

In [None]:
attributes = results['attributes']

In [None]:
len(attributes)

In [None]:
df = pd.DataFrame(attributes)

### Plotting function (can play around)

In [None]:
def scatter2d(x1, x2, row_values_: pd.DataFrame, label: str, legend: bool = True,
              ignore_list: List[dict] = [
                  {
                      'key': 'audio_type',
                      'values': ['audio_1_to_10', 'breathing']
                  }
              ], annotate=False,
              title=None):

    row_values = row_values_.copy()
    
    # check if the label columns exists
    assert label in row_values.columns
    
    # drop where label column is NaN
    row_values.dropna(subset=[label], inplace=True)
    
    # ignore certain values in given columns
    for ignore_dict in ignore_list:
        key, values = ignore_dict['key'], ignore_dict['values']
        row_values = row_values[~row_values[key].isin(values)]
    
    # retaining only relevant indices in latent embeddings
    keep_indices = list(row_values.index)
    x1 = x1[keep_indices]
    x2 = x2[keep_indices]

    labels = row_values[label].values
    unique_labels = np.unique(labels)

    colors = cm.plasma(np.linspace(0, 1, len(unique_labels)))

    f, ax = plt.subplots(1, figsize=(10, 10))

    for (i, label), color in zip(enumerate(unique_labels), colors):
        indices = np.where(labels == label)
        num = len(indices[0])
        ax.scatter(x1[indices], x2[indices], label='{} : {}'.format(label, num), color=color)

        if annotate:
            for j in indices[0]:
                ax.annotate('P{}'.format(i), (x1[j] + 0.1, x2[j] + 0.1))


    ax.set_ylabel('Component 2')
    ax.set_xlabel('Component 1')
    
    if title is not None:
        ax.set_title(title)

    ax.grid()

    if legend:
        ax.legend(loc='best')


### Sanity check plotting function

In [None]:
scatter2d(Z[:, 0], Z[:, 1], df, label='audio_type', ignore_list=[])

### Data source

In [None]:
df['data-source'] = df[['enroll_state', 'enroll_facility']].apply(lambda x: x[0] + ': {}'.format(x[1]), axis=1)

In [None]:
ignore_list = [
    {
        'key': 'audio_type',
        'values': []
    },
    {
        'key': 'testresult_covid_test_result',
        'values': []
    }
]

ignore_list[1]['values'] = []
scatter2d(Z[:, 0], Z[:, 1], df, label='data-source', title='Cough sound embeddings by data source: All', ignore_list=ignore_list)

ignore_list[1]['values'] = ['Positive']
scatter2d(Z[:, 0], Z[:, 1], df, label='data-source', title='Cough sound embeddings by data source: Negatives', ignore_list=ignore_list)

ignore_list[1]['values'] = ['Negative']
scatter2d(Z[:, 0], Z[:, 1], df, label='data-source', title='Cough sound embeddings by data source: Positives', ignore_list=ignore_list)

### Disease status

In [None]:
if 'disease_status' in df:
    df['covid_status'] = df['disease_status'].combine_first(df['testresult_covid_test_result'])
else:
    df['covid_status'] = df['testresult_covid_test_result']
df['covid_status'] = df['covid_status'].apply(lambda x: x.upper())

In [None]:
scatter2d(Z[:, 0], Z[:, 1], df, label='covid_status', title='Cough sound embeddings by disease status: V1.1 (Only NMCH)')

### Patient Identity

In [None]:
ignore_list = [
    {
        'key': 'audio_type',
        'values': ['audio_1_to_10', 'breathing']
    },
    {
        'key': 'dataset-name',
        'values': []
    }
]

scatter2d(Z[:, 0], Z[:, 1], df, label='unique_id', annotate=True, legend=False, ignore_list=ignore_list, title='[V1.1] Coughs by patient IDs: TSNE')

### Gender

In [None]:
scatter2d(Z[:, 0], Z[:, 1], df, label='enroll_patient_gender', title='Cough sounds by gender: Facility', ignore_list=ignore_list)

### Symptoms

In [None]:
all_symptoms = ['cough', 'fever', 'shortness_of_breath']

In [None]:
ignore_list: List[dict] = [
    {
        'key': 'audio_type',
        'values': ['audio_1_to_10', 'breathing']
    }
]

for symptom in all_symptoms:
    scatter2d(Z[:, 0], Z[:, 1], df, label='enroll_' + symptom, title='Cough sounds by {}'.format(symptom), ignore_list=ignore_list)

### Age

In [None]:
df['age_bucket'] = pd.cut(df.enroll_patient_age, bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])

In [None]:
scatter2d(Z[:, 0], Z[:, 1], df, label='age_bucket', title='Cough sounds by age-bucket', ignore_list=ignore_list)