In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os.path import join, splitext, basename
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mc
from matplotlib import rc
import matplotlib.cm as cm
import colorsys

from cac.analysis.utils import get_audio_type, get_unique_id
from cac.utils.audio import get_duration
from cac.utils.io import read_yml
from cac.utils.pandas import custom_read_csv, apply_filters
from cac.utils.viz import lighten_color

In [None]:
FIGURES_SAVE_DIR = '/all-output/paper/iclrw/figures'
os.makedirs(FIGURES_SAVE_DIR, exist_ok=True)

In [None]:
DATA_CONFIG_DIR = '/data/wiai-facility/processed/versions/'

#### Helper functions

#### Load attributes

In [None]:
attributes = custom_read_csv(
    '/data/wiai-facility/processed/attributes.csv',
    ['enroll_comorbidities', 'enroll_habits']
)

In [None]:
symptoms = ['enroll_cough', 'enroll_fever', 'enroll_shortness_of_breath']
attributes['all_symptoms'] = attributes[symptoms].apply(lambda x: "Yes" if "No" not in list(x) else "No", axis=1)
attributes['any_symptoms'] = attributes[symptoms].apply(lambda x: "Yes" if "Yes" in list(x) else "No", axis=1)

In [None]:
attributes.shape

#### Load annotations

In [None]:
annotations = pd.read_csv('/data/wiai-facility/processed/annotation.csv')

In [None]:
annotations['patient_id'] = annotations['users'].apply(get_unique_id)
annotations['audio_type'] = annotations['file'].apply(get_audio_type)

In [None]:
annotations.shape

#### Load data

In [None]:
DATA_VERSIONS_TO_LOAD = ['v9.4', 'v9.7', 'v9.8']

In [None]:
def get_data_from_mode(data_config, mode, audio='cough'):
    df = pd.DataFrame(data_config[mode])
    users = []

    for file in df.file.values:
        user = splitext(basename(file))[0]
        user = user[:user.find(f'_{audio}')]
        user = '_'.join(user.split('_')[:-2])
        users.append(user)
    
    df['patient_id'] = users
    df['set'] = mode
    df = pd.merge(df, attributes, on=['patient_id'])
    return df

In [None]:
data_cfgs = dict()

for data_version in tqdm(DATA_VERSIONS_TO_LOAD, desc='Loading data versions'):
    data_cfg_path = join(DATA_CONFIG_DIR, data_version + '.yml')
    data_cfg = read_yml(data_cfg_path)
    data_cfgs[data_version] = data_cfg
    
    for mode in data_cfg.keys():
        if mode in ['train', 'val', 'test', 'all']:
            data_cfg[mode] = get_data_from_mode(data_cfg, mode, audio='cough')

In [None]:
data_cfgs.keys()

In [None]:
data_cfgs['v9.8']['val'].shape

In [None]:
data_cfgs['v9.8']['all'].shape, data_cfgs['v9.7']['all'].shape, data_cfgs['v9.4']['all'].shape

In [None]:
data_cfgs['v9.8']['all'].testresult_covid_test_result.value_counts()

In [None]:
rc('text', usetex=True)
rc("font", family="serif", serif='Computer Modern Roman')

In [None]:
plt.rcParams['font.size'] = '23'

In [None]:
FIGURES_SAVE_DIR = '/all-output/paper/iclrw/figures'

In [None]:
version_to_title = {
    'v9.4': "Time-based",
    'v9.7': "Site-based",
    'v9.8': "Random"
}

In [None]:
attribute = 'testresult_covid_test_result'
mode = 'all'
fig, ax = plt.subplots(1, len(data_cfgs), figsize=(8 * len(data_cfgs), 8))

for i, version in enumerate(data_cfgs.keys()):
    
    data_cfgs[version]['train']['set'] = 'Train'
    data_cfgs[version]['val']['set'] = None
    data_cfgs[version]['test']['set'] = 'Test'

    df = pd.concat([data_cfgs[version]['train'], data_cfgs[version]['val'], data_cfgs[version]['test']])
    df['Test result'] = df[attribute]
    df = df.drop_duplicates(['patient_id'])

    order = ['Train', 'Test']
    hue_order = ['Positive', 'Negative']

    sns.countplot(data=df, x='set', ax=ax[i], hue='Test result', order=order, hue_order=hue_order, palette=['red', 'limegreen'])
    ax[i].grid()
    ax[i].set_title(version_to_title[version], fontsize=30)
    ax[i].set_xlabel('')
    if i > 0:
        ax[i].set_ylabel('')

    _ax = ax[i]
    patches = _ax.patches
    for patch in patches:
        x, _ = patch.xy
        counts = patch.get_height()
        _ax.text(x + 0.1, counts + 25, counts)

plt.savefig(join(FIGURES_SAVE_DIR, 'data-dist-v2.pdf'), bbox_inches='tight')
plt.show()