In [None]:
import pandas as pd
import os
import numpy as np
from preprocessing.lab_preprocessing.lab_preprocessing import preprocess_labs
from matplotlib.dates import DateFormatter
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
data_path = '/Users/jk1/stroke_datasets/stroke_unit_dataset/per_value/Extraction_20211110'
lab_file_start = 'labo'

In [None]:
lab_files = [pd.read_csv(os.path.join(data_path, f), delimiter=';', encoding='utf-8', dtype=str)
             for f in os.listdir(data_path)
             if f.startswith(lab_file_start)]

lab_df = pd.concat(lab_files, ignore_index=True)

In [None]:
preprocessed_lab_df = preprocess_labs(lab_df, verbose=False)

In [None]:
preprocessed_lab_df['sample_date'] = pd.to_datetime(preprocessed_lab_df['sample_date'], format='%d.%m.%Y %H:%M')
# find first sample date for each patient admission id
first_sample_dates_df = preprocessed_lab_df.groupby('case_admission_id')['sample_date'].min()
first_sample_dates_df.head(2)

In [None]:

preprocessed_lab_with_rel_dates_df = preprocessed_lab_df.join(first_sample_dates_df, on='case_admission_id', rsuffix='_first').copy()

In [None]:
preprocessed_lab_with_rel_dates_df['relative_sample_date'] = \
    (pd.to_datetime(preprocessed_lab_with_rel_dates_df['sample_date'], format='%d.%m.%Y %H:%M')
     - pd.to_datetime(preprocessed_lab_with_rel_dates_df['sample_date_first'], format='%d.%m.%Y %H:%M'))\
        .dt.total_seconds() / (60*60)


In [None]:
# get random id from all patient admission ids
pa_id = np.random.choice(preprocessed_lab_with_rel_dates_df['case_admission_id'].unique())
dosage_label = 'sodium'
temp = preprocessed_lab_with_rel_dates_df[(preprocessed_lab_with_rel_dates_df['case_admission_id'] == pa_id)
                                          & (preprocessed_lab_with_rel_dates_df['dosage_label'].isin([dosage_label]))].copy()
# temp['value'] = pd.to_numeric(temp['value'], errors='coerce')
# temp['sample_date'] = pd.to_datetime(temp['sample_date'], format='%d.%m.%Y %H:%M')
# temp = temp.dropna(subset=['value'])
ax = sns.scatterplot(x='relative_sample_date', y='value', data=temp, hue='value', legend=False)
# Define the date format
# date_form = DateFormatter("%d")
# ax.xaxis.set_major_formatter(date_form)
ax.tick_params(axis="x", rotation=45)

plt.show()

In [None]:
sodium_df = preprocessed_lab_with_rel_dates_df[preprocessed_lab_with_rel_dates_df['dosage_label'].isin(['sodium'])].copy()
sns.scatterplot(x='relative_sample_date', y='value', data=sodium_df, hue='value', legend=False, alpha=0.1)
plt.show()

In [None]:
g = sns.relplot(x='relative_sample_date', y='value', col='dosage_label', col_wrap=10,
                data=preprocessed_lab_with_rel_dates_df, hue='dosage_label', legend=False, alpha=0.1,
            facet_kws=dict(sharey=False))
g.set(xlim=(0, 350))
plt.show()

In [None]:
for dosage_label in preprocessed_lab_with_rel_dates_df['dosage_label'].unique():
    dosage_df = preprocessed_lab_with_rel_dates_df[preprocessed_lab_with_rel_dates_df['dosage_label'] == dosage_label]
    g = sns.displot(x="value", data=dosage_df, kde=True, legend=False)
    g.ax.set_title(dosage_label)
    plt.show()
    fig = g.fig
    # fig.savefig(f'{dosage_label.replace("/","")}.png')