
# fastMRI Multicoil → Image-Space Gaussian Noise (Step-by-Step)

This notebook helps you:
1. **Load** fastMRI brain multicoil `.h5` files
2. **Reconstruct** per-coil magnitude images from k-space
3. **(Optional) Normalize** intensities by a percentile
4. **Add Gaussian noise** in image space with a chosen `sigma`
5. **Save** the noisy inputs and noise std maps for your inference script

> Tip: Run cell-by-cell and inspect outputs at each step.


In [1]:

import os
import glob
import h5py
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

# ------------- FFT helpers -------------
def ifft2c(kspace: np.ndarray) -> np.ndarray:
    """Centered 2D inverse FFT with orthonormal scaling over last two axes."""
    x = np.fft.ifftshift(kspace, axes=(-2, -1))
    x = np.fft.ifft2(x, axes=(-2, -1), norm='ortho')
    x = np.fft.fftshift(x, axes=(-2, -1))
    return x

def mag_per_coil_from_kspace(ks: np.ndarray) -> np.ndarray:
    """ks: (S, C, H, W) complex -> returns (H, W, C, S) float32 magnitudes."""
    img = ifft2c(ks)                 # (S, C, H, W) complex
    mag = np.abs(img).astype(np.float32)
    return np.transpose(mag, (2, 3, 1, 0))  # (H, W, C, S)

def rss_combine(mag_coils: np.ndarray) -> np.ndarray:
    """Root-sum-of-squares across coils. mag_coils: (H, W, C)."""
    return np.sqrt(np.sum(mag_coils**2, axis=2))

def normalize_by_percentile(vol: np.ndarray, percentile: float):
    """Scale so that the given percentile maps to 1.0, then clip to [0,1]."""
    p = np.percentile(vol, percentile)
    sf = 1.0 / (p + 1e-12)
    return np.clip(vol * sf, 0.0, 1.0), sf


In [2]:

# ==== USER CONFIG ====
# Set to a directory containing .h5 files, or to a single .h5 path
INPUT_PATH = "/n/netscratch/zickler_lab/Lab/linbo/denoising_project/dataset/fastmri_brain_test/multicoil_test/file_brain_AXFLAIR_200_6002621.h5"     # <-- EDIT ME
OUTPUT_DIR = "/n/netscratch/zickler_lab/Lab/linbo/denoising_project/Restormer/simulate_noise/fastmri/data"         # <-- EDIT ME (will be created)
CAP_PERCENTILE = 99.8                            # intensity cap percentile for normalization
DO_NORMALIZE = True                              # set False to skip normalization
SIGMA = 0.02                                     # Gaussian noise std (if normalized, in [0,1] units)
SEED = 123                                       # set None for non-reproducible noise
SLICE_INDEX = 0                                  # which slice to visualize
COIL_INDEX = 0                                   # which coil to visualize
# =====================

os.makedirs(OUTPUT_DIR, exist_ok=True)
if SEED is not None:
    np.random.seed(SEED)

# Find .h5 files
if os.path.isdir(INPUT_PATH):
    H5_FILES = sorted(glob.glob(os.path.join(INPUT_PATH, "*.h5")))
elif INPUT_PATH.endswith(".h5"):
    H5_FILES = [INPUT_PATH]
else:
    H5_FILES = []

print(f"Found {len(H5_FILES)} .h5 file(s).")
for i, p in enumerate(H5_FILES[:10]):
    print(f"[{i}] {p}")
if len(H5_FILES) == 0:
    print("→ Update INPUT_PATH above to a valid folder or .h5 file.")


Found 1 .h5 file(s).
[0] /n/netscratch/zickler_lab/Lab/linbo/denoising_project/dataset/fastmri_brain_test/multicoil_test/file_brain_AXFLAIR_200_6002621.h5


In [None]:

# Pick the first file (or change index here)
FILE_INDEX = 0
if len(H5_FILES) == 0:
    raise SystemExit("No .h5 files found. Please update INPUT_PATH and re-run.")

fpath = H5_FILES[FILE_INDEX]
print("Using file:", fpath)

with h5py.File(fpath, 'r') as f:
    print("Keys in file:", list(f.keys()))
    if 'kspace' not in f:
        raise KeyError("This file doesn't contain 'kspace'.")
    ks = f['kspace'][()]  # shape (S, C, H, W), complex
    print("kspace shape (S, C, H, W):", ks.shape)

S, C, H, W = ks.shape


In [None]:

# Reconstruct per-coil magnitude (H, W, C, S)
mag_vol = mag_per_coil_from_kspace(ks)
print("Per-coil magnitude shape (H, W, C, S):", mag_vol.shape)
print("Stats clean  p10=%.4g  mean=%.4g  p90=%.4g  max=%.4g" % (
    np.percentile(mag_vol, 10), mag_vol.mean(), np.percentile(mag_vol, 90), mag_vol.max()
))

# Visualize RSS and one coil for SLICE_INDEX
SLICE_INDEX = int(min(max(SLICE_INDEX, 0), S-1))
COIL_INDEX  = int(min(max(COIL_INDEX, 0), C-1))

rss = rss_combine(mag_vol[..., SLICE_INDEX])
plt.figure()
plt.title(f"Clean RSS (slice {SLICE_INDEX})")
plt.imshow(rss, cmap='gray')
plt.axis('off')
plt.show()

coil_img = mag_vol[..., COIL_INDEX, SLICE_INDEX]
plt.figure()
plt.title(f"Clean Coil {COIL_INDEX} (slice {SLICE_INDEX})")
plt.imshow(coil_img, cmap='gray')
plt.axis('off')
plt.show()


In [None]:

if DO_NORMALIZE:
    mag_norm, scale_factor = normalize_by_percentile(mag_vol, CAP_PERCENTILE)
else:
    mag_norm, scale_factor = mag_vol, 1.0

print(f"""Normalization:
  enabled    : {DO_NORMALIZE}
  percentile : {CAP_PERCENTILE}
  scale_fact.: {scale_factor:.3e}
  stats      : min={mag_norm.min():.4g} max={mag_norm.max():.4g} mean={mag_norm.mean():.4g}
""".rstrip())

plt.figure()
plt.title(f"Normalized RSS (slice {SLICE_INDEX})")
plt.imshow(rss_combine(mag_norm[..., SLICE_INDEX]), cmap='gray')
plt.axis('off')
plt.show()


In [None]:

SIGMA = float(SIGMA)
noise = np.random.normal(loc=0.0, scale=SIGMA, size=mag_norm.shape).astype(np.float32)
noisy = mag_norm + noise
noisy = np.clip(noisy, 0.0, 1.0) if DO_NORMALIZE else np.clip(noisy, 0.0, None)

print("Noisy stats: min=%.4g max=%.4g mean=%.4g (SIGMA=%.4g)" % (noisy.min(), noisy.max(), noisy.mean(), SIGMA))

plt.figure()
plt.title(f"Noisy RSS (slice {SLICE_INDEX})")
plt.imshow(rss_combine(noisy[..., SLICE_INDEX]), cmap='gray')
plt.axis('off')
plt.show()

plt.figure()
plt.title(f"Noisy Coil {COIL_INDEX} (slice {SLICE_INDEX})")
plt.imshow(noisy[..., COIL_INDEX, SLICE_INDEX], cmap='gray')
plt.axis('off')
plt.show()


In [None]:

# Prepare outputs
base = os.path.splitext(os.path.basename(fpath))[0]
out_noisy = os.path.join(OUTPUT_DIR, f"{base}_noisy.nii.gz")
out_noise = os.path.join(OUTPUT_DIR, f"{base}_noise_std.nii.gz")
out_clean = os.path.join(OUTPUT_DIR, f"{base}_clean.nii.gz")

# Noise std map: constant sigma everywhere
noise_std_map = np.full(mag_norm.shape, SIGMA, dtype=np.float32)

# Expand noisy to (H, W, C, S, 1) for your inference pipeline (N=1)
noisy_5d = noisy[..., None]
clean_5d = mag_norm[..., None]

nib.save(nib.Nifti1Image(noisy_5d.astype(np.float32, copy=False), np.eye(4)), out_noisy)
nib.save(nib.Nifti1Image(noise_std_map.astype(np.float32, copy=False), np.eye(4)), out_noise)
nib.save(nib.Nifti1Image(clean_5d.astype(np.float32, copy=False), np.eye(4)), out_clean)

print("Wrote:")
print("  ", out_noisy, "  (H,W,C,S,1)")
print("  ", out_noise, "  (H,W,C,S)")
print("  ", out_clean, "  (H,W,C,S,1)")



## Notes
- If you want per-slice **different** noise levels, change `noise_std_map` to vary along `S` and regenerate `noise` accordingly.
- If you disable normalization (`DO_NORMALIZE=False`), set `SIGMA` in **raw units**.
- For batching many files, wrap the above workflow in a loop over `H5_FILES`.
