In [None]:
import sys
import pandas as pd
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
from functools import reduce

sys.path.append("../..")

from gemini.constants import *
from drift_detector.plotter import errorfill, plot_roc, plot_pr, linestyles, markers, colors, brightness, colorscale
from gemini.utils import run_shift_experiment, get_gemini_data, import_dataset_hospital, get_dataset_hospital, reshape_inputs

### Load data ###

In [None]:
data = pd.read_csv("/mnt/nfs/project/delirium/data/data_2020.csv")
data = data.loc[data["hospital_id"].isin([3])]

### Plot Outcomes ###

In [None]:
fig, ax = plt.subplots(figsize=(14, 4))
plt.hist(data["los"] - 0.08, bins=50, alpha=0.5, width=0.04, label="los")
plt.hist(data["palliative"] - 0.04, bins=50, alpha=0.5, width=0.04, label="palliative")
plt.hist(data["mort_hosp"], bins=50, alpha=0.5, width=0.04, label="mort_hosp")
plt.hist(
    data["readmission_7"] + 0.04, bins=50, alpha=0.5, width=0.04, label="readmission_7"
)
plt.hist(
    data["readmission_28"] + 0.08,
    bins=50,
    alpha=0.5,
    width=0.04,
    label="readmission_28",
)
fig.legend(loc="upper right")
plt.show()

### ER LOS ###

In [None]:
fig, ax = plt.subplots(figsize=(14, 4))
plt.hist(data["los_er"], bins=50, alpha=0.5, label="los_er")
fig.legend(loc="upper right")
plt.show()

### Triage Level ###

In [None]:
fig, ax = plt.subplots(figsize=(14, 4))
plt.hist(
    data["triage_level"].astype(str),
    bins=50,
    alpha=0.5,
    width=0.4,
    label="triage_level",
)
fig.legend(loc="upper right")
plt.show()

### ICD Codes ###

In [None]:
fig, ax = plt.subplots(figsize=(14, 6))
ICDS = [
    "icd10_C00_D49",
    "icd10_D50_D89",
    "icd10_E00_E89",
    "icd10_F01_F99",
    "icd10_G00_G99",
    "icd10_H00_H59",
    "icd10_H60_H95",
    "icd10_I00_I99",
    "icd10_J00_J99",
    "icd10_K00_K95",
    "icd10_L00_L99",
    "icd10_M00_M99",
    "icd10_N00_N99",
    "icd10_O00_O99",
    "icd10_Q00_Q99",
    "icd10_R00_R99",
    "icd10_S00_T88",
    "icd10_U07_U08",
    "icd10_Z00_Z99",
    "icd10_nan",
]
n = len(ICDS)
w = 0.04
x = np.arange(0, len([0, 1]))
for i, icd in enumerate(ICDS):
    icd_counts = list(data[icd].value_counts())
    if len(icd_counts) == 1:
        icd_counts.append(0)
    position = x + (w * (1 - n) / 2) + i * w
    plt.bar(position, icd_counts, width=w, alpha=0.5, label=icd)
fig.legend(loc="upper right")
plt.show()

### Query Admin/Diagnosis Data ###

In [None]:
cfg = config.read_config("../../configs/default/*.yaml")
db = Database(cfg)


def query_admin_diagnosis(db, years, hospitals):
    query = (
        select(
            db.public.ip_administrative.data,
            db.public.diagnosis.data,
            db.public.er_administrative.data,
        )
        .where(
            and_(
                db.public.ip_administrative.hospital_id.in_(hospitals),
                extract("year", db.public.ip_administrative.admit_date_time).in_(years),
            )
        )
        .join(
            db.public.diagnosis.data,
            db.public.ip_administrative.genc_id == db.public.diagnosis.genc_id,
        )
        .join(
            db.public.er_administrative.data,
            db.public.er_administrative.genc_id == db.public.diagnosis.genc_id,
        )
    )

    data = db.run_query(query)

    ## gemini variables
    data["is_er_diagnosis"] = np.where(data["is_er_diagnosis"] == True, 1, 0)

    ## bret's groupings
    data["dd_discharge"] = np.where(
        data["discharge_disposition"].isin([4, 5.0, 30, 40, 90]), 1, 0
    )
    data["dd_acute"] = np.where(
        data["discharge_disposition"].isin([1]), 1, 0
    )  ## Don't use
    data["dd_mortality"] = np.where(
        data["discharge_disposition"].isin([7, 66, 72, 73]), 1, 0
    )
    data["dd_transfer"] = np.where(
        data["discharge_disposition"].isin([2, 3, 10, 20]), 1, 0
    )
    data["dd_leave_ama"] = np.where(
        data["discharge_disposition"].isin([6, 12, 61, 62, 65]), 1, 0
    )
    data["dd_suicide"] = np.where(data["discharge_disposition"].isin([67, 74]), 1, 0)

    ## lookbook groupings

    data["lb_home"] = np.where(data["discharge_disposition"].isin([4, 5, 6, 12]), 1, 0)
    data["lb_lama"] = np.where(
        data["discharge_disposition"].isin([61, 62, 65, 67]), 1, 0
    )
    data["lb_transfer"] = np.where(
        data["discharge_disposition"].isin([20, 30, 40, 90]), 1, 0
    )
    data["lb_died"] = np.where(
        data["discharge_disposition"].isin([7, 72, 73, 74]), 1, 0
    )
    data["lb_acute"] = np.where(data["discharge_disposition"].isin([1, 10]), 1, 0)
    data["lb_other"] = np.where(data["discharge_disposition"].isin([2, 3, 8, 9]), 1, 0)


    ## readmission

    data["planned_acute"] = np.where(data["readmission"] == "1", 1, 0)
    data["unplanned_readmission_7_acute"] = np.where(data["readmission"] == "2", 1, 0)
    data["unplanned_readmission_28_acute"] = np.where(data["readmission"] == "3", 1, 0)
    data["unplanned_readmission_7_surgery"] = np.where(data["readmission"] == "4", 1, 0)
    data["new_acute"] = np.where(data["readmission"] == "5", 1, 0)
    data["none"] = np.where(data["readmission"] == "9", 1, 0)

    ## los
    data["los_3"] = np.where(data["los_derived"] > 3, 1, 0)
    data["los_14"] = np.where(data["los_derived"] > 14, 1, 0)
    data["los_30"] = np.where(data["los_derived"] > 30, 1, 0)
    data["los_60"] = np.where(data["los_derived"] > 60, 1, 0)

    ## los er
    data["los_er_7"] = np.where(data["duration_er_stay_derived"] > 7, 1, 0)
    data["los_er_14"] = np.where(data["duration_er_stay_derived"] > 14, 1, 0)
    data["los_er_30"] = np.where(data["duration_er_stay_derived"] > 30, 1, 0)

    ## triage level
    data["resuscitation"] = np.where(data["triage_level"].isin(["1", "L1"]), 1, 0)
    data["emergent"] = np.where(data["triage_level"].isin(["2", "L2"]), 1, 0)
    data["urgent"] = np.where(data["triage_level"].isin(["3", "L3"]), 1, 0)
    data["less_urgent"] = np.where(data["triage_level"].isin(["4", "L4"]), 1, 0)
    data["non_urgent"] = np.where(data["triage_level"].isin(["5"]), 1, 0)
    data["unknown"] = np.where(data["triage_level"].isin(["9"]), 1, 0)

    ## city
    data["toronto"] = np.where(data["city"].isin(["TORONTO", "toronto"]), 1, 0)

    ## covid
    data["covid_confirmed"] = np.where(data["diagnosis_code"] == "U071", 1, 0)
    data["covid_suspected"] = np.where(data["diagnosis_code"] == "U072", 1, 0)

    print(f"{len(data)} rows fetched!")
    return data


def plot_outcome_overtime(hosp, outcome):

    hosp_pos = hosp.loc[hosp[outcome] == 1]
    hosp_pos_counts = (
        hosp_pos.groupby([hosp_pos["admit_date_time"].dt.to_period("m")], sort=True)
        .count()
        .eval(outcome)
    )

    hosp_counts = (
        hosp.groupby([hosp["admit_date_time"].dt.to_period("m")], sort=True)
        .count()
        .eval(outcome)
    )

    # ind = np.arange(N)
    fig, ax = plt.subplots(figsize=(14, 4))
    plt.bar(
        hosp_counts.index.values.astype(str),
        hosp_pos_counts / hosp_counts,
        alpha=0.5,
        width=0.4,
        color="g",
        label="patients with outcome/total patients",
    )
    fig.legend(loc="upper right")
    plt.xticks(rotation=45)
    plt.show()

In [None]:
YEARS = ["2018", "2019", "2020"]
# HOSPITALS = SBK #THPC #MSH  #UHNTG  #UHNTW #SMH
HOSPITALS = ["SMH"]

### COVID ###

### DECREASES
# smh: acute + transfers + los_er_14 decreases, march = min home
# uhntw: acute + mortality + unplanned_readmission_7_acute + los_er_14 decreases, march = min home
# uhntg: slight er + transfer + mortality decreases
# msh: er + lama +acute + mortality + transfer decreases
# thpc: er + acute + mortality + transfers + planned_acute + unplanned_readmission_28_acute decreases
# thpm: acute +  morality + transfers decreases
# sbk: mortality +transfers + planned accute decreases

### INCREASES
# smh: lama increases, march = max mortality
# uhntw: lama increases, march = max mortality, from_nursing_home_mapped
# uhntg: march = very high planned_acute
# msh: discharge + home increases
# thpc: lama + home increases
# thpm: lama + home + unplanned_readmission_7_acute increases
# sbk: er + lama + unplanned_readmission_7_acute increases

HOSP = query_admin_diagnosis(db, YEARS, HOSPITALS)

# IR Administrative #

### City ###

In [None]:
fig, ax = plt.subplots(figsize=(20, 8))
plt.bar(
    list(HOSP["city"].unique()[1:30]),
    list(HOSP["city"].value_counts()[1:30]),
    alpha=0.5,
    label="city",
)
fig.legend(loc="upper right")
plt.xticks(rotation=45)
plt.show()

In [None]:
plot_outcome_overtime(HOSP, "toronto")

### ER Diagnosis ### 

In [None]:
plot_outcome_overtime(HOSP, "is_er_diagnosis")

### From Nursing Home Mapped ### 

In [None]:
plot_outcome_overtime(HOSP, "from_nursing_home_mapped")

### From Acute Care Institution Mapped ### 

In [None]:
plot_outcome_overtime(HOSP, "from_acute_care_institution_mapped")

## Discharge Disposition ##

In [None]:
# Discharge disposition codes on GEMINI
# Bret's groupings:
# discharge: [4,5,30, 40, 90]
# acute: [1]
# mortality: [7, 66, 72, 73]
# transfer:  [2, 3, 10, 20]
# Leave AMA: [6, 12, 61, 62, 65]
# suicide: [67, 74]
# ignored: [8, 9]
# remaining: [66, 73]

DISCHARGE_DISPOSITION_MAP = {
    1: "Transferred to acute care inpatient institution",
    2: "Transferred to continuing care",
    3: "Transferred to other",
    4: "Discharged to home or a home setting with support services",
    5: "Discharged home with no support services from an external agency required",
    6: "Signed out",
    7: "Died",
    8: "Cadaveric donor admitted for organ/tissue removal",
    9: "Stillbirth",
    10: "Transfer to another hospital",
    12: "Patient who does not return from a pass",
    20: "Transfer to another ED",
    30: "Transfer to residential care",  # Transfer to long-term care home (24-hour nursing), mental health and/or addiction treatment centreor hospice/palliative care facility
    40: "Transfer to group/supportive living",  # Transfer to assisted living/supportive housing or transitional housing, including shelters; thesesettings do not have 24-hour nursing care.
    61: "Absent without leave AWOL",
    62: "AMA",
    65: "Did not return from pass/leave",
    66: "Died while on pass leave",
    67: "Suicide out of facility",
    72: "Died in facility",
    73: "MAID",
    74: "Suicide",
    90: "Transfer to correctional",
}

ip_admin_lookup_query = select(
    db.public.lookup_ip_administrative.variable,
    db.public.lookup_ip_administrative.value,
    db.public.lookup_ip_administrative.description,
).subquery()
admin_lookup_data = db.run_query(ip_admin_lookup_query)
discharge_codes = admin_lookup_data.loc[
    admin_lookup_data["variable"] == "discharge_disposition"
]
print(discharge_codes)


### Leave AMA ### 

In [None]:
plot_outcome_overtime(HOSP, "dd_leave_ama")

In [None]:
plot_outcome_overtime(HOSP, "lb_lama")

### Discharge ### 

In [None]:
plot_outcome_overtime(HOSP, "dd_discharge")

### Acute  ###

In [None]:
plot_outcome_overtime(HOSP, "dd_acute")

In [None]:
plot_outcome_overtime(HOSP, "lb_acute")

### Mortality ###

In [None]:
plot_outcome_overtime(HOSP, "dd_mortality")

In [None]:
plot_outcome_overtime(HOSP, "lb_died")

###  Transfer ###

In [None]:
plot_outcome_overtime(HOSP, "dd_transfer")

In [None]:
plot_outcome_overtime(HOSP, "lb_transfer")

### Home ###

In [None]:
plot_outcome_overtime(HOSP, "lb_home")

### Other ###

In [None]:
plot_outcome_overtime(HOSP, "lb_other")

## Readmissions ##

In [None]:
readmission_codes = admin_lookup_data.loc[
    admin_lookup_data["variable"] == "readmission"
]
print(HOSP.readmission.unique())
pd.options.display.max_colwidth = 100
print(readmission_codes)

## CASE when NULLIF(REPLACE(REPLACE(i.readmission, 'Yes', '9'), 'No', '5'), '')::numeric::integer = 2 or  NULLIF(REPLACE(REPLACE(i.readmission, 'Yes', '9'), 'No', '5'), '')::numeric::integer = 4  THEN 1 ELSE 0 END AS readmission_7,
## CASE when NULLIF(REPLACE(REPLACE(i.readmission, 'Yes', '9'), 'No', '5'), '')::numeric::integer = 2 or  NULLIF(REPLACE(REPLACE(i.readmission, 'Yes', '9'), 'No', '5'), '')::numeric::integer = 3 or  NULLIF(REPLACE(REPLACE(i.readmission, 'Yes', '9'), 'No', '5'), '')::numeric::integer = 4  THEN 1 ELSE 0 END AS readmission_28,
## CASE when g.pal =1 THEN 1 ELSE 0 END AS palliative,

### Planned Acute ###

In [None]:
plot_outcome_overtime(HOSP, "planned_acute")

### Unplanned Readmission 7 Days Following Acute ###

In [None]:
plot_outcome_overtime(HOSP, "unplanned_readmission_7_acute")

### Unplanned Readmission 28 Days Following Acute ###

In [None]:
plot_outcome_overtime(HOSP, "unplanned_readmission_28_acute")

### Unplanned Readmission 7 Days Following Surgery ###

In [None]:
plot_outcome_overtime(HOSP, "unplanned_readmission_7_surgery")

### New Acute ###

In [None]:
plot_outcome_overtime(HOSP, "new_acute")

In [None]:
plot_outcome_overtime(HOSP, "none")

## Length of Stay ##

In [None]:
HOSP = HOSP.loc[HOSP["los_derived"] < 100]
fig, ax = plt.subplots(figsize=(14, 4))
plt.hist(HOSP["los_derived"], bins=200, alpha=0.5, label="los_derived")
fig.legend(loc="upper right")
plt.show()

In [None]:
plot_outcome_overtime(HOSP, "los_3")

In [None]:
plot_outcome_overtime(HOSP, "los_14")

In [None]:
plot_outcome_overtime(HOSP, "los_30")

In [None]:
plot_outcome_overtime(HOSP, "los_60")

# Diagnosis # 

In [None]:
# Diagnosis lookup table on GEMINI
lookup_query = select(
    db.public.lookup_diagnosis.variable,
    db.public.lookup_diagnosis.value,
    db.public.lookup_diagnosis.description,
).subquery()
diagnosis_lookup_data = db.run_query(lookup_query)
print(diagnosis_lookup_data)
diagnosis_lookup_data_dict = diagnosis_lookup_data[["value", "description"]].to_dict()
diagnosis_type_map = {}
for key, diagnosis_type in diagnosis_lookup_data_dict["value"].items():
    diagnosis_type_map[diagnosis_type] = diagnosis_lookup_data_dict["description"][key]

In [None]:
HOSP = query_admin_diagnosis(db, YEARS, HOSPITALS)
HOSP["diagnosis_type"] = HOSP["diagnosis_type"].map(diagnosis_type_map).astype(str)
fig, ax = plt.subplots(figsize=(14, 4))
plt.bar(
    list(HOSP["diagnosis_type"].unique()),
    HOSP["diagnosis_type"].value_counts(),
    alpha=0.5,
    width=0.4,
    label="diagnosis_type",
)
fig.legend(loc="upper right")
plt.xticks(rotation=45)
plt.show()

In [None]:
plot_outcome_overtime(HOSP, "covid_confirmed")

In [None]:
plot_outcome_overtime(HOSP, "covid_suspected")

# ER Admin #

In [None]:
# Diagnosis lookup table on GEMINI
lookup_query = select(
    db.public.lookup_er_administrative.variable,
    db.public.lookup_er_administrative.value,
    db.public.lookup_er_administrative.description,
).subquery()
er_admin_lookup_data = db.run_query(lookup_query)
print(er_admin_lookup_data)
er_admin_lookup_data_dict = er_admin_lookup_data[["value", "description"]].to_dict()
er_admin_type_map = {}
for key, er_admin_type in er_admin_lookup_data_dict["value"].items():
    er_admin_type_map[er_admin_type] = er_admin_lookup_data_dict["description"][key]

### Triage Level ###

In [None]:
HOSP = query_admin_diagnosis(
    db, YEARS, ["UHNTG"]
)  # HOSPITALS = SBK #THPC #MSH  #UHNTG  #UHNTW #SMH
fig, ax = plt.subplots(figsize=(14, 4))
plt.hist(
    HOSP["triage_level"].astype(str),
    bins=50,
    alpha=0.5,
    width=0.4,
    label="triage_level",
)
fig.legend(loc="upper right")
plt.show()

In [None]:
plot_outcome_overtime(HOSP, "resuscitation")

In [None]:
plot_outcome_overtime(HOSP, "emergent")

In [None]:
plot_outcome_overtime(HOSP, "urgent")

In [None]:
plot_outcome_overtime(HOSP, "less_urgent")

In [None]:
plot_outcome_overtime(HOSP, "non_urgent")

In [None]:
plot_outcome_overtime(HOSP, "unknown")

## duration_er_stay_derived ##

In [None]:
HOSP = HOSP.loc[HOSP["duration_er_stay_derived"] < 100]
fig, ax = plt.subplots(figsize=(14, 4))
plt.hist(
    HOSP["duration_er_stay_derived"],
    bins=200,
    alpha=0.5,
    label="duration_er_stay_derived",
)
fig.legend(loc="upper right")
plt.show()

In [None]:
plot_outcome_overtime(HOSP, "los_er_7")

In [None]:
plot_outcome_overtime(HOSP, "los_er_14")

In [None]:
plot_outcome_overtime(HOSP, "los_er_30")

## CCSR Lookup ##

In [None]:
# Diagnosis lookup table on GEMINI
lookup_query = select(
    db.public.lookup_ccsr.ccsr, db.public.lookup_ccsr.ccsr_desc
).subquery()
ccsr_lookup_data = db.run_query(lookup_query)

pd.set_option("display.max_rows", None)

print(ccsr_lookup_data)