In [46]:
from dataclasses import dataclass
from enum import StrEnum
from datetime import datetime, timedelta, date
import polars as pl
from collections import defaultdict, Counter
from typing import Tuple
from common.constants.column_types import (
    CPZP_SCHEMA,
    OZP_SCHEMA,
    POHLAVI_CPZP,
    TYP_UDALOSTI,
)
from common.constants.column_names import SHARED_COLUMNS, OZP_COLUMNS, CPZP_COLUMNS
import pickle
from common.constants.objects import (
    Person,
    Gender,
    AgeCohort,
    Prescription,
    PrescriptionType,
)
import matplotlib.pyplot as plt
import numpy as np
import os
from common.utils import (
    draw_chart,
    filter_by_date_range,
)

pl.Config.set_tbl_rows(20)
pl.Config.set_tbl_cols(60)

from typing import Any
import matplotlib.dates as mdates
from matplotlib.patches import Patch
from scipy.stats import fisher_exact

pl.Config.set_tbl_rows(-1)
import pandas as pd

POJISTOVNA = "cpzp"
VAX_PERIOD_IN_DAYS = 30
START_DATE = date(2015, 1, 1)
END_DATE = date(2025, 1, 1)
INJECTION_FORMS = {"Injekční suspenze", "Injekční/infuzní roztok"}
PERIOD = 365


def is_injection(form: str) -> bool:
    return form in INJECTION_FORMS


def skip_person_for_novax(p) -> bool:
    return (
        (p.zahajeni_pojisteni > START_DATE)
        or (p.ukonceni_pojisteni is not None and p.ukonceni_pojisteni < END_DATE)
        or p.died_at
        or bool(p.vaccines)
        or not p.prescriptions
    )


def skip_person_for_vax(p) -> bool:
    return (
        (p.zahajeni_pojisteni > START_DATE)
        or (p.ukonceni_pojisteni is not None and p.ukonceni_pojisteni < END_DATE)
        or p.died_at
        or (not p.vaccines)
        or (not p.prescriptions)
    )


def collapse_injections(prescriptions):
    """Yield prescriptions but ignore additional injections within 14 days."""
    last_inj_date = date.min
    for pr in prescriptions:
        if is_injection(pr.lekova_forma):
            if abs((last_inj_date - pr.date).days) < 14:
                continue
            last_inj_date = pr.date
        yield pr


def safe_div(a: float, b: float) -> float:
    return np.nan if b == 0 else a / b


def pvalue_from_df(df: pl.DataFrame) -> float:
    """Fisher exact test on Polars df with rows očkovaní/neočkovaní, cols před/po"""
    table = [
        [
            int(df.filter(pl.col("group") == "očkovaní")["před"][0]),
            int(df.filter(pl.col("group") == "očkovaní")["po"][0]),
        ],
        [
            int(df.filter(pl.col("group") == "neočkovaní")["před"][0]),
            int(df.filter(pl.col("group") == "neočkovaní")["po"][0]),
        ],
    ]
    return fisher_exact(table, alternative="less")[1]


def before_after_df(
    vax_map: dict[int, float], novax_map: dict[int, float], pivot: int = 0
) -> pl.DataFrame:
    def sums(m):
        before = sum(v for d, v in m.items() if d < pivot)
        after = sum(v for d, v in m.items() if d >= pivot)
        return before, after

    vb, va = sums(vax_map)
    nb, na = sums(novax_map)
    return pl.DataFrame(
        {"group": ["očkovaní", "neočkovaní"], "před": [vb, nb], "po": [va, na]}
    )


def before_after_sums(
    dates_map: dict[int, float | int], pivot_day: int = 0
) -> dict[str, float]:
    before = sum(v for d, v in dates_map.items() if d < pivot_day)
    after = sum(v for d, v in dates_map.items() if d >= pivot_day)
    return {"před": before, "po": after}

In [47]:
if POJISTOVNA == "both_companies":
    with open("./DATACON_data/cpzp_persons.pkl", "rb") as f:
        cpzp_persons: list[Person] = pickle.load(f)
    with open("./DATACON_data/ozp_persons.pkl", "rb") as f:
        ozp_persons: list[Person] = pickle.load(f)
    persons = cpzp_persons + ozp_persons
else:
    with open(f"./DATACON_data/{POJISTOVNA}_persons.pkl", "rb") as f:
        persons: list[Person] = pickle.load(f)

In [None]:
# new_persons = []

# for p in persons:
#     person = p
#     for v in person.vaccines:
#         if v.dose_number == 2:
#             v.dose_number = 1
#     new_persons.append(person)

In [48]:
vax_dates_map = defaultdict(lambda: defaultdict(list))
max_vax_intensity_map = defaultdict(dict)

novax_ppl_predpisy_map = defaultdict(lambda: defaultdict(int))
novax_ppl_prvopredpisy_map = defaultdict(lambda: defaultdict(int))
novax_ppl_prednison_equivs_map = defaultdict(lambda: defaultdict(float))
novax_ppl_imunosupresivy_map = defaultdict(lambda: defaultdict(int))

vax_ppl_prvopredpisy_map = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
vax_ppl_predpisy_map = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
vax_ppl_prednison_equivs_map = defaultdict(
    lambda: defaultdict(lambda: defaultdict(int))
)
vax_ppl_imunosupresivy_map = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

# --- Max vaccination dates by cohort ----------------------------------------
for person in persons:
    if person.died_at or not person.vaccines:
        continue
    for v in person.vaccines:
        vax_dates_map[v.age_cohort][v.dose_number].append(v.date)

for cohort, doses in vax_dates_map.items():
    for dose, dates in doses.items():
        if dates:
            max_vax_intensity_map[cohort][dose] = Counter(dates).most_common(1)[0][0]

# --- NOVAX metrics -----------------------------------------------------------
for p in persons:
    if skip_person_for_novax(p):
        continue

    for pr in collapse_injections(p.prescriptions):
        cohort = pr.age_cohort_at_prescription
        if pr.prescription_type == PrescriptionType.IMUNOSUPRESSIVE:
            novax_ppl_imunosupresivy_map[cohort][pr.date] += 1

        novax_ppl_predpisy_map[cohort][pr.date] += 1
        novax_ppl_prednison_equivs_map[cohort][pr.date] += pr.prednison_equiv

    first = p.prescriptions[0]
    novax_ppl_prvopredpisy_map[first.age_cohort_at_prescription][first.date] += 1


# --- VAX metrics -------------------------------------------------------------
for p in persons:
    if skip_person_for_vax(p):
        continue

    for v in p.vaccines:
        max_int_date = max_vax_intensity_map[v.age_cohort][v.dose_number]
        if abs((v.date - max_int_date).days) > VAX_PERIOD_IN_DAYS:
            continue

        # prescriptions relative to this vax
        for pr in collapse_injections(p.prescriptions):
            rel_day = (pr.date - v.date).days
            if pr.prescription_type == PrescriptionType.IMUNOSUPRESSIVE:
                vax_ppl_imunosupresivy_map[v.age_cohort][v.dose_number][rel_day] += 1

            vax_ppl_predpisy_map[v.age_cohort][v.dose_number][rel_day] += 1
            vax_ppl_prednison_equivs_map[v.age_cohort][v.dose_number][
                rel_day
            ] += pr.prednison_equiv

        first = min(p.prescriptions, key=lambda x: x.date)
        rel_first = (first.date - v.date).days
        vax_ppl_prvopredpisy_map[v.age_cohort][v.dose_number][rel_first] += 1

In [49]:
rows = []


for dose in [1, 2, 3]:
    for cohort in AgeCohort:
        rozhodne = max_vax_intensity_map[cohort][dose]

        start, end = rozhodne - timedelta(days=PERIOD), rozhodne + timedelta(
            days=PERIOD
        )
        rel_days = [(start + timedelta(days=i)) for i in range((end - start).days + 1)]

        def novax_window(src_map):
            return {(d - rozhodne).days: src_map[cohort].get(d, 0) for d in rel_days}

        def vax_window(src_map):
            return {
                day: src_map[cohort][dose].get(day, 0) for day in range(-PERIOD, PERIOD)
            }

        metrics = {
            "predpisy": (
                vax_window(vax_ppl_predpisy_map),
                novax_window(novax_ppl_predpisy_map),
            ),
            "prvopredpisy": (
                vax_window(vax_ppl_prvopredpisy_map),
                novax_window(novax_ppl_prvopredpisy_map),
            ),
            "kortikoidy": (
                vax_window(vax_ppl_prednison_equivs_map),
                novax_window(novax_ppl_prednison_equivs_map),
            ),
            "imunosupresivy": (
                vax_window(vax_ppl_imunosupresivy_map),
                novax_window(novax_ppl_imunosupresivy_map),
            ),
        }

        for metric_name, (vax_map, novax_map) in metrics.items():
            df = before_after_df(vax_map, novax_map)

            vax_b, vax_a = (
                df.filter(pl.col("group") == "očkovaní").select(["před", "po"]).row(0)
            )
            nov_b, nov_a = (
                df.filter(pl.col("group") == "neočkovaní").select(["před", "po"]).row(0)
            )

            vax_inc = safe_div(vax_a, vax_b) * 100 if vax_b else np.nan
            novax_inc = safe_div(nov_a, nov_b) * 100 if nov_b else np.nan
            diff = (
                (vax_inc - novax_inc)
                if (np.isfinite(vax_inc) and np.isfinite(novax_inc))
                else np.nan
            )
            p_val = pvalue_from_df(df)

            ratio = np.nan
            if vax_b and nov_b and np.isfinite(vax_inc) and np.isfinite(novax_inc):
                ratio = safe_div(vax_a / vax_b, nov_a / nov_b)

            rows.append(
                {
                    "period_days": PERIOD,
                    "age_cohort": str(cohort),
                    "vax_dose": int(dose),
                    "metric": metric_name,
                    "vax_increase": vax_inc,
                    "novax_increase": novax_inc,
                    "diff": diff,
                    "p_value": p_val,
                    "vax_before": vax_b,
                    "vax_after": vax_a,
                    "novax_before": nov_b,
                    "novax_after": nov_a,
                    "vax_vs_novax_ratio": ratio,
                }
            )

huge_df = pl.DataFrame(rows).sort(["age_cohort"])


huge_df.write_csv(f"results_{POJISTOVNA}.csv")