In [None]:
import numpy as np
from scipy.stats import entropy
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.decomposition import PCA
from kneed import KneeLocator

# === PARAMETERS ===
FOLDER = Path("/lustre/majlepy2/myproject/36files_source_localized")  # adjust if needed
segment_length = 50
bins = 30
normalize_data = True
n_nodes = 68
SCALE = 1e12

# === OUTPUT ===
output_dir = Path("/lustre/majlepy2/myproject/50_epochs_tc3d")
output_prefix = "_68localized_50epochs.npy"
output_dir.mkdir(parents=True, exist_ok=True)

# Figure path
figpath = output_dir / "group_metrics_50_epochs_segment.png"

# --- FIND FILES ---
files = sorted(FOLDER.glob("*_label_tc3d.npy"))
print(f"Found {len(files)} files in {FOLDER}")

# --- METRICS STORAGE ---
all_variances, all_entropies, all_pca_scores, valid_files = [], [], [], []

# --- PROCESS FILES ---
for file in files:
    X = np.load(file)
    print(f"\nLoaded {file.name} with shape {X.shape}")

    if X.shape[0] != n_nodes or X.shape[2] < segment_length:
        print(f"Skipping {file.name} (shape mismatch)")
        continue

    X = X * SCALE
    if normalize_data:
        X = (X - X.mean(axis=(1, 2), keepdims=True)) / (X.std(axis=(1, 2), keepdims=True) + 1e-8)

    valid_files.append(file)
    window_starts = np.arange(X.shape[2] - segment_length + 1)

    subj_vars, subj_ents, subj_pca = [], [], []

    for start in window_starts:
        segment = X[..., start:start + segment_length]

        # --- Variance ---
        avg_var = np.mean([np.var(segment[i]) for i in range(X.shape[0])])
        subj_vars.append(avg_var)

        # --- Entropy ---
        node_ents = []
        for i in range(X.shape[0]):
            values = segment[i].flatten()
            hist, _ = np.histogram(np.abs(values), bins=bins, density=True)
            probs = hist / np.sum(hist)
            probs = probs[probs > 0]
            node_ents.append(entropy(probs, base=2))
        subj_ents.append(np.mean(node_ents))

        # --- PCA Score (Elbow method) ---
        reshaped = segment.reshape(X.shape[0], -1)
        pca = PCA()
        pca.fit(reshaped.T)
        explained = pca.explained_variance_ratio_

        try:
            knee = KneeLocator(
                range(1, len(explained) + 1),
                explained,
                curve="convex",
                direction="decreasing"
            )
            elbow_idx = knee.knee or 1
        except Exception:
            elbow_idx = 1

        pca_score = np.sum(explained[:elbow_idx])
        subj_pca.append(pca_score)

    all_variances.append(subj_vars)
    all_entropies.append(subj_ents)
    all_pca_scores.append(subj_pca)

# --- FINAL CHECK ---
if not all_variances:
    raise ValueError("No valid data segments found.")

# --- ALIGN ---
min_windows = min(len(x) for x in all_variances)
V = np.array([x[:min_windows] for x in all_variances])
E = np.array([x[:min_windows] for x in all_entropies])
P = np.array([x[:min_windows] for x in all_pca_scores])

# --- NORMALIZED GROUP METRICS ---
group_var, group_ent, group_pca = V.mean(axis=0), E.mean(axis=0), P.mean(axis=0)

def normalize(x):
    denom = x.max() - x.min()
    return (x - x.min()) / denom if denom != 0 else np.zeros_like(x)

var_norm, ent_norm, pca_norm = normalize(group_var), normalize(group_ent), normalize(group_pca)
combined = (var_norm + ent_norm + pca_norm) / 3

# --- SELECT BEST WINDOW ---
best_var_idx, best_ent_idx = int(np.argmax(var_norm)), int(np.argmax(ent_norm))
best_pca_idx, best_comb_idx = int(np.argmax(pca_norm)), int(np.argmax(combined))

# --- PLOT (with matching colors) ---
colors = {"var": "tab:blue", "ent": "tab:orange", "pca": "tab:green", "comb": "tab:red"}

plt.figure(figsize=(12, 6))
plt.plot(var_norm, "-o", lw=2, markevery=20, label="Variance (normalized)", color=colors["var"])
plt.plot(ent_norm, "-s", lw=2, markevery=20, label="Entropy (normalized)", color=colors["ent"])
plt.plot(pca_norm, "-^", lw=2, markevery=20, label="PCA Score (elbow-based)", color=colors["pca"])
plt.plot(combined, "--", lw=3, label="Combined", color=colors["comb"])

plt.axvline(best_var_idx, linestyle=":", lw=1, color=colors["var"], label=f"Best Variance")
plt.axvline(best_ent_idx, linestyle=":", lw=1, color=colors["ent"], label=f"Best Entropy")
plt.axvline(best_pca_idx, linestyle=":", lw=1, color=colors["pca"], label=f"Best PCA")
plt.axvline(best_comb_idx, linestyle="-", lw=2, color=colors["comb"], label=f"Best Combined")

# Highlight selected window (best_comb_idx)
start, end = best_comb_idx, min(best_comb_idx + segment_length, len(combined))
plt.axvspan(start, end, alpha=0.15, color=colors["comb"], label=f"Selected window [{start}:{end})")

plt.xlabel("Window start epoch index")
plt.ylabel("Group-level metric (normalized)")
plt.title("Group-Level Informative Segment Selection")
plt.grid(True)
plt.legend(
    loc="lower right",
    bbox_to_anchor=(1, 0.19),   # shifted up
    fontsize=9,
    frameon=True,
    ncol=2
)
plt.tight_layout()

# --- SAVE FIGURE ---
plt.savefig(figpath, dpi=300, bbox_inches="tight")
print(f"Saved figure to {figpath}")

plt.show()

# --- EXTRACT AND SAVE SELECTED SEGMENTS ---
chosen_start = best_comb_idx
for file in valid_files:
    X = np.load(file)
    X_sel = X[..., chosen_start:chosen_start + segment_length]

    original_stem = file.stem.replace("_label_tc3d", "")
    output_name = original_stem + output_prefix
    output_path = output_dir / output_name

    np.save(output_path, X_sel)
    print(f"Saved: {output_path} with shape {X_sel.shape}")
