In [None]:
Build a yearly population table from yearly_raw with:

Total unique patients

Male/Female counts

Mean/Median/Q1/Q3 age overall

Mean/Median/Q1/Q3 age for Male

Mean/Median/Q1/Q3 age for Female

Save it as CSV.

Plot one combined figure showing:

Total / Male / Female patient counts (line plot)

Mean and median age for Male and Female (second axis line plot)

(If you want separate plots instead of a combined one, tell me — but this gives you everything in one place.)

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, FileLink, HTML


def build_and_plot_yearly_population_details(
    cohort_label,
    yearly_raw,
    years,
    output_dir,
    dpi=600
):
    """
    Build a yearly population demographics table (NOT deaths) from yearly_raw and plot it.

    Table (per year):
      - Total unique patients
      - Male unique patients
      - Female unique patients
      - Mean/Median/Q1/Q3 age overall
      - Mean/Median/Q1/Q3 age Male
      - Mean/Median/Q1/Q3 age Female

    Plot:
      - Left axis: patient counts (Total/Male/Female)
      - Right axis: mean/median ages for Male and Female

    Saves:
      - CSV table
      - JPEG plot

    Returns
    -------
    pop_df : pd.DataFrame
    """

    os.makedirs(output_dir, exist_ok=True)

    years_sorted = sorted([y for y in years if y in yearly_raw])

    def _stats(series):
        s = pd.to_numeric(series, errors="coerce").dropna()
        if s.empty:
            return {"Mean": np.nan, "Median": np.nan, "Q1": np.nan, "Q3": np.nan}
        return {
            "Mean": round(float(s.mean()), 2),
            "Median": round(float(s.median()), 2),
            "Q1": round(float(s.quantile(0.25)), 2),
            "Q3": round(float(s.quantile(0.75)), 2),
        }

    rows = []
    for year in years_sorted:
        df = yearly_raw[year].copy()
        df["AGE"] = year - df["YOB"]

        # Unique totals
        total_unique = int(df["ALF_E"].nunique())

        # Unique sex counts (based on rows; yearly_raw is already unique per ALF_E per year)
        male_df = df[df["GNDR_NAME"] == "Male"]
        female_df = df[df["GNDR_NAME"] == "Female"]

        male_unique = int(male_df["ALF_E"].nunique())
        female_unique = int(female_df["ALF_E"].nunique())

        # Stats
        stats_all = _stats(df["AGE"])
        stats_m = _stats(male_df["AGE"])
        stats_f = _stats(female_df["AGE"])

        rows.append({
            "Cohort": cohort_label,
            "Year": year,

            "Total Unique Patients": total_unique,
            "Male Unique Patients": male_unique,
            "Female Unique Patients": female_unique,

            "Mean Age (All)": stats_all["Mean"],
            "Median Age (All)": stats_all["Median"],
            "Q1 Age (All)": stats_all["Q1"],
            "Q3 Age (All)": stats_all["Q3"],

            "Mean Age (Male)": stats_m["Mean"],
            "Median Age (Male)": stats_m["Median"],
            "Q1 Age (Male)": stats_m["Q1"],
            "Q3 Age (Male)": stats_m["Q3"],

            "Mean Age (Female)": stats_f["Mean"],
            "Median Age (Female)": stats_f["Median"],
            "Q1 Age (Female)": stats_f["Q1"],
            "Q3 Age (Female)": stats_f["Q3"],
        })

    pop_df = pd.DataFrame(rows)

    short = "bav" if str(cohort_label).strip().upper() == "BAV" else "nonbav"
    out_csv = os.path.join(output_dir, f"{short}_yearly_population_details.csv")
    pop_df.to_csv(out_csv, index=False)

    print(f"✅ Saved yearly population details table → {out_csv}")
    display(FileLink(out_csv))
    display(HTML(pop_df.to_html(index=False)))

    # ---------------------------
    # Plot: counts + age (2 axes)
    # ---------------------------
    x = pop_df["Year"].tolist()

    fig, ax1 = plt.subplots(figsize=(14, 6))

    # Left axis: counts
    ax1.plot(x, pop_df["Total Unique Patients"], marker="o", label="Total unique patients")
    ax1.plot(x, pop_df["Male Unique Patients"], marker="o", label="Male unique patients")
    ax1.plot(x, pop_df["Female Unique Patients"], marker="o", label="Female unique patients")
    ax1.set_xlabel("Year")
    ax1.set_ylabel("Unique patients (count)")
    ax1.grid(axis="y", linestyle="--", alpha=0.3)

    # Right axis: ages
    ax2 = ax1.twinx()
    ax2.plot(x, pop_df["Mean Age (Male)"], marker="o", linestyle="--", label="Mean age (Male)")
    ax2.plot(x, pop_df["Median Age (Male)"], marker="o", linestyle=":", label="Median age (Male)")
    ax2.plot(x, pop_df["Mean Age (Female)"], marker="o", linestyle="--", label="Mean age (Female)")
    ax2.plot(x, pop_df["Median Age (Female)"], marker="o", linestyle=":", label="Median age (Female)")
    ax2.set_ylabel("Age (years)")

    plt.title(f"Yearly population counts and age summary – {cohort_label}")
    ax1.set_xticks(x)
    ax1.set_xticklabels(x, rotation=45, ha="right")

    # Merge legends
    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax1.legend(h1 + h2, l1 + l2, loc="upper left", frameon=False)

    plt.tight_layout()

    # Save plot
    fig_dir = os.path.join(output_dir, "Figures")
    os.makedirs(fig_dir, exist_ok=True)
    out_fig = os.path.join(fig_dir, f"{cohort_label}_Yearly_Population_Demographics.jpeg")
    plt.savefig(out_fig, dpi=dpi, format="jpeg", bbox_inches="tight", facecolor="white")
    print(f"✅ Saved plot → {out_fig}")

    plt.show()

    return pop_df


In [None]:
 # BAV

bav_pop_df = build_and_plot_yearly_population_details(
    cohort_label="BAV",
    yearly_raw=bav_results["yearly_raw"],
    years=years,
    output_dir=os.path.join(output_dir, "BAV"),
    dpi=600
)


In [None]:
# NON_BAV

nonbav_pop_df = build_and_plot_yearly_population_details(
    cohort_label="Non-BAV",
    yearly_raw=nonbav_results["yearly_raw"],
    years=years,
    output_dir=os.path.join(output_dir, "Non-BAV"),
    dpi=600
)



In [None]:
One single combined figure per cohort

Uses:

Yearly total unique patients (from yearly_raw)

Deaths by merged year-blocks with suppression (from the suppressed death table we created)

And we’ll keep the suppression rule exactly by plotting deaths at the merged block level (e.g., 2000–2001, 2002, …), while plotting patient counts yearly on the same x-axis.

To make that readable on one axis, the function:

plots yearly patients as a line

plots merged-block deaths as bars placed at the centre of the block

uses a second y-axis to avoid scale problems

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

def plot_patients_and_deaths_combined(
    cohort_label,
    yearly_raw,
    death_blocks_df,
    years,
    output_dir=None,
    dpi=600
):
    """
    Single combined graph per cohort:
      - Line: yearly total unique patients (from yearly_raw)
      - Bars: deaths per merged year-block (from death_blocks_df with suppression)

    Suppression rule is preserved because deaths come from the already-suppressed table.
    Bars are positioned at the midpoint of each year-block.

    Parameters
    ----------
    cohort_label : str
    yearly_raw : dict[int, pd.DataFrame]
        yearly_raw[year] contains one row per patient (unique ALF_E per year).
    death_blocks_df : pd.DataFrame
        Output of build_yearly_death_table_with_suppression().
        Must contain: Start Year, End Year, Deaths, Year Block
    years : iterable[int]
        Full year range for patient line x-axis (e.g., range(2000,2020))
    output_dir : str or None
        If provided, saves a JPEG.
    dpi : int
        Save dpi for JPEG.
    """

    years_sorted = sorted([y for y in years if y in yearly_raw])

    # --- Yearly unique patients ---
    patient_counts = [int(yearly_raw[y]["ALF_E"].nunique()) for y in years_sorted]

    # --- Build death bars from merged blocks ---
    # Midpoint x-position for each block
    block_midpoints = []
    block_deaths = []
    block_labels = []

    for _, r in death_blocks_df.iterrows():
        start_y = int(r["Start Year"])
        end_y = int(r["End Year"])
        deaths = float(r["Deaths"])
        mid = (start_y + end_y) / 2.0

        block_midpoints.append(mid)
        block_deaths.append(deaths)
        block_labels.append(str(r["Year Block"]))

    # Bar widths = block length (years) * 0.8 (nice spacing)
    block_widths = []
    for _, r in death_blocks_df.iterrows():
        start_y = int(r["Start Year"])
        end_y = int(r["End Year"])
        length = (end_y - start_y + 1)
        block_widths.append(length * 0.8)

    # --- Plot ---
    fig, ax1 = plt.subplots(figsize=(14, 6))

    # Patients line (left axis)
    ax1.plot(years_sorted, patient_counts, marker="o", linewidth=2, label="Total unique patients (yearly)")
    ax1.set_xlabel("Year")
    ax1.set_ylabel("Total unique patients (count)")
    ax1.set_title(f"Yearly patients and deaths (suppressed blocks) – {cohort_label}")
    ax1.grid(axis="y", linestyle="--", alpha=0.3)

    # Death bars (right axis)
    ax2 = ax1.twinx()
    ax2.bar(block_midpoints, block_deaths, width=block_widths, alpha=0.35, label="Deaths (merged blocks, threshold ≥10)")
    ax2.set_ylabel("Deaths (count, merged blocks)")

    # X limits and ticks
    if years_sorted:
        ax1.set_xlim(min(years_sorted) - 0.5, max(years_sorted) + 0.5)
    ax1.set_xticks(years_sorted)
    ax1.set_xticklabels(years_sorted, rotation=45, ha="right")

    # Add block labels above bars (optional but useful)
    # If you feel it clutters, comment this section.
    for x, y, lab in zip(block_midpoints, block_deaths, block_labels):
        ax2.text(x, y + (0.01 * max(block_deaths) if max(block_deaths) > 0 else 0.2),
                 lab, ha="center", va="bottom", fontsize=9, rotation=90)

    # Merge legends from both axes
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper left", frameon=False)

    plt.tight_layout()

    # Save
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, f"{cohort_label}_Patients_Deaths_Combined.jpeg")
        plt.savefig(save_path, dpi=dpi, format="jpeg", bbox_inches="tight", facecolor="white")
        print(f"✅ Saved combined plot → {save_path}")

        plt.tight_layout()

    # -----------------------------
    # SAVE FIGURE
    # -----------------------------
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(
            output_dir,
            f"{cohort_label}_Patients_Deaths_Combined.jpeg"
        )

        plt.savefig(
            save_path,
            dpi=dpi,
            format="jpeg",
            bbox_inches="tight",
            facecolor="white"
        )

        print(f"✅ Saved combined plot → {save_path}")

    plt.show()


    plt.show()


In [None]:
# BAV


death_blocks_bav = build_yearly_death_table_with_suppression(
    cohort_label="BAV",
    death_master=bav_results["death_master"],
    years=years,
    output_dir=os.path.join(output_dir, "BAV"),
    threshold=10
)

plot_patients_and_deaths_combined(
    cohort_label="BAV",
    yearly_raw=bav_results["yearly_raw"],
    death_blocks_df=death_blocks_bav,
    years=years,
    output_dir=os.path.join(output_dir, "BAV", "Figures"),
    dpi=600
)


In [None]:
# NON-BAV

death_blocks_nonbav = build_yearly_death_table_with_suppression(
    cohort_label="Non-BAV",
    death_master=nonbav_results["death_master"],
    years=years,
    output_dir=os.path.join(output_dir, "Non-BAV"),
    threshold=10
)

plot_patients_and_deaths_combined(
    cohort_label="Non-BAV",
    yearly_raw=nonbav_results["yearly_raw"],
    death_blocks_df=death_blocks_nonbav,
    years=years,
    output_dir=os.path.join(output_dir, "Non-BAV", "Figures"),
    dpi=600
)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, FileLink, HTML
import os


def build_yearly_death_table_with_suppression(
    cohort_label,
    death_master,
    years,
    output_dir,
    threshold=10
):
    """
    Build a deaths-only table by year using death_master (NOT age-grouped),
    applying suppression by merging consecutive years until death count >= threshold.

    Logic:
      - Start at the first year.
      - If deaths in that year < threshold, merge with next year and re-check.
      - Continue merging until >= threshold OR reach last year.
      - Output rows represent year blocks like:
            2000-2001, 2002, 2003-2005, ...
      - For each block compute:
            deaths, male deaths, female deaths, mean/median/Q1/Q3 age at death.

    Returns
    -------
    df_blocks : pd.DataFrame
        Suppressed/merged year-block table.
    """
    os.makedirs(output_dir, exist_ok=True)

    dm = death_master.copy()
    if dm.empty:
        df_blocks = pd.DataFrame(columns=[
            "Cohort", "Year Block", "Start Year", "End Year",
            "Deaths", "Deaths Male", "Deaths Female",
            "Mean Age at Death", "Median Age at Death", "Q1 Age at Death", "Q3 Age at Death"
        ])
        return df_blocks

    # Ensure correct year range
    years = sorted(list(years))
    min_y, max_y = min(years), max(years)

    # Filter death master to requested year range
    dm = dm[(dm["DOD_YEAR"] >= min_y) & (dm["DOD_YEAR"] <= max_y)].copy()

    # Helper for age stats
    def _stats(s):
        s = pd.to_numeric(s, errors="coerce").dropna()
        if s.empty:
            return np.nan, np.nan, np.nan, np.nan
        return (
            round(float(s.mean()), 2),
            round(float(s.median()), 2),
            round(float(s.quantile(0.25)), 2),
            round(float(s.quantile(0.75)), 2),
        )

    blocks = []
    i = 0

    while i < len(years):
        start_year = years[i]
        end_year = start_year

        # merge forward until threshold met OR reach last year
        while True:
            dm_block = dm[(dm["DOD_YEAR"] >= start_year) & (dm["DOD_YEAR"] <= end_year)]
            n_deaths = int(len(dm_block))

            if n_deaths >= threshold:
                break

            # if we are at the last year, stop merging even if still < threshold
            if end_year == years[-1]:
                break

            # move to next year
            next_index = years.index(end_year) + 1
            end_year = years[next_index]

        # finalize block stats
        dm_block = dm[(dm["DOD_YEAR"] >= start_year) & (dm["DOD_YEAR"] <= end_year)]
        n_deaths = int(len(dm_block))
        n_male = int((dm_block["GNDR_NAME"] == "Male").sum()) if n_deaths else 0
        n_female = int((dm_block["GNDR_NAME"] == "Female").sum()) if n_deaths else 0

        mean_age, med_age, q1_age, q3_age = _stats(dm_block["AGE_AT_DEATH"])

        year_block_label = f"{start_year}" if start_year == end_year else f"{start_year}-{end_year}"

        blocks.append({
            "Cohort": cohort_label,
            "Year Block": year_block_label,
            "Start Year": start_year,
            "End Year": end_year,
            "Deaths": n_deaths,
            "Deaths Male": n_male,
            "Deaths Female": n_female,
            "Mean Age at Death": mean_age,
            "Median Age at Death": med_age,
            "Q1 Age at Death": q1_age,
            "Q3 Age at Death": q3_age,
        })

        # move i to the next unprocessed year (end_year + 1)
        i = years.index(end_year) + 1

    df_blocks = pd.DataFrame(blocks)

    short = "bav" if str(cohort_label).strip().upper() == "BAV" else "nonbav"
    out_csv = os.path.join(output_dir, f"{short}_death_year_blocks_threshold_{threshold}.csv")
    df_blocks.to_csv(out_csv, index=False)

    print(f"✅ Saved death year-block table → {out_csv}")
    display(FileLink(out_csv))
    display(HTML(df_blocks.to_html(index=False)))

    return df_blocks


def plot_yearly_death_demographics(df_blocks, output_dir=None, cohort_label=None, dpi=300):
    """
    Plot graphs from the suppressed/merged death year-block table.

    Produces:
      1) Bar chart: Deaths per Year Block
      2) Line chart: Mean and Median age at death per Year Block

    If output_dir is provided, saves JPEGs as well.
    """
    if df_blocks is None or df_blocks.empty:
        print("⚠️ No death blocks to plot.")
        return

    if cohort_label is None:
        cohort_label = str(df_blocks["Cohort"].iloc[0]) if "Cohort" in df_blocks.columns else "Cohort"

    x_labels = df_blocks["Year Block"].astype(str).tolist()

    # --- Plot 1: Death counts ---
    plt.figure(figsize=(max(10, len(x_labels) * 0.9), 5))
    plt.bar(x_labels, df_blocks["Deaths"].values)
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Deaths (count)")
    plt.xlabel("Year block (merged for suppression)")
    plt.title(f"Deaths by year block (threshold suppression) – {cohort_label}")
    plt.tight_layout()

    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, f"{cohort_label}_Deaths_By_YearBlock.jpeg")
        plt.savefig(save_path, dpi=dpi, format="jpeg", bbox_inches="tight", facecolor="white")
        print(f"✅ Saved plot → {save_path}")

    plt.show()

    # --- Plot 2: Mean and Median age at death ---
    plt.figure(figsize=(max(10, len(x_labels) * 0.9), 5))
    plt.plot(x_labels, df_blocks["Mean Age at Death"].values, marker="o", label="Mean age at death")
    plt.plot(x_labels, df_blocks["Median Age at Death"].values, marker="o", label="Median age at death")
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Age (years)")
    plt.xlabel("Year block (merged for suppression)")
    plt.title(f"Age at death by year block – {cohort_label}")
    plt.legend()
    plt.tight_layout()

    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, f"{cohort_label}_AgeAtDeath_By_YearBlock.jpeg")
        plt.savefig(save_path, dpi=dpi, format="jpeg", bbox_inches="tight", facecolor="white")
        print(f"✅ Saved plot → {save_path}")

    plt.show()


In [None]:
def plot_yearly_unique_patients(yearly_raw, cohort_label, output_dir=None, dpi=300):
    """
    Plot yearly unique patient counts from yearly_raw (one row per ALF_E per year).
    """
    years_sorted = sorted(yearly_raw.keys())
    counts = [int(yearly_raw[y]["ALF_E"].nunique()) for y in years_sorted]

    plt.figure(figsize=(12, 4.5))
    plt.plot(years_sorted, counts, marker="o")
    plt.xlabel("Year")
    plt.ylabel("Unique patients (count)")
    plt.title(f"Yearly unique patients – {cohort_label}")
    plt.tight_layout()

    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, f"{cohort_label}_Yearly_UniquePatients.jpeg")
        plt.savefig(save_path, dpi=dpi, format="jpeg", bbox_inches="tight", facecolor="white")
        print(f"✅ Saved plot → {save_path}")

    plt.show()


In [None]:
def plot_yearly_unique_patients(yearly_raw, cohort_label, output_dir=None, dpi=300):
    """
    Plot yearly unique patient counts from yearly_raw (one row per ALF_E per year).
    """
    years_sorted = sorted(yearly_raw.keys())
    counts = [int(yearly_raw[y]["ALF_E"].nunique()) for y in years_sorted]

    plt.figure(figsize=(12, 4.5))
    plt.plot(years_sorted, counts, marker="o")
    plt.xlabel("Year")
    plt.ylabel("Unique patients (count)")
    plt.title(f"Yearly unique patients – {cohort_label}")
    plt.tight_layout()

    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, f"{cohort_label}_Yearly_UniquePatients.jpeg")
        plt.savefig(save_path, dpi=dpi, format="jpeg", bbox_inches="tight", facecolor="white")
        print(f"✅ Saved plot → {save_path}")

    plt.show()


In [None]:
# ---- New: death-by-year blocks (suppression) table + plots ----
# After you run:
# bav_results = run_halomap_and_tables("BAV", ...)

death_blocks_bav = build_yearly_death_table_with_suppression(
    cohort_label="BAV",
    death_master=bav_results["death_master"],
    years=years,
    output_dir=os.path.join(output_dir, "BAV"),
    threshold=10
)

plot_yearly_death_demographics(
    death_blocks_bav,
    output_dir=os.path.join(output_dir, "BAV", "Figures"),
    cohort_label="BAV",
    dpi=600
)

plot_yearly_unique_patients(
    bav_results["yearly_raw"],
    cohort_label="BAV",
    output_dir=os.path.join(output_dir, "BAV", "Figures"),
    dpi=600
)



In [None]:
# After:
# nonbav_results = run_halomap_and_tables("Non-BAV", ...)

death_blocks_nonbav = build_yearly_death_table_with_suppression(
    cohort_label="Non-BAV",
    death_master=nonbav_results["death_master"],
    years=years,
    output_dir=os.path.join(output_dir, "Non-BAV"),
    threshold=10
)

plot_yearly_death_demographics(
    death_blocks_nonbav,
    output_dir=os.path.join(output_dir, "Non-BAV", "Figures"),
    cohort_label="Non-BAV",
    dpi=600
)

plot_yearly_unique_patients(
    nonbav_results["yearly_raw"],
    cohort_label="Non-BAV",
    output_dir=os.path.join(output_dir, "Non-BAV", "Figures"),
    dpi=600
)


In [None]:
"death_blocks_df": death_blocks_df,


In [None]:
# Cohort HaloMap™ + Consistent Tables (Unique patients per period; deaths from death_master)
# ----------------------------------------------------------------------------------------------
# This version fixes the *death percentage logic* so that, within each period, the death sub-bubbles
# sum to 100% across the displayed age bins (unless you purposely exclude ages).
#
# Core concept (your requirement):
#   1) A death belongs to a PERIOD based on DOD_YEAR (death year).
#   2) Within that period, the death belongs to an AGE GROUP based on AGE_AT_DEATH.
#   3) "Deaths % of Period" = deaths_in_agebin_in_period / total_deaths_in_period * 100
#      -> If your age bins cover all possible ages-at-death, these percentages sum to 100 per period.
#
# IMPORTANT:
# If your bins stop at 80 (e.g., 66–80) but you have deaths at age >80, then the sum will be <100.
# To enforce sums of 100, either:
#   • include an 81+ bin, OR
#   • set the last bin upper bound to None (open-ended) and this code will treat it as 66+.
#
# Outputs (per cohort):
#   1) Period×Age-bin table (unique patients; FU_<year>; deaths from master)
#      - includes BOTH "Deaths % of Period" and "Deaths % of Bin" for transparency
#   2) Period totals table (overall + sex; death totals + sex; age stats + IQR)
#   3) Yearly follow-up table (unique patients/year + new admissions)
#   4) HaloMap plot using:
#        Main bubble: % of unique patients in period (age-bin / period unique total)
#        M/F bubbles: % male/female among known sex in the bin
#        D bubble:    % deaths in period (death distribution across age groups)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from IPython.display import display, FileLink, HTML

# -----------------------------------------------------------
# 1) Loading utilities
# -----------------------------------------------------------

def load_yearly_raw_for_cohort(data_dir, years, cohort_label, require_cols=("ALF_E","YOB","GNDR_NAME","DOD")):
    """
    Load yearly PKL files for a cohort.

    Returns
    -------
    yearly_raw : dict[int, pd.DataFrame]
        yearly_raw[year] includes one row per UNIQUE patient (ALF_E) for that year:
        ALF_E, YOB, GNDR_NAME, DOD (datetime, NaT if missing), Year
    cohort_key : str
        "original" for BAV, "matched" otherwise.
    """
    cohort_key = "original" if str(cohort_label).strip().upper() == "BAV" else "matched"
    yearly_raw = {}

    for year in years:
        pkl_path = os.path.join(data_dir, f"{year}.pkl")
        if not os.path.exists(pkl_path):
            print(f"⚠️ Missing PKL for {cohort_label} in {year}: {pkl_path}")
            continue

        data = pd.read_pickle(pkl_path)
        df = data.get(cohort_key, pd.DataFrame()).copy()
        if df.empty:
            continue

        for col in require_cols:
            if col not in df.columns:
                raise ValueError(f"{cohort_label} {year}: missing column '{col}' in key '{cohort_key}'")

        tmp = df[list(require_cols)].copy()
        tmp["Year"] = year
        tmp["DOD"] = pd.to_datetime(tmp["DOD"], errors="coerce")

        # one row per patient per year
        tmp = tmp.sort_values(["ALF_E"]).drop_duplicates(subset=["ALF_E"], keep="first")
        yearly_raw[year] = tmp

    if not yearly_raw:
        raise ValueError(f"No data found for cohort={cohort_label} in any requested year from {data_dir}")

    uniq_counts = {y: len(df) for y, df in yearly_raw.items()}
    print(f"✅ Loaded {len(yearly_raw)} yearly PKLs for cohort={cohort_label} (cohort_key='{cohort_key}')")
    print(f"   Unique patients per year (first 5): {dict(list(uniq_counts.items())[:5])}")

    return yearly_raw, cohort_key


def build_death_master_from_yearly(yearly_raw, years):
    """
    Build death master (ONE row per ALF_E) using the last non-null DOD seen.

    Returns columns:
      ALF_E, YOB, GNDR_NAME, DOD, DOD_YEAR, AGE_AT_DEATH
    """
    all_df = pd.concat(list(yearly_raw.values()), ignore_index=True)
    all_df["DOD_YEAR"] = pd.to_datetime(all_df["DOD"], errors="coerce").dt.year

    dead = all_df[all_df["DOD"].notna()].copy()
    dead = dead[dead["DOD_YEAR"].between(min(years), max(years))]

    if dead.empty:
        return pd.DataFrame(columns=["ALF_E","YOB","GNDR_NAME","DOD","DOD_YEAR","AGE_AT_DEATH"])

    dead = dead.sort_values(["ALF_E", "DOD"])
    death_master = (
        dead.groupby("ALF_E", as_index=False)
            .agg({"YOB":"first", "GNDR_NAME":"first", "DOD":"last", "DOD_YEAR":"last"})
    )
    death_master["AGE_AT_DEATH"] = death_master["DOD_YEAR"] - death_master["YOB"]
    return death_master


# -----------------------------------------------------------
# 2) Stats + binning helpers
# -----------------------------------------------------------

def compute_stats(series):
    """Mean, median, Q1, Q3 for numeric series (NaNs if empty)."""
    s = pd.to_numeric(series, errors="coerce").dropna()
    if s.empty:
        return {"Mean": np.nan, "Median": np.nan, "Q1": np.nan, "Q3": np.nan}
    return {
        "Mean":   round(float(s.mean()),   2),
        "Median": round(float(s.median()), 2),
        "Q1":     round(float(s.quantile(0.25)), 2),
        "Q3":     round(float(s.quantile(0.75)), 2),
    }

def make_age_bin_assigner(age_bins):
    """
    Inclusive bounds with support for open-ended bins:
      - Use (66, None) to represent 66+.

    Returns (assign_func, labels)
    """
    labels = []
    norm_bins = []
    for a, b in age_bins:
        if b is None:
            labels.append(f"{a}+")
            norm_bins.append((a, None))
        else:
            labels.append(f"{a}-{b}")
            norm_bins.append((a, b))

    def assign(age):
        if pd.isna(age):
            return None
        for (a, b), lab in zip(norm_bins, labels):
            if b is None:
                if age >= a:
                    return lab
            else:
                if a <= age <= b:
                    return lab
        return None

    return assign, labels


# -----------------------------------------------------------
# 3) Period × Age-bin table (canonical) + HaloMap df
# -----------------------------------------------------------

def build_period_agebin_table_unique(
    cohort_label,
    yearly_raw,
    death_master,
    periods,
    age_bins,
    years,
    output_dir
):
    """
    Canonical Period×Age-bin table with consistent logic.

    Includes BOTH:
      - Deaths % of Period  (sums to 100 per period if bins cover all ages-at-death)
      - Deaths % of Bin     (risk within the age-bin: deaths_in_bin / patients_in_bin)

    FU_<year> counts UNIQUE period+bin patients appearing in that year (within the period years).
    """
    os.makedirs(output_dir, exist_ok=True)
    assign_bin, age_labels = make_age_bin_assigner(age_bins)

    dm = death_master.copy()
    if not dm.empty:
        dm["AgeGroup_Death"] = dm["AGE_AT_DEATH"].apply(assign_bin)

    all_years_sorted = sorted([y for y in years if y in yearly_raw])

    rows = []
    plot_rows = []

    for (start, end) in periods:
        period_label = f"{start}-{end}"
        period_years = [y for y in all_years_sorted if start <= y <= end]
        if not period_years:
            continue

        # Unique patients across the period
        period_all = pd.concat([yearly_raw[y][["ALF_E","YOB","GNDR_NAME"]].copy() for y in period_years],
                               ignore_index=True)
        period_unique = period_all.drop_duplicates(subset=["ALF_E"]).copy()
        period_unique["AGE_AT_PERIOD_END"] = end - period_unique["YOB"]
        period_unique["AgeGroup_Period"] = period_unique["AGE_AT_PERIOD_END"].apply(assign_bin)

        # Deaths in this period (by DOD_YEAR)
        dm_period = dm[(dm["DOD_YEAR"] >= start) & (dm["DOD_YEAR"] <= end)].copy() if not dm.empty else dm

        # Denominator for "Deaths % of Period": count only deaths that fall into your bins
        if dm_period.empty:
            total_deaths_in_bins = 0
            deaths_outside_bins = 0
        else:
            dm_period["AgeGroup_Death"] = dm_period["AGE_AT_DEATH"].apply(assign_bin)
            total_deaths_in_bins = int(dm_period["AgeGroup_Death"].notna().sum())
            deaths_outside_bins = int(len(dm_period) - total_deaths_in_bins)

        if deaths_outside_bins > 0:
            print(f"⚠️ {cohort_label} {period_label}: {deaths_outside_bins} deaths fall outside your age_bins. "
                  f"Add an open-ended last bin like (66, None) or extend the upper bound to make death % sum to 100.")

        for label in age_labels:
            bin_df = period_unique[period_unique["AgeGroup_Period"] == label].copy()
            if bin_df.empty:
                continue

            ids_in_bin = set(bin_df["ALF_E"])
            n_patients = int(len(ids_in_bin))

            n_male = int((bin_df["GNDR_NAME"] == "Male").sum())
            n_female = int((bin_df["GNDR_NAME"] == "Female").sum())

            # deaths in this bin (AGE_AT_DEATH bin + within period)
            if dm_period.empty:
                deaths_bin = 0
                death_male = 0
                death_female = 0
            else:
                dm_bin = dm_period[dm_period["AgeGroup_Death"] == label]
                deaths_bin = int(len(dm_bin))
                death_male = int((dm_bin["GNDR_NAME"] == "Male").sum())
                death_female = int((dm_bin["GNDR_NAME"] == "Female").sum())

            deaths_pct_of_period = round((deaths_bin / total_deaths_in_bins * 100) if total_deaths_in_bins else 0.0, 2)
            deaths_pct_of_bin = round((deaths_bin / n_patients * 100) if n_patients else 0.0, 2)

            stats_all = compute_stats(bin_df["AGE_AT_PERIOD_END"])
            stats_m = compute_stats(bin_df.loc[bin_df["GNDR_NAME"] == "Male", "AGE_AT_PERIOD_END"])
            stats_f = compute_stats(bin_df.loc[bin_df["GNDR_NAME"] == "Female", "AGE_AT_PERIOD_END"])

            # FU counts (within the period only)
            follow_cols = {}
            for yr in all_years_sorted:
                col = f"FU_{yr}"
                if start <= yr <= end:
                    ids_year = set(yearly_raw[yr]["ALF_E"])
                    follow_cols[col] = int(len(ids_in_bin & ids_year))
                else:
                    follow_cols[col] = 0

            row = {
                "Cohort": cohort_label,
                "Period": period_label,
                "Age Group": label,

                "Unique Patients in Period": n_patients,
                "Male": n_male,
                "Female": n_female,

                "Deaths in Period (age-bin)": deaths_bin,
                "Deaths % of Period": deaths_pct_of_period,  # sums to 100 if bins cover all deaths
                "Deaths % of Bin": deaths_pct_of_bin,        # risk within bin

                "Mean Age (All, Period End)": stats_all["Mean"],
                "Median Age (All, Period End)": stats_all["Median"],
                "Q1 Age (All, Period End)": stats_all["Q1"],
                "Q3 Age (All, Period End)": stats_all["Q3"],

                "Mean Age (Male, Period End)": stats_m["Mean"],
                "Median Age (Male, Period End)": stats_m["Median"],
                "Q1 Age (Male, Period End)": stats_m["Q1"],
                "Q3 Age (Male, Period End)": stats_m["Q3"],

                "Mean Age (Female, Period End)": stats_f["Mean"],
                "Median Age (Female, Period End)": stats_f["Median"],
                "Q1 Age (Female, Period End)": stats_f["Q1"],
                "Q3 Age (Female, Period End)": stats_f["Q3"],

                "Deaths Male (age-bin)": death_male,
                "Deaths Female (age-bin)": death_female,
            }
            row.update(follow_cols)
            rows.append(row)

            plot_rows.append({
                "PeriodLabel": period_label,
                "AgeBinLabel": label,
                "Patients": n_patients,
                "Male": n_male,
                "Female": n_female,
                "Deaths": deaths_bin,
                "DeathsPctOfPeriod": float(deaths_pct_of_period),
                "DeathsPctOfBin": float(deaths_pct_of_bin),
            })

    period_agebin_df = pd.DataFrame(rows)
    plot_df = pd.DataFrame(plot_rows)

    short = "bav" if str(cohort_label).strip().upper() == "BAV" else "nonbav"
    out_csv = os.path.join(output_dir, f"{short}_period_agebin_unique_followup.csv")
    out_plot = os.path.join(output_dir, f"{short}_period_agebin_plottable.csv")

    period_agebin_df.to_csv(out_csv, index=False)
    plot_df.to_csv(out_plot, index=False)

    print(f"✅ Saved Period×Age-bin table → {out_csv}")
    print(f"✅ Saved HaloMap plottable df → {out_plot}")
    display(FileLink(out_csv))
    display(HTML(period_agebin_df.to_html(index=False)))

    return period_agebin_df, plot_df


# -----------------------------------------------------------
# 4) Period totals table (requested)
# -----------------------------------------------------------

def build_period_totals_table(cohort_label, yearly_raw, death_master, periods, years, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    all_years_sorted = sorted([y for y in years if y in yearly_raw])
    dm = death_master.copy()

    rows = []
    for (start, end) in periods:
        period_label = f"{start}-{end}"
        period_years = [y for y in all_years_sorted if start <= y <= end]
        if not period_years:
            continue

        period_all = pd.concat([yearly_raw[y][["ALF_E","YOB","GNDR_NAME"]].copy() for y in period_years], ignore_index=True)
        period_unique = period_all.drop_duplicates(subset=["ALF_E"]).copy()
        period_unique["AGE_AT_PERIOD_END"] = end - period_unique["YOB"]

        n_patients = int(len(period_unique))
        male_pat = int((period_unique["GNDR_NAME"] == "Male").sum())
        female_pat = int((period_unique["GNDR_NAME"] == "Female").sum())

        dm_period = dm[(dm["DOD_YEAR"] >= start) & (dm["DOD_YEAR"] <= end)].copy() if not dm.empty else dm
        n_deaths = int(len(dm_period))
        male_death = int((dm_period["GNDR_NAME"] == "Male").sum()) if not dm_period.empty else 0
        female_death = int((dm_period["GNDR_NAME"] == "Female").sum()) if not dm_period.empty else 0

        stats_m = compute_stats(period_unique.loc[period_unique["GNDR_NAME"] == "Male", "AGE_AT_PERIOD_END"])
        stats_f = compute_stats(period_unique.loc[period_unique["GNDR_NAME"] == "Female", "AGE_AT_PERIOD_END"])
        death_stats_m = compute_stats(dm_period.loc[dm_period["GNDR_NAME"] == "Male", "AGE_AT_DEATH"]) if not dm_period.empty else {"Mean":np.nan,"Median":np.nan,"Q1":np.nan,"Q3":np.nan}
        death_stats_f = compute_stats(dm_period.loc[dm_period["GNDR_NAME"] == "Female", "AGE_AT_DEATH"]) if not dm_period.empty else {"Mean":np.nan,"Median":np.nan,"Q1":np.nan,"Q3":np.nan}

        rows.append({
            "Cohort": cohort_label,
            "Period": period_label,
            "Unique Patients in Period": n_patients,
            "Male Patients": male_pat,
            "Female Patients": female_pat,
            "Deaths in Period (master)": n_deaths,
            "Deaths Male": male_death,
            "Deaths Female": female_death,

            "Mean Age (Male, Period End)": stats_m["Mean"],
            "Median Age (Male, Period End)": stats_m["Median"],
            "Q1 Age (Male, Period End)": stats_m["Q1"],
            "Q3 Age (Male, Period End)": stats_m["Q3"],

            "Mean Age (Female, Period End)": stats_f["Mean"],
            "Median Age (Female, Period End)": stats_f["Median"],
            "Q1 Age (Female, Period End)": stats_f["Q1"],
            "Q3 Age (Female, Period End)": stats_f["Q3"],

            "Mean Age at Death (Male)": death_stats_m["Mean"],
            "Median Age at Death (Male)": death_stats_m["Median"],
            "Q1 Age at Death (Male)": death_stats_m["Q1"],
            "Q3 Age at Death (Male)": death_stats_m["Q3"],

            "Mean Age at Death (Female)": death_stats_f["Mean"],
            "Median Age at Death (Female)": death_stats_f["Median"],
            "Q1 Age at Death (Female)": death_stats_f["Q1"],
            "Q3 Age at Death (Female)": death_stats_f["Q3"],
        })

    df = pd.DataFrame(rows)
    short = "bav" if str(cohort_label).strip().upper() == "BAV" else "nonbav"
    out_csv = os.path.join(output_dir, f"{short}_period_totals_summary.csv")
    df.to_csv(out_csv, index=False)
    print(f"✅ Saved Period totals table → {out_csv}")
    display(FileLink(out_csv))
    display(HTML(df.to_html(index=False)))
    return df


# -----------------------------------------------------------
# 5) Yearly follow-up table
# -----------------------------------------------------------

def build_yearly_followup_table_unique(cohort_label, yearly_raw, years, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    sorted_years = sorted([y for y in years if y in yearly_raw])

    first_year = {}
    for year in sorted_years:
        for pid in yearly_raw[year]["ALF_E"].unique():
            if pid not in first_year:
                first_year[pid] = year

    rows = []
    for year in sorted_years:
        df_year = yearly_raw[year].copy()
        df_year["AGE"] = year - df_year["YOB"]

        total_unique = int(df_year["ALF_E"].nunique())
        male_total = int((df_year["GNDR_NAME"] == "Male").sum())
        female_total = int((df_year["GNDR_NAME"] == "Female").sum())

        is_new = df_year["ALF_E"].map(first_year) == year
        df_new = df_year[is_new]
        new_unique = int(df_new["ALF_E"].nunique())
        male_new = int((df_new["GNDR_NAME"] == "Male").sum())
        female_new = int((df_new["GNDR_NAME"] == "Female").sum())

        stats = compute_stats(df_year["AGE"])
        rows.append({
            "Cohort": cohort_label,
            "Year": year,
            "Total Unique Patients": total_unique,
            "Male Total": male_total,
            "Female Total": female_total,
            "New Admissions": new_unique,
            "Male New": male_new,
            "Female New": female_new,
            "Mean Age": stats["Mean"],
            "Median Age": stats["Median"],
            "Q1 Age": stats["Q1"],
            "Q3 Age": stats["Q3"],
        })

    df = pd.DataFrame(rows)
    short = "bav" if str(cohort_label).strip().upper() == "BAV" else "nonbav"
    out_csv = os.path.join(output_dir, f"{short}_yearly_followup_unique.csv")
    df.to_csv(out_csv, index=False)
    print(f"✅ Saved yearly follow-up table → {out_csv}")
    display(FileLink(out_csv))
    display(HTML(df.to_html(index=False)))
    return df


# -----------------------------------------------------------
# 6) HaloMap plot (D bubble = Deaths % of Period)
# -----------------------------------------------------------

def plot_period_haloplot_unique(cohort_label, plot_df, periods, suppress_n_lt=10, save_dir=None, dpi=600):

    if plot_df.empty:
        print("⚠️ No data to plot.")
        return

    def lower_bound(label):
        s = str(label).replace("+","")
        try:
            return int(s.split("-")[0])
        except Exception:
            return 9999

    age_bins_sorted = sorted(plot_df["AgeBinLabel"].unique(), key=lower_bound)
    period_labels = [f"{s}-{e}" for (s, e) in periods]

    plot_df = plot_df.copy()
    plot_df["AgeBinLabel"] = pd.Categorical(plot_df["AgeBinLabel"], categories=age_bins_sorted, ordered=True)
    plot_df["PeriodLabel"] = pd.Categorical(plot_df["PeriodLabel"], categories=period_labels, ordered=True)
    plot_df = plot_df.sort_values(["AgeBinLabel", "PeriodLabel"])

    total_patients_by_period = plot_df.groupby("PeriodLabel")["Patients"].sum().to_dict()

    def main_bubble_size(pct):
        return 4000 if pct <= 25 else 5500 if pct <= 50 else 7000 if pct <= 75 else 9000

    def sub_bubble_size(pct):
        return 900 if pct <= 25 else 1200 if pct <= 50 else 1500 if pct <= 75 else 1800

    fig, ax = plt.subplots(figsize=(len(period_labels) * 4.0, len(age_bins_sorted) * 4.4))
    subgroup_colors = {"M": "#4A90E2", "F": "#E94E77", "D": "#F5A623"}

    for i, age_bin in enumerate(age_bins_sorted):
        for j, period in enumerate(period_labels):
            row = plot_df[(plot_df["PeriodLabel"] == period) & (plot_df["AgeBinLabel"] == age_bin)]
            if row.empty:
                continue

            total_n = int(row["Patients"].iloc[0])
            if total_n < suppress_n_lt:
                continue

            male_n = int(row["Male"].iloc[0])
            female_n = int(row["Female"].iloc[0])
            deaths_pct_period = float(row["DeathsPctOfPeriod"].iloc[0])  # sums to 100 per period if bins cover deaths

            period_total = total_patients_by_period.get(period, 0)
            total_pct = round((total_n / period_total * 100) if period_total else 0.0, 1)

            known_total = male_n + female_n
            male_pct = round((male_n / known_total * 100) if known_total else 0.0, 1)
            female_pct = round((female_n / known_total * 100) if known_total else 0.0, 1)

            if male_pct + female_pct > 100:
                s = male_pct + female_pct
                male_pct = round(male_pct / s * 100, 1)
                female_pct = round(female_pct / s * 100, 1)

            x, y = j, i
            ax.scatter(x, y, s=main_bubble_size(total_pct), color="lightblue", edgecolor="black", alpha=0.85)
            ax.text(x, y, f"{total_pct:.1f}%", fontsize=18, weight="bold", ha="center", va="center")

            sub_specs = [
                (-0.20, "M", male_n, male_pct),
                ( 0.00, "F", female_n, female_pct),
                ( 0.20, "D", 1,      round(deaths_pct_period, 1)),
            ]
            for dx, lab, raw_n, pct in sub_specs:
                if lab in ("M","F") and raw_n <= 0:
                    continue
                cx, cy = x + dx, y - 0.48
                ax.scatter(cx, cy, s=sub_bubble_size(pct), color=subgroup_colors[lab], edgecolor="black", alpha=0.95)
                ax.text(cx, cy, lab, fontsize=14, ha="center", va="center", weight="bold")
                ax.text(cx, cy + 0.10, f"{pct:.1f}%", fontsize=12, ha="center", va="bottom", weight="bold")

    ymin, ymax = ax.get_ylim()
    ax.set_ylim(ymin, ymax + 0.7)

    ax.set_xticks(range(len(period_labels)))
    ax.set_xticklabels(period_labels, fontsize=12)
    ax.set_yticks(range(len(age_bins_sorted)))
    ax.set_yticklabels(age_bins_sorted, fontsize=12)

    ax.set_title(f"Cohort HaloMap™ (Unique patients – {cohort_label})", fontsize=18, pad=20)
    ax.set_xlabel("Period", fontsize=13)
    ax.set_ylabel("Age Group", fontsize=13)
    ax.grid(axis="y", linestyle="--", alpha=0.3)

    legend_elements = [
        Line2D([0],[0], marker="o", color="w", label="Unique % in Period",
               markerfacecolor="lightblue", markersize=18, markeredgecolor="black"),
        Line2D([0],[0], marker="o", color="w", label="Male % of Known (bin)",
               markerfacecolor=subgroup_colors["M"], markersize=15, markeredgecolor="black"),
        Line2D([0],[0], marker="o", color="w", label="Female % of Known (bin)",
               markerfacecolor=subgroup_colors["F"], markersize=15, markeredgecolor="black"),
        Line2D([0],[0], marker="o", color="w", label="Deaths % of Period (sums to 100 if bins cover deaths)",
               markerfacecolor=subgroup_colors["D"], markersize=15, markeredgecolor="black"),
    ]
    ax.legend(handles=legend_elements, loc="lower center", bbox_to_anchor=(0.5, -0.25),
              ncol=4, frameon=False, fontsize=11)

    plt.tight_layout()
    #plt.subplots_adjust(bottom=0.30)

    # -------------------------------------------------------
    # Save Halo plot as JPEG
    # -----------------------------
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"{cohort_label}_HaloPlot.jpeg")
        plt.savefig(save_path, dpi=dpi, format="jpeg", bbox_inches="tight", facecolor="white")
        print(f"✅ Saved Halo plot → {save_path}")


    
    plt.show()


# -----------------------------------------------------------
# 7) Runner
# -----------------------------------------------------------

def run_halomap_and_tables(cohort_label, data_dir, years, periods, age_bins, output_dir):
    yearly_raw, cohort_key = load_yearly_raw_for_cohort(data_dir, years, cohort_label)
    death_master = build_death_master_from_yearly(yearly_raw, years)
    print(f"✅ Death master rows (deaths) for {cohort_label}: {len(death_master)}")

    period_agebin_df, plot_df = build_period_agebin_table_unique(
        cohort_label, yearly_raw, death_master, periods, age_bins, years, output_dir
    )
    period_totals_df = build_period_totals_table(
        cohort_label, yearly_raw, death_master, periods, years, output_dir
    )
    yearly_followup_df = build_yearly_followup_table_unique(
        cohort_label, yearly_raw, years, output_dir
    )

    plot_period_haloplot_unique(cohort_label, plot_df, periods, save_dir=os.path.join(output_dir, "Figures"), dpi=600)



    return {
        "yearly_raw": yearly_raw,
        "death_master": death_master,
        "period_agebin_df": period_agebin_df,
        "period_totals_df": period_totals_df,
        "yearly_followup_df": yearly_followup_df,
        "plot_df": plot_df,
        "cohort_key": cohort_key
    }


# -----------------------------------------------------------
# 8) Example usage (EDIT paths + bins)
# -----------------------------------------------------------

years = range(2000, 2020)

data_dir    = "../PKL_CAT_Jan_2025/Cat2_BAV_1_5"

output_dir = "Results/Section_1_Clustering/HaloMap_Unique_Consistent_v2"

# To make deaths % sum to 100, make the last bin open-ended (66+):
age_bins = [(0,55), (56,65), (66, None)]

periods = [(2000,2007), (2008,2012), (2013,2016), (2017,2019)]

bav_results = run_halomap_and_tables("BAV", data_dir, years, periods, age_bins, os.path.join(output_dir, "BAV"))
nonbav_results = run_halomap_and_tables("Non-BAV", data_dir, years, periods, age_bins, os.path.join(output_dir, "Non-BAV"))
