# 02_demographics

In [None]:
import os
import sys
import json
import pathlib
sys.path.append("..")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from datetime import timedelta
import traceback

current_dir = pathlib.Path.cwd()
parent_dir = current_dir.parent
with open(parent_dir.joinpath("config.json")) as file:
    cfg = json.load(file)
with open(current_dir.joinpath("demographics.json")) as file:
    dg_cfg = json.load(file)

In [None]:
import pymssql
driver = cfg["dbms"]
db_cfg = cfg[driver]
conn = pymssql.connect(server=db_cfg['@server'], user=db_cfg['@user'], password=db_cfg['@password'], port=db_cfg['@port'], database=db_cfg['@database'], as_dict=False)

In [None]:
drug_order = ["Acetaminophen", "Vancomycin", "Naproxen", "Celecoxib", "Acyclovir"]

In [None]:
def executeQuerynfetchall(conn, sql_query):
    """
        SQL 쿼리 실행 후 결과를 반환하는 함수
    """
    result = None
    try:
        with conn.cursor() as cursor:
            cursor.execute(sql_query)
            result = cursor.fetchall()
        conn.commit()
    except:
        traceback.print_exc()
    return result

In [None]:
# 1. 각 약물에 대한 환자 수
def get_result_of_num_patients(conn, drug_name):
    SQL = """SELECT p.is_abnormal, COUNT(DISTINCT p.person_id) as n_patients
    FROM (SELECT *, CASE WHEN n_diff IS NULL THEN 'False' ELSE 'True' END as is_abnormal FROM {@person_database_schema}.{@target_person_table}) p
    GROUP BY p.is_abnormal;"""
    sql_param_dict = {}
    sql_param_dict["@person_database_schema"] = db_cfg["@person_database_schema"]
    sql_param_dict["@target_person_table"] = f"person_{drug_name}"
    query = SQL.format(**sql_param_dict)
    result = executeQuerynfetchall(conn, query)
    if not result:
        print(f"{drug_name} is not found in the database.")
        return {}
    return {i[0]:i[1] for i in result}

In [None]:
# 2. 각 진단코드에 대한 환자 수
def get_num_people_for_condition(conn, drug_name, concept_id_list):
    MSSQL_Query = """SELECT p.is_abnormal, COUNT(DISTINCT p.person_id) as n_patients
    FROM (SELECT *, CASE WHEN n_diff IS NULL THEN 'False' ELSE 'True' END as is_abnormal FROM {@person_database_schema}.{@target_person_table}) p
    JOIN {@cdm_database_schema}.condition_occurrence co ON p.person_id = co.person_id
    JOIN {@cdm_database_schema}.concept c ON co.condition_concept_id = c.concept_id
    WHERE c.concept_id in ({@concept_id_list}) and p.cohort_start_date > co.condition_start_date and DATEADD(year, -1, p.cohort_start_date) < co.condition_start_date
    GROUP BY p.is_abnormal
    ORDER BY p.is_abnormal;"""
    if cfg["dbms"] == "postgresql":
        MSSQL_Query = MSSQL_Query.replace("DATEADD(YEAR, -1, p.cohort_start_date)", "p.cohort_start_date - INTERVAL '1 year'")
    sql_param_dict = {}
    sql_param_dict["@person_database_schema"] = db_cfg["@person_database_schema"]
    sql_param_dict["@target_person_table"] = f"person_{drug_name}"
    sql_param_dict["@cdm_database_schema"] = db_cfg["@cdm_database_schema"]
    sql_param_dict["@concept_id_list"] = ",".join([str(i) for i in concept_id_list])

    query = MSSQL_Query.format(**sql_param_dict)
    result = executeQuerynfetchall(conn, query)
    if not result:
        return {}
    return {i[0]:i[1] for i in result}

In [None]:
# 3. 각 진단검사에 대한 환자 값 리스트 (index date기준 1년 이내)
def get_values_for_measurement(conn, drug_name, concept_id_list):
    MSSQL_Query = """
    SELECT is_abnormal, value_as_number
    FROM 
    (
        SELECT m.*, p.is_abnormal, ROW_NUMBER() OVER (PARTITION BY m.person_id ORDER BY m.measurement_date) as row
        FROM (SELECT person_id, cohort_start_date, 
            CASE WHEN n_diff IS NULL THEN 'False' ELSE 'True' END AS is_abnormal
            FROM {@person_database_schema}.{@target_person_table}) p
        JOIN {@cdm_database_schema}.measurement m
            ON p.person_id = m.person_id
        WHERE m.measurement_concept_id IN (3013721, 3006923)
            AND m.measurement_date BETWEEN DATEADD(YEAR, -1, p.cohort_start_date) AND p.cohort_start_date 
    ) t
    WHERE t.row = 1
    ORDER BY is_abnormal;
    """
    if cfg["dbms"] == "postgresql":
        MSSQL_Query = MSSQL_Query.replace("DATEADD(YEAR, -1, p.cohort_start_date)", "p.cohort_start_date - INTERVAL '1 year'")

    sql_param_dict = {}
    sql_param_dict["@person_database_schema"] = db_cfg["@person_database_schema"]
    sql_param_dict["@target_person_table"] = f"person_{drug_name}"
    sql_param_dict["@cdm_database_schema"] = db_cfg["@cdm_database_schema"]
    sql_param_dict["@concept_id_list"] = ",".join([str(i) for i in concept_id_list])

    query = MSSQL_Query.format(**sql_param_dict)
    result = executeQuerynfetchall(conn, query)
    if not result:
        return {}
    
    from itertools import groupby
    group_dict = {}
    for key, group in groupby(result, key=lambda x: x[0]):
        group_dict[key] = [i[1] for i in group]
    return group_dict

In [None]:
def get_results_of_items_for_condition(conn, drug_name, items_cfg):
    """ 각 진단 코드에 대한 환자 수 반환 """
    _dict = {}
    for i, key in enumerate(items_cfg.keys()):
        item_cfg = items_cfg[key]
        item_name = item_cfg["name"]
        concept_id_list = ",".join([str(concept_id) for concept_id in item_cfg["concept_id"]])
        _dict[item_name] = get_num_people_for_condition(conn, drug_name, item_cfg["concept_id"])
    return _dict

def get_results_of_items_for_measurement(conn, drug_name, items_cfg):
    """ 각 진단검사 코드에 대한 결과 리스트 반환 """
    _dict = {}
    for i, key in enumerate(items_cfg.keys()):
        item_cfg = items_cfg[key]
        item_name = item_cfg["name"]
        concept_id_list = ",".join([str(concept_id) for concept_id in item_cfg["concept_id"]])
        _dict[item_name] = get_values_for_measurement(conn, drug_name, item_cfg["concept_id"])
    return _dict

In [None]:
def get_demographics_for_condition(conn, drug_order, table_cfg):
    """ 약물 별, 진단 환자 수 추출 """
    _dict = {}
    for i, drug_name in enumerate(tqdm(drug_order)):
        _dict[drug_name] = get_results_of_items_for_condition(conn, drug_name, table_cfg["items"])
    return _dict

def get_demographics_for_measurement(conn, drug_order, table_cfg):
    """ 약물 별, 진단검사 결과 추출 """
    _dict = {}
    for i, drug_name in enumerate(tqdm(drug_order)):
        _dict[drug_name] = get_results_of_items_for_measurement(conn, drug_name, table_cfg["items"])
    return _dict

In [None]:
result_dir = current_dir.joinpath("result")
result_dir.mkdir(exist_ok=True)

for table in dg_cfg.keys():
    table_cfg = dg_cfg[table]
    
    if not "name" in table_cfg.keys():
        continue
    
    demographics_dict = {}
    name = table_cfg["name"]
    if "condition_occurrence" == name:
        demographics_dict = get_demographics_for_condition(conn, drug_order, dg_cfg[table])
    elif "measurement" == name:
        demographics_dict = get_demographics_for_measurement(conn, drug_order, dg_cfg[table])

    # 결과 저장
    result_file = result_dir.joinpath(f"demographics_for_{name}.json")
    with open(result_file, "w") as f:
        json.dump(demographics_dict, f, indent=2)