In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import os
import seaborn as sns
import numpy as np

In [None]:
data_path = '/Users/jk1/temp/opsum_prepro_output/with_imaging/gsu_Extraction_20220815_prepro_24022024_133425_restricted_to_imaging/preprocessed_features_24022024_133425.csv'
outcomes_path = '/Users/jk1/temp/opsum_prepro_output/with_imaging/gsu_Extraction_20220815_prepro_24022024_133425_restricted_to_imaging/preprocessed_outcomes_24022024_133425.csv'

In [None]:
data_df = pd.read_csv(data_path)
outcomes_df = pd.read_csv(outcomes_path)

In [None]:
data_df.head()

In [None]:
outcomes_df.head()

In [None]:
pa_id = np.random.choice(data_df['case_admission_id'].unique())
pa_id

In [None]:
vital_name = 'Glasgow Coma Scale'
temp = data_df[(data_df['case_admission_id'] == pa_id)
                                          & (data_df.sample_label.isin([vital_name]))].copy()
ax = sns.scatterplot(x='relative_sample_date_hourly_cat', y='value', data=temp, hue='value', legend=False)
ax.set_xlabel('Hours from admission')
ax.set_ylabel(vital_name)
ax.tick_params(axis="x", rotation=45)
ax.set_ylim(-3.2, 3.2)
ax.set_title(vital_name)

plt.show()

In [None]:
data_df[(data_df['case_admission_id'] == pa_id) & (data_df.sample_label.isin(["cholesterol HDL"]))]

In [None]:
data_df.source.unique()

In [None]:
sample_labels = data_df.sample_label.unique()
sample_labels

In [None]:

from matplotlib.lines import Line2D


def plot_features(pa_id, subject_df, sample_labels, outcome, outcome_name, plot_source=False):
    plt.subplots_adjust(hspace=0.2)

    # set number of columns (use 3 to demonstrate the change)
    ncols = 5
    # calculate number of rows
    nrows = len(sample_labels) // ncols + (len(sample_labels) % ncols > 0)
    plt.figure(figsize=(5*ncols, 5*nrows))

    for n, sample_label in enumerate(sample_labels):
    # add a new subplot iteratively using nrows and cols
        ax = plt.subplot(nrows, ncols, n + 1)
        label_df = subject_df[(subject_df.sample_label.isin([sample_label]))].copy()

        if not plot_source:
            ax = sns.scatterplot(x='relative_sample_date_hourly_cat', y='value', data=label_df, hue='value', legend=False, ax=ax)
        else:
            palette = {"EHR": "navy", "EHR_locf_imputed": "lavender", "EHR_pop_imputed": "magenta", "EHR_pop_imputed_locf_imputed": "pink",
                       "stroke_registry": "darkgreen", "stroke_registry_locf_imputed": "paleturquoise", "stroke_registry_pop_imputed": "magenta", "stroke_registry_pop_imputed_locf_imputed": "pink",
                       "notes": "darkgreen", "notes_locf_imputed": "paleturquoise", "notes_pop_imputed": "magenta", "notes_pop_imputed_locf_imputed": "pink",}
            ax = sns.scatterplot(x='relative_sample_date_hourly_cat', y='value', data=label_df, hue='source', legend=False, ax=ax, palette=palette)

            if n == len(sample_labels) - 1:
                legend_elements = [Line2D([0], [0], marker='o', color='w', label=source,
                          markerfacecolor=palette[source], markersize=10) for source in palette.keys()]
                ax.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))

        # chart formatting
        # ax.set_xlabel('Hours from admission')
        ax.tick_params(axis="x", rotation=45)
        ax.set_ylim(-3, 3)
        ax.set_title(sample_label.upper())
        # ax.get_legend().remove()
        ax.set_xlabel("")

    fig = plt.gcf()
    fig.suptitle(f'{pa_id}; {outcome_name}: {outcome}', y=0.89, fontsize = 25)
    return fig


In [None]:
subject_df = data_df[(data_df['case_admission_id'] == pa_id)]
subject_outcome = outcomes_df[outcomes_df.case_admission_id == pa_id]['3M Death'].values[0]

In [None]:
plot_features(pa_id, subject_df, sample_labels, subject_outcome, '3M Death', plot_source=True)

In [None]:
from modun.file_io import ensure_dir

out_dir = os.path.join(os.path.dirname(data_path), 'data_visualisation')
ensure_dir(out_dir)

# Create data visualisations for all patients

In [None]:
outcome = '3M mRS 0-2'

In [None]:
for pa_id in tqdm(data_df.case_admission_id.unique()):
    subject_df = data_df[(data_df['case_admission_id'] == pa_id)]
    outcome_values = outcomes_df[outcomes_df.case_admission_id == pa_id][outcome].values
    if len(outcome_values) > 0:
        subj_outcome = outcome_values[0]
    else:
        subj_outcome = np.nan
    fig = plot_features(pa_id, subject_df, sample_labels, subj_outcome, outcome, plot_source=True)
    fig.savefig(os.path.join(out_dir, pa_id + '.pdf'), bbox_inches='tight')
    plt.close()