# 03_occurrence_period_distribution

In [None]:
import os
import sys
import json
import pathlib
sys.path.append("..")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from datetime import timedelta
import traceback

current_dir = pathlib.Path.cwd()
parent_dir = current_dir.parent
with open(parent_dir.joinpath("config.json")) as file:
    cfg = json.load(file)

In [None]:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

driver = cfg["dbms"]
db_cfg = cfg[driver]
username = db_cfg["@user"]
password = db_cfg["@password"]
host = db_cfg["@server"]
port = db_cfg["@port"]
database = db_cfg["@database"]
if cfg["dbms"] == "mssql":
    sqldriver = "mssql+pymssql"
elif cfg["dbms"] == "postgresql":
    sqldriver = "postgresql+psycopg2"
url = f"{sqldriver}://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(url, echo=False)
sessionlocal = sessionmaker(autocommit=False, autoflush=True, bind=engine)

In [None]:
drug_order = ["Acetaminophen", "Vancomycin", "Naproxen", "Celecoxib", "Acyclovir"]

In [None]:
def executeQuerynfetchall(engine, sql_query):
    """
        SQL 쿼리 실행 후 결과를 반환하는 함수
    """
    result = None
    try:
        with engine.connect() as conn:
            result = conn.execute(sql_query)
            result = result.fetchall()
        # engine.commit()
    except:
        traceback.print_exc()
    return result

### 1. distribution of duration of abnormal

In [None]:
def get_result_of_occurrence_period(engine, drug_name, psm=False):
    SQL = """SELECT n_diff FROM {@person_database_schema}.{@target_person_table} WHERE n_diff IS NOT NULL"""
    sql_param_dict = {}
    sql_param_dict["@person_database_schema"] = db_cfg["@person_database_schema"]
    sql_param_dict["@target_person_table"] = f"person_{drug_name}" if not psm else f"person_{drug_name}_psm"
    query = SQL.format(**sql_param_dict)
    values = executeQuerynfetchall(engine, query)
    if values is None:
        return None
    values = [value[0] for value in values]
    return values

In [None]:
occur_period_dict = {}
for drug_name in cfg["drug"].keys():
    values = get_result_of_occurrence_period(engine, drug_name)
    if values != None and len(values) >= 20:
        occur_period_dict[drug_name] = values

In [None]:

result_dir = current_dir.joinpath("result")
result_dir.mkdir(parents=True, exist_ok=True)

with open(result_dir.joinpath("occurrence_period_distribution.json"), "w") as f:
    json.dump(occur_period_dict, f, indent=2)

In [None]:
result_dir = current_dir.joinpath("result")
result_dir.mkdir(parents=True, exist_ok=True)

with open(result_dir.joinpath("occurrence_period_distribution.json")) as f:
    occur_period_dict = json.load(f)

In [None]:
# save figures 
drug_counts = len(drug_order)
fig = plt.figure(figsize=(5*drug_counts, 2))
for i, drug_name in enumerate(tqdm(drug_order)):
    drug_index = drug_order.index(drug_name)
    ax = plt.subplot(1, drug_counts, drug_index+1)
    if drug_name not in occur_period_dict.keys():
        sns.histplot([], kde=False, bins=60, color="blue", alpha=0.3, label=drug_name)
        ax.set_xlim(-1, 61)
        ax.set_xlabel("days")
        ax.set_ylabel("number of patients")
        continue
    values = occur_period_dict[drug_name]
    sns.histplot(values, kde=False, bins=60, color="blue", alpha=0.3, label=drug_name)
    mean, std, median = np.mean(values), np.std(values), np.median(values)
    count = len(values)
    iqr_low, iqr_high = np.quantile(values, 0.25), np.quantile(values, 0.75)
    ax.axvline(median, color="red", linestyle="--", label="median")
    ax.text(0.98, 0.95, f"median: {median:.1f} day ({iqr_low:.1f}-{iqr_high:.1f})", transform=plt.gca().transAxes, fontsize=12, verticalalignment="top", horizontalalignment="right")
    ax.text(0.98, 0.8, f"{count} patients", transform=plt.gca().transAxes, fontsize=12, verticalalignment="top", horizontalalignment="right")
    ax.set_xlim(-1, 61)
    ax.set_xlabel("days")
    ax.set_ylabel("number of patients")

plt.savefig(result_dir.joinpath("occurrence_period_distribution.png"), dpi=300, bbox_inches="tight")
plt.show()

In [None]:
# save figures 
drug_counts = len(drug_order)
fig = plt.figure(figsize=(5*drug_counts, 2))
for i, drug_name in enumerate(tqdm(drug_order)):
    drug_index = drug_order.index(drug_name)
    ax = plt.subplot(1, drug_counts, drug_index+1)
    if drug_name not in occur_period_dict.keys():
        sns.histplot([], kde=False, bins=60, color="blue", alpha=0.3, label=drug_name)
        ax.set_xlim(-1, 61)
        ax.set_xlabel("days")
        ax.set_ylabel("number of patients")
        continue
    sns.histplot([], kde=False, bins=60, color="blue", alpha=0.3, label=drug_name)
    ax.set_xlim(-1, 61)
    ax.set_xlabel("days")
    ax.set_ylabel("number of patients")

plt.savefig(result_dir.joinpath("occurrence_period_distribution2.png"), dpi=300, bbox_inches="tight")
plt.show()