In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import re

TIME_WINDOWS = [
    (1700, 2025),
    (1800, 2025),
    (1900, 2025),
    (1980, 2025),
    (-1000, 2025),
]

MIN_NUM_AS_YEAR = 32
MAX_YEAR = 2025
N_COLORS = 15
GAMMA = 0.40
MAX_TICKS_X = 30

GROUPS = {
    "sports": {
        "input": Path.home() / "Desktop" / "filename_dates_packed_history_of_sports_tagged.csv",
        "out_dir": Path.home() / "Desktop" / "heatmap_sports_all",
    },
    "objects": {
        "input": Path.home() / "Desktop" / "filename_dates_packed_historical_objects_tagged.csv",
        "out_dir": Path.home() / "Desktop" / "heatmap_objects_all",
    },
    "ideologies": {
        "input": Path.home() / "Desktop" / "filename_dates_packed_history_of_ideologies_tagged.csv",
        "out_dir": Path.home() / "Desktop" / "heatmap_ideologies_all",
    },
}

def filename_to_language(fn: str) -> str:
    if not isinstance(fn, str):
        return "Unknown"
    s = fn.strip()
    if s.lower().endswith(".txt"):
        s = s[:-4]
    return s.strip()

def make_discrete_cmap(n_colors=15):
    base = plt.cm.YlOrRd(np.linspace(0.06, 1.0, n_colors))
    return ListedColormap(base, name=f"YlOrRd_{n_colors}")

def draw_heatmap(pivot, out_path, group, start, end):
    if pivot.empty:
        print(f"[WARN] {group} {start}-{end} 没有数据")
        return

    row_max = pivot.max(axis=1).replace(0, 1)
    mat_norm = pivot.div(row_max, axis=0)

    mat_gamma = np.power(mat_norm.values.astype(float), GAMMA)

    boundaries = np.linspace(0.0, 1.0, N_COLORS + 1)
    cmap = make_discrete_cmap(N_COLORS)
    norm = BoundaryNorm(boundaries, N_COLORS, clip=True)

    fig_h = max(6.0, len(mat_norm.index) * 0.40)
    fig, ax = plt.subplots(figsize=(18, fig_h), dpi=160)

    im = ax.imshow(
        mat_gamma,
        aspect="auto",
        interpolation="nearest",
        cmap=cmap,
        norm=norm,
        origin="upper"
    )

    for y in range(1, len(mat_norm.index)):
        ax.axhline(y - 0.5, color="white", linewidth=0.7)

    ax.set_yticks(np.arange(len(mat_norm.index)))
    ax.set_yticklabels(mat_norm.index.tolist())

    years = mat_norm.columns.values
    num_years = len(years)
    if num_years <= MAX_TICKS_X:
        xticks = np.arange(num_years)
    else:
        step = max(1, num_years // MAX_TICKS_X)
        xticks = np.arange(0, num_years, step)
    ax.set_xticks(xticks)
    ax.set_xticklabels([str(years[i]) for i in xticks], rotation=90, ha="center")

    ax.set_title(f"Heatmap of {group.capitalize()} (All Languages, {start}-{end})", pad=14)

    plt.tight_layout()
    plt.savefig(out_path, dpi=160)
    plt.close()
    print(f"[OK] 已保存 {out_path}")

def process_group(group: str, cfg: dict):
    in_path = cfg["input"]
    out_dir = cfg["out_dir"]

    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"[INFO] 读取 {group}: {in_path}")
    df = pd.read_csv(in_path, dtype=str, encoding="utf-8")
    df["language"] = df["filename"].apply(filename_to_language)

    date_cols = [c for c in df.columns if c.startswith("dates_part")]
    if not date_cols:
        raise RuntimeError(f"{in_path} 未发现 dates_part* 列")

    records = []
    for _, row in df.iterrows():
        lang = row["language"]
        for c in date_cols:
            txt = row.get(c, "")
            if not isinstance(txt, str) or not txt.strip():
                continue
            for y in re.findall(r"\b(\d{3,4})\b", txt):
                y = int(y)
                if MIN_NUM_AS_YEAR <= y <= MAX_YEAR:
                    records.append((lang, y))

    years_df = pd.DataFrame(records, columns=["language", "year"])
    counts = years_df.groupby(["language", "year"]).size().reset_index(name="count")

    for (start, end) in TIME_WINDOWS:
        mask = (counts["year"] >= start) & (counts["year"] <= end)
        counts_window = counts[mask]
        if counts_window.empty:
            continue

        all_years = np.arange(start, end + 1, dtype=int)
        pivot = (
            counts_window.pivot(index="language", columns="year", values="count")
            .reindex(columns=all_years)
            .fillna(0)
        )

        out_fig = out_dir / f"heatmap_{group}_all_{start}-{end}.png"
        draw_heatmap(pivot, out_fig, group, start, end)

def main():
    for group, cfg in GROUPS.items():
        process_group(group, cfg)

if __name__ == "__main__":
    main()