# Tian QC Diagnosis

Goal: diagnose ROI-level validity issues in Tian S3 subcortical timeseries before building FC.

This notebook:
- Computes per-subject, per-ROI time-series std.
- Flags invalid ROIs (std < eps).
- Summarizes ROI validity rates and subject invalidity rates.
- Visualizes example FC matrices for “good” and “bad” subjects.

Notes:
- Uses only read-only operations.
- No heavy compute on login nodes; use interactively if needed.


In [None]:
import os
from pathlib import Path
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

ROOT = Path("/storage/bigdata/UKB/fMRI/gene-brain-CCA")
TS_ROOT = ROOT / "derived_schaefer_mdd" / "tian_timeseries"
IDS_PATH = ROOT / "derived_schaefer_mdd" / "tian_subset" / "ids_tian_subset.npy"

EPS = 1e-6  # std threshold for valid ROI

# Tian S3 labels (order used in extract_tian_weights.py)
TIAN_S3_LABELS = [
    "HIP-head-l", "HIP-head-r", "HIP-body-l", "HIP-body-r",
    "HIP-tail-l", "HIP-tail-r", "HIP-subiculum-l", "HIP-subiculum-r",
    "AMY-lateral-l", "AMY-lateral-r", "AMY-medial-l", "AMY-medial-r",
    "THA-VA-l", "THA-VA-r", "THA-VL-l", "THA-VL-r",
    "THA-VP-l", "THA-VP-r", "THA-IL-l", "THA-IL-r",
    "THA-MD-l", "THA-MD-r", "THA-PU-l", "THA-PU-r",
    "NAc-core-l", "NAc-core-r", "NAc-shell-l", "NAc-shell-r",
    "CAU-head-l", "CAU-head-r", "CAU-body-l", "CAU-body-r",
    "CAU-tail-l", "CAU-tail-r",
    "PUT-anterior-l", "PUT-anterior-r", "PUT-posterior-l", "PUT-posterior-r",
    "PUT-ventral-l", "PUT-ventral-r",
    "GP-internal-l", "GP-internal-r", "GP-external-l", "GP-external-r",
    "HTH-l", "HTH-r",
    "VTA-l", "VTA-r", "SN-l", "SN-r",
]

assert len(TIAN_S3_LABELS) == 50

In [None]:
from typing import Optional

def find_tian_file(eid: str, root: Path) -> Optional[str]:
    pats = [
        root / eid / "tian_s3_*.npy",
        root / eid / f"tian_s3_{eid}_*.npy",
    ]
    for pat in pats:
        files = sorted(glob(str(pat)))
        if files:
            return files[0]
    return None

ids = np.load(IDS_PATH, allow_pickle=True).astype(str)
print(f"Subjects in subset: {len(ids)}")

stds = []
valid_ids = []
missing = []

for i, eid in enumerate(ids):
    fp = find_tian_file(eid, TS_ROOT)
    if fp is None:
        missing.append(eid)
        continue
    ts = np.load(fp)
    if ts.ndim != 2:
        missing.append(eid)
        continue
    if ts.shape[0] < ts.shape[1]:
        ts = ts.T
    T, R = ts.shape
    if R != 50:
        missing.append(eid)
        continue
    sd = ts.std(axis=0)
    stds.append(sd.astype(np.float32, copy=False))
    valid_ids.append(eid)
    if (i + 1) % 100 == 0:
        print(f"Processed {i+1}/{len(ids)}")

stds = np.vstack(stds)
valid_ids = np.array(valid_ids, dtype=object)
print(f"Valid subjects: {len(valid_ids)}")
print(f"Missing/invalid: {len(missing)}")
print("stds shape:", stds.shape)

In [None]:
valid_mask = stds > EPS

roi_valid_rate = valid_mask.mean(axis=0)
subj_invalid_rate = 1.0 - valid_mask.mean(axis=1)

# Summary
print("ROI valid rate (min/median/max):", roi_valid_rate.min(), np.median(roi_valid_rate), roi_valid_rate.max())
print("Subject invalid rate (min/median/max):", subj_invalid_rate.min(), np.median(subj_invalid_rate), subj_invalid_rate.max())

# Top problematic ROIs
problem_idx = np.argsort(roi_valid_rate)[:10]
print("Lowest valid-rate ROIs:")
for idx in problem_idx:
    print(f"{idx:2d} {TIAN_S3_LABELS[idx]:<18s} valid_rate={roi_valid_rate[idx]:.3f}")

# Quick heatmap for a random subset
rng = np.random.default_rng(42)
subset_idx = rng.choice(len(valid_ids), size=min(80, len(valid_ids)), replace=False)
plt.figure(figsize=(12, 6))
sns.heatmap(valid_mask[subset_idx, :], cbar=False)
plt.title("ROI validity (subset of subjects)")
plt.xlabel("ROI")
plt.ylabel("Subject (subset)")
plt.show()

# ROI validity barplot
plt.figure(figsize=(12, 4))
plt.bar(range(50), roi_valid_rate)
plt.axhline(0.90, color="red", linestyle="--", label="90% threshold")
plt.ylim(0, 1.05)
plt.title("ROI validity rate")
plt.xlabel("ROI index")
plt.ylabel("Valid fraction")
plt.legend()
plt.show()

In [None]:
def compute_fc(ts: np.ndarray) -> np.ndarray:
    mu = ts.mean(axis=0, keepdims=True)
    sd = ts.std(axis=0, keepdims=True)
    sd = np.where(sd == 0, 1.0, sd)
    zt = (ts - mu) / sd
    with np.errstate(invalid="ignore"):
        corr = np.corrcoef(zt, rowvar=False)
    corr = np.nan_to_num(corr, nan=0.0)
    return corr

# Pick 3 good and 3 bad subjects based on invalid rate
sorted_idx = np.argsort(subj_invalid_rate)
example_good = sorted_idx[:3]
example_bad = sorted_idx[-3:]

for label, idx_list in [("good", example_good), ("bad", example_bad)]:
    for idx in idx_list:
        eid = valid_ids[idx]
        fp = find_tian_file(eid, TS_ROOT)
        ts = np.load(fp)
        if ts.shape[0] < ts.shape[1]:
            ts = ts.T
        corr = compute_fc(ts)
        plt.figure(figsize=(4, 4))
        plt.imshow(corr, vmin=-1, vmax=1, cmap="coolwarm")
        plt.colorbar(fraction=0.046, pad=0.04)
        plt.title(f"{label} subject {eid}\ninvalid_rate={subj_invalid_rate[idx]:.2f}")
        plt.tight_layout()
        plt.show()