# R-NET TBI Inference
## Intro
Given a list of nii.gz files inside a folder, makes a brain mask and a ROI mask prediction with a given loaded model

## Libraries

In [29]:
# utils
from utils.utils import save_excel_table
from utils.nifti import estimate_volume

# visualization
from utils.vedo import plot_slicer_cloud

# neural imaging
import nibabel as nib

# tensorflow
import tensorflow as tf
from evaluation.metrics import *

# other
import os
import numpy as np

# make numpy printouts easier to read.
np.set_printoptions(precision=3, suppress=True)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

# print the version of tensorflow
print("Tensorflow version: ", tf.__version__)
print("Cuda version: ", tf.sysconfig.get_build_info()['cuda_version'])

Num GPUs Available:  1
Tensorflow version:  2.10.0
Cuda version:  64_112


## Configuration

**Folder structure** <br>
The script expects a standard folder structure similar to the following:
```python
C:/.../FOLDER
  FLASH/
    ID_MICE_A/
      Anat/
        ID_MICE_A_N4.nii.gz
        ...
    ID_MICE_B/
      Anat/
        ID_MICE_B_N4.nii.gz
        ...
  RARE/
    ID_MICE_F/
      Anat/
        ID_MICE_F_N4.nii.gz
        ...
  Other.../
```
In this structure, there is a folder for each group of mice, such as `FLASH` or `RARE`. 

Inside each group folder, it is **strictly required** that each mouse file name starts with the same folder name as that of the mouse. 

For example, the `ID_MICE_A` folder should contain only files starting with `ID_MICE_A`. The subfolder (e.g., `Anat`) and the postfix file name extension (e.g., `N4`) can be modified later if needed.



**Parameters**<br>
Here it's possible to change some paramaters to make predictions. <br>


Prediction mode:
- rodent_model (string): choose between 'mice' or 'rats' model
- domain_adaptation (boolean): enables the domain adaptation to increase the number of predicted brain regions
- default_hemisphere (string): choose between 'left' or 'right' for hemisphere division and ventricles separation if no lesion is found 

In [30]:
rodent_model = 'mice' # choose between 'mice' or 'rats' model
domain_adaptation = False # True if you want to use the domain adaptation model, False if you want to use the original model (works only for mice)
default_hemisfere = 'right' # right or left hemisphere division if no lesion is found

Paths:
- dataset_folder (List[Dataset]): list of datasets containing the MRIs of each subject

```python
# Eample
class Dataset:
    def __init__(self, dataset_path, subfolder='Anat', dti_mode=False):
        self.dataset_path = dataset_path # The path to the dataset
        self.subfolder = subfolder # The default subfolder for each subject
        self.dti_mode = dti_mode # If the user wants to process only DTI 

datasets: list[Dataset] = []

# Append a dataset
datasets.append(Dataset(r'C:\\...\\data\\mice\\flash', subfolder='Anat', dti_mode=False))
```

In [None]:
class Dataset:
    def __init__(self, dataset_path, subfolder='Anat', dti_mode=False):
        self.dataset_path = dataset_path
        self.subfolder = subfolder
        self.dti_mode = dti_mode

datasets: list[Dataset] = []

# Add one or more datasets to the list with the append method.
datasets.append(Dataset(r'C:\\Users\\...\\Desktop\\...', subfolder='Anat', dti_mode=False))

Save results:
- save (boolean): if true, saves the predictions inside each subject's folder

In [32]:
save = True

3D preview:
- show_3d_preview (boolean): shows the predicted mask in an external window

In [33]:
show_3d_preview = True

Morphological smoothing:
- remove_small_objects (boolean): if true, removes small unconnected regions based on object_min_area
- object_min_area (int): the smallest allowable contiguos region size, in voxels
- fill_small_holes (boolean): if true, fills small holes
- holes_max_area (int): the maximum area, in voxels, of a contiguous hole that will be filled

In [34]:
remove_small_objects = True
object_min_area = 30000
fill_small_holes = True
holes_max_area = 20000

Prediction parameters:
- patch_size (tuple): size of the sliding window used to extract patches from the image
- patch_resolution (tuple): desired target resolution for all patch (should be equal to the training resolution of the model)
- stride (int): translation offset of the sliding window (less is better but requires more computational time)

Suggested stride values: 6,8,12,16

In [35]:
patch_size = (80,80,80) if rodent_model == 'mice' else (96,96,96)
patch_resolution = (0.1,0.1,0.1) #mm
stride = 20
threshold = 0.5

Input and output filenames:

In [36]:
# Input Image
input_postfix = '_anat_orig'

# Output Masks
brain_prediction_postfix = '_brain_mask_r3dnet'
roi_prediction_postfix = f'_regions_r3dnet' if not domain_adaptation else f'_regions_r3dnet_da'
output_extension = '.nii.gz'

# Excel Table Name
excel_name = f'predicted_r3dnet_volumes.xlsx' if not domain_adaptation else f'predicted_r3dnet_da_volumes.xlsx'

**From here the code should remain unchanged**

In [37]:
# Output labels ( Network -> Output mask labels)
rnet_labels_mapping = {
     0: 0,
     1: 1,
     2: 3,
     3: 21
 }

rnet_postprocessed_mapping = {
    0: 0,
    1: 1,
    2: 3,
    3: 13,
    4: 21
}

rnet_name_mapping = {
    0: {'name': 'Background', 'value': 0},
    1: {'name': 'Lesion', 'value': 1},
    2: {'name': 'Contra Ventricle', 'value': 3},
    3: {'name': 'Ipsi Ventricle', 'value': 13},
    4: {'name': 'Third Ventricle', 'value': 21}
}

# For domain adaptation
da_labels_mapping = {
    0: 0, #0
    1: 1, #1
    2: 2, #2
    3: 3, #3
    4: 5, #5
    5: 6, #6
    6: 12, #12
    7: 13, #13
    8: 15, #15
    9: 16, #16
    10: 21, #21
}

da_name_mapping = {
    0: {'name': 'Background', 'value': 0},
    1: {'name': 'Lesion', 'value': 1},
    2: {'name': 'Contra CC', 'value': 2},
    3: {'name': 'Ipsi CC', 'value': 12},
    4: {'name': 'Contra Ventricle', 'value': 3},
    5: {'name': 'Ipsi Ventricle', 'value': 13},
    6: {'name': 'Contra Hippo', 'value': 5},
    7: {'name': 'Ipsi Hippo', 'value': 15},
    8: {'name': 'Contra Cortex', 'value': 6},
    9: {'name': 'Ipsi Cortex', 'value': 16},
    10: {'name': 'Third Ventricle', 'value': 21},
}

da_postprocessed_mapping = da_labels_mapping


# Selector
if domain_adaptation:
    postprocessed_mapping = da_postprocessed_mapping
    name_mapping = da_name_mapping
    labels_mapping = da_labels_mapping
else:
    postprocessed_mapping = rnet_postprocessed_mapping
    name_mapping = rnet_name_mapping
    labels_mapping = rnet_labels_mapping

num_classes = len(labels_mapping)
print(f"Number of classes: {num_classes}")

Number of classes: 4


## Load the model
Load a previously trained model to start making predictions

In [38]:
import models.networks
from evaluation.metrics import *
from evaluation.losses import *

def get_model_path(domain_adaptation=False, rodent_model='mice'):
    """
    Returns the model based on the rodent model and domain adaptation status.
    """
    if rodent_model not in ['mice', 'rats']:
        raise ValueError("rodent_model must be either 'mice' or 'rats'.")

    if domain_adaptation:
        if rodent_model != 'mice':
            raise ValueError("Domain adaptation is only available for the mice model.")
        model_name = 'rnet_da.h5'
    else:
        model_name = f"rnet_{rodent_model}.h5"

    model_path = f'../models/{model_name}'
    return model_path

model_path = get_model_path(domain_adaptation=domain_adaptation, rodent_model=rodent_model)
model = tf.keras.models.load_model(model_path,
                                   custom_objects={ "loss": diceCELoss(),
                                                    "precision": precision_coefficient(),
                                                    "sensitivity": sensitivity_coefficient(),
                                                    "specificity": specificity_coefficient(),
                                                    "K": tf.keras.backend,
                                                    "training": False,
                                                  }, compile=False)

model.summary()

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, 80, 80, 80,  0           []                               
                                 1)]                                                              
                                                                                                  
 conv3d_172 (Conv3D)            (None, 80, 80, 80,   448         ['input_4[0][0]']                
                                16)                                                               
                                                                                                  
 batch_normalization_102 (Batch  (None, 80, 80, 80,   64         ['conv3d_172[0][0]']             
 Normalization)                 16)                                                         

## Preprocessing configuration
The standard pipeline order is:
1. (opt.) Cut and pad the image to a default matrix dimension
2. (opt.) Correct x10 intensity values
3. Apply N4 Bias Field correction
4. Copy the orientation from the ref. image
5. Resample to a target resolution
6. Normalize the intensity values with z-score (default)

In [39]:
from preprocessing.preprocessor import Preprocessor, Resample, Reorient, Normalize, CorrectX10, N4BiasFieldCorrection, SaveNifti

# ref image for reorientation
ref_img = nib.load(os.path.join('../example', 'RARE', 'TBI_fm_19_49', 'Anat', 'TBI_fm_19_49_N4.nii.gz'))

# Create an instance of the MRIProcessor class
processor = Preprocessor([
    CorrectX10(),
    N4BiasFieldCorrection(),
    SaveNifti(postfix='_N4', replace=input_postfix),
    Reorient(ref_img),
    Resample(target_resolution=patch_resolution, interpolation=0),
    Normalize()
])

## Final Inference for random cropping
Makes the predictions by sliding through the input a patch volume of size (76,76,76) with a stride of 8

In [40]:
from evaluation.inference import RandomCroppingPrediction

# Create an instance of the RandomCroppingPrediction class
predictor = RandomCroppingPrediction(model, patch_size=patch_size, stride=stride, threshold=threshold, num_classes=num_classes)

Make predictions for every subject inside the folder

In [41]:
import os
import nibabel as nib
import numpy as np
from evaluation.postprocessing import ipsi_contra_division_callback, morphology_refinement_callback
from utils.nifti import create_3d_image_from_dti

def is_valid_input_file(filename, postfix):
    return (
        filename.startswith(postfix) or
        filename == postfix or
        filename.endswith(postfix + '.nii') or
        filename.endswith(postfix + '.nii.gz')
    )

def load_and_preprocess_image(file_path, dti_mode, subject_folder, subject):
    nii_img = nib.load(file_path)

    # Handle DTI volumes
    if dti_mode and len(nii_img.shape) > 3:
        dti_out_name = subject + '_dti_out'
        nii_img = create_3d_image_from_dti(nii_img, output_path=subject_folder, name=dti_out_name)

    if len(nii_img.shape) > 3:
        nii_img = nib.Nifti1Image(nii_img.get_fdata()[..., 0], affine=nii_img.affine, header=nii_img.header)

    return nii_img

def postprocess_and_save(x_prep, nii_img, y_mask, y_regions, subject, subject_folder, dataset):
    # Apply post-processing
    y_mask = morphology_refinement_callback(
        fill_small_holes=fill_small_holes,
        holes_max_area=holes_max_area,
        remove_small_objects=remove_small_objects,
        object_min_area=object_min_area
    )(y_mask)

    if not domain_adaptation:
        y_regions = ipsi_contra_division_callback(
            use_centroids=True,
            default_hemisphere=default_hemisfere
        )(y_regions)

    # Wrap in NIfTI
    y_pred_nifti = nib.Nifti1Image(y_regions, affine=x_prep.affine, dtype=np.float64, header=x_prep.header)
    y_pred_mask_nifti = nib.Nifti1Image(y_mask, affine=x_prep.affine, dtype=np.float64, header=x_prep.header)

    if show_3d_preview:
        plot_slicer_cloud(x_prep, y_pred_nifti)

    # Save outputs
    roi_save_path = os.path.join(subject_folder, subject + roi_prediction_postfix + output_extension)
    mask_save_path = os.path.join(subject_folder, subject + brain_prediction_postfix + output_extension)

    if not dataset.dti_mode:
        final_image = processor.deprocess(y_pred_nifti, nii_img, postprocessed_mapping, save_path=roi_save_path, verbose=False)
        estimate_volume(final_image, resolution=patch_resolution, verbose=True)

    processor.deprocess(y_pred_mask_nifti, nii_img, postprocessed_mapping, save_path=mask_save_path, verbose=False)

def process_subject(subject_folder, subject, dataset):
    files = [
        f for f in os.listdir(subject_folder)
        if is_valid_input_file(f, input_postfix)
    ]
    
    if not files:
        print(f"File with postfix {input_postfix} not found for {subject}, skipping")
        return

    for file in files:
        try:
            print('\n|-', file, '-------------------\ \n')
            file_path = os.path.join(subject_folder, file)

            nii_img = load_and_preprocess_image(file_path, dataset.dti_mode, subject_folder, subject)
            x_prep = processor.preprocess(nii_img, path=file_path)

            results = predictor.random_cropping_inference(x_prep, with_brain_mask=True)
            postprocess_and_save(x_prep, nii_img, results['brain_mask'], results['roi'], subject, subject_folder, dataset)

            print('----------------------------------------------------------// \n\n')

        except Exception as e:
            print(f"Error processing case {subject}: {e}")

def run_prediction_pipeline():
    for dataset in datasets:
        base_path = dataset.dataset_path
        if not os.path.isdir(base_path):
            print(f'Folder {base_path} not found, skipping')
            continue

        print(f'\n| Processing dataset: {dataset.dataset_path} --')

        for subject in os.listdir(base_path):
            subject_folder = os.path.join(base_path, subject, dataset.subfolder)
            if not os.path.isdir(subject_folder):
                continue
            process_subject(subject_folder, subject, dataset)

        # Save volume stats
        save_excel_table(
            base_path,
            sub_folder=dataset.subfolder,
            include_only_list=None,
            save_folder=base_path,
            pred_roi_name=roi_prediction_postfix + output_extension,
            pred_brain_name=brain_prediction_postfix + output_extension,
            name_mapping=name_mapping,
            file_name=excel_name,
            postfix_mode=True
        )

In [None]:
# Run the prediction pipeline
run_prediction_pipeline()