In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os.path import join
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mc
from matplotlib import rc
import matplotlib.cm as cm
import colorsys

from cac.analysis.utils import get_audio_type, get_unique_id
from cac.utils.audio import get_duration
from cac.utils.io import read_yml

In [None]:
rc('text', usetex=True)
rc("font", family="serif", serif='Computer Modern Roman')

In [None]:
FIGURES_SAVE_DIR = '/all-output/paper/iclrw/figures'
os.makedirs(FIGURES_SAVE_DIR, exist_ok=True)

#### Helper functions

In [None]:
def apply_filters(df, filters):
    X = df.copy()
    
    for key, value in filters.items():
        if isinstance(value, (list, np.ndarray)):
            indices = X[key].isin(value)
        else:
            indices = X[key] == value
        X = X[indices].copy()
        X = X.reset_index(drop=True)

    return X

In [None]:
def custom_eval(x):
    if isinstance(x, str):
        x = x.replace('[', '')
        x = x.replace(']', '')

        x = x.split(',')
        x = [y.rstrip().lstrip() for y in x]
        return x
    else:
        return ['NA']

In [None]:
def split_column_into_columns(df, column):
    df[column] = df[column].apply(custom_eval)

    unique_values = []

    for i in tqdm(range(len(df))):
        index = df.index[i]

        list_of_values = df.loc[index, column]

        for x in list_of_values:
            if (x != 'NA') and (x != ''):
                attributes.at[index, x] = 'Yes'
                if x not in unique_values:
                    unique_values.append(x)

    df[unique_values] = df[unique_values].fillna('No')
    df[f'any_{column}'] = df[unique_values].apply(lambda x: 'Yes' if 'Yes' in list(x) else 'No', axis=1)
    return df

#### Load attributes

In [None]:
attributes = pd.read_csv('/data/wiai-facility/processed/attributes.csv')

In [None]:
attributes = split_column_into_columns(attributes, 'enroll_comorbidities')

In [None]:
attributes = split_column_into_columns(attributes, 'enroll_habits')

In [None]:
attributes.shape

#### Load annotations

In [None]:
annotations = pd.read_csv('/data/wiai-facility/processed/annotation.csv')

In [None]:
annotations['patient_id'] = annotations['users'].apply(get_unique_id)
annotations['audio_type'] = annotations['file'].apply(get_audio_type)

In [None]:
annotations.shape

#### Load current set

In [None]:
current_data_config = read_yml('/data/wiai-facility/processed/versions/v9.4.yml')

In [None]:
df_train = pd.DataFrame(current_data_config['train'])
df_train['set'] = 'train'
df_val = pd.DataFrame(current_data_config['val'])
df_val['set'] = 'val'
df_test = pd.DataFrame(current_data_config['test'])
df_test['set'] = 'test'

current_set = pd.concat([df_train, df_val, df_test],axis=0)
current_set = current_set.reset_index()

current_set['patient_id'] = current_set['file'].apply(get_unique_id)
current_set['audio_type'] = current_set['file'].apply(get_audio_type)

In [None]:
len(current_set['file'].unique())

In [None]:
current_set.shape, current_set.shape[0] // 3

In [None]:
current_set = pd.merge(current_set, attributes, on='patient_id')

In [None]:
current_set.shape

In [None]:
current_set = current_set.drop_duplicates(subset=['patient_id'])

In [None]:
current_set.shape

In [None]:
current_set['Age (years)'] = pd.cut(current_set.enroll_patient_age, bins=[0, 19, 29, 39, 49, 59, 69, 79, 89, 100])

In [None]:
current_set['Gender'] = current_set['enroll_patient_gender']

In [None]:
current_set['set'].value_counts()

In [None]:
current_set['label'].astype(str).value_counts() * 3

In [None]:
apply_filters(current_set, {'set': 'test'})['enroll_facility'].value_counts()

In [None]:
apply_filters(current_set, {'set': 'test'})['testresult_end_time'].min()

In [None]:
# taken from V9.4
cutoff_date = '2020-10-09T19:34:01.272GMT+05:30'

### Data splitting strategy

In [None]:
from datetime import datetime as dt

In [None]:
current_set['week_number'] = current_set['testresult_end_time'].apply(
    lambda x: dt.strptime(x.split('T')[0], '%Y-%m-%d').isocalendar()[1]
)

In [None]:
current_set['facility_code'] = current_set['enroll_facility'].apply(lambda x: x if x not in ['NMCH', 'KIMS, Satara', 'DCH Baleshwar'] else f'Z-{x}')

In [None]:
current_set['facility_code'].value_counts()

In [None]:
cutoff_week_number = apply_filters(current_set, {'testresult_end_time': cutoff_date})['week_number'][0] - 0.5 - current_set['week_number'].min()

In [None]:
current_set['facility_code'] = current_set['facility_code'].astype('category').cat.codes

In [None]:
current_set['week_number_recounted'] = current_set['week_number'] - current_set['week_number'].min()
all_weeks = list(sorted(current_set['week_number_recounted'].unique()))
all_sites = list(sorted(current_set['facility_code'].unique()))

df = pd.DataFrame(0, index=all_sites, columns=all_weeks)

weekwise_groups = current_set.groupby('week_number_recounted').groups
for week_number, indices in weekwise_groups.items():
    for index in indices:
        row = current_set.loc[index]
        site = row['facility_code']
        df.at[site, week_number] += 1

In [None]:
TEST_SEPERATION_INDEX_FOR_SITE = 23.5
TEST_SEPERATION_INDEX_FOR_TIME = cutoff_week_number

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(13, 7))

ax.grid()
# ax.set_title("Data Slicing Grid", fontsize=35)
ax.set_xlabel("Time (No. weeks from start of data collection)", fontsize=27)
ax.set_ylabel("Collection site index", fontsize=27)
ax.set_xticks(df.columns)
ax.set_yticks(df.index)

for week in df.columns:
    for site in df.index:
        count = df.at[site, week]
        ax.scatter(week, site, s= (count * 1.8 + 0), c='red')

# ax.axvline(x=TEST_SEPERATION_INDEX_FOR_TIME, label="Time-based", c='blue')
# ax.axhline(y=TEST_SEPERATION_INDEX_FOR_SITE, label="Site-based", c='darkgreen')
ax.axvline(x=TEST_SEPERATION_INDEX_FOR_TIME, c='blue')
ax.axhline(y=TEST_SEPERATION_INDEX_FOR_SITE, c='darkgreen')

ax.margins(x=0, y=0)
ax.axvspan(TEST_SEPERATION_INDEX_FOR_TIME, 30, alpha=0.1, color='blue', label='Time-based')
ax.axhspan(TEST_SEPERATION_INDEX_FOR_SITE, 27, alpha=0.1, color='green', label='Site-based')

plt.legend(loc='center left', fontsize=27, bbox_to_anchor=(0.0,0.35))

path = join(FIGURES_SAVE_DIR, 'data_slicing_grid_v3.pdf')
plt.savefig(path, bbox_inches='tight')
plt.show()

### Age and gender distribution

In [None]:
def lighten_color(color, amount=0.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])

In [None]:
def plot_real_valued_distribution(df, column, title, legend=False, show_mean=True, show_xlabel=True, show_ylabel=True,
                                  loc='upper right', size=(20, 6), kde=False,
                                  ceil=100, save=False, name='sample', ylabel=None, xlabel=None):

    fig, ax = plt.subplots(1, 1, figsize=size)
    
    ax.grid()
    df.at[df[column] > ceil, column] = ceil

    g = sns.distplot(df[column], color=lighten_color('#FE7465', 1.0), ax=ax, kde=kde, hist_kws=dict(edgecolor="#FE7465", linewidth=1))
    ax.set_title(title, fontsize=25)
    
    if show_mean:
        mean = np.round(df[column].mean(), 2)
        ax.axvline(x=mean, label=f'Mean: {mean} sec', linestyle='--', color='black', linewidth=1.5)

    ylabel = ylabel if ylabel is not None else g.get_ylabel()
    ylabel = '' if not show_ylabel else ylabel
    ax.set_ylabel(ylabel, fontsize=22)
    
    xlabel = xlabel if xlabel is not None else g.get_xlabel()
    xlabel = '' if not show_xlabel else xlabel
    ax.set_xlabel(xlabel, fontsize=22)

    ax.tick_params(axis="x", labelsize=18)
    ax.tick_params(axis="y", labelsize=18)

    if legend:
        plt.legend(loc=loc, fontsize=22)
    
    if save:
        path = join(FIGURES_SAVE_DIR, f'{name}.pdf')
        plt.savefig(path, bbox_inches='tight')

    plt.show()

In [None]:
def show_counts(graph, df, column=None, use_list=False, list_of_values=None):
    i = 0
    for p in graph.patches:
        height = p.get_height()
        label = g.get_xticklabels()[i].get_text()
        if not use_list:
            count = df[column].value_counts()[label]
        else:
            assert list_of_values is not None
            vals, counts = np.unique(list_of_values, return_counts=True)
            count = dict(zip(vals, counts))[label]
        graph.text(p.get_x()+p.get_width()/2., height + 2.0, count, ha="center")
        i += 1


def change_width(ax, new_value, num_hue=2, num_classes=2):

    for i, patch in enumerate(ax.patches):
        current_width = patch.get_width()
        diff = current_width - new_value

        # we change the bar width
        patch.set_width(new_value)

        # we recenter the bar
        if i < 2:
            patch.set_x(patch.get_x() + diff*0.5)
        else:
            patch.set_x(patch.get_x() - diff*0.5)
#         import ipdb; ipdb.set_trace()
    
#     ax.get_xticklabels()[0].set_x(diff)


def plot_categorical_distribution(df, column, title, show_xlabel=True, show_ylabel=True, loc='upper right',
                                  show_hue=True, hue='testresult_covid_test_result',
                                  hue_order=['Positive', 'Negative'], size=(15, 6), rotation=0, palette='Blues_r',
                                  counts=False, ylabel=None, xlabel=None, save=False, name='sample',
                                  reduce_width=False, new_width=0.3, xticklabels=[]):

    fig, ax = plt.subplots(1, 1, figsize=size)

    ax.set_title(title, fontsize=28)
    
    df['COVID'] = df['testresult_covid_test_result']
    palette = ['#FE7465', '#51B867']
    
    if show_hue:
        g = sns.countplot(x=column, data=df, ax=ax, hue='COVID', hue_order=hue_order, palette=palette)
    else:
        g = sns.countplot(x=column, data=df, ax=ax, palette=palette, label=list(df[column].unique()))
    
    if not len(xticklabels):
        xticklabels = g.get_xticklabels()
    g.set_xticklabels(xticklabels, rotation=rotation)
    if counts:
        show_counts(g, df, column)
    
    ylabel = ylabel if ylabel is not None else g.get_ylabel()
    ylabel = '' if not show_ylabel else ylabel
    ax.set_ylabel(ylabel, fontsize=24)
    
    xlabel = xlabel if xlabel is not None else g.get_xlabel()
    xlabel = '' if not show_xlabel else xlabel
    ax.set_xlabel(xlabel, fontsize=24)

    if reduce_width:
        change_width(ax, new_width)

    ax.tick_params(axis="x", labelsize=21)
    ax.tick_params(axis="y", labelsize=21)

    ax.grid()
    if show_hue:
        plt.legend(loc=loc, fontsize=24)
    
    if save:
        path = join(FIGURES_SAVE_DIR, f'{name}.pdf')
        plt.savefig(path, bbox_inches='tight')
    

    plt.show()

In [None]:
def plot_pie_chart(df, column, title, size):
    fig, ax = plt.subplots(1, 1, figsize=size)
    
    _dict = dict(df[column].value_counts())
    total = len(df[column])
    labels = list(_dict.keys())
    colors = ['#fc4f30', '#008fd5']
#     colors = ['pink', 'lightblue']
    explode = (0, 0.1)
    explode2 = (0.2, 0)
    
    autopct = '%.2f%%'
    ax.pie(list(_dict.values()), startangle=90, colors=colors, wedgeprops={'edgecolor': 'grey'}, autopct=lambda x: f'{np.round(x, 2)} \%',
           labels=labels, explode=explode, shadow=True, labeldistance=None, pctdistance=0.5, textprops={'fontsize': 15})
    ax.set_title(title)
    
    plt.legend(loc='upper right', fontsize=15)
    plt.show()

In [None]:
# plot_pie_chart(current_set, 'enroll_patient_gender', '', (12, 12))

In [None]:
plot_categorical_distribution(current_set, 'Age (years)', title='Age (Years)', show_xlabel=False, show_ylabel=False, ylabel='No. of individuals',
                              rotation=0, palette='ch:start=0.5,hue=3.5_r', save=True, name='age_v4', size=(14, 6))

In [None]:
plot_categorical_distribution(current_set, 'Gender', title='Sex', ylabel='No. of individuals', show_xlabel=False, show_ylabel=False,
                              rotation=0, palette='ch:start=0.5,hue=3.5_r', size=(6, 6), loc='upper right',
                              save=True, name='gender_v5', reduce_width=True, new_width=0.25)

In [None]:
current_set['Facility'] = current_set['enroll_facility']

In [None]:
num_facilities = len(current_set['Facility'].unique())

In [None]:
num_facilities

In [None]:
plot_categorical_distribution(current_set, 'Facility', title='LOCATION', ylabel='No. of individuals', show_xlabel=False, show_ylabel=False,
                              rotation=0, palette='ch:start=0.5,hue=3.5_r', size=(20, 6),
                              loc='best', save=True, name='facility_v6', xticklabels=[f'F{x}' for x in range(num_facilities)])

In [None]:
current_set['testresult_covid_test_result'].value_counts()

In [None]:
current_set['COVID Status'] = current_set['testresult_covid_test_result']
plot_categorical_distribution(current_set, 'COVID Status', title='Disease Status', ylabel='No. of individuals', show_xlabel=False, show_ylabel=True,
                              show_hue=False, size=(5, 6), loc='upper right', save=True, name='covid_v4', reduce_width=True, new_width=0.3)

In [None]:
current_set['Duration (seconds)'] = current_set['end']
plot_real_valued_distribution(current_set, 'Duration (seconds)', title='DURATION (seconds)', size=(12, 5), show_xlabel=False,
                              ylabel='No. of cough samples', legend=True, save=True, name='duration_v4')

In [None]:
current_set.enroll_patient_gender.value_counts()