In [1]:
import numpy as np
import nibabel as nib
from nibabel.processing import resample_from_to
from scipy import ndimage
from dipy.io import read_bvals_bvecs
from dipy.io.image import load_nifti, save_nifti
from dipy.core.gradients import gradient_table
from dipy.segment.mask import median_otsu
from dipy.reconst.dti import TensorModel, fractional_anisotropy, from_lower_triangular
import os

def voxel_sizes_from_affine(affine):
    return np.sqrt((affine[:3, :3] ** 2).sum(axis=0))

def extract_b0_avg(dwi_data, bvals, b0_thresh=50):
    # bvals array -> promedio de volúmenes con b <= b0_thresh
    b0_mask = np.asarray(bvals) <= b0_thresh
    if not np.any(b0_mask):
        raise RuntimeError("No se han encontrado volúmenes B0 (b <= {}).".format(b0_thresh))
    b0s = dwi_data[..., b0_mask]
    # si sólo uno, devolver ese, si varios, promedio
    if b0s.ndim == 3:
        return b0s
    else:
        return np.mean(b0s, axis=3)

def make_masks_from_dwi(dwi_path, bval_path, bvec_path,
                        out_dir='masks_out',
                        brain_mask_method='median_otsu',  # o 'bet' si prefieres FSL (no implementado aquí)
                        median_otsu_radius=4, median_otsu_numpass=4,
                        fa_thr=0.2,  # umbral para WM
                        deep_mm=5.0,  # distancia mínima al borde en mm para deep WM
                        clean_iters=1,
                        resample_to_peaks_path=None):
    """
    Genera brain_mask, wm_mask y deep_wm_mask desde DWI.
    Si resample_to_peaks_path es provisto, también remuestrea las máscaras al espacio de ese target (soporta target 4D).
    """
    os.makedirs(out_dir, exist_ok=True)

    # Cargar DWI y bvals/bvecs
    data, affine = load_nifti(dwi_path)
    bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path)
    gtab = gradient_table(bvals = bvals, bvecs = bvecs)

    # Extraer B0 promedio
    b0 = extract_b0_avg(data, bvals, b0_thresh=50)

    # Brain mask: median_otsu sobre la imagen B0
    if brain_mask_method == 'median_otsu':
        b0_cut, brain_mask = median_otsu(b0, median_radius=median_otsu_radius, numpass=median_otsu_numpass)
    else:
        raise ValueError("brain_mask_method no soportado en este script.")

    # Guardar brain mask
    brain_mask_img = nib.Nifti1Image(brain_mask.astype(np.uint8), affine)
    brain_mask_path = os.path.join(out_dir, 'brain_mask.nii.gz')
    nib.save(brain_mask_img, brain_mask_path)

    # Ajustar mask para DTI fitting: usar brain_mask
    # Fit DTI (usar máscara cerebral para limitar voxeles)
    # Nota: dipy espera data con shape (X,Y,Z,N)
    tensor_model = TensorModel(gtab)
    try:
        tenfit = tensor_model.fit(data, mask=brain_mask)
    except Exception as e:
        # fallback: intentar ajustar sin mask
        tenfit = tensor_model.fit(data)

    # calcular FA
    fa = fractional_anisotropy(tenfit.evals)
    # clip/NaN fix
    fa = np.nan_to_num(fa, nan=0.0, posinf=0.0, neginf=0.0)

    # Guardar FA (opcional)
    fa_img = nib.Nifti1Image(fa.astype(np.float32), affine)
    fa_path = os.path.join(out_dir, 'fa.nii.gz')
    nib.save(fa_img, fa_path)

    # wm_mask: FA > fa_thr dentro del brain_mask
    wm_mask = (fa > fa_thr) & (brain_mask > 0)

    # limpieza morfológica (pequeñas operaciones)
    if clean_iters > 0:
        wm_mask = ndimage.binary_opening(wm_mask, iterations=1)
        wm_mask = ndimage.binary_closing(wm_mask, iterations=1)

    # Guardar wm_mask
    wm_img = nib.Nifti1Image(wm_mask.astype(np.uint8), affine)
    wm_path = os.path.join(out_dir, f'wm_mask_fa{fa_thr:.2f}.nii.gz')
    nib.save(wm_img, wm_path)

    # deep_wm_mask: distancia dentro de la wm hasta borde > deep_mm
    # distance_transform_edt acepta 'sampling' (tamaño de voxel) para devolver mm si sampling=voxel_sizes
    vs = voxel_sizes_from_affine(affine)
    # distancia dentro de la máscara (en *voxeles* si no se usa sampling, o en mm si sampling=vs)
    dist = ndimage.distance_transform_edt(wm_mask, sampling=vs)
    deep_mask = dist >= deep_mm

    # Guardar deep mask
    deep_img = nib.Nifti1Image(deep_mask.astype(np.uint8), affine)
    deep_path = os.path.join(out_dir, f'deep_wm_mask_{int(deep_mm)}mm.nii.gz')
    nib.save(deep_img, deep_path)

    # Si el target (peaks) fue provisto y queremos remuestrear las máscaras al espacio del target:
    resampled_paths = {}
    if resample_to_peaks_path is not None:
        tgt_img = nib.load(resample_to_peaks_path)
        # si target es 4D, usar shape 3D
        if tgt_img.ndim > 3:
            tgt_shape3 = tuple(tgt_img.shape[:3])
            tgt_aff = tgt_img.affine
            to_vox_map = (tgt_shape3, tgt_aff)
        else:
            to_vox_map = tgt_img

        for mask_img, name in [(brain_mask_img, 'brain_mask_to_peaks.nii.gz'),
                               (wm_img, 'wm_mask_to_peaks.nii.gz'),
                               (deep_img, 'deep_wm_mask_to_peaks.nii.gz')]:
            res = resample_from_to(mask_img, to_vox_map, order=0)
            outp = os.path.join(out_dir, name)
            # binarizar y guardar
            data_res = (res.get_fdata() > 0).astype(np.uint8)
            out_img = nib.Nifti1Image(data_res, res.affine)
            out_img.set_data_dtype(np.uint8)
            nib.save(out_img, outp)
            resampled_paths[name] = outp

    results = {
        'brain_mask': brain_mask_path,
        'fa': fa_path,
        'wm_mask': wm_path,
        'deep_wm_mask': deep_path,
        'resampled_masks': resampled_paths
    }
    return results


# ---------------------------
# Ejemplo de uso: ajusta rutas
# ---------------------------
dwi_path = '/home/riemann007/JupyterLab/Tesis/Datos Reales/fMONCHO_DWI_SEN_8B0_4x32_B2000_4x32_b2500.nii'
bval_path = '/home/riemann007/JupyterLab/Tesis/Datos Reales/fMONCHO_DWI_SEN_8B0_4x32_B2000_4x32_b2500.bval'
bvec_path = '/home/riemann007/JupyterLab/Tesis/Datos Reales/fMONCHO_DWI_SEN_8B0_4x32_B2000_4x32_b2500.bvec'
# peaks_path = '/home/riemann007/JupyterLab/Tesis/MLP/Train_1/generated_peaks_csd.nii.gz'  # tu peaks

out = make_masks_from_dwi(dwi_path, bval_path, bvec_path,
                          out_dir='masks_validation',
                          fa_thr=0.22,
                          deep_mm=5.0,
                          clean_iters=1,
                          resample_to_peaks_path=None)

print("Resultados guardados:", out)


Resultados guardados: {'brain_mask': 'masks_validation/brain_mask.nii.gz', 'fa': 'masks_validation/fa.nii.gz', 'wm_mask': 'masks_validation/wm_mask_fa0.22.nii.gz', 'deep_wm_mask': 'masks_validation/deep_wm_mask_5mm.nii.gz', 'resampled_masks': {}}
