In [1]:
from dataclasses import dataclass
from enum import StrEnum
from datetime import datetime, timedelta
import polars as pl
from collections import defaultdict, Counter
from typing import Tuple
from common.constants.column_types import (
    CPZP_SCHEMA,
    OZP_SCHEMA,
    POHLAVI_CPZP,
    TYP_UDALOSTI,
)
from common.constants.column_names import SHARED_COLUMNS, OZP_COLUMNS, CPZP_COLUMNS
import pickle
from common.constants.objects import (
    Person,
    Gender,
    AgeCohort,
    Prescription,
    PrescriptionType,
)
import matplotlib.pyplot as plt
import numpy as np
import os

pl.Config.set_tbl_rows(20)
pl.Config.set_tbl_cols(60)

from typing import Any
import matplotlib.dates as mdates

In [7]:
# CONSTANTS
POJISTOVNA = "ozp"
VAX_PERIOD_IN_DAYS = 30

In [8]:
with open(f"./DATACON_data/{POJISTOVNA}_persons.pkl", "rb") as f:
    persons: list[Person] = pickle.load(f)

In [9]:
# VARIABLES
vax_dates_map: dict[AgeCohort, dict[int, list[datetime]]] = defaultdict(dict)
max_vax_intensity_map: dict[AgeCohort, dict[int, datetime]] = defaultdict(dict)

novax_ppl_predpisy_map: dict[AgeCohort, dict[datetime, int]] = defaultdict(
    lambda: defaultdict(int)
)
novax_ppl_prvopredpisy_map: dict[AgeCohort, dict[datetime, int]] = defaultdict(
    lambda: defaultdict(int)
)
novax_ppl_prednison_equivs_map: dict[AgeCohort, dict[datetime, float]] = defaultdict(
    lambda: defaultdict(float)
)


def nested_defaultdict():
    return defaultdict(lambda: defaultdict(int))


vax_ppl_prvopredpisy_map = defaultdict(nested_defaultdict)
vax_ppl_predpisy_map = defaultdict(nested_defaultdict)
vax_ppl_prednison_equivs_map = defaultdict(nested_defaultdict)

# === MAX VAXINATION DATES BY COHORT ===
for person in persons:
    if person.died_at or not person.vaccines:
        continue
    for vax in person.vaccines:
        if vax.dose_number not in vax_dates_map[vax.age_cohort]:
            vax_dates_map[vax.age_cohort][vax.dose_number] = []
        vax_dates_map[vax.age_cohort][vax.dose_number].append(vax.date)


for age_cohort_at_prescription, doses_map in vax_dates_map.items():
    for dose_number, dates in doses_map.items():
        if not dates:
            continue

        max_date, _ = max(Counter(dates).items(), key=lambda x: x[1])
        max_vax_intensity_map[age_cohort_at_prescription][dose_number] = max_date


# === NOVAX PEOPLE METRICS ===
for p in persons:
    if p.died_at:
        continue

    if not p.vaccines:
        if not p.prescriptions:
            continue

        last_prescription_date: datetime = datetime.min.date()
        for prescription in p.prescriptions:
            age_cohort_at_prescription = prescription.age_cohort_at_prescription
            if (
                prescription.lekova_forma == "Injekční suspenze"
                or prescription.lekova_forma == "Injekční/infuzní roztok"
            ):
                if abs((last_prescription_date - prescription.date).days) < 14:
                    continue
            last_prescription_date = prescription.date

            novax_ppl_predpisy_map[age_cohort_at_prescription][prescription.date] += 1

            novax_ppl_prednison_equivs_map[age_cohort_at_prescription][
                prescription.date
            ] += 1

        first_prescription_date = p.prescriptions[0].date
        age_cohort_at_1st_prescription = p.prescriptions[0].age_cohort_at_prescription
        novax_ppl_prvopredpisy_map[age_cohort_at_1st_prescription][
            first_prescription_date
        ] += 1

for p in persons:
    if p.died_at or not p.vaccines:
        continue

    for vax in p.vaccines:
        max_int_date = max_vax_intensity_map[vax.age_cohort][vax.dose_number]
        if abs((vax.date - max_int_date).days) > VAX_PERIOD_IN_DAYS:
            continue

        if not p.prescriptions:
            continue

        # === PRESCRIPTIONS ===
        last_prescription_date: datetime = datetime.min.date()
        for prescription in p.prescriptions:

            if (
                prescription.lekova_forma == "Injekční suspenze"
                or prescription.lekova_forma == "Injekční/infuzní roztok"
            ):
                if abs((last_prescription_date - prescription.date).days) < 14:
                    continue
            last_prescription_date = prescription.date

            # === PRESCRIPTIONS COUNT ===
            vax_ppl_predpisy_map[vax.age_cohort][vax.dose_number][
                prescription.date
            ] += 1

            # === PREDNISON EQUIV ===
            if prescription.prednison_equiv is not None:
                vax_ppl_prednison_equivs_map[vax.age_cohort][vax.dose_number][
                    prescription.date
                ] += prescription.prednison_equiv

        # === FIRST PRESCRIPTIONS ===
        first_prescription = min(p.prescriptions, key=lambda x: x.date)

        vax_ppl_prvopredpisy_map[vax.age_cohort][vax.dose_number][
            first_prescription.date
        ] += 1

In [10]:
class ChartDrawer:
    def draw_2x2_block(
        self,
        vax_dates_map,
        novax_dates_map,
        rozhodne_datum,
        title,
        axes,
        row_offset,
        col_offset,
    ):
        vax_sums = self.__get_before_after_sums(vax_dates_map, rozhodne_datum)
        novax_sums = self.__get_before_after_sums(novax_dates_map, rozhodne_datum)
        total_vax_sum = sum(vax_sums.values())
        total_novax_sum = sum(novax_sums.values())

        # Očkovaná skupina
        self.__draw_scatter_plot(
            axes[row_offset][col_offset],
            list(vax_dates_map.keys()),
            list(vax_dates_map.values()),
            f"{title} - Očkovaná skupina",
            rozhodne_datum,
            total_vax_sum,
        )

        self.__draw_bar_chart(
            axes[row_offset][col_offset + 1],
            list(vax_sums.keys()),
            list(vax_sums.values()),
            f"{title} - Očkovaná skupina",
        )

        # Neočkovaná skupina
        self.__draw_scatter_plot(
            axes[row_offset + 1][col_offset],
            list(novax_dates_map.keys()),
            list(novax_dates_map.values()),
            f"{title} - Neočkovaná skupina",
            rozhodne_datum,
            total_novax_sum,
        )

        self.__draw_bar_chart(
            axes[row_offset + 1][col_offset + 1],
            list(novax_sums.keys()),
            list(novax_sums.values()),
            f"{title} - Neočkovaná skupina",
        )

    def __draw_scatter_plot(self, ax, x_data, y_data, title, rozhodne_datum, total_sum):
        # Původní data
        ax.plot(x_data, y_data, label="Data", alpha=0.7, marker="o", linestyle="None")

        # Vykreslení rozhodného data
        ax.axvline(
            x=rozhodne_datum,
            color="green",
            linestyle="--",
            linewidth=2,
            label=f"Rozhodné datum: {rozhodne_datum.strftime('%Y-%m-%d')}",
        )

        # Průměr před rozhodným datem
        before_values = [y for x, y in zip(x_data, y_data) if x < rozhodne_datum]
        if before_values:
            before_avg = sum(before_values) / len(before_values)
            ax.axhline(before_avg, color="blue", linestyle="--", label="Průměr před")

        # Průměr po rozhodném datu (včetně něj)
        after_values = [y for x, y in zip(x_data, y_data) if x >= rozhodne_datum]
        if after_values:
            after_avg = sum(after_values) / len(after_values)
            ax.axhline(after_avg, color="purple", linestyle="--", label="Průměr po")

        # Týdenní průměry
        weekly_buckets = defaultdict(list)
        for x, y in zip(x_data, y_data):
            week_start = x - timedelta(days=x.weekday())  # pondělí daného týdne
            weekly_buckets[week_start].append(y)

        # Spočítat průměr za každý týden
        weekly_avg_points = sorted(
            (week_start, sum(vals) / len(vals))
            for week_start, vals in weekly_buckets.items()
        )

        if weekly_avg_points:
            week_x, week_y = zip(*weekly_avg_points)
            ax.plot(week_x, week_y, color="orange", marker="s", label="Týdenní průměr")

        # Formát osy X jako datum
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
        ax.xaxis.set_major_locator(mdates.AutoDateLocator())
        ax.tick_params(axis="x", rotation=45)

        # Popisky a vzhled
        ax.set_xlabel("Dny kolem max intenzity")
        ax.set_ylabel("Počet předpisů za den")
        ax.set_title(title)
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.text(
            0.99,
            0.95,
            f"Celkem: {total_sum:,}",
            transform=ax.transAxes,
            ha="right",
            va="top",
            fontsize=10,
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.5),
        )

    def __draw_bar_chart(self, ax, x_data, y_data, title):
        bars = ax.bar(x_data, y_data, color="skyblue", edgecolor="black")
        ax.set_xlabel("Období před a po")
        ax.set_ylabel("Celkový počet předpisů")
        ax.set_title(title)
        ax.grid(True, alpha=0.3)

        percentages = [100, (y_data[1] / y_data[0]) * 100 if y_data[0] != 0 else 0]
        for i, (bar, value) in enumerate(zip(bars, y_data)):
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height(),
                f"{value:,} / {percentages[i]:.2f}%",
                ha="center",
                va="bottom",
                fontsize=9,
            )

    def __get_before_after_sums(self, dates_map, rozhodne_datum):
        before_items = [(d, v) for d, v in dates_map.items() if d < rozhodne_datum]
        after_items = [(d, v) for d, v in dates_map.items() if d >= rozhodne_datum]

        before_sum = sum(v for _, v in before_items)
        after_sum = sum(v for _, v in after_items)

        return {"před": before_sum, "po": after_sum}

In [11]:
for PRESCRIPTION_PERIOD_IN_DAYS in [30, 60, 90, 180, 364]:
    for age_cohort_at_prescription in AgeCohort:
        for dose_number in [1, 2, 3]:
            rozhodne_datum = max_vax_intensity_map[age_cohort_at_prescription][
                dose_number
            ]

            def filter_by_date_range(dates: dict[datetime, int]) -> dict[datetime, int]:
                start = rozhodne_datum - timedelta(days=PRESCRIPTION_PERIOD_IN_DAYS)
                end = rozhodne_datum + timedelta(days=PRESCRIPTION_PERIOD_IN_DAYS)
                return {
                    day: dates.get(day, 0)
                    for day in (
                        start + timedelta(days=i) for i in range((end - start).days + 1)
                    )
                }

            vax_predpisy_map = filter_by_date_range(
                vax_ppl_predpisy_map[age_cohort_at_prescription][dose_number]
            )
            novax_predpisy_map = filter_by_date_range(
                novax_ppl_predpisy_map[age_cohort_at_prescription]
            )

            vax_prvopredpisy_map = filter_by_date_range(
                vax_ppl_prvopredpisy_map[age_cohort_at_prescription][dose_number]
            )
            novax_prvopredpisy_map = filter_by_date_range(
                novax_ppl_prvopredpisy_map[age_cohort_at_prescription]
            )

            vax_kortikoidy_map = filter_by_date_range(
                vax_ppl_prednison_equivs_map[age_cohort_at_prescription][dose_number]
            )
            novax_kortikoidy_map = filter_by_date_range(
                novax_ppl_prednison_equivs_map[age_cohort_at_prescription]
            )

            fig, axes = plt.subplots(nrows=2, ncols=6, figsize=(60, 20))
            fig.tight_layout(pad=10.0)
            fig.suptitle(
                f" {POJISTOVNA} - {PRESCRIPTION_PERIOD_IN_DAYS} Dnů od {rozhodne_datum.strftime('%d.%m.%Y')} - {age_cohort_at_prescription.value} - Dose {dose_number}",
                fontsize=36,
            )

            drawer = ChartDrawer()

            drawer.draw_2x2_block(
                vax_dates_map=vax_predpisy_map,
                novax_dates_map=novax_predpisy_map,
                rozhodne_datum=rozhodne_datum,
                title="Předpisy",
                axes=axes,
                row_offset=0,
                col_offset=0,
            )
            drawer.draw_2x2_block(
                vax_dates_map=vax_prvopredpisy_map,
                novax_dates_map=novax_prvopredpisy_map,
                rozhodne_datum=rozhodne_datum,
                title="Prvopředpisy",
                axes=axes,
                row_offset=0,
                col_offset=2,
            )
            drawer.draw_2x2_block(
                vax_dates_map=vax_kortikoidy_map,
                novax_dates_map=novax_kortikoidy_map,
                rozhodne_datum=rozhodne_datum,
                title="Kortikoidové ekvivalenty",
                axes=axes,
                row_offset=0,
                col_offset=4,
            )

            os.makedirs(
                f"out/{POJISTOVNA}/{PRESCRIPTION_PERIOD_IN_DAYS}", exist_ok=True
            )
            plt.savefig(
                f"out/{POJISTOVNA}/{PRESCRIPTION_PERIOD_IN_DAYS}/{age_cohort_at_prescription.value}-{dose_number}.png"
            )
            plt.close(fig)