# Time-course hierarchical clustering (first-pass metrics)

This notebook reproduces the PNG + Excel + ZIP outputs we generated in chat.

**You only need to edit the `INPUT_PATH`, `OUTPUT_DIR`, and `DROP_TIMEPOINTS` variables in the next cell.**


In [6]:
# ====== USER SETTINGS (edit these) ======
INPUT_PATH = r"J:/Cohen Lab/Maria Clara/2_Lab data/9_Napari/OUTPUT new code/New PCA soma/PCA cleaning/Data S1.csv"  # CSV or XLSX
OUTPUT_DIR = r"J:/Cohen Lab/Maria Clara/2_Lab data/9_Napari/OUTPUT new code/hierarchical clustering/OUT_timecourse_clustering/OUT_timecourse_k12_DiffLayout"  # folder will be overwritten

# Put timepoints to drop here. Examples:
#   []            -> keep all
#   ["iPSCs"]     -> neurites-style
DROP_TIMEPOINTS = []

# Clustering + plotting parameters
TP_ORDER = ["iPSCs", "day7", "day14", "day21", "day28"]
KMAX = 12
MIN_NON_NA = 5
VLIM = 2.5
DPI = 170
CMAP = "BrBG"


In [7]:
import re, shutil, warnings
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import pdist

warnings.filterwarnings('ignore')


In [8]:
def normalize_tp(x):
    if pd.isna(x):
        return None
    s = str(x).lower()
    if "ipsc" in s:
        return "iPSCs"
    for d in [28, 21, 14, 7]:
        if re.search(rf"day[\s_-]*{d}\b", s):
            return f"day{d}"
    for d in [28, 21, 14, 7]:
        if re.search(rf"(?<!\d)d[\s_-]*{d}(?!\d)", s):
            return f"day{d}"
    return None

def find_tp_col(df):
    for c in df.columns:
        if c.lower().replace('_','-') in ['cell-type','celltype','cell_type']:
            return c
    for c in df.columns[:30]:
        if df[c].astype(str).str.contains('day|ipsc|d7|d14|d21|d28', case=False, na=False).mean() > 0.2:
            return c
    raise ValueError("Couldn't find timepoint column. Expected something like 'cell-type'.")

def find_image_col(df):
    for c in df.columns:
        if c.lower().replace('_','-') in ['image-name','imagename','image_name']:
            return c
    return None

def ensure_numeric(df, cols, min_non_na):
    keep = []
    for c in cols:
        if pd.api.types.is_numeric_dtype(df[c]):
            keep.append(c)
        else:
            coer = pd.to_numeric(df[c], errors='coerce')
            if coer.notna().sum() >= min_non_na:
                df[c] = coer
                keep.append(c)
    return keep

def plot_heat(out_png, Z, labels_ord, matrix, tp_order, vlim, cmap,
              percell_mode=False, tp_idx_vec=None, dpi=170):
    """Timecourse heatmap with:
      - dendrogram on the LEFT
      - a vertical cluster-color strip between dendrogram and heatmap
      - a clean legend for clusters (no overlaps)
      - timepoint labels ABOVE the heatmap
      - horizontal z-score colorbar BELOW
    """
    import seaborn as sns
    import matplotlib.patches as mpatches

    fig = plt.figure(figsize=(16, 22))

    # Axes layout (figure coords):
    # dendrogram (left), cluster strip (thin), heatmap (center), colorbar (bottom under heatmap)
    ax_d = fig.add_axes([0.05, 0.12, 0.18, 0.78])   # dendrogram (left)
    ax_s = fig.add_axes([0.235, 0.12, 0.015, 0.78]) # cluster strip
    ax_h = fig.add_axes([0.255, 0.12, 0.69, 0.78])  # heatmap
    ax_c = fig.add_axes([0.255, 0.06, 0.69, 0.02])  # colorbar (below heatmap)

    # --- Dendrogram on the left, pointing RIGHT toward the heatmap
    if Z is not None and matrix.shape[0] > 1:
        dendrogram(Z, orientation='left', ax=ax_d, no_labels=True, color_threshold=None)
    ax_d.set_xticks([]); ax_d.set_yticks([])

    # --- Cluster strip colors (based on labels_ord, which already matches the plotted order)
    clusters = sorted(np.unique(labels_ord))
    pal = sns.color_palette("tab20", n_colors=len(clusters))
    cl2color = {cl: pal[i] for i, cl in enumerate(clusters)}
    strip_rgb = np.array([cl2color[int(cl)] for cl in labels_ord]).reshape(-1, 1, 3)

    ax_s.imshow(strip_rgb, aspect='auto', interpolation='nearest')
    ax_s.set_xticks([]); ax_s.set_yticks([])

    # --- Heatmap
    im = ax_h.imshow(matrix, aspect='auto', interpolation='nearest', cmap=cmap, vmin=-vlim, vmax=vlim)
    ax_h.set_yticks([])

    # Put timepoint labels ABOVE the heatmap
    ax_h.xaxis.set_ticks_position('top')
    ax_h.xaxis.set_label_position('top')
    ax_h.tick_params(axis='x', top=True, bottom=False, labeltop=True, labelbottom=False)

    if percell_mode and tp_idx_vec is not None:
        x_pos, labs = [], []
        for tp in sorted(np.unique(tp_idx_vec)):
            idxs = np.where(tp_idx_vec == tp)[0]
            if len(idxs) == 0:
                continue
            if len(x_pos) > 0:
                ax_h.axvline(idxs[0] - 0.5, color='black', lw=0.5)
            x_pos.append(int((idxs[0] + idxs[-1]) / 2))
            labs.append(tp_order[tp])
        ax_h.set_xticks(x_pos); ax_h.set_xticklabels(labs, fontsize=9)
        ax_h.set_xlabel('Cells grouped by time point')
    else:
        ax_h.set_xticks(range(len(tp_order)))
        ax_h.set_xticklabels(tp_order, fontsize=10)
        ax_h.set_xlabel('Time point (median z)')

    # --- Horizontal colorbar BELOW
    cb = plt.colorbar(im, cax=ax_c, orientation='horizontal')
    cb.set_label('z-score')

    # --- Cluster legend (top-left, away from dendrogram)
    handles = [mpatches.Patch(color=cl2color[int(cl)], label=f"Cluster {int(cl)}") for cl in clusters]
    fig.legend(handles=handles, title="Metric clusters", loc="upper left",
               bbox_to_anchor=(0.255, 1.15), frameon=True, ncol=1)

    fig.savefig(out_png, dpi=dpi, bbox_inches='tight', pad_inches=0.2)
    plt.close(fig)

def write_excel(path, overview_dict, metrics_flat, med_flat, tp_order,
                metrics_ord, labels_ord, med_ord, z_ord, dz_ord):
    with pd.ExcelWriter(path, engine="openpyxl") as xw:
        pd.DataFrame({'item': list(overview_dict.keys()), 'value': list(overview_dict.values())}).to_excel(
            xw, 'overview', index=False
        )
        pd.DataFrame({'flat_metric': metrics_flat}).to_excel(xw, 'flat_metrics_list', index=False)
        if len(metrics_flat):
            pd.DataFrame(med_flat, index=metrics_flat,
                         columns=[f'median_{tp}' for tp in tp_order]).to_excel(
                xw, 'flat_medians_byTP', index=True
            )

        master = pd.DataFrame({'metric': metrics_ord, 'cluster_id': labels_ord})
        for i, tp in enumerate(tp_order):
            master[f'median_{tp}'] = med_ord[:, i]
            master[f'z_{tp}'] = z_ord[:, i]
        for i, (a, b) in enumerate(zip(tp_order[:-1], tp_order[1:])):
            master[f'dz_{b}-vs-{a}'] = dz_ord[:, i]
        master['sum_abs_dz'] = np.abs(dz_ord).sum(axis=1)
        master['rank_sum_abs_dz'] = master['sum_abs_dz'].rank(ascending=False, method='dense').astype(int)
        master.to_excel(xw, 'metric_to_cluster', index=False)

        for cid in sorted(pd.unique(labels_ord)):
            tab = master[master['cluster_id'] == cid].sort_values('rank_sum_abs_dz')
            tab.to_excel(xw, f'c{cid}'[:31], index=False)


In [9]:
# ====== RUN PIPELINE ======
inp = Path(INPUT_PATH)
outdir = Path(OUTPUT_DIR)
if outdir.exists():
    shutil.rmtree(outdir)
outdir.mkdir(parents=True, exist_ok=True)

# Load
if inp.suffix.lower() in ['.xlsx', '.xls']:
    xl = pd.ExcelFile(inp)
    df0 = xl.parse(xl.sheet_names[0])
else:
    df0 = pd.read_csv(inp)

df0.columns = [str(c) for c in df0.columns]
tp_col = find_tp_col(df0)
img_col = find_image_col(df0)

df0['dataset_norm'] = df0[tp_col].apply(normalize_tp)

# Apply DROP
drop_set = set(DROP_TIMEPOINTS or [])
df = df0[df0['dataset_norm'].isin(TP_ORDER)].copy()
if drop_set:
    df = df[~df['dataset_norm'].isin(drop_set)].copy()

tp_order_use = [tp for tp in TP_ORDER if tp not in drop_set]
assert len(tp_order_use) >= 2, 'Need at least 2 timepoints after DROP.'

tp_to_idx = {tp:i for i,tp in enumerate(tp_order_use)}
df['_tp_idx'] = df['dataset_norm'].map(tp_to_idx).astype(int)
df = df.sort_values('_tp_idx').reset_index(drop=True)

# Metrics: all numeric columns except ID columns
id_like = {'dataset_norm', '_tp_idx', tp_col}
if img_col:
    id_like.add(img_col)
cand_cols = [c for c in df.columns if c not in id_like]
metrics = ensure_numeric(df, cand_cols, MIN_NON_NA)
cov = df[metrics].notna().sum(axis=0)
metrics_cov = cov[cov >= MIN_NON_NA].index.tolist()

# Median per timepoint (forced rows)
med_df = df.groupby('_tp_idx')[metrics_cov].median(numeric_only=True).reindex(range(len(tp_order_use)))
med = med_df.to_numpy().T

# Drop all-NaN metrics
ok_any = np.isfinite(med).any(axis=1)
metrics2 = [m for m,ok in zip(metrics_cov, ok_any) if ok]
med = med[ok_any,:]

# Dynamic vs flat
var_time = np.nanvar(med, axis=1)
mask_dyn = var_time > 0
metrics_dyn = [m for m,ok in zip(metrics2, mask_dyn) if ok]
metrics_flat = [m for m,ok in zip(metrics2, mask_dyn) if not ok]
med_dyn = med[mask_dyn,:]
med_flat = med[~mask_dyn,:] if len(metrics_flat) else np.zeros((0,len(tp_order_use)))

# z-score across timepoints (median-based)
mu = np.nanmean(med_dyn, axis=1, keepdims=True)
sd = np.nanstd(med_dyn, axis=1, ddof=1, keepdims=True)
sd[~np.isfinite(sd)] = 1.0
sd[sd==0] = 1.0
z_med = np.nan_to_num((med_dyn-mu)/sd, nan=0.0)
dz = z_med[:,1:] - z_med[:,:-1]

# per-cell z-score for visualization
tp_idx_vec = df['_tp_idx'].to_numpy()
percell = []
for m in metrics_dyn:
    v = df[m].to_numpy(float)
    mu1 = np.nanmean(v)
    sd1 = np.nanstd(v, ddof=1)
    if not np.isfinite(sd1) or sd1==0:
        sd1 = 1.0
    percell.append(np.nan_to_num((v-mu1)/sd1, nan=0.0))
percell = np.vstack(percell) if percell else np.zeros((0,len(df)))

print('Loaded rows:', len(df))
print('Timepoint column used:', tp_col)
print('Image column:', img_col)
print('Timepoints used:', tp_order_use)
print('Metrics (>= MIN_NON_NA):', len(metrics_cov))
print('Dynamic metrics:', len(metrics_dyn), '| Flat metrics:', len(metrics_flat))


Loaded rows: 395
Timepoint column used: cell_type
Image column: image_name
Timepoints used: ['iPSCs', 'day7', 'day14', 'day21', 'day28']
Metrics (>= MIN_NON_NA): 248
Dynamic metrics: 190 | Flat metrics: 58


In [10]:
# ====== CLUSTER + EXPORT ======
def run_diffward():
    n = dz.shape[0]
    if n <= 1:
        return None, list(range(n)), np.ones(n,int)
    Z = linkage(pdist(dz, metric='euclidean'), method='ward', optimal_ordering=False)
    leaves = dendrogram(Z, no_plot=True)['leaves']
    labels = fcluster(Z, t=KMAX, criterion='maxclust')
    return Z, leaves, labels

def run_corravg():
    n = z_med.shape[0]
    if n <= 1:
        return None, list(range(n)), np.ones(n,int)
    corr = np.corrcoef(z_med)
    dist = np.nan_to_num(1-corr, nan=1.0)
    dcond = dist[np.triu_indices(n,1)]
    Z = linkage(dcond, method='average', optimal_ordering=False)
    leaves = dendrogram(Z, no_plot=True)['leaves']
    labels = fcluster(Z, t=KMAX, criterion='maxclust')
    return Z, leaves, labels

def run_patterncomplete():
    sgn = np.zeros_like(dz, int)
    sgn[dz>0] = 1
    sgn[dz<0] = -1
    n = sgn.shape[0]
    if n <= 1:
        return None, list(range(n)), np.ones(n,int)
    d=[]
    for i in range(n-1):
        for j in range(i+1,n):
            d.append(np.mean(sgn[i,:] != sgn[j,:]))
    Z = linkage(np.array(d,float), method='complete', optimal_ordering=False)
    leaves = dendrogram(Z, no_plot=True)['leaves']
    labels = fcluster(Z, t=KMAX, criterion='maxclust')
    return Z, leaves, labels

methods = {
    'DIFF_WARD': run_diffward,
    'CORR_AVG': run_corravg,
    'PATTERN_COMPLETE': run_patterncomplete,
}

meta_rows=[]
for name, fn in methods.items():
    Z, leaves, labels = fn()
    # --- Make cluster IDs match the visual order (top-to-bottom in the dendrogram) ---
    # fcluster() assigns numeric labels based on tree traversal, which can look "scrambled" compared
    # to the dendrogram leaf order. We remap cluster IDs by the order of first appearance in the
    # dendrogram-ordered labels, so PNG labels and Excel labels are identical.
    def _remap_labels_by_first_appearance(labels_in_leaf_order):
        mapping = {}
        new = np.empty_like(labels_in_leaf_order, dtype=int)
        nxt = 1
        for i, old in enumerate(labels_in_leaf_order):
            old = int(old)
            if old not in mapping:
                mapping[old] = nxt
                nxt += 1
            new[i] = mapping[old]
        return new, mapping

    metrics_ord = np.array([metrics_dyn[i] for i in leaves])
    labels_leaf = labels[leaves]
    labels_ord, _label_map = _remap_labels_by_first_appearance(labels_leaf)
    # also remap the full label vector (for summary counts etc.)
    labels = np.array([_label_map[int(x)] for x in labels], dtype=int)
    z_ord = z_med[leaves,:]
    dz_ord = dz[leaves,:]
    med_ord = med_dyn[leaves,:]
    percell_ord = percell[leaves,:]

    plot_heat(outdir/f"{name}_MEDIAN_HEAT_WITH_DENDRO_NUM.png", Z, labels_ord, z_ord,
              tp_order_use, VLIM, CMAP, percell_mode=False, dpi=DPI)
    plot_heat(outdir/f"{name}_PERCELL_HEAT_WITH_DENDRO_NUM.png", Z, labels_ord, percell_ord,
              tp_order_use, VLIM, CMAP, percell_mode=True, tp_idx_vec=tp_idx_vec, dpi=DPI)

    overview = {
        'input_file': inp.name,
        'tp_col_used': tp_col,
        'image_col': img_col,
        'rows_used': len(df),
        f'metrics_cov>={MIN_NON_NA}': len(metrics_cov),
        'dynamic_metrics': len(metrics_dyn),
        'flat_metrics': len(metrics_flat),
        'kmax': KMAX,
        'n_clusters': int(pd.Series(labels).nunique()),
        'timepoints_present': ', '.join(tp_order_use),
        'dropped_timepoints': ', '.join(DROP_TIMEPOINTS) if DROP_TIMEPOINTS else '',
    }

    write_excel(outdir/f"{name}_clusters_RAW_andZ.xlsx", overview,
                metrics_flat, med_flat, tp_order_use,
                metrics_ord, labels_ord, med_ord, z_ord, dz_ord)

    meta_rows.append({'method': name, 'n_clusters': overview['n_clusters']})

# Summary CSVs
tp_counts = df['dataset_norm'].value_counts().reindex(tp_order_use).fillna(0).astype(int)
pd.DataFrame({'dataset': tp_order_use, 'n_cells': tp_counts.values}).to_csv(outdir/'tp_counts.csv', index=False)
pd.DataFrame(meta_rows).to_csv(outdir/'cluster_counts.csv', index=False)

# Zip
zip_path = shutil.make_archive(str(outdir), 'zip', str(outdir))
print('ZIP:', zip_path)


ZIP: J:\Cohen Lab\Maria Clara\2_Lab data\9_Napari\OUTPUT new code\hierarchical clustering\OUT_timecourse_clustering\OUT_timecourse_k12_DiffLayout.zip
