## 2.1 Partition of rs-fMRI time series

We follow the paper’s partition strategy: window length L=20 timepoints.
To ensure site-consistent window counts, each site is truncated to the
time-series length reported in the paper (KKI 119, NYU 171, OHSU 74,
NeuroIMAGE 257, Peking_1 231). Subjects shorter than the target length
are dropped.

In [19]:
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

ArrayLike = Union[np.ndarray]

@dataclass
class WindowingResult:
    windows: np.ndarray  # shape: (T, N, L) 
    T: int               # number of valid windows 
    dropped: int         # number of dropped timepoints at the end (0...L-1)
    L: int               # window length
    step: int            # step size between windows
    N: int               # number of ROIs
    M: int               # original number of timepoints


def partition_time_series(
    X: ArrayLike,
    L: int = 20,
    *,
    overlap: Optional[int] = None,
    step: Optional[int] = None,
    time_axis: int = -1,
    drop_incomplete: bool = True,
    return_metadata: bool = True,
) -> Union[np.ndarray, WindowingResult]:
    """
    Sliding-window partitioning of ROI time series.

    X shape:
      - (N, M) if time_axis=-1  (default)
      - (M, N) if time_axis=0

    Params:
      - L: window length (timepoints)
      - overlap: number of timepoints overlapped between consecutive windows (0..L-1)
      - step: hop size between window starts (1..L). If provided, it overrides overlap.
      - drop_incomplete: if True, only keep full windows of length L (paper behavior)
    """
    X = np.asarray(X)
    if X.ndim != 2:
        raise ValueError(f"X must be 2D, got shape {X.shape}")
    if L <= 0:
        raise ValueError("L must be positive")
    
    if time_axis == 0:
        X_nt = X.T
    elif time_axis == -1:
        X_nt = X
    else:
        raise ValueError("time_axis must be 0 (time first) or -1 (time last)")
    
    N, M = X_nt.shape

    if step is None:
        if overlap is None:
            step = L  # no overlap
        else:
            if not (0 <= overlap < L):
                raise ValueError(f"overlap must be in [0, {L-1}], got {overlap}")
            step = L - overlap
    else:
        if not (1 <= step <= L):
            raise ValueError(f"step must be in [1, {L}], got {step}")
        
    # Compute window start indices
    starts = np.arange(0, M, step, dtype=int)
    if drop_incomplete:
        starts = starts[starts + L <= M]

    T = len(starts)
    if T == 0:
        windows = np.empty((0, N, L), dtype=X.dtype)
        dropped = M
        return WindowingResult(windows, T, dropped, L, step, N, M) if return_metadata else windows
    
    windows = np.stack([X_nt[:, s:s+L] for s in starts], axis=0)  # shape (T, N, L)

    # Define "dropped" as the tail after the last possible full-window start
    last_start = starts[-1]
    used_until = last_start + L
    dropped = max(0, M - used_until)

    if return_metadata:
        return WindowingResult(windows, T, dropped, L, step, N, M)
    else:
        return windows

def partition_dataset(
    subjects: List[np.ndarray],
    L: int = 20,
    time_axis: int = -1,
) -> Tuple[List[np.ndarray], List[int]]:
    """
    Apply partitioning to a list of subjects.

    Returns:
      - windows_per_subject: List of arrays, each (T_i, N, L)
      - T_per_subject: List[int], number of windows per subject
    """
    windows_per_subject = []
    T_per_subject = []
    for subj_ts in subjects:
        result = partition_time_series(subj_ts, L=L, time_axis=time_axis, return_metadata=True)
        windows_per_subject.append(result.windows)
        T_per_subject.append(result.T)
    return windows_per_subject, T_per_subject



### Loading Athena ROI time series (.1D)
Athena provides AAL ROI mean time series in AFNI-style .1D files with header
columns (e.g., `File`, `Sub-brick`, and `Mean_*` ROI columns). We parse these
files and keep only `Mean_*` columns, yielding a timepoints × ROIs matrix.

In [None]:
import pandas as pd
from pathlib import Path
from collections import defaultdict

# ---- Paths ----
PROJECT_ROOT = Path("..").resolve()
DATA_PROCESSED = PROJECT_ROOT / "data" / "processed"

train_csv = DATA_PROCESSED / "subjects_train_paper.csv"
test_csv  = DATA_PROCESSED / "subjects_test_paper.csv"

df_train = pd.read_csv(train_csv)
df_test  = pd.read_csv(test_csv)

print("Train rows:", len(df_train), "| cols:", list(df_train.columns))
print("Test rows :", len(df_test),  "| cols:", list(df_test.columns))

# ---- Find the time-series path column automatically ----
# Common names you might have used in df_tc / manifests:
CANDIDATE_PATH_COLS = [
    "tc_path", "timeseries_path", "time_series_path", "path",
    "roi_ts_path", "roi_timeseries_path", "file", "filepath",
    "roi_path", "aal_path"
]

def infer_path_col(df: pd.DataFrame) -> str:
    cols = {c.lower(): c for c in df.columns}
    for cand in CANDIDATE_PATH_COLS:
        if cand.lower() in cols:
            return cols[cand.lower()]
    # fallback: first col that contains 'path'
    for c in df.columns:
        if "path" in c.lower():
            return c
    raise ValueError(
        "Could not infer the time-series path column. "
        "Please set PATH_COL manually."
    )

PATH_COL = infer_path_col(df_train)
print("Inferred PATH_COL =", PATH_COL)

# ---- Loader for one subject time series ----
def load_subject_timeseries(tc_path: str, expected_N: int = 116) -> np.ndarray:
    """
    Loads ROI time series from a path.
    Supports: .npy, .npz, .csv, .tsv, .txt, .1D (Athena/AFNI style with headers).
    Returns a 2D numpy array (not yet forced to N x M).
    """
    p = Path(tc_path)
    if not p.is_absolute():
        p = (PROJECT_ROOT / p).resolve()

    if not p.exists():
        raise FileNotFoundError(f"Time-series file not found: {p}")

    suf = p.suffix.lower()

    if suf == ".npy":
        X = np.load(p)
        return np.asarray(X)

    if suf == ".npz":
        z = np.load(p)
        key0 = list(z.keys())[0]
        X = z[key0]
        return np.asarray(X)

    # Text-like formats
    if suf in [".csv"]:
        df = pd.read_csv(p, sep=",")
    elif suf in [".tsv"]:
        df = pd.read_csv(p, sep="\t")
    elif suf in [".txt", ".1d"]:
        # Athena .1D often uses tabs or variable whitespace. Python engine handles regex sep.
        # Also tolerate comment lines.
        df = pd.read_csv(p, sep=r"\s+|\t+", engine="python", comment="#")
    else:
        raise ValueError(f"Unsupported file extension: {suf} for {p}")

    # If it already came in without headers (rare), fall back to numeric conversion
    if df.shape[1] == 1:
        # might be because separator didn't match; try whitespace
        df = pd.read_csv(p, sep=r"\s+", engine="python", comment="#")

    # Prefer ROI columns named like Mean_XXXX
    mean_cols = [c for c in df.columns if str(c).startswith("Mean_")]

    if len(mean_cols) > 0:
        roi_df = df[mean_cols]
    else:
        # Otherwise drop obvious metadata columns and keep numeric columns
        drop_cols = [c for c in df.columns if str(c).lower() in ["file", "sub-brick", "subbrick", "brick", "index"]]
        tmp = df.drop(columns=drop_cols, errors="ignore")
        # Keep numeric-convertible columns
        roi_df = tmp.apply(pd.to_numeric, errors="coerce")

    # Drop rows that are all NaN (can happen if header/formatting issues)
    roi_df = roi_df.dropna(axis=0, how="all")

    X = roi_df.to_numpy(dtype=float)

    if X.ndim != 2:
        raise ValueError(f"Loaded time series must be 2D, got shape {X.shape} from {p}")

    # Sanity check: try to ensure we got ~116 ROI columns
    if X.shape[1] != expected_N and X.shape[0] == expected_N:
        # could be transposed already; keep as-is
        pass
    elif X.shape[1] != expected_N and X.shape[0] != expected_N:
        # Don't hard-fail; just warn so you can inspect one file quickly.
        print(f"[WARN] {p.name}: unexpected shape {X.shape} (expected one dim == {expected_N}).")

    return X

def ensure_N_by_M(X: np.ndarray, expected_N: int = 116) -> Tuple[np.ndarray, int]:
    """
    Ensures X is (N, M). If X is (M, N), transpose it.
    Returns (X_fixed, time_axis_used).
    """
    if X.shape[0] == expected_N:
        return X, -1  # already (N, M)
    if X.shape[1] == expected_N:
        return X.T, 0  # was (M, N), transposed
    raise ValueError(f"Expected one dimension to be N={expected_N} ROIs, got shape {X.shape}")

PAPER_M = {
    "KKI": 119,
    "NYU": 171,
    "OHSU": 74,
    "NeuroIMAGE": 257,
    "Peking_1": 231
}

def truncate_to_paper_length(
    X_nm: np.ndarray,
    site: str,
    *,
    drop_if_shorter: bool = True
) -> Optional[np.ndarray]:
    """
    X_nm shape: (N, M). Truncate to the paper's site-specific timepoints if longer.
    If shorter than expected, keep as-is (or you can drop those subjects).
    """
    target = PAPER_M.get(site)
    if target is None:
        return X_nm
    
    M = X_nm.shape[1]
    if M >= target:
        return X_nm[:, :target]
    else:
       return None if drop_if_shorter else X_nm

# ---- Apply partitioning to a manifest ----
def partition_manifest(
    df,
    L=20,
    step=20,
    overlap=None,
    expected_N=116,
    max_subjects=None,
    drop_if_shorter=True,
):
    per_site = defaultdict(list)
    dropped_short = defaultdict(int)

    n = len(df) if max_subjects is None else min(len(df), max_subjects)

    for i in range(n):
        row = df.iloc[i]

        X = load_subject_timeseries(row[PATH_COL])
        X_fixed, _ = ensure_N_by_M(X, expected_N=expected_N)

        site = row["site"]

        # ---- NEW: truncate to paper length ----
        X_fixed = truncate_to_paper_length(X_fixed, site, drop_if_shorter=drop_if_shorter)
        if X_fixed is None:
            dropped_short[site] += 1
            continue

        res = partition_time_series(
            X_fixed,
            L=L,
            step=step,
            overlap=overlap,
            time_axis=-1,
            drop_incomplete=True,
            return_metadata=True,
        )

        per_site[site].append((res.M, res.T, res.dropped))

    # Print summary
    for site, vals in per_site.items():
        Ms = np.array([v[0] for v in vals])
        Ts = np.array([v[1] for v in vals])
        Ds = np.array([v[2] for v in vals])

        print(f"\nSite: {site}")
        print(f"  subjects kept: {len(vals)}")
        if dropped_short.get(site, 0) > 0:
            print(f"  dropped (shorter than paper M): {dropped_short[site]}")
        print(f"  M (timepoints): min={Ms.min()}, median={int(np.median(Ms))}, max={Ms.max()}")
        print(f"  T (windows, L={L}, step={step}): min={Ts.min()}, median={int(np.median(Ts))}, max={Ts.max()}")
        print(f"  dropped tail: min={Ds.min()}, median={int(np.median(Ds))}, max={Ds.max()}")

    # Also show if any site had only dropped subjects (edge case)
    for site, cnt in dropped_short.items():
        if site not in per_site:
            print(f"\nSite: {site}")
            print(f"  subjects kept: 0")
            print(f"  dropped (shorter than paper M): {cnt}")

    return per_site

# ---- Run: non-overlapping (paper default) ----
print("\n=== TRAIN | non-overlapping (L=20, step=20) ===")
_ = partition_manifest(df_train, L=20, step=20)

print("\n=== TEST | non-overlapping (L=20, step=20) ===")
_ = partition_manifest(df_test, L=20, step=20)


Train rows: 511 | cols: ['site', 'subject_id', 'tc_path', 'T', 'R', 'DX', 'dx_raw', 'label']
Test rows : 162 | cols: ['site', 'subject_id', 'tc_path', 'T', 'R', 'DX']
Inferred PATH_COL = tc_path

=== TRAIN | non-overlapping (L=20, step=20) ===

Site: NYU
  subjects kept: 216
  M (timepoints): min=171, median=171, max=171
  T (windows, L=20, step=20): min=8, median=8, max=8
  dropped tail: min=11, median=11, max=11

Site: NeuroIMAGE
  subjects kept: 48
  M (timepoints): min=257, median=257, max=257
  T (windows, L=20, step=20): min=12, median=12, max=12
  dropped tail: min=17, median=17, max=17

Site: KKI
  subjects kept: 83
  M (timepoints): min=119, median=119, max=119
  T (windows, L=20, step=20): min=5, median=5, max=5
  dropped tail: min=19, median=19, max=19

Site: Peking_1
  subjects kept: 85
  M (timepoints): min=231, median=231, max=231
  T (windows, L=20, step=20): min=11, median=11, max=11
  dropped tail: min=11, median=11, max=11

Site: OHSU
  subjects kept: 78
  dropped (sh

Note: After enforcing paper time-series lengths, OHSU test subjects were shorter than 74 timepoints in the Athena release and were excluded.

In [14]:
T_PER_SITE = {
    "KKI": 5,
    "NYU": 8,
    "OHSU": 3,
    "NeuroIMAGE": 12,
    "Peking_1": 11
}


## 2.2 Dynamic functional connectivity generation (CNN + bilinear pooling)

For each window, we extract ROI-wise temporal features using a 3-layer CNN
(Conv → BN → ReLU → Dropout). We then construct the dynamic functional
connectivity matrix via bilinear pooling: A_t = H_t H_t^T, producing a
symmetric N×N FC matrix. Finally, we vectorize the upper triangular part
(excluding the diagonal) to obtain a 6670-D feature vector per window.

### Expected tensor shapes (paper-faithful)

Input window:        (N=116, L=20)

After CNN (Keras): (N=116, d, 1) → squeeze → (N=116, d)
With valid padding and two (1×3) conv layers: d = 16

After bilinear FC:   (116, 116)

Vectorized FC:       (116 * 115 / 2 = 6670)

Sequence per subject (T, 6670)


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

def tdnet_cnn_block(N=116, L=20, dropout=0.3):
    inp = layers.Input(shape=(N, L, 1))

    x = layers.Conv2D(4, (1, 3), padding="valid", use_bias=True)(inp)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout)(x)

    x = layers.Conv2D(2, (1, 3), padding="valid", use_bias=True)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout)(x)

    x = layers.Conv2D(1, (1, 1), padding="valid", use_bias=True)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout)(x)

    return Model(inp, x, name="TDNet_CNN_Block")

def upper_triangle_vectorize(A: tf.Tensor) -> tf.Tensor:
    """
    A: (B, N, N) symmetric FC matrices
    Returns: (B, N*(N-1)/2) upper-triangular (excluding diagonal)
    """
    N = tf.shape(A)[-1]
    # boolean mask for upper triangle excluding diagonal
    mask = tf.linalg.band_part(tf.ones((N, N), dtype=tf.bool), 0, -1)  # upper incl diag
    mask = tf.logical_and(mask, tf.logical_not(tf.eye(N, dtype=tf.bool)))  # exclude diag
    # flatten last 2 dims then boolean mask
    A_flat = tf.reshape(A, (tf.shape(A)[0], -1))           # (B, N*N)
    mask_flat = tf.reshape(mask, (-1,))                    # (N*N,)
    return tf.boolean_mask(A_flat, mask_flat, axis=1)      # (B, N*(N-1)/2)

def fc_generation(
    cnn_model: tf.keras.Model,
    window_batch: tf.Tensor,
) -> tf.Tensor:
    """
    window_batch: (B, N, L, 1) where N=116, L=20
    Returns: (B, 6670) vectorized dynamic FC feature
    """
    # CNN features: (B, N, d, 1) if you used filters=1 final conv (as in paper)
    feat = cnn_model(window_batch, training=False)

    # squeeze last channel -> (B, N, d)
    feat = tf.squeeze(feat, axis=-1)

    # Bilinear FC: A = feat @ feat^T -> (B, N, N)
    A = tf.matmul(feat, feat, transpose_b=True)

    # Flatten (upper triangle) -> (B, 6670)
    f = upper_triangle_vectorize(A)

    return f


In [16]:
import numpy as np

def subject_dfc_sequence(
    cnn_model: tf.keras.Model,
    windows_TNL: np.ndarray,
) -> np.ndarray:
    """
    windows_TNL: (T, N, L)
    Returns: (T, 6670)
    """
    # Add channel dim: (T, N, L, 1)
    x = windows_TNL[..., np.newaxis].astype(np.float32)

    # Treat each window as a batch element (B=T)
    f = fc_generation(cnn_model, tf.convert_to_tensor(x))
    return f.numpy()


### Sanity check
We verify output shapes and FC symmetry on a single subject.

In [17]:
cnn = tdnet_cnn_block(dropout=0.3)

# Take 1 subject, 1 window
one_subject = df_train.iloc[0]
X = load_subject_timeseries(one_subject["tc_path"])
X, _ = ensure_N_by_M(X, expected_N=116)
X = truncate_to_paper_length(X, one_subject["site"], drop_if_shorter=True)

res = partition_time_series(X, L=20, step=20, time_axis=-1, return_metadata=True)
windows = res.windows  # (T, 116, 20)

seq = subject_dfc_sequence(cnn, windows)
print("windows:", windows.shape)     # (T, 116, 20)
print("dfc seq:", seq.shape)         # (T, 6670)


windows: (8, 116, 20)
dfc seq: (8, 6670)


In [20]:
w = windows[0][np.newaxis, ..., np.newaxis].astype(np.float32)  # (1,116,20,1)
print("CNN output shape:", cnn(w, training=False).shape)         # expect (1,116,16,1)


CNN output shape: (1, 116, 16, 1)


In [21]:
feat = tf.squeeze(cnn(w, training=False), axis=-1)              # (1,116,16)
A = tf.matmul(feat, feat, transpose_b=True)                     # (1,116,116)
sym_err = tf.reduce_max(tf.abs(A - tf.transpose(A, perm=[0,2,1]))).numpy()
print("max symmetry error:", sym_err)  # should be ~0

max symmetry error: 0.0
