In [1]:
import numpy as np
import pandas as pd
import nibabel as nib

from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.tracking.streamline import set_number_of_points
from dipy.segment.clustering import QuickBundles

from pathlib import Path
import re
import os
from dipy.io.stateful_tractogram import StatefulTractogram, Space
from dipy.tracking.streamline import transform_streamlines

from dipy.align.imaffine import (
    AffineRegistration,
    MutualInformationMetric,
    transform_centers_of_mass,
)
from dipy.align.transforms import TranslationTransform3D, RigidTransform3D, AffineTransform3D

In [73]:
trk_dir = Path("Problematic Track List/IPL")  # folder containing *.trk files
reference_nii = Path("MNI152_T1_2mm.nii.gz")  # used for loading/saving coordinates

out_dir = Path("derivatives/atlas_trk_IPL")
out_dir.mkdir(parents=True, exist_ok=True)

n_points = 100          # resample points per streamline for clustering stability
balanced_n = 200        # per-subject streamline count for fair pooling
qb_threshold = 10.0     # clustering threshold (mm-ish if your space is mm)
qb_max_clusters = 50    # cap to avoid explosion

rng_seed = 0


In [74]:
def _to_rasmm_inplace_compatible(sft):
    """DIPY-version-safe conversion to RASMM."""
    ret = sft.to_space(Space.RASMM)  # some versions return None (in-place)
    return sft if ret is None else ret


def compute_native_to_mni_affine_dipy(
    native_b0_nii: str,
    mni_ref_nii: str,
    nbins: int = 32,
    level_iters=(1000, 200, 50),
    sigmas=(3.0, 1.0, 0.0),
    factors=(4, 2, 1),
) -> np.ndarray:
    """Compute affine mapping native(b0) RASMM -> MNI RASMM using DIPY imaffine."""
    if not os.path.exists(native_b0_nii):
        raise FileNotFoundError(f"Native b0 NIfTI not found: {native_b0_nii}")
    if not os.path.exists(mni_ref_nii):
        raise FileNotFoundError(f"MNI reference NIfTI not found: {mni_ref_nii}")

    native_img = nib.load(native_b0_nii)
    mni_img = nib.load(mni_ref_nii)

    static = mni_img.get_fdata().astype(np.float32)     # target (MNI)
    moving = native_img.get_fdata().astype(np.float32)  # source (native)

    static_aff = mni_img.affine
    moving_aff = native_img.affine

    metric = MutualInformationMetric(nbins=nbins, sampling_proportion=None)
    affreg = AffineRegistration(metric=metric, level_iters=level_iters, sigmas=sigmas, factors=factors)

    com = transform_centers_of_mass(static, static_aff, moving, moving_aff)

    trans = affreg.optimize(
        static, moving,
        transform=TranslationTransform3D(),
        params0=None,
        static_grid2world=static_aff,
        moving_grid2world=moving_aff,
        starting_affine=com.affine
    )

    rigid = affreg.optimize(
        static, moving,
        transform=RigidTransform3D(),
        params0=None,
        static_grid2world=static_aff,
        moving_grid2world=moving_aff,
        starting_affine=trans.affine
    )

    aff = affreg.optimize(
        static, moving,
        transform=AffineTransform3D(),
        params0=None,
        static_grid2world=static_aff,
        moving_grid2world=moving_aff,
        starting_affine=rigid.affine
    )

    return rigid.affine


def default_subject_id_from_trk(trk_path: Path) -> str:
    """
    Extract subject id from filename like IFG_Orb_001.trk -> '001'.
    Adjust this if your naming differs.
    """
    m = re.search(r"(\d+)(?=\.trk$)", trk_path.name)
    if not m:
        raise ValueError(f"Cannot parse subject id from: {trk_path.name}")
    return m.group(1)


def find_subject_b0(subject_id: str, refs_root: Path) -> Path:
    patterns = [
        f"**/data{subject_id}_HARDI_DWIb0.nii",
        f"**/data{subject_id}_HARDI_DWIb0.nii.gz",
        f"**/data{int(subject_id)}_HARDI_DWIb0.nii",      # in case id not zero-padded on disk
        f"**/data{int(subject_id)}_HARDI_DWIb0.nii.gz",
    ]
    hits = []
    for pat in patterns:
        hits.extend(list(refs_root.glob(pat)))

    # de-duplicate
    hits = sorted(set(hits))
    if len(hits) == 0:
        raise FileNotFoundError(
            f"No b0 found for subject_id={subject_id} under {refs_root}.\n"
            f"Tried patterns: {patterns}"
        )
    if len(hits) > 1:
        raise RuntimeError(
            f"Multiple b0 candidates for subject_id={subject_id}:\n" +
            "\n".join(str(h) for h in hits)
        )
    return hits[0]


def register_trks_to_mni_per_subject(
    trk_paths,
    refs_root: str,
    mni_reference_nii: str,
    out_dir: str,
    suffix: str = "_MNI",
    overwrite: bool = False,
    id_parser=default_subject_id_from_trk,):
    """
    Register many .trk files into MNI space, using a matching per-subject b0 reference
    located somewhere under refs_root.

    - Caches per-subject affines.
    - Saves each output TRK with MNI reference.
    """
    refs_root = Path(refs_root)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    if not os.path.exists(mni_reference_nii):
        raise FileNotFoundError(f"MNI reference NIfTI not found: {mni_reference_nii}")

    mni_img = nib.load(mni_reference_nii)
    affine_cache = {}  # subject_id -> (A_native_to_mni, native_b0_path)

    for p in trk_paths:
        p = Path(p)
        subj_id = id_parser(p)

        out_path = out_dir / f"{p.stem}{suffix}{p.suffix}"
        if out_path.exists() and not overwrite:
            continue

        # Find the subject's native reference b0
        native_b0 = find_subject_b0(subj_id, refs_root)

        # Compute / reuse subject affine
        if subj_id not in affine_cache:
            A = compute_native_to_mni_affine_dipy(
                native_b0_nii=str(native_b0),
                mni_ref_nii=str(mni_reference_nii),
            )
            affine_cache[subj_id] = (A, native_b0)
        else:
            A, _ = affine_cache[subj_id]

        # Load tractogram using subject-specific native reference
        sft = load_tractogram(str(p), reference=str(native_b0), bbox_valid_check=False)
        if sft is None:
            raise RuntimeError(f"Failed to load {p} with reference {native_b0}")
        ret = sft.to_space(Space.RASMM)
        if ret is not None:
            sft = ret
        # sft = _to_rasmm_inplace_compatible(sft)
        streams_mni = list(transform_streamlines(sft.streamlines, A))

        out_sft = StatefulTractogram(streams_mni, mni_img, Space.RASMM)
        save_tractogram(out_sft, str(out_path), bbox_valid_check=False)

    return affine_cache  # so you can inspect per-subject transforms


In [75]:
trk_paths = sorted(trk_dir.glob("*.trk"))

affines = register_trks_to_mni_per_subject(
    trk_paths=trk_paths,
    refs_root="derivatives/references",        # root folder containing data001/, data002/, ...
    mni_reference_nii="MNI152_T1_2mm.nii.gz",
    out_dir="derivatives/registration/IPL",
    overwrite=False
)

print("Subjects registered:", sorted(affines.keys()))

Subjects registered: []


In [76]:
def load_subject_streamlines(trk_path, reference_nii, n_points):
    tg = load_tractogram(str(trk_path), reference=str(reference_nii), bbox_valid_check=False)
    if tg is None:
        return []
    ret = tg.to_space(Space.RASMM)
    if ret is not None:
        tg = ret
    streams = list(tg.streamlines)
    if len(streams) == 0:
        return []
    streams = set_number_of_points(streams, n_points)
    return streams

trk_dir_mni = Path("derivatives/registration/IPL")
trk_paths = sorted(trk_dir_mni.glob("*_MNI.trk"))
# trk_paths = sorted(trk_dir.glob("*.trk"))
assert len(trk_paths) > 0, f"No .trk found in {trk_dir}"

streams_by_subj = []
subj_ids = []

for p in trk_paths:
    s = load_subject_streamlines(p, reference_nii, n_points)
    streams_by_subj.append(s)
    subj_ids.append(p.stem)

counts = [len(s) for s in streams_by_subj]
pd.DataFrame({"subject": subj_ids, "n_streamlines": counts}).sort_values("n_streamlines", ascending=False).head(10)

Unnamed: 0,subject,n_streamlines
6,IPL_008_MNI,353
5,IPL_006_MNI,278
4,IPL_005_MNI,249
3,IPL_004_MNI,229
2,IPL_003_MNI,215
7,IPL_009_MNI,189
1,IPL_002_MNI,93
0,IPL_001_MNI,68
8,IPL_010_MNI,56


In [77]:
def subject_balanced_pool(streams_by_subj, subj_ids, n_per_subj, seed=0):
    rng = np.random.default_rng(seed)
    pooled = []
    kept = {}
    for sid, s in zip(subj_ids, streams_by_subj):
        s = list(s)
        if len(s) == 0:
            kept[sid] = 0
            continue
        if len(s) <= n_per_subj:
            pooled.extend(s)
            kept[sid] = len(s)
        else:
            idx = rng.choice(len(s), size=n_per_subj, replace=False)
            pooled.extend([s[i] for i in idx])
            kept[sid] = n_per_subj
    return pooled, kept

balanced_pool, kept_counts = subject_balanced_pool(streams_by_subj, subj_ids, balanced_n, seed=rng_seed)
len(balanced_pool), list(kept_counts.items())[:3]

(1406, [('IPL_001_MNI', 68), ('IPL_002_MNI', 93), ('IPL_003_MNI', 200)])

In [78]:
def save_trk_streamlines(streamlines, reference_nii, out_path):
    ref_img = nib.load(str(reference_nii))
    sft = StatefulTractogram(streamlines, ref_img, Space.RASMM)
    save_tractogram(sft, str(out_path), bbox_valid_check=False)

In [79]:
qb = QuickBundles(threshold=qb_threshold, max_nb_clusters=qb_max_clusters)
clusters = qb.cluster(balanced_pool)

centroids = [c.centroid for c in clusters]
cluster_sizes = np.array([len(c) for c in clusters], dtype=int)
weights = cluster_sizes / cluster_sizes.sum()

# Save centroids .trk
centroid_trk = out_dir / f"atlasA_centroids_thr{qb_threshold}_K{len(centroids)}.trk"
save_trk_streamlines(centroids, reference_nii, centroid_trk)

# Save weights table
dfA = pd.DataFrame({
    "cluster_id": np.arange(len(centroids)),
    "cluster_size": cluster_sizes,
    "weight": weights
}).sort_values("weight", ascending=False)

dfA.to_csv(out_dir / f"atlasA_centroids_thr{qb_threshold}_weights.csv", index=False)

dfA.head(10)

Unnamed: 0,cluster_id,cluster_size,weight
15,15,284,0.201991
20,20,155,0.110242
27,27,148,0.105263
5,5,144,0.102418
12,12,134,0.095306
7,7,126,0.089616
13,13,65,0.04623
30,30,50,0.035562
2,2,39,0.027738
1,1,30,0.021337


In [80]:
pop_trk = out_dir / f"atlasB_population_balanced_n{balanced_n}_total{len(balanced_pool)}.trk"
save_trk_streamlines(balanced_pool, reference_nii, pop_trk)

pop_trk

WindowsPath('derivatives/atlas_trk_IPL/atlasB_population_balanced_n200_total1406.trk')

In [81]:
def downsample_streamlines(streamlines, n_total=10000, seed=0):
    rng = np.random.default_rng(seed)
    if len(streamlines) <= n_total:
        return list(streamlines)
    idx = rng.choice(len(streamlines), size=n_total, replace=False)
    return [streamlines[i] for i in idx]

pop_light = downsample_streamlines(balanced_pool, n_total=10000, seed=rng_seed)
pop_light_trk = out_dir / f"atlasB_population_balanced_light10k.trk"
save_trk_streamlines(pop_light, reference_nii, pop_light_trk)
pop_light_trk

WindowsPath('derivatives/atlas_trk_IPL/atlasB_population_balanced_light10k.trk')