Skip to content

MED-IPLAB/DeepSeg-MS

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DeepSeg-MS: A 3D network based on hierarchical multi-scale learning for MRI multiple sclerosis lesion segmentation

This is the official implementation of the paper
"DeepSeg-MS: A 3D network based on hierarchical multi-scale learning for MRI multiple sclerosis lesion segmentation"

Abstract

Accurate 3D segmentation of Multiple Sclerosis (MS) lesions from multimodal Magnetic Resonance Imaging (MRI) is crucial for assessing disease progression and guiding clinical decisions. However, existing deep learning-based methods often struggle to effectively capture the complex 3D spatial dependencies of MS lesions while ensuring robust hierarchical multi-scale features learning. To address these challenges, we introduce DeepSeg-MS, a novel deep neural network specifically designed for volumetric MS lesion segmentation. Our approach leverages performance of a novel deep architecture that embeds Squeeze-and-Attention (SA), Deep Supervision (DS), and Atrous Spatial Pyramid Pooling (ASPP) modules to enhance feature extraction and overall segmentation accuracy. The SA blocks allow the model to focus on relevant lesion regions across spatial dimensions, while DS ensures effective hierarchical features refinement, and multi-scale representation learning. Additionally, the integration of ASPP enables effective multi-scale features aggregation by capturing contextual information at different receptive fields through parallel dilated convolutions, addressing the challenges to segment heterogeneous MS lesions in 3D MRI scans. DeepSeg-MS was extensively evaluated on two totally different datasets of this fields, demonstrating a significant improvement over existing state-of-the-art 3D segmentation methods. The effectiveness of our architecture in 3D MS lesion segmentation offer potential benefits for both clinical applications and automated disease monitoring.

Architectures

The architecture used for the segmentation of MS lesions is called DeepSeg-MS, a novel deep learning-based network for fully automated MS lesion segmentation from multimodal MRI. DeepSeg-MS integrates three key components to enhance segmentation precision: Attention mechanisms, Deep Supervision, and Atrous Spatial Pyramid Pooling (ASPP), addressing lesion variability across different patients leveraging 3D multimodality MRI.

The following image represents the DeepSeg-MS architecture:

alt text

Setup

To start, create and activate the necessary Conda environment:

conda env create -f deepsegms.yml
conda activate deepsegms

Dataset

Organise the dataset in a JSON file for MONAI dataloader with the following fields:

  • num_fold: the total number of folds
  • fold#: the fold to be trained, specify the number instead of #
    • train:
      • id: the unique patient identifier, must be the same for all its timepoints
      • data:
        • images: path list of FLAIR, T1-w and T2-w
        • mask: path of the ground-truth mask
    • val:
      • id: the unique patient identifier, must be the same for all its timepoints
      • data:
        • images: path list of FLAIR, T1-w and T2-w
        • mask: path of the ground-truth mask
    • test:
      • id: the unique patient identifier, must be the same for all its timepoints
      • data:
        • images: path list of FLAIR, T1-w and T2-w
        • mask: path of the ground-truth mask

Training

To start training, use the following command. You can modify the training parameters by referring to train.py and adjusting the arguments.

python train.py \
    --root_dir      Path to the dataset \
    --folds_path    Path to the json file with fold splitting \
    --logdir        Path were to save log and models

Optional arguments:
    --configuration         Specify the configuration to be used in training, used to rename the log directory.
    --num_fold              The number of fold to be trained. Defaults to 0 for fold0.
    --number_modality       The number of modalities to be used. Defaults to 3 for FLAIR, T1-w and T2-w
    --number_targets        Output target. Defaults to 1.
    --feature               Six integers as numbers of features. Defaults to [64, 64, 128, 256, 512, 64].
    --model_name            Specify the model name to resume training. Default to ‘’ to start training from-scratch, set to ‘final_model_0.xxxx.pt’ to resume training, where xxxx stands for the Dice Score of the last epoch trained.
    --use_SA                Set to True to use Squeeze-and-Attention blocks in the UNet. Defaults to False.
    --use_SAEncoder         Set to True to use Squeeze-and-Attention blocks in the Encoder. Defaults to False.
    --use_deep_supervision  Set to True to use Deep Supervision. Defaults to False.
    --use_aspp              Set to True to use ASPP. Defaults to False.
    --use_MEEncoder         Set to True to use Multi-Encoder mode (one per modalities). Defaults to False.
    --resize                Specify crop size. Defaults to 96.
    --learning_rate         Specify the learning rate. Defaults to 1e-4.
    --batch_size            Specify the batch size. Defaults to 2.
    --max_epoch             Specify the maximum number of epochs. Defaults to 1200.
    --val_every             Specify every how many epochs to perform the validation. Defaults to 10.

Make sure to explore the argument options in train.py to customize the training process, such as batch size, learning rate, and number of epochs.

Testing

Once the training is completed, you can run the testing phase with the following command. Similar to training, you can modify the testing parameters by checking test.py and adjusting the arguments.

python test.py \
    --root_dir      Path to the dataset \
    --folds_path    Path to the json file with fold splitting \
    --testdir       Path where to save predictions

Optional arguments:
    --configuration         Specify the configuration to be used in training, used to rename the log directory.
    --num_fold              The number of fold to be trained. Defaults to 0 for fold0.
    --number_modality       The number of modalities to be used. Defaults to 3 for FLAIR, T1-w and T2-w.
    --number_targets        Output target. Defaults to 1.
    --feature               Six integers as numbers of features. Defaults to [64, 64, 128, 256, 512, 64].
    --use_SA                Set to True to use Squeeze-and-Attention blocks in the UNet. Defaults to False.
    --use_SAEncoder         Set to True to use Squeeze-and-Attention blocks in the Encoder. Defaults to False.
    --use_deep_supervision  Set to True to use Deep Supervision. Defaults to False.
    --use_aspp              Set to True to use ASPP. Defaults to False.
    --use_MEEncoder         Set to True to use Multi-Encoder mode (one per modalities). Defaults to False.
    --model_name            Specify the name of the model to be tested. Defaults to ‘best_model_0.xxxx.pt’ for the best model in validation, where xxxx stands for the Dice Score of the best validation epoch.
    --batch_size            Specify the batch size. Defaults to 1.
    --resize                Specify crop size. Defaults to 96.
  
    

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%