In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import re

TIME_WINDOWS = [
    (1700, 2025),
    (1800, 2025),
    (1900, 2025),
    (1980, 2025),
]
SCOPES = ["top10", "top25", "top50"]

MIN_NUM_AS_YEAR = 32
MAX_YEAR = 2025
N_COLORS = 15
GAMMA = 0.40
MAX_TICKS_X = 30

GROUPS = {
    "sports": {
        "input": Path.home() / "Desktop" / "filename_dates_parsed_history_of_sports_tagged.csv",
        "cluster_prefix": "dbscan_sports",
        "out_dir": Path.home() / "Desktop" / "cluster_heatmap_sports",
    },
    "objects": {
        "input": Path.home() / "Desktop" / "filename_dates_parsed_historical_objects_tagged.csv",
        "cluster_prefix": "dbscan_objects",
        "out_dir": Path.home() / "Desktop" / "cluster_heatmap_objects",
    },
    "ideologies": {
        "input": Path.home() / "Desktop" / "filename_dates_parsed_history_of_ideologies_tagged.csv",
        "cluster_prefix": "dbscan_ideologies",
        "out_dir": Path.home() / "Desktop" / "cluster_heatmap_ideologies",
    },
}

def filename_to_language(fn: str) -> str:
    if not isinstance(fn, str):
        return "Unknown"
    s = fn.strip()
    if s.lower().endswith(".txt"):
        s = s[:-4]
    return s.strip()

def read_clusters(cluster_path: Path) -> pd.DataFrame:
    df = pd.read_csv(cluster_path)
    if "language" not in df.columns or "cluster_id" not in df.columns:
        raise RuntimeError(f"{cluster_path} 必须包含 language 和 cluster_id")
    df["language"] = df["language"].astype(str).str.replace(".txt", "", regex=False).str.strip()
    df["cluster_id"] = pd.to_numeric(df["cluster_id"], errors="coerce").astype("Int64")
    return df

def make_discrete_cmap(n_colors=15):
    base = plt.cm.YlOrRd(np.linspace(0.06, 1.0, n_colors))
    return ListedColormap(base, name=f"YlOrRd_{n_colors}")

def add_cluster_gaps(df: pd.DataFrame, cluster_col="cluster_id") -> pd.DataFrame:

    new_rows = []
    last_cluster = None
    for idx, row in df.iterrows():
        cluster = row[cluster_col]
        if last_cluster is not None and cluster != last_cluster:
            gap = pd.Series(np.nan, index=df.columns, name=f"gap_cluster_{last_cluster}")
            new_rows.append(gap)
        new_rows.append(row)
        last_cluster = cluster
    return pd.DataFrame(new_rows)

def draw_clustered_heatmap(pivot, clusters, out_path, group, scope, start, end):
    pivot.index = pivot.index.str.replace(".txt", "", regex=False).str.strip()
    clusters["language"] = clusters["language"].str.replace(".txt", "", regex=False).str.strip()

    keep_langs = clusters["language"].unique()
    pivot = pivot.loc[pivot.index.intersection(keep_langs)]

    merged = pivot.reset_index().rename(columns={"index": "language"})
    merged = merged.merge(clusters, on="language", how="left")

    merged["cluster_id"] = pd.Categorical(merged["cluster_id"], ordered=True)
    merged = merged.sort_values(["cluster_id", "language"]).set_index("language")

    merged = add_cluster_gaps(merged, "cluster_id")

    mat = merged.drop(columns=["cluster_id"], errors="ignore")

    if mat.empty:
        print(f"[WARN] {group}-{scope} {start}-{end} 没有数据")
        return

    row_max = mat.max(axis=1, skipna=True).replace(0, 1)
    mat_norm = mat.div(row_max, axis=0)

    mat_gamma = np.power(mat_norm.values.astype(float), GAMMA)
    mat_gamma = np.ma.masked_invalid(mat_gamma)

    boundaries = np.linspace(0.0, 1.0, N_COLORS + 1)
    cmap = make_discrete_cmap(N_COLORS)
    cmap.set_bad(color="white")  
    norm = BoundaryNorm(boundaries, N_COLORS, clip=True)

    fig_h = max(6.0, len(mat_norm.index) * 0.40)
    fig, ax = plt.subplots(figsize=(18, fig_h), dpi=160)

    im = ax.imshow(
        mat_gamma,
        aspect="auto",
        interpolation="nearest",
        cmap=cmap,
        norm=norm,
        origin="upper"
    )

    for y in range(1, len(mat_norm.index)):
        if not str(mat.index[y]).startswith("gap_cluster_"):
            ax.axhline(y - 0.5, color="white", linewidth=0.7)

    ax.set_yticks(np.arange(len(mat_norm.index)))
    ax.set_yticklabels(mat_norm.index.tolist())

    years = mat_norm.columns.values
    num_years = len(years)
    if num_years <= MAX_TICKS_X:
        xticks = np.arange(num_years)
    else:
        step = max(1, num_years // MAX_TICKS_X)
        xticks = np.arange(0, num_years, step)
    ax.set_xticks(xticks)
    ax.set_xticklabels([str(years[i]) for i in xticks], rotation=90, ha="center")

    ax.set_title(f"Clustered Heatmap of {group.capitalize()} ({scope}, {start}-{end})", pad=14)

    plt.tight_layout()
    plt.savefig(out_path, dpi=160)
    plt.close()
    print(f"[OK] 已保存 {out_path}")

def process_group(group: str, cfg: dict):
    in_path = cfg["input"]
    out_dir = cfg["out_dir"]
    cluster_prefix = cfg["cluster_prefix"]

    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"[INFO] 读取 {group}: {in_path}")
    df = pd.read_csv(in_path, dtype=str, encoding="utf-8")
    df["language"] = df["filename"].apply(filename_to_language)

    date_cols = [c for c in df.columns if c.startswith("parsed_years")]
    if not date_cols:
        raise RuntimeError(f"{in_path} 未发现 parsed_years 列")

    records = []
    for _, row in df.iterrows():
        lang = row["language"]
        for c in date_cols:
            txt = row.get(c, "")
            if not isinstance(txt, str) or not txt.strip():
                continue
            for y in re.findall(r"\b(\d{3,4})\b", txt):
                y = int(y)
                if MIN_NUM_AS_YEAR <= y <= MAX_YEAR:
                    records.append((lang, y))

    years_df = pd.DataFrame(records, columns=["language", "year"])
    counts = years_df.groupby(["language", "year"]).size().reset_index(name="count")

    for (start, end) in TIME_WINDOWS:
        mask = (counts["year"] >= start) & (counts["year"] <= end)
        counts_window = counts[mask]
        if counts_window.empty:
            continue

        all_years = np.arange(start, end + 1, dtype=int)
        pivot = (
            counts_window.pivot(index="language", columns="year", values="count")
            .reindex(columns=all_years)
            .fillna(0)
        )

        for scope in SCOPES:
            cpath = Path.home() / "Desktop" / f"{cluster_prefix}_{scope}.csv"
            if not cpath.exists():
                print(f"[WARN] 缺少 {cpath}")
                continue
            clusters = read_clusters(cpath)
            out_fig = out_dir / f"heatmap_clustered_{group}_{scope}_{start}-{end}.png"
            draw_clustered_heatmap(pivot, clusters, out_fig, group, scope, start, end)

def main():
    for group, cfg in GROUPS.items():
        process_group(group, cfg)

if __name__ == "__main__":
    main()