In [28]:
#pip install connected-components-3d monai

In [29]:
"""This module contains functions for preprocessing MRI images and segmentations."""

import numpy as np
from skimage import exposure

def znorm_rescale(img):
    """Applies Z-score normalization and rescaling to a MRI image."""

    # Z-score norm
    movingNan=np.copy(img)
    movingNan[movingNan==0]=np.nan
    movingMean=np.nanmean(movingNan)
    movingSTD=np.nanstd(movingNan)
    moving=(img-movingMean)/movingSTD
    b=255/(1-(moving.max()/moving.min()))
    a=-b/moving.min()
    movingNorm=np.copy(moving)
    movingNorm=np.round((movingNorm*a)+b,2)

    # Rescaling
    p2, p98 = np.percentile(movingNorm, (1, 99)) # These parameters may not be optimal, further testing could be done
    moving_rescale = exposure.rescale_intensity(movingNorm, in_range=(p2, p98))

    return moving_rescale

# Crop ranges for center crop.
X_START, X_END, Y_START, Y_END, Z_START, Z_END = (56,184, 24,216, 14,142)

def center_crop(img):
    """Center crops a MRI image (or seg) to be (128, 192, 128)."""
    return img[X_START:X_END, Y_START:Y_END, Z_START:Z_END]

def undo_center_crop(input):
    """Undos center crop of a MRI image (or seg)."""
    out = np.zeros((240, 240, 155))
    out[X_START:X_END, Y_START:Y_END, Z_START:Z_END] = input 
    return out

postprocess.py

In [67]:
"""Enhanced postprocessing module with adaptive thresholds and morphology-aware processing"""

import numpy as np
import cc3d
from skimage.morphology import dilation, ball

def adaptive_rm_dust(pred_mat):
    """Enhanced dust removal with adaptive thresholds and morphology preservation"""
    # Initial conservative pass to preserve small structures
    pred_mat_clean = cc3d.dust(pred_mat, threshold=70, connectivity=26)
    
    # Secondary pass with label-specific processing
    for label in [1, 2, 3]:  # Process each label separately
        if label in pred_mat_clean:
            label_mask = (pred_mat_clean == label)
            
            # Label-specific thresholds
            thresholds = {1: 35, 2: 60, 3: 30}  # TC, WT, ET respectively
            structure = ball(1) if label == 3 else ball(2)  # Smaller for ET
            
            # Morphological closing to preserve structure
            label_mask = dilation(label_mask, structure)
            cleaned = cc3d.dust(label_mask, threshold=thresholds[label], connectivity=26)
            pred_mat_clean[np.logical_and(pred_mat_clean == label, ~cleaned)] = 0
    
    return pred_mat_clean

def get_tissue_wise_seg(pred_mat, tissue_type, dilation_size=0):
    """Enhanced with optional morphological dilation to bridge small gaps"""
    mask = np.zeros_like(pred_mat)
    
    if tissue_type == 'WT':
        mask = pred_mat > 0
    elif tissue_type == 'TC':
        mask = np.logical_or(pred_mat == 1, pred_mat == 3)
    elif tissue_type == 'ET':
        mask = pred_mat == 3
    
    if dilation_size > 0:
        mask = dilation(mask, ball(dilation_size))
    
    return mask.astype(np.uint16)

def enhanced_rm_tt_dust(pred_mat, tt):
    """Enhanced dust removal with adaptive morphology"""
    # Strategy parameters
    strategies = {
        'ET': {
            'threshold': 10,  # Very sensitive for ET
            'connectivity': 6,
            'dilation': 1,   # Small dilation to bridge gaps
            'min_volume': 10 # Minimum volume to preserve
        },
        'TC': {
            'threshold': 40,
            'connectivity': 18,
            'dilation': 0,
            'min_volume': 20
        },
        'WT': {
            'threshold': 60,
            'connectivity': 26,
            'dilation': 0,
            'min_volume': 30
        }
    }
    
    params = strategies[tt]
    pred_mat_tt = get_tissue_wise_seg(pred_mat, tt, params['dilation'])
    
    # Two-stage dust removal
    temp_clean = cc3d.dust(
        pred_mat_tt,
        threshold=params['min_volume'],
        connectivity=params['connectivity']
    )
    final_clean = cc3d.dust(
        temp_clean,
        threshold=params['threshold'],
        connectivity=params['connectivity']
    )
    
    rm_dust_mask = np.logical_and(pred_mat_tt==1, final_clean==0)
    pred_mat[rm_dust_mask] = 0
    return rm_dust_mask

def enhanced_fill_holes(pred_mat, tt, label, rm_dust_mask):
    """Enhanced hole filling with structure preservation"""
    hole_params = {
        'ET': {'threshold': 10, 'connectivity': 6, 'max_hole_size': 15},
        'TC': {'threshold': 20, 'connectivity': 18, 'max_hole_size': 30},
        'WT': {'threshold': 40, 'connectivity': 26, 'max_hole_size': 50}
    }
    
    params = hole_params[tt]
    pred_mat_tt = get_tissue_wise_seg(pred_mat, tt)
    
    # Detect holes more precisely
    tt_holes = 1 - pred_mat_tt
    tt_holes_rm = cc3d.dust(
        tt_holes,
        threshold=params['threshold'],
        connectivity=params['connectivity']
    )
    
    # Remove overly large holes
    large_holes = cc3d.dust(
        tt_holes_rm,
        threshold=params['max_hole_size'],
        connectivity=params['connectivity']
    )
    tt_holes_rm[large_holes > 0] = 0
    
    tt_filled = 1 - tt_holes_rm
    holes_mask = np.logical_and.reduce((
        tt_filled == 1,
        pred_mat == 0,
        rm_dust_mask,
        cc3d.dust(tt_filled, threshold=5, connectivity=6) > 0  # Ensure connected
    ))
    pred_mat[holes_mask] = label

def rm_dust_fh(pred_mat):
    """Optimized processing pipeline with ET-first strategy"""
    # Initial cleanup
    pred_mat = adaptive_rm_dust(pred_mat)
    
    # ET-specific processing
    rm_et_mask = enhanced_rm_tt_dust(pred_mat, 'ET')
    enhanced_fill_holes(pred_mat, 'TC', 1, rm_et_mask)
    
    # TC processing
    rm_tc_mask = enhanced_rm_tt_dust(pred_mat, 'TC')
    enhanced_fill_holes(pred_mat, 'WT', 2, rm_tc_mask)
    
    # Final WT processing
    _ = enhanced_rm_tt_dust(pred_mat, 'WT')
    
    # Final morphological smoothing
    for label in [1, 2, 3]:
        if label in pred_mat:
            label_mask = (pred_mat == label)
            pred_mat[label_mask & ~cc3d.dust(label_mask, threshold=10, connectivity=6)] = 0
    
    return pred_mat

brats_dataset.py

In [31]:
from torch.utils.data import Dataset
import os
import nibabel as nib
import numpy as np
import torch
import random
from scipy.ndimage import map_coordinates, gaussian_filter, zoom
#from ..processing.preprocess import znorm_rescale, center_crop

class BratsDataset(Dataset):
    """Dataset class for loading BraTS training and test data with advanced augmentations.
    
    Args:
        data_dir: Directory of training or test data.
        mode: Either 'train' or 'test' specifying which data is being loaded.
        augment: Whether to apply data augmentations (only for training mode).
    """
    
    def __init__(self, data_dir, mode, augment=False, max_subjects=None):
        """Initialize the dataset with augmentation parameters."""
        self.data_dir = data_dir
        self.subject_list = sorted(os.listdir(data_dir))  # sorted for reproducibility
        if max_subjects is not None:
            self.subject_list = self.subject_list[:max_subjects]
        self.mode = mode
        self.augment = augment and mode == 'train'  # Only augment training data
        
        # Enhanced augmentation parameters for stronger augmentation
        self.flip_prob = 0.7          # Increased probability of applying random flips
        self.rotation_range = 20       # Increased degrees (± range for random rotations)
        self.noise_std = 0.15          # Increased standard deviation of Gaussian noise
        self.gamma_range = (0.6, 1.4)  # Wider range for gamma correction
        self.elastic_alpha = (0., 1200.)  # Increased magnitude range for elastic deformation
        self.elastic_sigma = (8., 15.)   # Wider smoothness range for elastic deformation
        self.bias_field_scale = 0.4      # Increased strength of bias field artifact
        
        # New augmentation parameters for stronger augmentation
        self.scale_range = (0.85, 1.15)  # Random scaling factor
        self.contrast_range = (0.7, 1.3)  # Contrast adjustment range
        self.brightness_range = (-0.1, 0.1)  # Brightness adjustment range
        self.blur_prob = 0.2           # Probability of applying Gaussian blur
        self.blur_sigma = (0.5, 1.5)   # Blur sigma range
        self.cutout_prob = 0.15        # Probability of applying cutout
        self.cutout_size = (10, 30)    # Cutout size range
        
        # Modality-specific augmentation parameters
        self.modality_names = ['t1c', 't1n', 't2f', 't2w']
        self.modality_specific_prob = 0.3  # Probability of applying modality-specific augmentations
        
        # T1c-specific (contrast-enhanced): More aggressive contrast/brightness
        self.t1c_contrast_range = (0.6, 1.4)
        self.t1c_brightness_range = (-0.15, 0.15)
        
        # T1n-specific (native): More noise augmentation
        self.t1n_noise_multiplier = 1.5
        
        # T2f-specific (FLAIR): More bias field and gamma correction
        self.t2f_bias_multiplier = 1.3
        self.t2f_gamma_range = (0.5, 1.5)
        
        # T2w-specific: More blur and elastic deformation
        self.t2w_blur_multiplier = 1.5
        self.t2w_elastic_multiplier = 1.2

    def __len__(self):
        return len(self.subject_list)
    
    def load_nifti(self, subject_name, suffix):
        """Loads nifti file for given subject and suffix.
        
        Args:
            subject_name: Name of the subject directory.
            suffix: Modality suffix (e.g., 't1c', 'seg').
            
        Returns:
            Loaded nibabel nifti object.
        """
        nifti_filename = f'{subject_name}-{suffix}.nii'
        nifti_path = os.path.join(self.data_dir, subject_name, nifti_filename)
        return nib.load(nifti_path)
    
    def load_subject_data(self, subject_name):
        """Loads images and segmentation (if in train mode) for a subject.
        
        Args:
            subject_name: Name of the subject directory.
            
        Returns:
            For training: tuple of (modalities_data, seg_data)
            For testing: modalities_data
        """
        modalities_data = []
        for suffix in self.modality_names:  # All 4 standard BraTS modalities
            modality_data = self.load_nifti(subject_name, suffix).get_fdata()
            modalities_data.append(modality_data)

        if self.mode == 'train':
            seg_data = self.load_nifti(subject_name, 'seg').get_fdata()
            return modalities_data, seg_data
        return modalities_data
    
    def apply_augmentations(self, imgs, seg=None):
        """Apply random augmentations to images and segmentation with stronger augmentation.
        
        Args:
            imgs: List of modality images.
            seg: Optional segmentation mask.
            
        Returns:
            Augmented images and segmentation (if provided).
        """
        if not self.augment:
            return imgs, seg
            
        # Random flips (increased probability)
        if random.random() < self.flip_prob:
            axis = random.randint(0, 2)  # Random axis (0, 1, or 2)
            imgs = [np.flip(img, axis=axis) for img in imgs]
            if seg is not None:
                seg = np.flip(seg, axis=axis)
        
        # Random scaling (new augmentation)
        if random.random() < 0.4:
            scale_factor = random.uniform(*self.scale_range)
            imgs, seg = self.scale_images(imgs, seg, scale_factor)
                
        # Random rotation (increased probability)
        if random.random() < 0.4:
            angle = random.uniform(-self.rotation_range, self.rotation_range)
            imgs = [self.rotate_image(img, angle) for img in imgs]
            if seg is not None:
                seg = self.rotate_image(seg, angle, is_seg=True)
                
        # Elastic deformations (increased probability)
        if random.random() < 0.4:
            imgs, seg = self.elastic_deform(imgs, seg)
        
        # Cutout augmentation (new)
        if random.random() < self.cutout_prob:
            imgs = self.apply_cutout(imgs)
                
        # Intensity transformations (increased probabilities)
        for i in range(len(imgs)):
            modality = self.modality_names[i]
            
            # Standard intensity augmentations
            if random.random() < 0.4:  # Gaussian noise
                noise_std = self.noise_std
                if modality == 't1n':  # More noise for T1n
                    noise_std *= self.t1n_noise_multiplier
                noise = np.random.normal(0, noise_std, imgs[i].shape)
                imgs[i] = imgs[i] + noise
                
            if random.random() < 0.4:  # Gamma correction
                gamma_range = self.gamma_range
                if modality == 't2f':  # More aggressive gamma for T2f
                    gamma_range = self.t2f_gamma_range
                gamma = random.uniform(*gamma_range)
                imgs[i] = np.sign(imgs[i]) * (np.abs(imgs[i]) ** gamma)
                
            if random.random() < 0.4:  # Bias field artifact
                bias_scale = self.bias_field_scale
                if modality == 't2f':  # Stronger bias field for T2f
                    bias_scale *= self.t2f_bias_multiplier
                imgs[i] = self.add_bias_field(imgs[i], bias_scale)
            
            # New intensity augmentations
            if random.random() < 0.3:  # Contrast adjustment
                contrast_range = self.contrast_range
                if modality == 't1c':  # More aggressive contrast for T1c
                    contrast_range = self.t1c_contrast_range
                contrast = random.uniform(*contrast_range)
                imgs[i] = imgs[i] * contrast
                
            if random.random() < 0.3:  # Brightness adjustment
                brightness_range = self.brightness_range
                if modality == 't1c':  # More aggressive brightness for T1c
                    brightness_range = self.t1c_brightness_range
                brightness = random.uniform(*brightness_range)
                imgs[i] = imgs[i] + brightness
                
            if random.random() < self.blur_prob:  # Gaussian blur
                blur_sigma = random.uniform(*self.blur_sigma)
                if modality == 't2w':  # More blur for T2w
                    blur_sigma *= self.t2w_blur_multiplier
                imgs[i] = gaussian_filter(imgs[i], sigma=blur_sigma)
            
            # Apply modality-specific augmentations
            if random.random() < self.modality_specific_prob:
                imgs[i] = self.apply_modality_specific_augmentation(imgs[i], modality)
                
        return imgs, seg
    
    def apply_modality_specific_augmentation(self, img, modality):
        """Apply modality-specific augmentations.
        
        Args:
            img: Input image.
            modality: Modality name ('t1c', 't1n', 't2f', 't2w').
            
        Returns:
            Augmented image.
        """
        if modality == 't1c':
            # T1c: Simulate contrast agent variations
            if random.random() < 0.5:
                # Simulate uneven contrast enhancement
                enhancement_field = self.create_enhancement_field(img.shape)
                img = img * enhancement_field
                
        elif modality == 't1n':
            # T1n: Add more complex noise patterns
            if random.random() < 0.5:
                # Add Rician noise (common in MRI)
                img = self.add_rician_noise(img)
                
        elif modality == 't2f':
            # T2f: Simulate CSF flow artifacts
            if random.random() < 0.5:
                # Add flow artifacts
                img = self.add_flow_artifacts(img)
                
        elif modality == 't2w':
            # T2w: Add motion artifacts
            if random.random() < 0.5:
                # Simulate motion artifacts
                img = self.add_motion_artifacts(img)
                
        return img
    
    def create_enhancement_field(self, shape):
        """Create a random enhancement field for T1c modality."""
        # Create low-frequency field
        field = np.random.randn(*[s//8 + 1 for s in shape])
        field = zoom(field, 
                    [shape[0]/field.shape[0], 
                     shape[1]/field.shape[1],
                     shape[2]/field.shape[2]], 
                    order=1)
        return 1.0 + 0.3 * np.tanh(field)
    
    def add_rician_noise(self, img):
        """Add Rician noise to image."""
        noise_level = 0.05 * np.std(img)
        noise_real = np.random.normal(0, noise_level, img.shape)
        noise_imag = np.random.normal(0, noise_level, img.shape)
        return np.sqrt((img + noise_real)**2 + noise_imag**2)
    
    def add_flow_artifacts(self, img):
        """Add CSF flow artifacts to FLAIR images."""
        # Create periodic artifacts
        z_coords = np.arange(img.shape[2])
        artifact_pattern = 0.1 * np.sin(2 * np.pi * z_coords / 10)
        artifact_field = np.broadcast_to(artifact_pattern, img.shape)
        return img + img * artifact_field
    
    def add_motion_artifacts(self, img):
        """Add motion artifacts to images."""
        # Simulate motion by applying small random translations
        shift_x = random.uniform(-2, 2)
        shift_y = random.uniform(-2, 2)
        from scipy.ndimage import shift
        return shift(img, [shift_x, shift_y, 0], mode='constant')
    
    def scale_images(self, imgs, seg, scale_factor):
        """Scale images and segmentation by a factor."""
        scaled_imgs = []
        for img in imgs:
            scaled_img = zoom(img, scale_factor, order=3)
            # Crop or pad to original size
            scaled_img = self.resize_to_original(scaled_img, img.shape)
            scaled_imgs.append(scaled_img)
        
        if seg is not None:
            scaled_seg = zoom(seg, scale_factor, order=0)
            seg = self.resize_to_original(scaled_seg, seg.shape)
            
        return scaled_imgs, seg
    
    def resize_to_original(self, img, target_shape):
        """Resize image to target shape by cropping or padding."""
        current_shape = img.shape
        
        # Calculate padding/cropping for each dimension
        result = img.copy()
        for i in range(len(target_shape)):
            diff = target_shape[i] - current_shape[i]
            if diff > 0:  # Need to pad
                pad_before = diff // 2
                pad_after = diff - pad_before
                pad_width = [(0, 0)] * len(target_shape)
                pad_width[i] = (pad_before, pad_after)
                result = np.pad(result, pad_width, mode='constant')
            elif diff < 0:  # Need to crop
                crop_before = (-diff) // 2
                crop_after = current_shape[i] - (-diff) + crop_before
                slices = [slice(None)] * len(target_shape)
                slices[i] = slice(crop_before, crop_after)
                result = result[tuple(slices)]
        
        return result
    
    def apply_cutout(self, imgs):
        """Apply cutout augmentation to images."""
        cutout_imgs = []
        for img in imgs:
            img_copy = img.copy()
            
            # Random cutout parameters
            cutout_size = random.randint(*self.cutout_size)
            x = random.randint(0, max(1, img.shape[0] - cutout_size))
            y = random.randint(0, max(1, img.shape[1] - cutout_size))
            z = random.randint(0, max(1, img.shape[2] - cutout_size))
            
            # Apply cutout
            img_copy[x:x+cutout_size, y:y+cutout_size, z:z+cutout_size] = 0
            cutout_imgs.append(img_copy)
            
        return cutout_imgs
        
    def elastic_deform(self, imgs, seg):
        """Apply elastic deformation to images and segmentation with stronger deformation.
        
        Args:
            imgs: List of modality images.
            seg: Optional segmentation mask.
            
        Returns:
            Deformed images and segmentation (if provided).
        """
        shape = imgs[0].shape
        alpha = random.uniform(*self.elastic_alpha)  # Deformation magnitude
        sigma = random.uniform(*self.elastic_sigma)  # Deformation smoothness
        
        # Create random displacement fields
        dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode='constant') * alpha
        dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode='constant') * alpha
        dz = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode='constant') * alpha

        # Create coordinate grid
        x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]), indexing='ij')
        coords = np.array([x + dx, y + dy, z + dz])
        
        # Apply deformation to each modality (cubic interpolation)
        deformed_imgs = [map_coordinates(img, coords, order=3, mode='reflect') for img in imgs]
        
        # Apply to segmentation (nearest neighbor interpolation)
        if seg is not None:
            seg = map_coordinates(seg, coords, order=0, mode='constant')
            
        return deformed_imgs, seg
        
    def add_bias_field(self, image, bias_scale=None):
        """Add MRI bias field artifact to an image.
        
        Args:
            image: Input image to modify.
            bias_scale: Optional bias field scale override.
            
        Returns:
            Image with simulated bias field.
        """
        if bias_scale is None:
            bias_scale = self.bias_field_scale
            
        shape = image.shape
        # Create low-frequency random field
        rand_field = np.random.randn(*[s//16 + 1 for s in shape])
        rand_field = zoom(rand_field, 
                         [shape[0]/rand_field.shape[0], 
                          shape[1]/rand_field.shape[1],
                          shape[2]/rand_field.shape[2]], 
                         order=1)
        
        # Create smooth multiplicative bias field
        bias_field = np.exp(bias_scale * rand_field)
        return image * bias_field
        
    def rotate_image(self, image, angle, is_seg=False):
        """Rotate image by specified angle around axial plane.
        
        Args:
            image: Input image to rotate.
            angle: Rotation angle in degrees.
            is_seg: Whether the image is a segmentation mask.
            
        Returns:
            Rotated image.
        """
        from scipy.ndimage import rotate
        axes = (0, 1)  # Rotate in axial plane
        order = 0 if is_seg else 3  # Nearest neighbor for seg, cubic for images
        return rotate(image, angle, axes=axes, reshape=False, order=order, mode='constant')
    
    def __getitem__(self, idx):
        """Load and process a single subject's data.
        
        Args:
            idx: Index of the subject to load.
            
        Returns:
            For training: (subject_name, modalities, segmentation)
            For testing: (subject_name, modalities)
        """
        subject_name = self.subject_list[idx]

        # Load the data
        if self.mode == 'train':
            imgs, seg = self.load_subject_data(subject_name)
        else:
            imgs = self.load_subject_data(subject_name)
            seg = None

        # Apply augmentations (only for training when augment=True)
        imgs, seg = self.apply_augmentations(imgs, seg)

        # Standard preprocessing
        imgs = [znorm_rescale(img) for img in imgs]  # Normalize each modality
        imgs = [center_crop(img) for img in imgs]    # Center crop
        
        # Convert to tensors
        imgs = [torch.from_numpy(img[None, ...].astype(np.float32)) for img in imgs]  # Add channel dim
        
        if self.mode == 'train':
            seg = center_crop(seg)
            seg = torch.from_numpy(seg[None, ...].astype(np.float32))  # Add channel dim
            return subject_name, imgs, seg
        
        return subject_name, imgs

general_utils.py

In [40]:

import torch
import numpy as np
import nibabel as nib
import os
#from ..processing.preprocess import undo_center_crop

def seg_to_one_hot_channels(seg):
    """Converts segmentation to 3 channels, each a one-hot encoding of a tumour region label.

    Args:
        seg: Tensor of shape B1HWD, where each entry is a voxel label.

    Returns:
        Tensor of shape B3HWD, where each channel is one-hot encoding of a disjoint region.
    """
    B, _, H, W, D = seg.shape
    seg3 = torch.zeros((B, 3, H, W, D), device=seg.device)
    
    # Squeeze the channel dimension for comparison
    seg = seg.squeeze(1)  # Now shape [B, H, W, D]
    
    for channel_value in [1, 2, 3]:
        # Create mask and add channel dimension back
        mask = (seg == channel_value).float().unsqueeze(1)  # Shape [B, 1, H, W, D]
        seg3[:, channel_value-1:channel_value, :, :, :] = mask
    
    return seg3

def disjoint_to_overlapping(seg_disjoint):
    """Converts tensor representing one-hot encoding of disjoint regions to that of overlapping ones.

    Args:
        seg_disjoint: Tensor of shape B3HWD, where each channel is one-hot encoding of a disjoint region. 

    Returns:
        Tensor of shape B3HWD, where each channel is one-hot encoding of an overlapping region.
    """
    seg_overlapping = torch.zeros_like(seg_disjoint)
    seg_overlapping[:,0] = seg_disjoint[:, 0] + seg_disjoint[:, 1] + seg_disjoint[:, 2] #WHOLE TUMOR
    seg_overlapping[:,1] = seg_disjoint[:, 0] + seg_disjoint[:, 2] #TUMOR CORE
    seg_overlapping[:,2] = seg_disjoint[:, 2] #ENHANCING TUMOR
    return seg_overlapping

def overlapping_probs_to_preds(output, t1=0.45, t2=0.4, t3=0.45):
    """Converts output of model trained on overlapping regions to one-hot encodings of disjoint regions.

    Args:
        output: Tensor of shape B3HWD. Output of model, representing probabilties each voxel belongs to each overlapping region.
        t1: Threshold for being in whole tumor (WT). Defaults to 0.45.
        t2: Threshold for being in tumor core (TC). Defaults to 0.4.
        t3: Threshold for being in enhancing tumor (ET). Defaults to 0.45.

    Returns:
        Tensor of shape B3HWD, where each channel is one-hot encoding of a disjoint region.
    """
    output = output.cpu().detach()
    c1, c2, c3 = output[:, 0] > t1, output[:, 1] > t2, output[:, 2] > t3
    preds = (c1 > 0).to(torch.uint8) # NCR
    preds[(c2 == False) * (c1 == True)] = 2 # ED
    preds[(c3 == True) * (c1 == True)] = 3 # ET
    output_plot = torch.zeros_like(output)
    output_plot[:, 0] = (preds == 1).to(torch.uint8) #NCR
    output_plot[:, 1] = (preds == 2).to(torch.uint8) #ED
    output_plot[:, 2] = (preds == 3).to(torch.uint8) #ET
    output_plot = output_plot.to(torch.uint8)
    return output_plot

def disjoint_probs_to_preds(output, t=0.5):
    """Converts output of model trained on disjoint regions to one-hot encodings of disjoint regions.

    Args:
        output: Tensor of shape B3HWD. Output of model, representing probabilties each voxel belongs to each disjoint region.
        t: Threshold value. If channel probability for a voxel is the maximum across all channels AND greater than this threshold, channel value will be encoded as 1, otherwise 0. Defaults to 0.5.

    Returns:
        Tensor of shape B3HWD, where each channel is one-hot encoding of a disjoint region.
    """
    output = output.cpu().detach()
    c1, c2, c3 = output[:, 0], output[:, 1], output[:, 2]
    max_label = torch.max(torch.max(c1, c2), c3)
    preds = torch.zeros_like(output)
    preds[:, 0] = torch.where(c1 < max_label, torch.tensor(0), max_label)
    preds[:, 1] = torch.where(c2 < max_label, torch.tensor(0), max_label)
    preds[:, 2] = torch.where(c3 < max_label, torch.tensor(0), max_label)
    output_plot = torch.zeros_like(output)
    for c in range(0, 3):
        output_plot[:, c] = torch.where(preds[:, c] > t, torch.tensor(1.), torch.tensor(0.))
    output_plot = output_plot.to(torch.uint8)
    return output_plot

def probs_to_preds(output, training_regions):
    """Converts tensor of voxel probabilities to tensor of disjoint region labels.

    Args:
        output: Tensor of shape B3HWD. Output of model, representing probabilties each voxel belongs to a region.
        training_regions: Whether probabilities relate to overlapping or disjoint regions.

    Returns:
        Tensor of shape B3HWD, where each channel is one-hot encoding of a disjoint region.
    """
    if training_regions == 'overlapping':
        preds = overlapping_probs_to_preds(output)
    elif training_regions == 'disjoint':
        preds = disjoint_probs_to_preds(output)

    return preds

def fetch_affine_header(subject_name, data_dir):
    """Finds affine and header of a modality nifti for given subject.

    Args:
        subject_name: Name of given subject. Will also be name of folder containing MRI niftis.
        data_dir: Parent directory of subject data folder.

    Returns:
        The affine and header objects from a modality nifti of the subject.
    """

    modality_nifti_filename = f'{subject_name}-t1c.nii'
    modality_nifti_path = os.path.join(data_dir, subject_name, modality_nifti_filename)
    nifti = nib.load(modality_nifti_path)
    
    return nifti.affine, nifti.header

def one_hot_channels_to_three_labels(pred):
    """Converts tensor of one-hot encodings of disjoint regions to be single channel. Each voxel is assigned a single disjoint region label.

    Args:
        pred: Array-like of shape 3HWD, where channels are one-hot encodings of disjoint regions.

    Returns:
        Array-like of shape HWD, associating to each voxel a single disjoint region label.
    """
    return pred[0] + pred[1]*2 + pred[2]*3

def save_pred_as_nifti(pred, save_dir, data_dir, subject_name, postprocess_function=None):
    """Saves predicted segmentation as nifti file with affine and header objects matching its MRI niftis.

    Args:
        pred: Tensor of shape HWD, associating to each voxel a single disjoint region label.
        save_dir: Directory in which to save the predicted segmentation nifti.
        data_dir: Parent directory of subject data folder.
        subject_name: Name of given subject. Will also be name of folder containing MRI niftis.
        postprocess_function: If provided, performs this postprocessing on the prediction. Defaults to None.
    """
    # Convert back from 3 one-hot encoded channels to 1 channel with 3 tumour region labels
    pred = np.array(pred)
    pred_for_nifti = one_hot_channels_to_three_labels(pred)
    pred_for_nifti = np.squeeze(pred_for_nifti)
    pred_for_nifti = undo_center_crop(pred_for_nifti)
    pred_for_nifti = pred_for_nifti.astype(np.uint8)

    if postprocess_function:
        pred_for_nifti = postprocess_function(pred_for_nifti)

    affine, header = fetch_affine_header(subject_name, data_dir)
    pred_nifti = nib.nifti1.Nifti1Image(pred_for_nifti, affine=affine, header=header)
    filename = f'{subject_name}-seg.nii'
    nib.nifti1.save(pred_nifti, os.path.join(save_dir, filename))

model_utils.py 

In [33]:
"""This module contains utility functions for training models."""

import os
import numpy as np
import torch
from torch.utils.data import DataLoader

#from ..datasets import brats_dataset
#from .general_utils import seg_to_one_hot_channels, disjoint_to_overlapping

def load_or_initialize_training(model, optimizer, latest_ckpt_path, train_with_val=False):
    """Loads training checkpoint if it exists, or initializes training from scratch.

    Args:
        model: The PyTorch model to be trained.
        optimizer: The optimizer used for training.
        latest_ckpt_path: The path to the latest model checkpoint.
        train_with_val: If True, also returns best saved validation loss and dice. Defaults to False.

    Returns:
        The starting epoch number.
        If 'train_with_val' is True, also returns best saved validation loss and dice.
    """

    if not os.path.exists(latest_ckpt_path):
        epoch_start = 1
        if train_with_val:
            best_vloss = float('inf')
            best_dice = 0
        print('No training checkpoint found. Will start training from scratch.')
    else:
        print('Training checkpoint found. Loading checkpoint...')
        checkpoint = torch.load(latest_ckpt_path,  weights_only=False)
        epoch_start = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model_sd'])
        optimizer.load_state_dict(checkpoint['optim_sd'])
        if train_with_val:
            best_vloss = checkpoint['vloss']
            best_dice = checkpoint['dice']
        print(f'Checkpoint loaded. Will continue training from epoch {epoch_start}.')

    if train_with_val:
        return epoch_start, best_vloss, best_dice
    return epoch_start

def make_dataloader(data_dir, shuffle, mode, augment, batch_size=1, max_subjects=None):
    dataset = BratsDataset(data_dir, mode=mode, augment=augment, max_subjects=max_subjects)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=True)
    return dataloader

def exp_decay_learning_rate(optimizer, epoch, init_lr, decay_rate):
    """Exponentially decays learning rate of optimizer at given epoch."""
    lr = init_lr * (decay_rate ** (epoch-1))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def compute_loss(output, seg, loss_functs, loss_weights):
    """Computes weighted loss between model output and ground truth, summed across each region."""
    loss = 0.
    for n, loss_function in enumerate(loss_functs):      
        temp = 0
        for i in range(3):
            temp += loss_function(output[:,i:i+1].cuda(), seg[:,i:i+1].cuda())

        loss += temp * loss_weights[n]
    return loss

def train_one_epoch(model, optimizer, train_loader, loss_functions, loss_weights, training_regions):
    """Performs one training loop of model according to given optimizer, loss functions and associated weights.

    Args:
        model: The PyTorch model to be trained.
        optimizer: The optimizer used for training.
        train_loader: The dataloader for training data.
        loss_functions: List of loss functions.
        loss_weights: List of associated weightings for each loss function.
        training_regions: String specifying whether 'disjoint' or 'overlapping' regions will be used for training.

    Returns:
        The average training loss over the epoch.
    """
    losses_over_epoch = []
    for _, imgs, seg in train_loader:
        model.train()

        # Move data to GPU.
        imgs = [img.cuda() for img in imgs] # img is B1HWD
        seg = seg.cuda()

        # Split segmentation into 3 channels.
        seg = seg_to_one_hot_channels(seg)
        # seg is B3HWD - each channel is one-hot encoding of a disjoint region

        if training_regions == 'overlapping':
            seg = disjoint_to_overlapping(seg)
            # seg is B3HWD - each channel is one-hot encoding of an overlapping region

        x_in = torch.cat(imgs, dim=1) # x_in is B4HWD
        outputs = model(x_in)
        
        # Handle both dictionary output (deep supervision) and single output cases
        if isinstance(outputs, dict):
            # Compute loss for each output and combine them
            loss = 0
            for output in outputs.values():
                output = output.float()
                loss += compute_loss(output, seg, loss_functions, loss_weights)
            # Normalize by number of outputs
            loss = loss / len(outputs)
        else:
            # Original single output case
            outputs = outputs.float()
            loss = compute_loss(outputs, seg, loss_functions, loss_weights)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses_over_epoch.append(loss.detach().cpu())

    # Compute loss from the epoch.
    average_epoch_loss = np.mean(losses_over_epoch)
    return average_epoch_loss



def freeze_layers(model, frozen_layers):
    """Freezes specified model layers. Afterwards parameters in these layers will not be updated when training.

    Args:
        model: The model to be trained.
        frozen_layers: List of strings specifying model layers.
    """

    for name, param in model.named_parameters():
        needs_freezing = False
        for layer in frozen_layers:
            if layer in name:
                needs_freezing = True
                break
        if needs_freezing:
            print(f'Freezing parameter {name}.')
            param.requires_grad = False

def check_frozen(model, frozen_layers):
    """Iterates through model layers and checks whether specified layers are frozen.

    Args:
        model: The model to be trained.
        frozen_layers: List of strings specifying model layers.
    """
    for name, param in model.named_parameters():
        needs_freezing = False
        for layer in frozen_layers:
            if layer in name:
                needs_freezing = True
                break
        if needs_freezing:
            if param.requires_grad:
                print(f'Warning! Param {name} should not require grad but does.')
                break
            else:
                print(f'Parameter {name} is frozen.')

# Example parts of unet_3d model to freeze
# 'encoder': ['Conv1', 'Conv2', 'Conv3', 'Conv4', 'Conv5', 'Conv6', 'Conv7'],
# 'decoder': ['Up6', 'Up_conv6', 'Up5', 'Up_conv5', 'Up4', 'Up_conv4', 'Up3', 'Up_conv3', 'Conv_1x13', 'Up2', 'Up_conv2', 'Conv_1x12', 'Up1', 'Up_conv1', 'Conv_1x11'],
# 'middle' : ['Conv5', 'Conv6', 'Conv7', 'Up6', 'Up_conv6', 'Up5', 'Up_conv5', 'Up4', 'Up_conv4'],
# 'none' : [],
# 'deep_decoder': ['Up6', 'Up_conv6', 'Up5', 'Up_conv5', 'Up4', 'Up_conv4']

enhanced3dunet.py 

In [34]:
"""Contains the architecture for the baseline 3D U-Net model, based on NVIDIA's optimized U-Net."""

import torch
import torch.nn as nn

# instance norm    
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out1,ch_out2,k1,k2,s1,s2):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.InstanceNorm3d(ch_in),
            nn.Conv3d(ch_in, ch_out1, kernel_size=k1,stride=s1,padding=1,bias=True),
            nn.LeakyReLU(inplace=True),
            nn.InstanceNorm3d(ch_out1),
            nn.Conv3d(ch_out1, ch_out2, kernel_size=k2,stride=s2,padding=1,bias=True),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x
      
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            #nn.ConvTranspose3d(ch_in, ch_out,kernel_size=2,stride=2,padding=1,bias=True)
            nn.ConvTranspose3d(ch_in, ch_out,kernel_size=2,stride=2,padding=1,bias=True,output_padding=1,dilation=2),
            #nn.Upsample(scale_factor=2),
            nn.InstanceNorm3d(ch_in),
 	        nn.LeakyReLU(inplace=True)
        )

    def forward(self,x):
        #print(x.device)
        x = self.up(x)
        
        return x

# 3d unet with optimized instance norm    
class U_Net3d(nn.Module):
    def __init__(self,img_ch=4,output_ch=3):
        super(U_Net3d,self).__init__()
        nf= 8
        self.Maxpool = nn.MaxPool3d(kernel_size=2,stride=2).to(device='cuda:0')
        self.Conv1 = conv_block(ch_in=img_ch,ch_out1=nf*2,ch_out2=nf*2,k1=3,k2=3,s1=1,s2=1).to(device='cuda:0')
        self.Conv2 = conv_block(ch_in=nf*2,ch_out1=nf*3,ch_out2=nf*3,k1=3,k2=3,s1=2,s2=1).to(device='cuda:0')
        self.Conv3 = conv_block(ch_in=nf*3,ch_out1=nf*4,ch_out2=nf*4,k1=3,k2=3,s1=2,s2=1).to(device='cuda:0')
        self.Conv4 = conv_block(ch_in=nf*4,ch_out1=nf*6,ch_out2=nf*6,k1=3,k2=3,s1=2,s2=1).to(device='cuda:0')
        self.Conv5 = conv_block(ch_in=nf*6,ch_out1=nf*8,ch_out2=nf*8,k1=3,k2=3,s1=2,s2=1).to(device='cuda:0')
        self.Conv6 = conv_block(ch_in=nf*8,ch_out1=nf*12,ch_out2=nf*12,k1=3,k2=3,s1=2,s2=1).to(device='cuda:0')
        self.Conv7 = conv_block(ch_in=nf*12,ch_out1=nf*16,ch_out2=nf*16,k1=3,k2=3,s1=2,s2=1).to(device='cuda:0')

        self.Up6 = up_conv(ch_in=nf*16,ch_out=nf*12).to(device='cuda:0')
        self.Up_conv6 = conv_block(ch_in=nf*24, ch_out1=nf*12, ch_out2=nf*12,k1=3,k2=3,s1=1,s2=1).to(device='cuda:0')
        
        self.Up5 = up_conv(ch_in=nf*12,ch_out=nf*8).to(device='cuda:0')
        self.Up_conv5 = conv_block(ch_in=nf*16, ch_out1=nf*8, ch_out2=nf*8,k1=3,k2=3,s1=1,s2=1).to(device='cuda:0')

        self.Up4 = up_conv(ch_in=nf*8,ch_out=nf*6).to(device='cuda:0')
        self.Up_conv4 = conv_block(ch_in=nf*12, ch_out1=nf*6, ch_out2=nf*6,k1=3,k2=3,s1=1,s2=1).to(device='cuda:0')
        
        self.Up3 = up_conv(ch_in=nf*6,ch_out=nf*4).to(device='cuda:0')
        self.Up_conv3 = conv_block(ch_in=nf*8, ch_out1=nf*4,ch_out2=nf*4,k1=3,k2=3,s1=1,s2=1).to(device='cuda:0')
        self.Conv_1x13 = nn.Conv3d(nf*4,output_ch,kernel_size=1,stride=1,padding=0).to(device='cuda:0')
        
        self.Up2 = up_conv(ch_in=output_ch,ch_out=nf*3).to(device='cuda:0')
        self.Up_conv2 = conv_block(ch_in=nf*6 , ch_out1=nf*3,ch_out2=nf*3,k1=3,k2=3,s1=1,s2=1).to(device='cuda:0')
        self.Conv_1x12 = nn.Conv3d(nf*3,output_ch,kernel_size=1,stride=1,padding=0).to(device='cuda:0')
        
        self.Up1 = up_conv(ch_in=output_ch,ch_out=nf*2).to(device='cuda:0')
        self.Up_conv1 = conv_block(ch_in=nf*4, ch_out1=nf*2,ch_out2=nf*2,k1=3,k2=3,s1=1,s2=1).to(device='cuda:0')
        self.Conv_1x11 = nn.Conv3d(nf*2,output_ch,kernel_size=1,stride=1,padding=0).to(device='cuda:0')
        
        self.Sig =nn.Sigmoid().to(device='cuda:0')


    def forward(self,x):
        # encoding path
        x = x.to(device='cuda:0')
        x1 = self.Conv1(x)
 
        x2 = self.Conv2(x1)       
        
        x3 = self.Conv3(x2)
       
        x4 = self.Conv4(x3)       
       
        x5 = self.Conv5(x4)       
        
        x6 = self.Conv6(x5)

        x7 = self.Conv7(x6)
        

        # decoding + concat path
       
        d6 = self.Up6(x7.to(device='cuda:0'))
        d6 = torch.cat((x6.to(device='cuda:0'),d6),dim=1)
        d6 = self.Up_conv6(d6)

        d5 = self.Up5(d6)
        d5 = torch.cat((x5.to(device='cuda:0'),d5),dim=1)
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x4.to(device='cuda:0'),d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x3.to(device='cuda:0'),d3),dim=1)
        d3 = self.Up_conv3(d3)
        d3 = self.Conv_1x13(d3)
        d3 = self.Sig(d3)
        
        d2 = self.Up2(d3)
        d2 = torch.cat((x2.to(device='cuda:0'),d2),dim=1)
        d2 = self.Up_conv2(d2)
        d2 = self.Conv_1x12(d2)
        d2 = self.Sig(d2)
        
        d1 = self.Up1(d2)
        d1 = torch.cat((x1.to(device='cuda:0'),d1),dim=1)
        d1 = self.Up_conv1(d1)
        d1 = self.Conv_1x11(d1)
        d1 = self.Sig(d1)

        return d1.to(device='cuda:0')
    
    def __str__(self):
        num_params = sum(p.numel() for p in self.parameters())
        return f"UNet_3d (with {num_params:,} parameters)"

train.py

In [None]:
import os
import numpy as np
import torch 
from torch import optim
import csv
import torch.nn as nn

from monai.losses import DiceLoss, FocalLoss, DiceCELoss

#from ..utils.model_utils import load_or_initialize_training, make_dataloader, exp_decay_learning_rate, train_one_epoch
    
def train(data_dir, model, loss_functions, loss_weights, init_lr, max_epoch, training_regions='overlapping', out_dir=None, decay_rate=0.995, backup_interval=10, batch_size=2, scheduler_patience=5):
    """Runs basic training routine.

    Args:
        data_dir: Directory of training data.
        model: The PyTorch model to be trained.
        loss_functions: List of loss functions to be used for training.
        loss_weights: List of weights corresponding to each loss function.
        init_lr: Initial value of learning rate.
        max_epoch: Maximum number of epochs to train for.
        training_regions: Whether training on 'disjoint' or 'overlapping' regions. Defaults to 'overlapping'.
        out_dir: The directory to save model checkpoints and loss values. Defaults to None.
        decay_rate: Rate at which to decay the learning rate. Defaults to 0.995.
        backup_interval: How often to save a backup checkpoint. Defaults to 10.
        batch_size: Batch size of dataloader. Defaults to 1.
        scheduler_patience: Patience for ReduceLROnPlateau scheduler. Defaults to 5.
    """
    # Set up directories and paths.
    if out_dir is None:
        out_dir = os.getcwd()
    latest_ckpt_path = os.path.join(out_dir, 'latest_ckpt.pth.tar')
    training_loss_path = os.path.join(out_dir, 'training_loss.csv')
    backup_ckpts_dir = os.path.join(out_dir, 'backup_ckpts')
    if not os.path.exists(backup_ckpts_dir):
        os.makedirs(backup_ckpts_dir)
        os.system(f'chmod a+rwx {backup_ckpts_dir}')

    print("---------------------------------------------------")
    print(f"TRAINING SUMMARY")
    print(f"Data directory: {data_dir}")
    print(f"Model: {model}")
    print(f"Loss functions: {loss_functions}") 
    print(f"Loss weights: {loss_weights}")
    print(f"Initial learning rate: {init_lr}")
    print(f"Max epochs: {max_epoch}")
    print(f"Training regions: {training_regions}")
    print(f"Out directory: {out_dir}")
    print(f"Decay rate: {decay_rate}")
    print(f"Backup interval: {backup_interval}")
    print(f"Batch size: {batch_size}")
    print(f"Scheduler patience: {scheduler_patience}")
    print("---------------------------------------------------")

    # Enhanced optimizer: AdamW with proper weight decay
    optimizer = optim.AdamW(model.parameters(), lr=init_lr, weight_decay=1e-5, amsgrad=True)
    
    # Learning rate scheduler: ReduceLROnPlateau for adaptive learning rate adjustment
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                                                     patience=scheduler_patience, verbose=True)

    # Check if training for first time or continuing from a saved checkpoint.
    epoch_start = load_or_initialize_training(model, optimizer, latest_ckpt_path)

    train_loader = make_dataloader(data_dir, shuffle=True, mode='train', augment=True, 
                                   batch_size=batch_size, max_subjects=60)

    print('Training starts.')
    for epoch in range(epoch_start, max_epoch+1):
        print(f'Starting epoch {epoch}...')

        # Note: We now use ReduceLROnPlateau instead of exponential decay
        # The scheduler will adjust LR based on loss plateau
        
        average_epoch_loss = train_one_epoch(model, optimizer, train_loader, loss_functions, loss_weights, training_regions)

        # Step the scheduler with the current loss
        scheduler.step(average_epoch_loss)

        # Save and report loss from the epoch.
        save_tloss_csv(training_loss_path, epoch, average_epoch_loss)
        print(f'Epoch {epoch} completed. Average loss = {average_epoch_loss:.4f}.')
        print(f'Current learning rate: {optimizer.param_groups[0]["lr"]:.6f}')

        print('Saving model checkpoint...')
        checkpoint = {
            'epoch': epoch,
            'model_sd': model.state_dict(),
            'optim_sd': optimizer.state_dict(),
            'scheduler_sd': scheduler.state_dict(),
            'model': model,
            'loss_functions': loss_functions,
            'loss_weights': loss_weights,
            'init_lr': init_lr,
            'training_regions': training_regions,
            'decay_rate': decay_rate,
            'scheduler_patience': scheduler_patience
        }
        torch.save(checkpoint, latest_ckpt_path)
        if epoch % backup_interval == 0:
            torch.save(checkpoint, os.path.join(backup_ckpts_dir, f'epoch{epoch}.pth.tar'))
        print('Checkpoint saved successfully.')

    
def save_tloss_csv(pathname, epoch, tloss):
    with open(pathname, mode='a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        if epoch == 1:
            writer.writerow(['Epoch', 'Training Loss'])
        writer.writerow([epoch, tloss])

if __name__ == '__main__':

    #from ..models import unet3d, enhanced3dunet
    import torch.nn as nn

    data_dir = 'D:/brats2023_updated/ASNR-MICCAI-BraTS2023-SSA-Challenge-TrainingData_V2'
    model = U_Net3d()
    #model = Optimized3DUNet()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Enhanced loss functions: DiceCE (combines Dice and CrossEntropy) + FocalLoss
    loss_functions = [
        DiceCELoss(include_background=False, softmax=True, lambda_dice=0.5, lambda_ce=0.5),
        FocalLoss(gamma=2.0, weight=torch.tensor([1.0, 2.0, 3.0], device=device))  # Higher weight for tumor classes
    ]
    loss_weights = [0.7, 0.3]  # Favor DiceCE over Focal

    lr = 3e-4
    max_epoch = 20
    #out_dir = '/home/mailab/Documents/brats2023_updated/Result'
    out_dir = 'D:/brats2023_updated/Result'

    train(data_dir, model, loss_functions, loss_weights, lr, max_epoch, out_dir=out_dir)

---------------------------------------------------
TRAINING SUMMARY
Data directory: /kaggle/input/brats-africa-dataset/BraTS-Africa Dataset/BraTS-Africa/95_Glioma
Model: UNet_3d (with 3,171,417 parameters)
Loss functions: [DiceCELoss(
  (dice): DiceLoss()
  (cross_entropy): CrossEntropyLoss()
  (binary_cross_entropy): BCEWithLogitsLoss()
), FocalLoss()]
Loss weights: [0.7, 0.3]
Initial learning rate: 0.0003
Max epochs: 20
Training regions: overlapping
Out directory: /kaggle/working/Result
Decay rate: 0.995
Backup interval: 10
Batch size: 2
Scheduler patience: 5
---------------------------------------------------
Training checkpoint found. Loading checkpoint...
Checkpoint loaded. Will continue training from epoch 4.
Training starts.
Starting epoch 4...
Epoch 4 completed. Average loss = 1.9303.
Current learning rate: 0.000100
Saving model checkpoint...
Checkpoint saved successfully.
Starting epoch 5...
Epoch 5 completed. Average loss = 1.8709.
Current learning rate: 0.000100
Saving mode

KeyboardInterrupt: 

infer.py

In [68]:
import os
import torch
#from ..utils.model_utils import make_dataloader
#from ..utils.general_utils import probs_to_preds, save_pred_as_nifti


def infer(data_dir, ckpt_path, out_dir=None, batch_size=1, postprocess_function=None):
    """Uses trained model to make predictions on test data.
    Args:
        data_dir: Directory of test data.
        ckpt_path: Path of trained model.
        out_dir: Directory in which to save predictions. Defaults to None.
        batch_size: Batch size of dataloader. Defaults to 1.
        postprocess_function: The postprocessing function to use. Defaults to None.
    """
    # Set up directories and paths.
    if out_dir is None:
        out_dir = os.getcwd()
    preds_dir = os.path.join(out_dir, 'preds')
    if not os.path.exists(preds_dir):
        os.makedirs(preds_dir)
        os.system(f'chmod a+rwx {preds_dir}')
    
    print(f"Loading model from {ckpt_path}...")
    checkpoint = torch.load(ckpt_path, weights_only=False)
    model = checkpoint['model']
    loss_functions = checkpoint['loss_functions']
    loss_weights = checkpoint['loss_weights']
    training_regions = checkpoint['training_regions']
    epoch = checkpoint['epoch']
    model_sd = checkpoint['model_sd']
    
    # Handle backward compatibility with older checkpoints
    scheduler_patience = checkpoint.get('scheduler_patience', 'Not available (older checkpoint)')
    
    model.load_state_dict(model_sd)
    print(f"Model loaded.")
    
    print("---------------------------------------------------")
    print(f"TRAINING SUMMARY")
    print(f"Model: {model}")
    print(f"Loss functions: {loss_functions}") 
    print(f"Loss weights: {loss_weights}")
    print(f"Training regions: {training_regions}")
    print(f"Epochs trained: {epoch}")
    print(f"Scheduler patience: {scheduler_patience}")
    print("---------------------------------------------------")
    
    print("INFERENCE SUMMARY")
    print(f"Data directory: {data_dir}")
    print(f"Trained model checkpoint path: {ckpt_path}")
    print(f"Out directory: {out_dir}")
    print(f"Batch size: {batch_size}")
    print(f"Postprocess function: {postprocess_function}")
    print("---------------------------------------------------")
    
    test_loader = make_dataloader(data_dir, shuffle=False, mode='test', augment=True, batch_size=batch_size)
    
    print('Inference starts.')
    with torch.no_grad():
        for subject_names, imgs in test_loader:
            model.eval()
            # Move data to GPU.
            imgs = [img.cuda() for img in imgs] # img is B1HWD
            x_in = torch.cat(imgs, dim=1) # x_in is B4HWD
            
            output = model(x_in)
            if isinstance(output, dict):
                output = output['final']
            output = output.float()
            
            preds = probs_to_preds(output, training_regions)
            # preds is B3HWD - each channel is one-hot encoding of a disjoint region
            
            # Iterate over batch and save each prediction.
            for i, subject_name in enumerate(subject_names):
                save_pred_as_nifti(preds[i], preds_dir, data_dir, subject_name, postprocess_function)
    
    print(f'Inference completed. Predictions saved in {preds_dir}.')

if __name__ == '__main__':
    #from ..processing.postprocess import rm_dust_fh
    
    data_dir = '/kaggle/input/validation/BraTS2024-SSA-Challenge-TestData'
    #ckpt_path = '/home/mailab/Documents/brats2023_updated/Result/latest_ckpt.pth.tar'
    ckpt_path = '/kaggle/working/Result/backup_ckpts/epoch10.pth.tar'
    out_dir = '/kaggle/working/prediction'
    postprocess_function = rm_dust_fh
    
    infer(data_dir, ckpt_path, out_dir=out_dir, postprocess_function=postprocess_function)

Loading model from /kaggle/working/Result/backup_ckpts/epoch10.pth.tar...
Model loaded.
---------------------------------------------------
TRAINING SUMMARY
Model: UNet_3d (with 3,171,417 parameters)
Loss functions: [DiceCELoss(
  (dice): DiceLoss()
  (cross_entropy): CrossEntropyLoss()
  (binary_cross_entropy): BCEWithLogitsLoss()
), FocalLoss()]
Loss weights: [0.7, 0.3]
Training regions: overlapping
Epochs trained: 10
Scheduler patience: 5
---------------------------------------------------
INFERENCE SUMMARY
Data directory: /kaggle/input/validation/BraTS2024-SSA-Challenge-TestData
Trained model checkpoint path: /kaggle/working/Result/backup_ckpts/epoch10.pth.tar
Out directory: /kaggle/working/prediction
Batch size: 1
Postprocess function: <function rm_dust_fh at 0x7fd9540fa5c0>
---------------------------------------------------
Inference starts.
Inference completed. Predictions saved in /kaggle/working/prediction/preds.


tester.ipynb

In [69]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# MONAI imports
from monai.metrics import (
    DiceMetric, 
    HausdorffDistanceMetric, 
    SurfaceDistanceMetric,
    ConfusionMatrixMetric
)
from monai.transforms import (
    Compose, 
    LoadImage, 
    EnsureChannelFirst, 
    ToTensor,
    AsDiscrete
)
import torch

class BraTSEvaluator:
    def __init__(self, gt_folder, pred_folder):
        """
        Initialize BraTS evaluator
        
        Args:
            gt_folder (str): Path to ground truth folder
            pred_folder (str): Path to predictions folder
        """
        self.gt_folder = Path(gt_folder)
        self.pred_folder = Path(pred_folder)
        
        # Initialize MONAI metrics
        self.dice_metric = DiceMetric(include_background=False, reduction="mean_batch")
        self.hausdorff_metric = HausdorffDistanceMetric(include_background=False, percentile=95)
        self.surface_distance_metric = SurfaceDistanceMetric(include_background=False)
        self.confusion_matrix_metric = ConfusionMatrixMetric(include_background=False)
        
        # Transform for loading images
        self.transform = Compose([
            LoadImage(image_only=True),
            EnsureChannelFirst(),
            ToTensor()
        ])
        
        # BraTS label mappings
        self.label_mapping = {
            'ET': 3,  # Enhancing Tumor
            'TC': [1, 3],  # Tumor Core (ED + ET)
            'WT': [1, 2, 3]  # Whole Tumor (ED + NET + ET)
        }
        
        self.results = []
    
    def get_binary_mask(self, mask, labels):
        mask = mask.int()
        if isinstance(labels, int):
            return (mask == labels).float()
        else:
            binary_mask = torch.zeros_like(mask, dtype=torch.bool)
            for label in labels:
                binary_mask = binary_mask | (mask == label)
            return binary_mask.float()

    def compute_lesion_wise_metrics(self, gt_mask, pred_mask, region):
        """Compute lesion-wise metrics"""
        # Get connected components for ground truth
        gt_np = gt_mask.cpu().numpy().astype(np.uint8)
        pred_np = pred_mask.cpu().numpy().astype(np.uint8)
        
        # Simple connected components (you might want to use scipy.ndimage.label)
        from scipy.ndimage import label as scipy_label
        gt_labeled, num_gt_lesions = scipy_label(gt_np)
        
        if num_gt_lesions == 0:
            return {
                f'LesionWise_Dice_{region}': 0.0,
                f'LesionWise_NSD_0.5_{region}': 0.0,
                f'LesionWise_NSD_1.0_{region}': 0.0,
                f'LesionWise_Hausdorff95_{region}': np.inf
            }
        
        lesion_dices = []
        lesion_nsd_05 = []
        lesion_nsd_10 = []
        lesion_hausdorff = []
        
        for lesion_id in range(1, num_gt_lesions + 1):
            lesion_gt = (gt_labeled == lesion_id).astype(np.float32)
            lesion_pred = pred_np * lesion_gt  # Prediction in lesion region
            
            # Convert back to torch tensors
            lesion_gt_tensor = torch.from_numpy(lesion_gt).unsqueeze(0).unsqueeze(0)
            lesion_pred_tensor = torch.from_numpy(lesion_pred).unsqueeze(0).unsqueeze(0)
            
            # Compute metrics for this lesion
            if lesion_gt_tensor.sum() > 0:
                # Dice
                dice_val = self.dice_metric(lesion_pred_tensor, lesion_gt_tensor)
                lesion_dices.append(dice_val.item())
                
                # Surface distance metrics
                try:
                    surface_distances = self.surface_distance_metric(lesion_pred_tensor, lesion_gt_tensor)
                    nsd_05 = (surface_distances <= 0.5).float().mean()
                    nsd_10 = (surface_distances <= 1.0).float().mean()
                    lesion_nsd_05.append(nsd_05.item())
                    lesion_nsd_10.append(nsd_10.item())
                except:
                    lesion_nsd_05.append(0.0)
                    lesion_nsd_10.append(0.0)
                
                # Hausdorff distance
                try:
                    hausdorff_val = self.hausdorff_metric(lesion_pred_tensor, lesion_gt_tensor)
                    lesion_hausdorff.append(hausdorff_val.item())
                except:
                    lesion_hausdorff.append(np.inf)
        
        return {
            f'LesionWise_Dice_{region}': np.mean(lesion_dices) if lesion_dices else 0.0,
            f'LesionWise_NSD_0.5_{region}': np.mean(lesion_nsd_05) if lesion_nsd_05 else 0.0,
            f'LesionWise_NSD_1.0_{region}': np.mean(lesion_nsd_10) if lesion_nsd_10 else 0.0,
            f'LesionWise_Hausdorff95_{region}': np.mean(lesion_hausdorff) if lesion_hausdorff else np.inf
        }
    
    def compute_sensitivity_specificity(self, gt_mask, pred_mask):
        """Compute sensitivity and specificity"""
        # Flatten masks
        gt_flat = gt_mask.flatten()
        pred_flat = pred_mask.flatten()
        
        # Convert to binary (any tumor vs background)
        gt_binary = (gt_flat > 0).float()
        pred_binary = (pred_flat > 0).float()
        
        # Compute confusion matrix components
        tp = ((gt_binary == 1) & (pred_binary == 1)).sum().item()
        tn = ((gt_binary == 0) & (pred_binary == 0)).sum().item()
        fp = ((gt_binary == 0) & (pred_binary == 1)).sum().item()
        fn = ((gt_binary == 1) & (pred_binary == 0)).sum().item()
        
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        
        return sensitivity, specificity
    
    def evaluate_case(self, gt_file, pred_file):
        """Evaluate a single case"""
        print(f"Evaluating: {gt_file.name}")
        
        # Load images
        gt_img = self.transform(gt_file)
        pred_img = self.transform(pred_file)
        
        # Ensure same shape and add batch dimension
        if gt_img.shape != pred_img.shape:
            print(f"Warning: Shape mismatch for {gt_file.name}")
            return None
        
        gt_img = gt_img.unsqueeze(0)  # Add batch dimension
        pred_img = pred_img.unsqueeze(0)
        
        case_results = {'case_id': gt_file.stem}
        
        # Evaluate each region (ET, TC, WT)
        for region, labels in self.label_mapping.items():
            # Get binary masks for current region
            gt_region = self.get_binary_mask(gt_img, labels)
            pred_region = self.get_binary_mask(pred_img, labels)
            
            # Standard metrics
            try:
                # Dice
                dice_val = self.dice_metric(pred_region, gt_region)
                case_results[f'Dice_{region}'] = dice_val.item()
                
                # Hausdorff Distance
                hausdorff_val = self.hausdorff_metric(pred_region, gt_region)
                case_results[f'Hausdorff95_{region}'] = hausdorff_val.item()
                
                # Surface Distance (NSD)
                surface_distances = self.surface_distance_metric(pred_region, gt_region)
                nsd_05 = (surface_distances <= 0.5).float().mean()
                nsd_10 = (surface_distances <= 1.0).float().mean()
                case_results[f'NSD_0.5_{region}'] = nsd_05.item()
                case_results[f'NSD_1.0_{region}'] = nsd_10.item()
                
            except Exception as e:
                print(f"Error computing standard metrics for {region}: {e}")
                case_results[f'Dice_{region}'] = 0.0
                case_results[f'Hausdorff95_{region}'] = np.inf
                case_results[f'NSD_0.5_{region}'] = 0.0
                case_results[f'NSD_1.0_{region}'] = 0.0
            
            # Lesion-wise metrics
            try:
                lesion_metrics = self.compute_lesion_wise_metrics(gt_region.squeeze(), pred_region.squeeze(), region)
                case_results.update(lesion_metrics)
            except Exception as e:
                print(f"Error computing lesion-wise metrics for {region}: {e}")
                case_results[f'LesionWise_Dice_{region}'] = 0.0
                case_results[f'LesionWise_NSD_0.5_{region}'] = 0.0
                case_results[f'LesionWise_NSD_1.0_{region}'] = 0.0
                case_results[f'LesionWise_Hausdorff95_{region}'] = np.inf
        
        # Sensitivity and Specificity (overall)
        try:
            sensitivity, specificity = self.compute_sensitivity_specificity(gt_img.squeeze(), pred_img.squeeze())
            case_results['sensitivity'] = sensitivity
            case_results['specificity'] = specificity
        except Exception as e:
            print(f"Error computing sensitivity/specificity: {e}")
            case_results['sensitivity'] = 0.0
            case_results['specificity'] = 0.0
        
        return case_results
    
    def run_evaluation(self):
        """Run evaluation on all cases"""
        # Get all ground truth files
        gt_files = list(self.gt_folder.glob('*.nii.gz'))
        if not gt_files:
            gt_files = list(self.gt_folder.glob('*.nii'))
        
        print(f"Found {len(gt_files)} ground truth files")
        
        for gt_file in gt_files:
            # Find corresponding prediction file
            pred_file = self.pred_folder / gt_file.name
            if not pred_file.exists():
                # Try different naming conventions
                possible_names = [
                    gt_file.name,
                    gt_file.name.replace('_seg', ''),
                    gt_file.name.replace('_gt', ''),
                    gt_file.stem + '_pred.nii.gz'
                ]
                
                for name in possible_names:
                    pred_file = self.pred_folder / name
                    if pred_file.exists():
                        break
                
                if not pred_file.exists():
                    print(f"Prediction file not found for {gt_file.name}")
                    continue
            
            # Evaluate case
            case_result = self.evaluate_case(gt_file, pred_file)
            if case_result:
                self.results.append(case_result)
        
        return self.results
    
    def save_results(self, output_file='brats_evaluation_results.csv'):
        """Save evaluation results to CSV"""
        if not self.results:
            print("No results to save")
            return
        
        df = pd.DataFrame(self.results)
        
        # Calculate summary statistics
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        summary_stats = df[numeric_cols].agg(['mean', 'std', 'median', 'min', 'max'])
        
        # Save detailed results
        df.to_csv(output_file, index=False)
        print(f"Detailed results saved to {output_file}")
        
        # Save summary statistics
        summary_file = output_file.replace('.csv', '_summary.csv')
        summary_stats.to_csv(summary_file)
        print(f"Summary statistics saved to {summary_file}")
        
        # Print summary
        print("\n=== EVALUATION SUMMARY ===")
        print(f"Total cases evaluated: {len(df)}")
        print("\nMean scores:")
        
        # Standard metrics
        for col in ['Dice_ET', 'Dice_TC', 'Dice_WT', 
                   'Hausdorff95_ET', 'Hausdorff95_TC', 'Hausdorff95_WT',
                   'sensitivity', 'specificity']:
            if col in df.columns:
                print(f"{col}: {df[col].mean():.4f} ± {df[col].std():.4f}")
        
        # Lesion-wise metrics
        print("\nLesion-wise metrics:")
        for col in ['LesionWise_Dice_ET', 'LesionWise_Dice_TC', 'LesionWise_Dice_WT',
                   'LesionWise_Hausdorff95_ET', 'LesionWise_Hausdorff95_TC', 'LesionWise_Hausdorff95_WT']:
            if col in df.columns:
                print(f"{col}: {df[col].mean():.4f} ± {df[col].std():.4f}")
        
        return df, summary_stats

def main():
    """Main function to run BraTS evaluation"""
    # Set your folder paths here
    gt_folder = "/kaggle/input/original-validation/original validation"
    pred_folder = "prediction/preds"
    
    # Initialize evaluator
    evaluator = BraTSEvaluator(gt_folder, pred_folder)
    
    # Run evaluation
    print("Starting BraTS 23 Africa Challenge Evaluation...")
    results = evaluator.run_evaluation()
    
    if results:
        # Save results
        df, summary = evaluator.save_results('brats23_africa_results.csv')
        print("Evaluation completed successfully!")
    else:
        print("No valid results found. Please check your file paths and formats.")

if __name__ == "__main__":
    main()

Starting BraTS 23 Africa Challenge Evaluation...
Found 35 ground truth files
Evaluating: BraTS-SSA-00158-000-seg.nii
Evaluating: BraTS-SSA-00134-000-seg.nii
Evaluating: BraTS-SSA-00169-000-seg.nii
Evaluating: BraTS-SSA-00132-000-seg.nii
Evaluating: BraTS-SSA-00138-000-seg.nii
Evaluating: BraTS-SSA-00154-000-seg.nii
Evaluating: BraTS-SSA-00218-000-seg.nii
Evaluating: BraTS-SSA-00180-000-seg.nii
Evaluating: BraTS-SSA-00155-000-seg.nii
Evaluating: BraTS-SSA-00139-000-seg.nii
Evaluating: BraTS-SSA-00136-000-seg.nii
Evaluating: BraTS-SSA-00163-000-seg.nii
Evaluating: BraTS-SSA-00130-000-seg.nii
Evaluating: BraTS-SSA-00198-000-seg.nii
Evaluating: BraTS-SSA-00226-000-seg.nii
Evaluating: BraTS-SSA-00140-000-seg.nii
Evaluating: BraTS-SSA-00143-000-seg.nii
Evaluating: BraTS-SSA-00157-000-seg.nii
Evaluating: BraTS-SSA-00137-000-seg.nii
Evaluating: BraTS-SSA-00129-000-seg.nii
Evaluating: BraTS-SSA-00225-000-seg.nii
Evaluating: BraTS-SSA-00126-000-seg.nii
Evaluating: BraTS-SSA-00171-000-seg.nii
Eva

In [66]:
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

def overlay_mask_on_image(image, mask, color, alpha=0.4):
    """Return an RGB image with the mask overlaid in color."""
    img_rgb = np.stack([image]*3, axis=-1)
    colored_mask = np.zeros_like(img_rgb)
    for i in range(3):
        colored_mask[..., i] = color[i]
    # Normalize image to 0-1 for display
    img_norm = (image - image.min()) / (image.max() - image.min())
    img_rgb = np.stack([img_norm]*3, axis=-1)
    overlay = img_rgb.copy()
    overlay[mask > 0] = (1 - alpha) * img_rgb[mask > 0] + alpha * colored_mask[mask > 0]
    return overlay

def visualize_gt_pred(data_dir, subject_name, gt_dir, pred_dir, slice_index=None):
    """
    Visualize ground truth and prediction overlays for a subject.
    
    Args:
        data_dir: path to full brain MRI files (assumed nifti)
        subject_name: subject identifier (filename prefix)
        gt_dir: path to ground truth masks (nifti)
        pred_dir: path to predicted masks (nifti)
        slice_index: axial slice index to visualize (if None, use middle slice)
    """
    # Load full brain MRI (assumed t1ce modality for example)
    img_path = os.path.join(data_dir, f"{subject_name}-t1c.nii") 
    img_nii = nib.load(img_path)
    img = img_nii.get_fdata()
    
    # Load ground truth segmentation
    gt_path = os.path.join(gt_dir, f"{subject_name}-seg.nii") 
    gt_nii = nib.load(gt_path)
    gt = gt_nii.get_fdata()
    
    # Load prediction segmentation
    pred_path = os.path.join(pred_dir, f"{subject_name}-seg.nii") 
    pred_nii = nib.load(pred_path)
    pred = pred_nii.get_fdata()
    
    # Choose slice
    if slice_index is None:
        slice_index = img.shape[2] // 2
    
    img_slice = img[:, :, slice_index]
    gt_slice = gt[:, :, slice_index]
    pred_slice = pred[:, :, slice_index]
    
    # Define subregions masks for GT and pred
    def get_subregion_masks(seg):
        ET_mask = (seg == 3)
        TC_mask = np.isin(seg, [1,3])
        WT_mask = np.isin(seg, [1,2,3])
        return ET_mask, TC_mask, WT_mask
    
    gt_ET, gt_TC, gt_WT = get_subregion_masks(gt_slice)
    pred_ET, pred_TC, pred_WT = get_subregion_masks(pred_slice)
    
    # Colors for overlays: ET-red, TC-yellow, WT-green
    colors = {
        'ET': [1, 0, 0],     # red
        'TC': [1, 1, 0],     # yellow
        'WT': [0, 1, 0],     # green
    }
    
    # Create overlays
    gt_overlay = overlay_mask_on_image(img_slice, gt_WT, colors['WT'], alpha=0.2)
    gt_overlay = overlay_mask_on_image(gt_overlay, gt_TC, colors['TC'], alpha=0.3)
    gt_overlay = overlay_mask_on_image(gt_overlay, gt_ET, colors['ET'], alpha=0.4)
    
    pred_overlay = overlay_mask_on_image(img_slice, pred_WT, colors['WT'], alpha=0.2)
    pred_overlay = overlay_mask_on_image(pred_overlay, pred_TC, colors['TC'], alpha=0.3)
    pred_overlay = overlay_mask_on_image(pred_overlay, pred_ET, colors['ET'], alpha=0.4)
    
    # Plot side by side
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(gt_overlay)
    axs[0].set_title(f"{subject_name} - Ground Truth")
    axs[1].imshow(pred_overlay)
    axs[1].set_title(f"{subject_name} - Prediction")
    
    for ax in axs:
        ax.axis('off')
    
    # Legend patches
    patches = [mpatches.Patch(color=colors[k], label=k) for k in ['ET', 'TC', 'WT']]
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.show()


# Example usage:
data_dir = "/kaggle/input/brats-africa-dataset/BraTS-Africa Dataset/BraTS-Africa/95_Glioma"
gt_dir = "/kaggle/input/original-validation"
pred_dir = "/kaggle/working/prediction/preds"
subject_name = "BraTS-SSA-00125-000'"  # Replace with actual subject ID

visualize_gt_pred(data_dir, subject_name, gt_dir, pred_dir)


FileNotFoundError: No such file or no access: '/kaggle/input/brats-africa-dataset/BraTS-Africa Dataset/BraTS-Africa/95_Glioma/BraTS-SSA-00125-000'-t1c.nii'