In [None]:
"""
draw_clustered_heatmaps_1700_2025_spaced_numbered.py
---------------------------------------------------
Original style retained, but:
- Inserts 3 blank rows between clusters
- Moves cluster numbers [0]–[4] left by 10 units
---------------------------------------------------
Author: Samuel Jiang (2025)
"""

from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import re

START_YEAR = 1700
END_YEAR = 2025
BIN_WIDTH = 1
N_COLORS = 15
GAMMA = 0.4
MAX_TICKS_X = 14
CLUSTER_GAP = 3

GROUPS = {
    "objects": {
        "parsed": Path.home() / "Desktop" / "parsed_years_historical_objects_tagged.csv",
        "clusters": Path.home() / "Desktop" / "spectral_clusters_cleaned" / "clusters_objects_cleaned.csv",
    },
    "ideologies": {
        "parsed": Path.home() / "Desktop" / "parsed_years_history_of_ideologies_tagged.csv",
        "clusters": Path.home() / "Desktop" / "spectral_clusters_cleaned" / "clusters_ideologies_cleaned.csv",
    },
    "sports": {
        "parsed": Path.home() / "Desktop" / "parsed_years_history_of_sports_tagged.csv",
        "clusters": Path.home() / "Desktop" / "spectral_clusters_cleaned" / "clusters_sports_cleaned.csv",
    },
}

OUT_DIR = Path.home() / "Desktop" / "clustered_heatmaps_single"
OUT_DIR.mkdir(exist_ok=True)

def filename_to_language(fn: str) -> str:
    if not isinstance(fn, str):
        return "Unknown"
    s = fn.strip()
    if s.lower().endswith(".txt"):
        s = s[:-4]
    return s.strip()


def parse_years(text: str):
    if not isinstance(text, str):
        return []
    return [int(x) for x in re.findall(r"-?\d{1,4}", text)
            if START_YEAR <= int(x) <= END_YEAR]


def make_discrete_cmap(n_colors=15):
    base = plt.cm.magma_r(np.linspace(0.06, 1.0, n_colors))
    return ListedColormap(base, name=f"magma_r_{n_colors}")

def build_and_draw_single(group: str, parsed_path: Path, cluster_path: Path):
    print(f"[INFO] Processing {group}")
    df = pd.read_csv(parsed_path, dtype=str, encoding="utf-8")
    df["language"] = df["filename"].apply(filename_to_language)

    clusters = pd.read_csv(cluster_path)
    clusters["Language"] = clusters["Language"].str.replace(".txt", "", regex=False)
    clusters["Cluster"] = pd.to_numeric(clusters["Cluster"], errors="coerce")
    records = []
    for _, row in df.iterrows():
        lang = row["language"]
        years = parse_years(row.get("parsed_years", ""))
        for y in years:
            records.append((lang, y))
    years_df = pd.DataFrame(records, columns=["language", "year"])
    if years_df.empty:
        print(f"[WARN] No data for {group}")
        return

    counts = years_df.groupby(["language", "year"]).size().reset_index(name="count")
    all_years = np.arange(START_YEAR, END_YEAR + 1, BIN_WIDTH)
    pivot = (
        counts.pivot(index="language", columns="year", values="count")
        .reindex(columns=all_years)
        .fillna(0)
    )

    merged = clusters.merge(pivot, left_on="Language", right_index=True, how="inner")
    merged = merged.sort_values(["Cluster", "Language"], kind="stable")

    langs_ordered = merged["Language"].tolist()
    pivot_sorted = merged.drop(columns=["Language", "Cluster"])
    clusters_sorted = merged["Cluster"].tolist()
    row_max = pivot_sorted.max(axis=1).replace(0, 1)
    mat_norm = pivot_sorted.div(row_max, axis=0)
    mat_gamma = np.power(mat_norm.values.astype(float), GAMMA)

    clusters_arr = np.array(clusters_sorted)
    mat_with_gaps = []
    cluster_labels_with_gaps = []

    for c in sorted(set(clusters_sorted)):
        idxs = np.where(clusters_arr == c)[0]
        mat_with_gaps.append(mat_gamma[idxs, :])
        cluster_labels_with_gaps.extend([langs_ordered[i] for i in idxs])
        mat_with_gaps.append(np.full((CLUSTER_GAP, mat_gamma.shape[1]), np.nan))
        cluster_labels_with_gaps.extend([""] * CLUSTER_GAP)

    mat_final = np.vstack(mat_with_gaps)

    fig_h = max(6, len(cluster_labels_with_gaps) * 0.40)
    fig, ax = plt.subplots(figsize=(18, fig_h), dpi=160)

    cmap = make_discrete_cmap(N_COLORS)
    boundaries = np.linspace(0.0, 1.0, N_COLORS + 1)
    norm = BoundaryNorm(boundaries, N_COLORS, clip=True)

    im = ax.imshow(
        mat_final,
        aspect="auto",
        interpolation="nearest",
        cmap=cmap,
        norm=norm,
        origin="upper"
    )
    for y in range(1, len(mat_final)):
        ax.axhline(y - 0.5, color="white", linewidth=0.4)
    pos = 0
    unique_clusters = sorted(set(clusters_sorted))
    for c in unique_clusters[:-1]:
        subset = merged[merged["Cluster"] == c]
        pos += len(subset) + CLUSTER_GAP
        ax.axhline(pos - 0.5, color="white", linewidth=2.0)

    pos = 0
    for c in unique_clusters:
        subset = merged[merged["Cluster"] == c]
        mid = pos + len(subset) / 2
        ax.text(-100, mid, f"[{int(c)}]",
                va="center", ha="right",
                fontsize=35, fontweight="bold", color="black")
        pos += len(subset) + CLUSTER_GAP

    ax.set_yticks(np.arange(len(cluster_labels_with_gaps)))
    ax.set_yticklabels(cluster_labels_with_gaps, fontsize=11)

    years = pivot_sorted.columns.values
    step = max(1, len(years) // MAX_TICKS_X)
    xticks = np.arange(0, len(years), step)
    ax.set_xticks(xticks)
    ax.set_xticklabels([str(years[i]) for i in xticks],
                       rotation=90, ha="center", fontsize=16)

    title_map = {
        "objects": "Historical Objects",
        "ideologies": "History of Ideologies",
        "sports": "History of Sports",
    }
    title = f"{title_map[group]} (1700–2025, {len(unique_clusters)} Clusters)"
    ax.set_title(title, pad=20, fontsize=30)

    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Normalized Frequency (Row Max Scaling + PowerNorm)", fontsize=12)

    plt.tight_layout(rect=[0.1, 0, 1, 1])

    out_path = OUT_DIR / f"{group}_clustered_heatmap_1700-2025_spaced_numbered.png"
    plt.savefig(out_path, dpi=160)
    plt.close()
    print(f"[OK] Saved {out_path}")


def main():
    for group, paths in GROUPS.items():
        build_and_draw_single(group, paths["parsed"], paths["clusters"])


if __name__ == "__main__":
    main()