In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm
import re

GROUPS = {
    "sports": {
        "dates": Path.home() / "Desktop" / "filename_dates_packed_history_of_sports_tagged.csv",
        "wordcount": Path.home() / "Desktop" / "filename_wordcounts_sports.csv",
        "out_dir": Path.home() / "Desktop" / "full_heatmap_sports",
        "gamma": 0.7,
        "pct_low": 5,
        "pct_high": 95,
    },
    "ideologies": {
        "dates": Path.home() / "Desktop" / "filename_dates_packed_history_of_ideologies_tagged.csv",
        "wordcount": Path.home() / "Desktop" / "filename_wordcounts_ideologies.csv",
        "out_dir": Path.home() / "Desktop" / "full_heatmap_ideologies",
        "gamma": 0.7,
        "pct_low": 5,
        "pct_high": 95,
    },
    "objects": {
        "dates": Path.home() / "Desktop" / "filename_dates_packed_historical_objects_tagged.csv",
        "wordcount": Path.home() / "Desktop" / "filename_wordcounts_objects.csv",
        "out_dir": Path.home() / "Desktop" / "full_heatmap_objects",
        "gamma": 0.4,
        "pct_low": 2,
        "pct_high": 98,
    },
}

MIN_NUM_AS_YEAR = 32
MAX_YEAR = 2025
MAX_TICKS_X = 30

def filename_to_language(fn: str) -> str:
    if not isinstance(fn, str):
        return "Unknown"
    s = fn.strip()
    if s.lower().endswith(".txt"):
        s = s[:-4]
    return s.strip()

def extract_years(text: str):
    years = []
    if not isinstance(text, str) or not text.strip():
        return years
    for m in re.findall(r"\b(\d{1,4})\b", text):
        val = int(m)
        if MIN_NUM_AS_YEAR <= val <= MAX_YEAR:
            years.append(val)
    return years

def draw_full_heatmap(pivot, years_sorted, out_path, title, gamma, pct_low, pct_high):
    def col_zscore(df):
        m, s = df.mean(axis=0), df.std(axis=0).replace(0, np.nan)
        return df.sub(m, axis=1).div(s, axis=1).fillna(0.0)

    def row_zscore(df):
        m, s = df.mean(axis=1), df.std(axis=1).replace(0, np.nan)
        return df.sub(m, axis=0).div(s, axis=0).fillna(0.0)

    col_z = col_zscore(pivot)
    row_z = row_zscore(pivot)
    hybrid = 0.7 * row_z + 0.3 * col_z

    vals = hybrid.values
    lo, hi = np.nanpercentile(vals, [pct_low, pct_high])
    if not np.isfinite(lo): lo = 0.0
    if not np.isfinite(hi) or hi - lo < 1e-12: hi = lo + 1.0
    mat = np.clip((vals - lo) / (hi - lo), 0, 1)

    fig_h = max(6, len(pivot.index) * 0.40)
    fig, ax = plt.subplots(figsize=(18, fig_h))
    fig.patch.set_facecolor("white")
    ax.set_facecolor("white")

    if mat.shape[1] > 0:
        im = ax.imshow(
            mat,
            aspect="auto",
            cmap="plasma",
            norm=PowerNorm(gamma=gamma, vmin=0, vmax=1),
        )
    else:
        im = ax.imshow(np.zeros((len(pivot.index), 1)), aspect="auto", cmap="plasma")

    ax.set_yticks(np.arange(len(pivot.index)))
    ax.set_yticklabels(pivot.index.tolist(), color="black")

    if len(years_sorted) > 0:
        num_years = len(years_sorted)
        step = max(1, num_years // MAX_TICKS_X) if num_years > MAX_TICKS_X else 1
        xticks = np.arange(0, num_years, step)
        ax.set_xticks(xticks)
        ax.set_xticklabels([str(years_sorted[i]) for i in xticks],
                           rotation=45, ha="right", color="black")
    else:
        ax.set_xticks([0]); ax.set_xticklabels([""], color="black")

    ax.set_xlabel("Year", color="black")
    ax.set_ylabel("Language", color="black")
    ax.set_title(title, color="black")

    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Normalized Frequency", color="black")
    plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='black')

    plt.tight_layout()
    plt.savefig(out_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor="none")
    plt.close()
    print(f"[OK] 已保存: {out_path}")

def process_group(name, cfg):
    print(f"\n[INFO] 处理 {name}")
    dates_path = cfg["dates"]
    wc_path = cfg["wordcount"]
    out_dir = cfg["out_dir"]
    out_dir.mkdir(parents=True, exist_ok=True)

    wc = pd.read_csv(wc_path)
    wc["language"] = wc["filename"].apply(filename_to_language)

    df = pd.read_csv(dates_path, dtype=str)
    df["language"] = df["filename"].apply(filename_to_language)
    date_cols = [c for c in df.columns if c.startswith("dates_part")]

    records = []
    for _, row in df.iterrows():
        lang = row["language"]
        for c in date_cols:
            txt = row.get(c, "")
            for y in extract_years(txt):
                records.append((lang, y))

    years_df = pd.DataFrame(records, columns=["language", "year"])
    counts = years_df.groupby(["language", "year"]).size().reset_index(name="count")

    if counts.empty:
        print(f"[WARN] {name} 没有数据")
        return

    years_sorted = sorted(counts["year"].unique())
    pivot = (
        counts.pivot(index="language", columns="year", values="count")
        .reindex(columns=years_sorted)
        .fillna(0)
    )

    out_fig = out_dir / f"heatmap_{name}_all_languages.png"
    draw_full_heatmap(
        pivot, years_sorted, out_fig,
        f"Full-line Heatmap of {name.capitalize()} (All Languages)",
        gamma=cfg["gamma"],
        pct_low=cfg["pct_low"],
        pct_high=cfg["pct_high"],
    )

def main():
    for name, cfg in GROUPS.items():
        process_group(name, cfg)

if __name__ == "__main__":
    main()