In [4]:
import numpy as np
import nibabel as nib
import torch
from dipy.io.streamline import save_tractogram
from dipy.tracking.streamline import Streamlines
from dipy.io.stateful_tractogram import Space, StatefulTractogram


def extract_cubic_neighborhood(x, y, z, data, n=3):
    """Extrae vecindad 3D manejando bordes"""
    offset = n // 2
    if (x - offset < 0) or (x + offset + 1 > data.shape[0]):
        return None
    if (y - offset < 0) or (y + offset + 1 > data.shape[1]):
        return None
    if (z - offset < 0) or (z + offset + 1 > data.shape[2]):
        return None
    return data[x - offset:x + offset + 1, y - offset:y + offset + 1, z - offset:z + offset + 1]


class TractographyGenerator:
    def __init__(
        self,
        model_path,
        peaks_path,
        mask_path,
        nbh_path,
        deep_wm_mask_path=None,
        step_size=0.5
    ):
        # Cargar modelo
        self.model = self.load_model(model_path)
        self.model.eval()

        # Cargar datos de imagen
        self.peaks_img = nib.load(peaks_path)
        self.peaks_data = self.peaks_img.get_fdata()
        self.peaks_affine = self.peaks_img.affine

        # Máscara general y máscara de materia blanca profunda
        self.mask_data = nib.load(mask_path).get_fdata()
        if deep_wm_mask_path:
            # Cargar máscara de materia blanca profunda
            self.deep_mask_data = nib.load(deep_wm_mask_path).get_fdata().astype(bool)
        else:
            # Si no se proporciona, usar 
            # distancia a borde para generar máscara profunda (ejemplo 5 mm)
            from scipy.ndimage import distance_transform_edt
            binary = self.mask_data > 0
            # Suponiendo voxeles isotrópicos, obtener tamaño de voxel en mm
            voxel_sizes = np.sqrt((self.peaks_affine[:3, :3] ** 2).sum(axis=0))
            dist = distance_transform_edt(binary) * voxel_sizes[0]
            self.deep_mask_data = dist > 5.0

        self.step = step_size

        # Cargar parámetros de normalización
        train_data = np.load(nbh_path)
        self.X_mean = train_data['inputs'].mean(axis=0)
        self.X_std = train_data['inputs'].std(axis=0)

    def load_model(self, path):
        class MLP(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.layers = torch.nn.Sequential(
                    torch.nn.Linear(408, 512),
                    torch.nn.BatchNorm1d(512),
                    torch.nn.ReLU(),
                    torch.nn.Dropout(0.4),
                    torch.nn.Linear(512, 256),
                    torch.nn.BatchNorm1d(256),
                    torch.nn.ReLU(),
                    torch.nn.Dropout(0.3),
                    torch.nn.Linear(256, 128),
                    torch.nn.BatchNorm1d(128),
                    torch.nn.ReLU(),
                    torch.nn.Linear(128, 3)
                )

            def forward(self, x):
                return self.layers(x)

        model = MLP()
        model.load_state_dict(torch.load(path))
        return model

    def generate_seeds(self, num_seeds):
        """Genera puntos semilla aleatorios en materia blanca profunda"""
        indices = np.argwhere(self.deep_mask_data)
        selected = np.random.choice(len(indices), size=num_seeds, replace=False)
        return [nib.affines.apply_affine(self.peaks_affine, idx) for idx in indices[selected]]

    def is_in_mask(self, point):
        """Verifica si un punto está dentro de la máscara"""
        voxel = nib.affines.apply_affine(np.linalg.inv(self.peaks_affine), point)
        x,y,z = np.round(voxel).astype(int)
        if 0 <= x < self.mask_data.shape[0] and \
           0 <= y < self.mask_data.shape[1] and \
           0 <= z < self.mask_data.shape[2]:
            return self.mask_data[x,y,z] > 0
        return False
    
    def predict_direction(self, current_point, prev_dir):
        """Predice la siguiente dirección usando el modelo"""
        voxel = nib.affines.apply_affine(np.linalg.inv(self.peaks_affine), current_point)
        x,y,z = np.round(voxel).astype(int)
        
        # Extraer vecindad
        neighborhood = extract_cubic_neighborhood(x,y,z, self.peaks_data)
        if neighborhood is None:
            return None
        
        # Crear input del modelo
        input_sample = np.concatenate([
            neighborhood.flatten(),
            prev_dir.flatten()
        ])
        
        # Normalizar
        input_norm = (input_sample - self.X_mean) / (self.X_std + 1e-8)
        
        # Predecir
        with torch.no_grad():
            tensor_input = torch.FloatTensor(input_norm).unsqueeze(0)
            direction = self.model(tensor_input).numpy().squeeze()
        
        return direction / np.linalg.norm(direction)

    def track(self, seed, max_steps=1000):
        streamline = [seed]
        current_point = np.array(seed, dtype=np.float32)
        
        voxel = nib.affines.apply_affine(np.linalg.inv(self.peaks_affine), seed)
        x,y,z = np.round(voxel).astype(int)
        peaks = self.peaks_data[x,y,z].reshape(5,3)
        prev_dir = peaks[0] / np.linalg.norm(peaks[0])+110
        
        # Configuración de umbral angular (60 grados en radianes)
        max_angle_radians = 1.0472  # 60° = 1.0472 rad
    
        for _ in range(max_steps):
            new_dir = self.predict_direction(current_point, prev_dir)
            if new_dir is None:
                break
            
            # Calcular ángulo entre dirección anterior y nueva
            cos_angle = np.dot(prev_dir, new_dir)
            if cos_angle > 1.0: cos_angle = 1.0
            if cos_angle < -1.0: cos_angle = -1.0
            angle = np.arccos(cos_angle)
            
            # Detener si el ángulo excede el umbral
            if angle > max_angle_radians:
                break
                
            # Actualizar posición
            new_point = current_point + self.step * new_dir
            
            if not self.is_in_mask(new_point):
                break
            if self.check_self_intersection(streamline, new_point):
                break
                
            streamline.append(new_point.copy())
            current_point = new_point
            prev_dir = new_dir  # Usar new_dir directamente (ya normalizado)
        
        return np.array(streamline)
    
    def check_self_intersection(self, streamline, point, threshold=0.1):
        """Detecta auto-intersecciones en el streamline"""
        if len(streamline) < 10:
            return False
        distances = np.linalg.norm(np.array(streamline[-10:]) - point, axis=1)
        return np.any(distances < threshold)
    
    def generate_tractogram(self, num_streamlines, output_path):
        """Genera múltiples streamlines y guarda el tractograma"""
        seeds = self.generate_seeds(num_streamlines)
        streamlines = []
        
        for seed in seeds:
            s = self.track(seed)
            if len(s) > 4:  # Descarta streamlines muy cortos
                streamlines.append(s)
        
        # Guardar en formato .tck
        sft = StatefulTractogram(streamlines, self.peaks_img, Space.RASMM)
        save_tractogram(sft, output_path, bbox_valid_check=False)




## Prueba con datos de entrenamiento

In [15]:
# Uso del código
generator = TractographyGenerator(
    model_path = '/home/riemann007/JupyterLab/Tesis/MLP/Train_1/best_model.pth',
    peaks_path = '/home/riemann007/JupyterLab/Tesis/Proyecto/Datos/Entrenamiento/ISMRM Challenge 2022/ismrm2015_withReversed__peaks.nii.gz',
    mask_path = '/home/riemann007/JupyterLab/Tesis/Proyecto/Datos/Entrenamiento/ISMRM Challenge 2022/ismrm2015_withReversed__local_seeding_mask.nii.gz',
    nbh_path = '/home/riemann007/JupyterLab/Tesis/Proyecto/Datos/Entrenamiento/Vecindades/data_100000_n3_k1.npz',
    step_size = 0.5
)

num_sl = 500000

generator.generate_tractogram(
    num_streamlines=num_sl,
    output_path=f'train_MLP1_{num_sl}.tck'
)

ValueError: Cannot take a larger sample than population when 'replace=False'

## Prueba con datos de validación

In [5]:
# Uso del código
generator = TractographyGenerator(
    model_path = '/home/riemann007/JupyterLab/Tesis/MLP/Train_1/best_model.pth',
    peaks_path = '/home/riemann007/JupyterLab/Tesis/MLP/Train_1/generated_peaks_csd.nii.gz',
    mask_path = '/home/riemann007/JupyterLab/Tesis/masks_validation/wm_mask_fa0.22.nii.gz',
    nbh_path = '/home/riemann007/JupyterLab/Tesis/Proyecto/Datos/Entrenamiento/Vecindades/data_100000_n3_k1.npz',
    step_size = 0.5
)

num_sl = 100

generator.generate_tractogram(
    num_streamlines=num_sl,
    output_path=f'validacion0.1_MLP1_{num_sl}.tck'
)

In [18]:
import time

num_sl = 40000

t1 = time.time()

generator.generate_tractogram(
    num_streamlines=num_sl,
    output_path=f'validacion0.1_MLP1_{num_sl}.tck'
)

t2 = time.time()



Tiempo de ejecución en hrs:  6.319953709178501


In [19]:
print("Tiempo de ejecución en hrs: ", (t2-t1)/3600)

Tiempo de ejecución en hrs:  0.6319953709178501
