DeepSeg-MS: A 3D network based on hierarchical multi-scale learning for MRI multiple sclerosis lesion segmentation
This is the official implementation of the paper
"DeepSeg-MS: A 3D network based on hierarchical multi-scale learning for MRI multiple sclerosis lesion segmentation"
Accurate 3D segmentation of Multiple Sclerosis (MS) lesions from multimodal Magnetic Resonance Imaging (MRI) is crucial for assessing disease progression and guiding clinical decisions. However, existing deep learning-based methods often struggle to effectively capture the complex 3D spatial dependencies of MS lesions while ensuring robust hierarchical multi-scale features learning. To address these challenges, we introduce DeepSeg-MS, a novel deep neural network specifically designed for volumetric MS lesion segmentation. Our approach leverages performance of a novel deep architecture that embeds Squeeze-and-Attention (SA), Deep Supervision (DS), and Atrous Spatial Pyramid Pooling (ASPP) modules to enhance feature extraction and overall segmentation accuracy. The SA blocks allow the model to focus on relevant lesion regions across spatial dimensions, while DS ensures effective hierarchical features refinement, and multi-scale representation learning. Additionally, the integration of ASPP enables effective multi-scale features aggregation by capturing contextual information at different receptive fields through parallel dilated convolutions, addressing the challenges to segment heterogeneous MS lesions in 3D MRI scans. DeepSeg-MS was extensively evaluated on two totally different datasets of this fields, demonstrating a significant improvement over existing state-of-the-art 3D segmentation methods. The effectiveness of our architecture in 3D MS lesion segmentation offer potential benefits for both clinical applications and automated disease monitoring.
The architecture used for the segmentation of MS lesions is called DeepSeg-MS, a novel deep learning-based network for fully automated MS lesion segmentation from multimodal MRI. DeepSeg-MS integrates three key components to enhance segmentation precision: Attention mechanisms, Deep Supervision, and Atrous Spatial Pyramid Pooling (ASPP), addressing lesion variability across different patients leveraging 3D multimodality MRI.
The following image represents the DeepSeg-MS architecture:
To start, create and activate the necessary Conda environment:
conda env create -f deepsegms.yml
conda activate deepsegmsOrganise the dataset in a JSON file for MONAI dataloader with the following fields:
num_fold: the total number of foldsfold#: the fold to be trained, specify the number instead of #train:id: the unique patient identifier, must be the same for all its timepointsdata:images: path list of FLAIR, T1-w and T2-wmask: path of the ground-truth mask
val:id: the unique patient identifier, must be the same for all its timepointsdata:images: path list of FLAIR, T1-w and T2-wmask: path of the ground-truth mask
test:id: the unique patient identifier, must be the same for all its timepointsdata:images: path list of FLAIR, T1-w and T2-wmask: path of the ground-truth mask
To start training, use the following command. You can modify the training parameters by referring to train.py and adjusting the arguments.
python train.py \
--root_dir Path to the dataset \
--folds_path Path to the json file with fold splitting \
--logdir Path were to save log and models
Optional arguments:
--configuration Specify the configuration to be used in training, used to rename the log directory.
--num_fold The number of fold to be trained. Defaults to 0 for fold0.
--number_modality The number of modalities to be used. Defaults to 3 for FLAIR, T1-w and T2-w
--number_targets Output target. Defaults to 1.
--feature Six integers as numbers of features. Defaults to [64, 64, 128, 256, 512, 64].
--model_name Specify the model name to resume training. Default to ‘’ to start training from-scratch, set to ‘final_model_0.xxxx.pt’ to resume training, where xxxx stands for the Dice Score of the last epoch trained.
--use_SA Set to True to use Squeeze-and-Attention blocks in the UNet. Defaults to False.
--use_SAEncoder Set to True to use Squeeze-and-Attention blocks in the Encoder. Defaults to False.
--use_deep_supervision Set to True to use Deep Supervision. Defaults to False.
--use_aspp Set to True to use ASPP. Defaults to False.
--use_MEEncoder Set to True to use Multi-Encoder mode (one per modalities). Defaults to False.
--resize Specify crop size. Defaults to 96.
--learning_rate Specify the learning rate. Defaults to 1e-4.
--batch_size Specify the batch size. Defaults to 2.
--max_epoch Specify the maximum number of epochs. Defaults to 1200.
--val_every Specify every how many epochs to perform the validation. Defaults to 10.
Make sure to explore the argument options in train.py to customize the training process, such as batch size, learning rate, and number of epochs.
Once the training is completed, you can run the testing phase with the following command. Similar to training, you can modify the testing parameters by checking test.py and adjusting the arguments.
python test.py \
--root_dir Path to the dataset \
--folds_path Path to the json file with fold splitting \
--testdir Path where to save predictions
Optional arguments:
--configuration Specify the configuration to be used in training, used to rename the log directory.
--num_fold The number of fold to be trained. Defaults to 0 for fold0.
--number_modality The number of modalities to be used. Defaults to 3 for FLAIR, T1-w and T2-w.
--number_targets Output target. Defaults to 1.
--feature Six integers as numbers of features. Defaults to [64, 64, 128, 256, 512, 64].
--use_SA Set to True to use Squeeze-and-Attention blocks in the UNet. Defaults to False.
--use_SAEncoder Set to True to use Squeeze-and-Attention blocks in the Encoder. Defaults to False.
--use_deep_supervision Set to True to use Deep Supervision. Defaults to False.
--use_aspp Set to True to use ASPP. Defaults to False.
--use_MEEncoder Set to True to use Multi-Encoder mode (one per modalities). Defaults to False.
--model_name Specify the name of the model to be tested. Defaults to ‘best_model_0.xxxx.pt’ for the best model in validation, where xxxx stands for the Dice Score of the best validation epoch.
--batch_size Specify the batch size. Defaults to 1.
--resize Specify crop size. Defaults to 96.
