In [None]:
from pathlib import Path

import polars as pl

ROOT_DIR = "/storage/shared/mimic-iv/meds_v0.3.2/"  # Replace with your actual root directory

lab_to_codes = {}
pl.Config.set_fmt_str_lengths(100)
df = pl.read_parquet(f"{ROOT_DIR}/meds/metadata/codes.parquet")
creatinine_codes = df.filter(
    pl.col("description").str.contains("Creatinine [Mass/volume] in Blood", literal=True)
    | pl.col("description").str.contains("Creatinine [Mass/volume] in Serum or Plasma", literal=True)
)["code"].to_list()
lab_to_codes["creatinine"] = "|".join(creatinine_codes)

hemoglobin_codes = df.filter(
    pl.col("description").str.contains("Hemoglobin [Mass/volume] in Blood by calculation", literal=True)
    | pl.col("description").str.contains("Hemoglobin [Mass/volume] in Blood", literal=True)
)["code"].to_list()
lab_to_codes["hemoglobin"] = "|".join(hemoglobin_codes)

hematocrit_codes = df.filter(
    pl.col("description").str.contains(
        "Hematocrit [Volume Fraction] of Blood by Automated count", literal=True
    )
    | pl.col("description").str.contains("Hematocrit [Volume Fraction] of Blood by Estimated", literal=True)
)["code"].to_list()
lab_to_codes["hematocrit"] = "|".join(hematocrit_codes)


leukocytes_codes = df.filter(
    pl.col("description").str.contains("Leukocytes [#/volume] in Blood by Automated count", literal=True)
)["code"].to_list()
lab_to_codes["leukocytes"] = "|".join(leukocytes_codes)


platets_codes = df.filter(
    pl.col("description").str.contains("Platelets [#/volume] in Blood by Automated count", literal=True)
)["code"].to_list()
lab_to_codes["platets"] = "|".join(platets_codes)


def get_aces_config(location, lab, time_interval, extrema):
    lab_codes = lab_to_codes[lab]
    min_val, max_val = extrema
    if min_val is None and max_val is None:
        raise ValueError("Can't define both min and max")
    if not (min_val is None or max_val is None):
        raise ValueError("Defined neither min nor max")
    if min_val is not None:
        # YES MIN SHOULD BE NAMED MAX!!! The min value cutoff is the max value for aces to search for when defining this predicate.
        extrema_type = "max"
        value = min_val
    else:
        extrema_type = "min"
        value = max_val
    lab_requirement = f"""  abnormal_lab:
    code: {{regex: "{lab_codes}"}}
    value_{extrema_type}: {value}
    value_{extrema_type}_inclusive: True
    """
    return f"""#This config checks for an abnormal {lab} lab within {time_interval} after {location}
predicates:
  trigger_event:
    code: {{regex: "{location}//.*"}}
  lab:
    code: {{regex: "{lab_codes}"}}
{lab_requirement}

trigger: trigger_event

windows:
  input:
    start: NULL
    end: trigger
    start_inclusive: True
    end_inclusive: True
    index_timestamp: end
  target:
    start: input.end
    end: start + {time_interval}
    start_inclusive: True
    end_inclusive: True
    has:
      lab: (1, None)
    label: abnormal_lab

"""


results = {}
tasks = []
extrema = {
    "creatinine": (None, 2.0),
    "hemoglobin": (None, 2.0),
    "hematocrit": (24, None),
    "leukocytes": (5, None),
    "platets": (20, None),
}
for location in [
    "HOSPITAL_ADMISSION",
    "HOSPITAL_DISCHARGE",
    "ICU_ADMISSION",
    "ICU_DISCHARGE",
]:
    for lab in ["creatinine", "hemoglobin", "hematocrit", "leukocytes", "platets"]:
        for time_interval in ["30d", "60d", "90d"]:
            config = get_aces_config(location, lab, time_interval, extrema[lab])
            task_name = f"abnormal_lab/{location.lower()}/{lab}/{time_interval}"
            fp = Path(f"../ZERO_SHOT_TUTORIAL/configs/tasks/eic/{task_name}.yaml")
            fp.parent.mkdir(parents=True, exist_ok=True)
            fp.write_text(config)
            tasks.append('"' + task_name + '"')
print("\n".join(tasks))