In [2]:
import os, sys
# adjust this path if your repo is elsewhere
proj_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if proj_root not in sys.path:
    sys.path.append(proj_root)

# confirm device
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)
print(sys.executable)

Using device: cuda
/home/spieterman/dev/projects/dwi-transformer/.venv/bin/python


In [None]:
# Shape of each 3D volume (after your resampling step)
IMG_SHAPE = (150, 150, 150)

# Latent dimension for all AEs
LATENT_DIM = 128

# Training settings
BATCH_SIZE = 4
LR = 1e-3
EPOCHS = 5

# Paths
DATA_DIR = os.path.join(proj_root, "data", "resampled_volumes")

In [4]:
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

# Import your model modules
from models.custom_ae import Custom3dAE
from models.monai_ae  import get_monai_ae
#from models.resnet_ae import ResNetAE

# Registry to swap models by name
MODEL_REGISTRY = {
    "custom": lambda: Custom3dAE(latent_dim=LATENT_DIM, in_shape=(1,*IMG_SHAPE)).to(DEVICE),
    "monai":  lambda: get_monai_ae().to(DEVICE),
    #"resnet": lambda: ResNetAE(latent_dim=LATENT_DIM).to(DEVICE),
}

In [None]:
import os
from collections import Counter
import nibabel as nib

# test_path = "~/dev/projects/dwi-preprocessing/data/preproc/sub-OAS30001/ses-d0757/sub-OAS30001_ses-d0757_dwi_allruns.nii.gz"
DATA_DIR = "~/dev/projects/dwi-preprocessing/data/preproc/"   # adjust to your root folder

def find_nifti_files(root_dir):
    root_dir = os.path.abspath(os.path.expanduser(root_dir))
    for dirpath, _, filenames in os.walk(root_dir):
        for fn in filenames:
            if fn.endswith((".nii.gz")):
                yield os.path.join(dirpath, fn)

shape_counts = Counter()
errors = []

for path in find_nifti_files(DATA_DIR):
    try:
        img = nib.load(path)
        shape_counts[img.shape] += 1
    except Exception as e:
        errors.append((path, str(e)))

# Display results
print("🔍 Unique volume shapes found:")
for shape, count in shape_counts.items():
    print(f"{shape}: {count} file{'s' if count>1 else ''}")

if errors:
    print(f"\n⚠️  {len(errors)} file(s) failed to load:")
    for p, err in errors[:5]:
        print(f"    - {p}: {err}")
    if len(errors)>5:
        print(f"    ... and {len(errors)-5} more")


🔍 Unique volume shapes found:
(91, 109, 91, 26): 151 files
(91, 109, 91): 239 files
(91, 109, 91, 91): 61 files
(91, 109, 91, 65): 9 files
(91, 109, 91, 13): 5 files
(91, 109, 91, 52): 1 file
(91, 109, 91, 69): 2 files
(91, 109, 91, 2): 8 files
(91, 109, 91, 50): 1 file
(91, 109, 91, 48): 1 file


In [None]:
from pathlib import Path
import numpy as np, nibabel as nib, json, argparse
from tqdm import tqdm

ROOT   = Path("/home/spieterman/dev/projects/dwi-preprocessing/data/preproc")
CACHE  = Path("/home/spieterman/dev/projects/dwi-transformer/data/encoder")                                       # output root
CACHE.mkdir(exist_ok=True)

def find_sessions(root: Path):
    return sorted(root.rglob("*_dwi_allruns.nii.gz"))

def normalise_dwi(dwi_data: np.ndarray, bvals: np.ndarray):
    """
    Load a 4-D DWI series and apply robust, session-level gain and z-score normalisation.

    Parameters
    ----------
    dwi_path : Path
        Path to the cleaned 4-D DWI NIfTI (shape [X,Y,Z,N]).
    bval_path : Path
        Path to the matching .bval file (one row of N b-values).

    Returns
    -------
    dwi_norm : ndarray (float32)
        The fully normalised DWI data (same shape as input).
    stats : dict
        {
          'gain': float,       # scale factor applied to align b0 median → 1.0
          'mean': float,       # session-wide mean after gain scaling
          'std':  float        # session-wide std  after gain scaling
        }
    """

    # 1) Identify all b0 volumes (b-value == 0)
    b0_indices  = np.where(bvals == 0)[0]         # e.g. array([0, 10, 20])

    # 2) Extract those b0 volumes and form a union-mask of nonzero voxels
    #    (handles slight mis-alignments: if *any* run has signal, we treat it as brain)
    b0_volumes = np.take(dwi_data, b0_indices, axis=3)  # shape (X, Y, Z, N_b0)
    mask       = np.any(b0_volumes > 0, axis=3)         # boolean mask [X,Y,Z]

    # 3) Remove any stray zero-intensity voxels inside the union mask
    #    (e.g. holes due to warping) before computing the gain
    b0_values  = b0_volumes[mask].ravel()
    b0_values  = b0_values[b0_values > 0]               # drop zeros

    # 4) SESSION-GAIN NORMALISATION:
    #    Anchor the median of all b0 tissue intensities to 1.0,
    #    removing scanner-/coil-level scale differences across sessions
    gain       = 1.0 / (np.median(b0_values) + 1e-12)
    dwi_scaled = dwi_data * gain

    # 5) SESSION-WIDE Z-SCORE NORMALISATION:
    #    Compute mean/std across all brain voxels in all volumes,
    #    giving zero-mean/unit-variance inputs for the autoencoder,
    #    yet preserving relative shell attenuation patterns.
    dwi_values = dwi_scaled[mask, ...].ravel()
    dwi_values = dwi_values[dwi_values > 0]              # drop any sneaky zeros
    mean       = dwi_values.mean()
    std        = dwi_values.std() + 1e-6
    dwi_norm   = (dwi_scaled - mean) / std

    # 6) Return the normalised data plus the stats for reproducibility
    return dwi_norm.astype(np.float32)


def process_session(dwi_path: Path):
    """
    Process a single 4-D DWI session and cache its 3-D gradient volumes.

    Parameters
    ----------
    dwi_path : pathlib.Path
        Path to the cleaned 4-D DWI NIfTI file (shape [X, Y, Z, N]).

    Steps
    -----
    1) Derive the matching .bval and .bvec file paths.
    2) Parse patient ID (sub-XXX) and session ID (ses-YYY) from the filename.
    3) Create an output directory at CACHE/sub-XXX/ses-YYY/.
    4) Load the raw 4-D DWI data and corresponding b-values.
    5) Load the .bval file containing the b-values for each gradient.
    6) Normalize the entire volume (session-level gain + z-score).
    7) Split the normalized 4-D volume into N individual 3-D gradient arrays and save each gradient as a compressed .npz containing:
        - vol_data: 3-D image array
        - bval: single float b-value
        - bvec: 3-D b-vector (if available, currently not used)
        - affine: 4 x 4 spatial transform
        - patient, session tags
    """

    # 1) Derive file stems for .bval and .bvec
    #    We strip off the trailing ".nii.gz" by slicing off 7 chars
    base      = dwi_path.with_suffix("").with_suffix("")  # remove .nii.gz (with_suffix only removes one suffix at a time)
    bval_path = base.with_suffix(".bval")
    bvec_path = base.with_suffix(".bvec")

    # 2) Extract patient & session IDs from the BIDS-style filename
    #    Filename looks like "sub-XXX_ses-YYY_dwi_allruns.nii.gz"
    p_id = base.name.split("_")[0]  # e.g. "sub-0001"
    s_id = base.name.split("_")[1]  # e.g. "ses-01"

    # 3) Prepare the output folder for this session
    #    e.g. cache/sub-0001_ses-01/
    out_dir = CACHE / p_id / s_id
    out_dir.mkdir(parents=True, exist_ok=True)

    # 4) Load the 4-D DWI image (keeps affine + header for later)
    dwi_img = nib.load(dwi_path)                             # nibabel Nifti1Image
    dwi_raw = dwi_img.get_fdata().astype(np.float32)         # (X, Y, Z, N)

    # 5) Load acquisition metadata
    #    .bval: one row of N diffusion weightings
    #    .bvec: 3×N axis vectors
    bvals = np.loadtxt(bval_path)                            # shape (N,)
    # bvecs = np.loadtxt(bvec_path)                            # shape (3, N)

    # 6) Normalize the 4-D data with our robust session-level function
    #    Returns the normalized array
    dwi_norm = normalise_dwi(dwi_raw, bvals)

    # 7) Split the normalized 4-D volume into individual 3-D gradient volumes
    #    and save each as a compressed .npz with all relevant metadata.
    for g in range(dwi_norm.shape[3]):
        vol_data = dwi_norm[..., g]  # 3-D array (X, Y, Z)

        out_file = out_dir / f"{p_id}_{s_id}_grad{g:03d}.npz"
        np.savez_compressed(
            out_file,
            vol_data=vol_data,                          # 3D gradient volume (X, Y, Z)
            bval=np.float32(bvals[g]),                  # single b-value for this gradient
            # bvec=np.float32(bvecs[:, g]),             # 3D b-vector for this gradient TODO: bvecs are wrongly concatenated over runs
            affine=dwi_img.affine.astype(np.float32),   # preserves spatial orientation for nifti reconstruction
            patient=p_id,                               # for downstream grouping or sampling
            session=s_id                                # ditto
        )


# if __name__ == "__main__":
for dwi_file in tqdm(find_sessions(Path(ROOT))):
    process_session(dwi_file)


100%|██████████| 239/239 [11:33<00:00,  2.90s/it]


In [None]:
from torch.utils.data import Dataset

class AEVolumes(Dataset):
    def __init__(self, cache_root: str):
        # 1) Find all cached gradient files
        self.files = sorted(Path(cache_root).rglob("*grad*.npz"))

        # 2) Preload metadata for sampling weights
        self.patients = []
        self.sessions = []
        self.bvals    = []
        for f in self.files:
            data = np.load(f)
            # Extract scalar metadata
            self.patients.append(data["patient"].item())  # e.g. "sub-0001"
            self.sessions.append(data["session"].item())  # e.g. "ses-01"
            self.bvals.append(float(data["bval"].item()))
        
        print(f"Found {len(self.files)} cached gradient files in {cache_root}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # 1) Load the .npz
        data = np.load(self.files[idx])

        # 2) Volume: [1, X, Y, Z]
        vol = torch.from_numpy(data["vol_data"]).unsqueeze(0).float()

        # 3) Acquisition metadata
        bval = torch.tensor(float(data["bval"].item()), dtype=torch.float32)
        # bvec = torch.from_numpy(data["bvec"].astype(np.float32))

        # 4) Spatial metadata
        affine = torch.from_numpy(data["affine"].astype(np.float32))  # [4, 4]

        # 5) Provenance tags
        patient = data["patient"].item()  # string
        session = data["session"].item()  # string

        return {
            "vol":     vol,
            "bval":    bval,
            # "bvec":    bvec,
            "affine":  affine,
            "patient": patient,
            "session": session,
        }

Found 10431 cached gradient files in /home/spieterman/dev/projects/dwi-transformer/data/encoder
Dataset size: 10431


In [None]:
from collections import Counter, defaultdict
import torch
from torch.utils.data import WeightedRandomSampler, DataLoader

def make_balanced_sampler(dataset: AEVolumes) -> WeightedRandomSampler:
    """
    Build a sampler whose per-sample weight is

        w(p,s,b) = 1 / |S_p| * 1 / m_{p,s} * |B| / k_b

    where
        |S_p|     = number of sessions for patient p
        m_{p,s}   = number of volumes in session s of patient p
        k_b       = total volumes in shell b  (across entire dataset)
        |B|       = number of distinct shells

    Returns
    -------
    WeightedRandomSampler
        Can be fed into DataLoader(..., sampler=sampler)
    """

    # 1.  Gather per-sample metadata already stored by the Dataset
    patients = dataset.patients        # list len = len(dataset)
    sessions = dataset.sessions        # e.g. "ses-d0757"
    bvals    = dataset.bvals           # float or int

    # Create a composite session key "sub-XXX_ses-YYY"
    sess_keys = [f"{p}_{s}" for p, s in zip(patients, sessions)]

    # 2.  Pre-compute the three count tables
    #   a. how many sessions per patient  |S_p|
    patient_to_sessions = defaultdict(set)
    for p, s in zip(patients, sessions):
        patient_to_sessions[p].add(s)
    S_counts = {p: len(sset) for p, sset in patient_to_sessions.items()}

    #   b. how many volumes in each session  m_{p,s}
    sess_counts = Counter(sess_keys)               # m_{p,s}

    #   c. how many volumes in each shell  k_b
    shell_counts = Counter(bvals)                  # k_b
    B = len(shell_counts)                          # |B|

    # 3.  Build the per-sample weight list
    weights = []
    for p, s, key, b in zip(patients, sessions, sess_keys, bvals):
        w =  (1.0 / S_counts[p])          # patient balance
        w *= (1.0 / sess_counts[key])     # session balance
        w *= (B   / shell_counts[b])      # shell balance
        weights.append(w)

    print(weights[:10])  # debug: show first 10 weights

    # 4.  Wrap in a WeightedRandomSampler
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(weights),
        num_samples=len(weights),     # one "epoch" = same expected size as dataset
        replacement=True              # draw with replacement → true probability sampling
    )


# TODO: Validate the sampler mathematically
ds = AEVolumes(cache_root=CACHE)
sampler = make_balanced_sampler(ds)

Found 10431 cached gradient files in /home/spieterman/dev/projects/dwi-transformer/data/encoder
[0.0018792282635930846, 0.0026959022286125087, 0.0028005974607916355, 0.0026959022286125087, 0.003770739064856712, 0.0026959022286125087, 0.0013448090371167296, 0.0015719974848040243, 0.0014755065905961047, 0.0013479511143062544]


In [None]:
x = next(iter(dl)).to(DEVICE)
for name, maker in MODEL_REGISTRY.items():
    model = maker()
    recon, z = model(x)
    print(f"{name:6} → recon shape: {tuple(recon.shape)}, latent shape: {tuple(z.shape)}")

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    for batch in loader:
        batch = batch.to(DEVICE)
        recon, _ = model(batch)
        loss = criterion(recon, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.size(0)
    return total_loss / len(loader.dataset)

def validate_epoch(model, loader, criterion):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(DEVICE)
            recon, _ = model(batch)
            val_loss += criterion(recon, batch).item() * batch.size(0)
    return val_loss / len(loader.dataset)


In [None]:
import pandas as pd

results = []
for model_name in ["custom", "monai"]:  # add "resnet" once ready
    print(f"\n=== Training {model_name} AE ===")
    model = MODEL_REGISTRY[model_name]()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.MSELoss()

    logs = {"epoch": [], "train_loss": [], "val_loss": []}
    for epoch in range(1, EPOCHS+1):
        tr = train_epoch(model, dl, optimizer, criterion)
        vl = validate_epoch(model, dl, criterion)
        logs["epoch"].append(epoch)
        logs["train_loss"].append(tr)
        logs["val_loss"].append(vl)
        print(f"Epoch {epoch:02d} | train: {tr:.4f} | val: {vl:.4f}")
    df = pd.DataFrame(logs)
    df["model"] = model_name
    results.append(df)
results_df = pd.concat(results, ignore_index=True)
results_df


In [None]:
import seaborn as sns
sns.lineplot(data=results_df, x="epoch", y="train_loss", hue="model", marker="o")
sns.lineplot(data=results_df, x="epoch", y="val_loss",   hue="model", marker="x", linestyle="--")
plt.title("AE Reconstruction Loss by Model")
plt.show()

In [None]:
model = MODEL_REGISTRY["custom"]()  # or “monai”
model.eval()
sample = next(iter(dl))[0].unsqueeze(0).to(DEVICE)  # single volume
with torch.no_grad():
    recon, _ = model(sample)

# show middle slice
slice_idx = IMG_SHAPE[2] // 2
fig, axes = plt.subplots(1,2, figsize=(8,4))
axes[0].imshow(sample.cpu().numpy()[0,0,:,:,slice_idx], cmap="gray")
axes[0].set_title("Input")
axes[1].imshow(recon.cpu().numpy()[0,0,:,:,slice_idx], cmap="gray")
axes[1].set_title("Reconstruction")
plt.show()