In [None]:

# ============================================
# Notebook 02 - Topological Feature Extraction
# ============================================

from ripser import ripser
from persim import plot_diagrams
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os

# --------------------------------------------
# 1. Load preprocessed dataset
# --------------------------------------------
df_rr = pd.read_pickle("/content/drive/MyDrive/Paper_TDA_HRV/data_preprocessed/RR_preprocessed.pkl")

embedding_dim = 3
tau = 10

# --------------------------------------------
# 2. Time-delay embedding
# --------------------------------------------
def embed_time_series(ts, d=embedding_dim, tau=tau):
    N = len(ts) - (d - 1) * tau
    return np.column_stack([ts[i*tau:i*tau+N] for i in range(d)])

# --------------------------------------------
# 3. Compute persistence diagrams
# --------------------------------------------
results = []
for _, row in df_rr.iterrows():
    emb = embed_time_series(row["RR_series"])
    dgms = ripser(emb, maxdim=1)['dgms']
    results.append({
        "File": row["File"],
        "Age_Group": row["Age_Group"],
        "Diagram_H1": dgms[1]
    })

df_tda = pd.DataFrame(results)
print("✅ Persistence diagrams computed for all groups.")

# --------------------------------------------
# 4. Plot example persistence diagrams by group
# --------------------------------------------
group_order = [
    "Neonates (0–1 mo)",
    "Early Infancy (1–5 mo)",
    "Late Infancy (6–11 mo)",
    "Toddlers (1–2 yr)",
    "Preschoolers (3–5 yr)",
    "School-age (6–11 yr)",
    "Adolescents (12–17 yr)"
]

n_rows, n_cols = 2, 4
fig, axs = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
for idx, group in enumerate(group_order):
    ax = axs.flat[idx]
    subset = df_tda[df_tda["Age_Group"] == group]
    if not subset.empty:
        row = subset.iloc[0]
        plot_diagrams([np.empty((0,2)), np.array(row["Diagram_H1"])], ax=ax, show=False)
        ax.set_title(group, fontsize=11)
        ax.set_xlabel("Birth")
        ax.set_ylabel("Death")
    else:
        ax.axis("off")
fig.suptitle("Persistence Diagrams (H₁) by Age Group", fontsize=16)
plt.tight_layout()
plt.show()

# --------------------------------------------
# 5. Plot average persistence landscapes
# --------------------------------------------
def compute_landscape(diagram, num_points=100, k=1):
    if diagram.shape[0] == 0:
        return np.zeros(num_points)
    xs = np.linspace(np.min(diagram[:, 0]), np.max(diagram[:, 1]), num_points)
    landscape = []
    for x in xs:
        vals = [max(0, min(x - b, d - x)) for b, d in diagram]
        vals = sorted(vals, reverse=True)
        landscape.append(vals[k-1] if len(vals) >= k else 0)
    return np.array(landscape)

df_tda["Landscape_H1"] = df_tda["Diagram_H1"].apply(lambda d: compute_landscape(d, num_points=100, k=1))

fig, axs = plt.subplots(2, 4, figsize=(15, 6))
for idx, group in enumerate(group_order):
    ax = axs.flat[idx]
    sub = df_tda[df_tda["Age_Group"] == group]
    landscapes = [l for l in sub["Landscape_H1"] if len(l) > 0]
    if landscapes:
        mean_land = np.mean(landscapes, axis=0)
        std_land = np.std(landscapes, axis=0)
        x = np.arange(len(mean_land))
        ax.plot(x, mean_land, color="navy", lw=2)
        ax.fill_between(x, mean_land-std_land, mean_land+std_land, color="steelblue", alpha=0.3)
        ax.set_title(group, fontsize=10)
        ax.set_xlabel("Grid Index")
        ax.set_ylabel("Landscape Value")
    else:
        ax.axis("off")
fig.suptitle("Persistence Landscapes (H₁) Mean ± SD by Age Group", fontsize=15)
plt.tight_layout()
plt.show()