In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn_extra.cluster import KMedoids
from pathlib import Path
import re

def parse_year(item):
    if not isinstance(item, str):
        return None
    item = item.strip()
    if not item:
        return None

    if "bc" in item.lower() or "bce" in item.lower():
        match = re.findall(r"\d{1,4}", item)
        if match:
            return -int(match[0])

    if "to" in item or "-" in item:
        match = re.findall(r"\d{3,4}", item)
        if len(match) >= 2:
            start, end = map(int, match[:2])
            return (start + end) // 2
        elif match:
            return int(match[0])

    if "century" in item.lower():
        match = re.findall(r"\d+", item)
        if match:
            century_num = int(match[0])
            year = century_num * 100 - 50
            if "bc" in item.lower():
                year = -year
            return year

    match = re.findall(r"\d{3,4}", item)
    if match:
        return int(match[0])

    return None
def run_kmedoids_with_years(wasserstein_csv, dates_csv, n_clusters=6,
                            relative=False, outlier_frac=0.1, save_path=None):

    df = pd.read_csv(wasserstein_csv, index_col=0)
    dist_matrix = df.values
    langs = [x.replace(".txt", "") for x in df.index.tolist()]

    kmed = KMedoids(n_clusters=n_clusters, metric="precomputed", random_state=0).fit(dist_matrix)
    labels = kmed.labels_
    medoids = [langs[m] for m in kmed.medoid_indices_]

    dates_df = pd.read_csv(dates_csv)
    years_dict = {}
    for _, row in dates_df.iterrows():
        lang = row["filename"].replace(".txt", "")
        if lang in langs:
            yrs = []
            for col in dates_df.columns:
                if col.startswith("dates_part"):
                    for item in str(row[col]).split(","):
                        yr = parse_year(item)
                        if yr is not None and 1800 <= yr <= 2025:
                            yrs.append(yr)
            years_dict[lang] = yrs

    dists = np.zeros(len(langs))
    for cluster_id in np.unique(labels):
        cluster_idx = np.where(labels == cluster_id)[0]
        cluster_langs = [langs[i] for i in cluster_idx]
        cluster_years = []
        for l in cluster_langs:
            cluster_years.extend(years_dict.get(l, []))
        if cluster_years:
            centroid = np.median(cluster_years)
            for idx in cluster_idx:
                yrs = years_dict.get(langs[idx], [])
                if yrs:
                    dists[idx] = np.abs(np.median(yrs) - centroid)
                else:
                    dists[idx] = np.inf

    keep_n = max(1, int(len(dists) * (1 - outlier_frac)))
    keep_idx = np.argsort(dists)[:keep_n]
    keep_langs = {langs[i] for i in keep_idx}

    cluster_years = {c: [] for c in range(n_clusters)}
    for i, lang in enumerate(langs):
        if lang in keep_langs:
            cluster_years[labels[i]].extend(years_dict.get(lang, []))

    fig, ax = plt.subplots(figsize=(10, 6))
    for c in range(n_clusters):
        data = cluster_years[c]
        if not data:
            continue
        if relative:
            medoid = medoids[c]
            medoid_years = years_dict.get(medoid, [])
            if medoid_years:
                medoid_center = int(np.median(medoid_years))
                aligned = [y - medoid_center for y in data]
                sns.kdeplot(aligned, fill=True, alpha=0.4, ax=ax,
                            label=f"Cluster {c} (medoid={medoid})")
        else:
            sns.kdeplot(data, fill=True, alpha=0.4, ax=ax,
                        label=f"Cluster {c} (medoid={medoids[c]})")

    if relative:
        ax.axvline(0, color="k", linestyle="--", label="Medoid Center")
        ax.set_xlabel("Relative Time (years from medoid)")
        ax.set_title(f"{Path(wasserstein_csv).stem} (relative, {int(outlier_frac*100)}% outliers removed)")
    else:
        ax.set_xlim(1800, 2025)
        ax.set_xlabel("Year (1800 to 2025)")
        ax.set_title(f"{Path(wasserstein_csv).stem} (absolute, {int(outlier_frac*100)}% outliers removed)")

    ax.set_ylabel("Density")
    ax.legend(loc="upper left", fontsize=6)

    if save_path:
        plt.savefig(save_path, dpi=300)
        plt.close(fig)
        print(f"Saved: {save_path}")
    else:
        plt.show()

    return labels, medoids

FOLDER = Path("/home/njian29/Desktop")
OUT_DIR = FOLDER / "timeline_plots_1800_2025"
OUT_DIR.mkdir(exist_ok=True)

categories = ["objects", "ideologies", "sports"]
scopes = ["top10", "top25", "top50"]

dates_files = {
    "objects": FOLDER / "filename_dates_packed_historical_objects_tagged.csv",
    "ideologies": FOLDER / "filename_dates_packed_history_of_ideologies_tagged.csv",
    "sports": FOLDER / "filename_dates_packed_history_of_sports_tagged.csv",
}

results = {}

for cat in categories:
    for scope in scopes:
        wasserstein_file = FOLDER / f"wasserstein_{cat}_{scope}.csv"
        dates_file = dates_files[cat]
        name_abs = f"{cat}_{scope}_absolute.png"
        labels_abs, medoids_abs = run_kmedoids_with_years(
            wasserstein_file, dates_file, n_clusters=6,
            relative=False, outlier_frac=0.1,
            save_path=OUT_DIR / name_abs
        )
        results[f"{cat}_{scope}_absolute"] = medoids_abs

        name_rel = f"{cat}_{scope}_relative.png"
        labels_rel, medoids_rel = run_kmedoids_with_years(
            wasserstein_file, dates_file, n_clusters=6,
            relative=True, outlier_frac=0.1,
            save_path=OUT_DIR / name_rel
        )
        results[f"{cat}_{scope}_relative"] = medoids_rel

for name, m in results.items():
    print(f"{name}: {m}")