In [None]:
# Modified version: pick a REAL matrix (medoid) per group as representative
# - After EM, for each group g, we compute the standardized group mean curve M_g (T,F).
# - Among samples whose hard assignment == g, choose the sample i minimizing ||Yz[i] - M_g||_F^2.
# - If a group is empty (rare), fall back to global nearest by soft distance.
#
# Usage is the same:
#   res = run_multitrajectory_clustering(input_dir="...", n_groups=..., degree=2, output_dir="...")
#
# Outputs now save, for each group:
#   representative_group_{g}.csv  -> the chosen REAL sample matrix (not the model mean)
#   representatives_map.csv       -> which original file was chosen and its distance

import os
import glob
import math
import numpy as np
import pandas as pd
from typing import List, Tuple, Dict

def _read_csv_matrix(path: str) -> np.ndarray:
    df = pd.read_csv(path)
    df = df.copy()
    df = df.ffill().bfill()
    if df.isna().any().any():
        df = df.fillna(df.mean(numeric_only=True))
    df = df.select_dtypes(include=[np.number])
    return df.to_numpy(dtype=float)

def _build_time_design(T: int, degree: int) -> np.ndarray:
    t = np.linspace(0.0, 1.0, T)
    X = np.vstack([t**d for d in range(degree+1)]).T
    return X

def _init_by_kmeans(vecs: np.ndarray, k: int, max_iter: int = 50, seed: int = 42) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.RandomState(seed)
    N, D = vecs.shape
    centroids = np.empty((k, D))
    centroids[0] = vecs[rng.randint(N)]
    for c in range(1, k):
        dists = np.min(((vecs[:, None, :] - centroids[None, :c, :])**2).sum(axis=2), axis=1)
        probs = dists / (dists.sum() + 1e-12)
        idx = rng.choice(N, p=probs)
        centroids[c] = vecs[idx]
    labels = np.zeros(N, dtype=int)
    for _ in range(max_iter):
        d2 = ((vecs[:, None, :] - centroids[None, :, :])**2).sum(axis=2)
        new_labels = np.argmin(d2, axis=1)
        if np.all(new_labels == labels):
            break
        labels = new_labels
        for c in range(k):
            mask = labels == c
            if mask.any():
                centroids[c] = vecs[mask].mean(axis=0)
            else:
                centroids[c] = vecs[rng.randint(N)]
    return centroids, labels

def _weighted_least_squares(X: np.ndarray, Y: np.ndarray, w: np.ndarray) -> Tuple[np.ndarray, float]:
    w = np.asarray(w, dtype=float) + 1e-12
    WX = X * w[:, None]
    XtWX = X.T @ WX
    XtWy = X.T @ (w * Y)
    reg = 1e-6 * np.eye(XtWX.shape[0])
    beta = np.linalg.solve(XtWX + reg, XtWy)
    resid = Y - X @ beta
    eff_dof = max(w.sum() - X.shape[1], 1.0)
    sigma2 = float((w * resid**2).sum() / eff_dof)
    sigma2 = max(sigma2, 1e-8)
    return beta, sigma2

def run_multitrajectory_clustering(
    input_dir: str,
    n_groups: int = 3,
    degree: int = 2,
    max_em_iter: int = 100,
    tol: float = 1e-4,
    seed: int = 42,
    output_dir: str = None
) -> Dict[str, any]:
    assert os.path.isdir(input_dir), f"Directory not found: {input_dir}"
    paths = sorted(glob.glob(os.path.join(input_dir, "*.csv")))
    assert paths, f"No CSV files found in: {input_dir}"
    
    # Load
    matrices = [ _read_csv_matrix(p) for p in paths ]
    shapes = {m.shape for m in matrices}
    assert len(shapes) == 1, f"All CSVs must have the same shape, found shapes: {shapes}"
    T, F = matrices[0].shape
    N = len(matrices)
    Y = np.stack(matrices, axis=0)  # (N, T, F)
    
    # Standardize per-feature across (N,T)
    Y_flat = Y.reshape(N*T, F)
    meanF = Y_flat.mean(axis=0, keepdims=True)
    stdF = Y_flat.std(axis=0, keepdims=True) + 1e-12
    Yz = (Y_flat - meanF) / stdF
    Yz = Yz.reshape(N, T, F)
    
    X = _build_time_design(T, degree)  # (T,P)
    P = X.shape[1]
    
    # Init responsibilities via k-means
    vecs = Yz.reshape(N, T*F)
    _, labels = _init_by_kmeans(vecs, n_groups, seed=seed)
    R = np.zeros((N, n_groups))
    R[np.arange(N), labels] = 1.0
    
    rng = np.random.RandomState(seed)
    pi = R.mean(axis=0)
    beta = rng.normal(scale=0.01, size=(n_groups, F, P))
    sigma2 = np.ones((n_groups, F))
    
    def sample_group_loglik(y_if: np.ndarray) -> np.ndarray:
        ll = np.zeros(n_groups)
        for g in range(n_groups):
            means = X @ beta[g].T  # (T,F)
            resid2 = (y_if - means)**2
            s2 = sigma2[g][None, :]
            ll_f = -0.5 * (T * np.log(2*np.pi*s2) + (resid2 / s2).sum(axis=0))
            ll[g] = ll_f.sum()
        return ll
    
    prev_elbo = -np.inf
    for it in range(max_em_iter):
        # M-step
        pi = R.mean(axis=0) + 1e-12
        pi = pi / pi.sum()
        X_stack = np.tile(X, (N, 1))  # (N*T, P)
        for g in range(n_groups):
            w_i = R[:, g]
            w_stack = np.repeat(w_i, T)
            for f in range(F):
                y_stack = Yz[:, :, f].reshape(N*T)
                b, s2 = _weighted_least_squares(X_stack, y_stack, w_stack)
                beta[g, f, :] = b
                sigma2[g, f] = s2
        
        # E-step
        log_pi = np.log(pi + 1e-12)
        logR = np.zeros((N, n_groups))
        for i in range(N):
            ll = sample_group_loglik(Yz[i])
            logR[i] = ll + log_pi
        logR = logR - logR.max(axis=1, keepdims=True)
        R_new = np.exp(logR)
        R_new = R_new / R_new.sum(axis=1, keepdims=True)
        
        # ELBO proxy
        elbo = 0.0
        for i in range(N):
            ll = sample_group_loglik(Yz[i]) + log_pi
            m = ll.max()
            elbo += m + math.log(np.exp(ll - m).sum())
        
        R = R_new
        if it > 0 and abs(elbo - prev_elbo) < tol * (1 + abs(prev_elbo)):
            break
        prev_elbo = elbo
    
    # Hard assignments
    hard = R.argmax(axis=1)
    
    # Build group mean curves (standardized) for distance computation
    group_means_z = [(X @ beta[g].T) for g in range(n_groups)]  # list of (T,F)
    
    # Choose medoid per group: nearest REAL sample in standardized space
    rep_indices = []
    rep_distances = []
    for g in range(n_groups):
        members = np.where(hard == g)[0]
        if members.size == 0:
            # fallback to all with soft weighting; pick global nearest
            members = np.arange(N)
        Mz = group_means_z[g]  # (T,F)
        # Frobenius norm squared distances
        dists = np.array([np.sum((Yz[i] - Mz)**2) for i in members])
        idx_local = members[np.argmin(dists)]
        rep_indices.append(int(idx_local))
        rep_distances.append(float(dists[np.argmin(dists)]))
    
    # Save outputs
    if output_dir is None:
        output_dir = os.path.join("/mnt/data", f"gbtm_multitraj_results_k{n_groups}_deg{degree}")
    os.makedirs(output_dir, exist_ok=True)
    
    rep_paths = []
    for g, i_idx in enumerate(rep_indices, start=1):
        df_rep = pd.DataFrame(Y[i_idx], columns=[f"feat_{j}" for j in range(F)])
        out_path = os.path.join(output_dir, f"representative_group_{g}.csv")
        df_rep.to_csv(out_path, index=False)
        rep_paths.append(out_path)
    
    # Save assignment + which sample chosen per group
    assign_df = pd.DataFrame({
        "file": paths,
        "group": (hard + 1),
        **{f"resp_{g+1}": R[:, g] for g in range(n_groups)}
    })
    assign_path = os.path.join(output_dir, "assignments.csv")
    assign_df.to_csv(assign_path, index=False)
    
    rep_map = pd.DataFrame({
        "group": list(range(1, n_groups+1)),
        "chosen_index": rep_indices,
        "chosen_file": [paths[i] for i in rep_indices],
        "distance_to_mean_curve_z2": rep_distances
    })
    rep_map_path = os.path.join(output_dir, "representatives_map.csv")
    rep_map.to_csv(rep_map_path, index=False)
    
    return {
        "representative_paths": rep_paths,
        "representatives_map_path": rep_map_path,
        "assignments_path": assign_path,
        "output_dir": output_dir,
        "n_iter": it + 1
    }

res = run_multitrajectory_clustering(
    input_dir="./data/lorenz",
    n_groups=3,
    degree=2,
    output_dir="./test_res"
)
print("Representatives:", res["representative_paths"])
print("Assignments:", res["assignments_path"])
print("Map:", res["representatives_map_path"])


Representative paths: ['./test_res/representative_group_1.csv', './test_res/representative_group_2.csv', './test_res/representative_group_3.csv', './test_res/representative_group_4.csv', './test_res/representative_group_5.csv', './test_res/representative_group_6.csv', './test_res/representative_group_7.csv', './test_res/representative_group_8.csv', './test_res/representative_group_9.csv', './test_res/representative_group_10.csv', './test_res/representative_group_11.csv', './test_res/representative_group_12.csv', './test_res/representative_group_13.csv', './test_res/representative_group_14.csv', './test_res/representative_group_15.csv', './test_res/representative_group_16.csv', './test_res/representative_group_17.csv', './test_res/representative_group_18.csv', './test_res/representative_group_19.csv', './test_res/representative_group_20.csv', './test_res/representative_group_21.csv', './test_res/representative_group_22.csv', './test_res/representative_group_23.csv', './test_res/represen