Anonymous repository for One-shot Joint Extraction, Registration and Segmentation of Neuroimaging Data. This repository contains the implementation of JERS.
The corresponding files contain the source code and sample data of JERS.
- dataset : Sample dataset for code test
- main.py : Main code for JERS training
- model.py : Supporting models
- preprocess_data.py : Using for preprocess neuroimaging data
- train.py : Supporting training
- utils.py : Supporting functions
Note that all public datasets used in the paper can be found here:
Raw data can be preprocessed via preprocess_data.py.
The following script is for training:
python main.py
Parameters:
-
main.py :
- train_set_name : file name of training set, default "LPBA40_train_sample.npy"
- val_set_name : file name of validation set, default "LPBA40_val_sample.npy"
- test_set_name : file name of test set, default "LPBA40_test_sample.npy"
- dice_label : dataset name of anatomical label, default "LPBA40"
- fixed_set_name : dataset name of the target (fixed) image, default "LPBA40"
- reg_loss_name : training similarity loss function , default "NCC"
- if_compute_dice : if compute the dice score for evaluation, default True
- gamma : threshold of sigmoid function , default 10
- beta : segmentation loss term wight, default 0.1
- lamda_mask : value of mask smoothing regularization parameter, default 1.0
- mask_smooth_loss_func : loss function of mask smoothing, default first_Grad("l2")
- ext_stage : number of stages of extraction, default 5
- reg_stage : number of stages of registration, default 5
- if_train_aug : apply data augmentation during the training, default True
- batch_size : batch size, default 1
- img_size : size of input images, default 96
- num_epochs : number of epochs, default 1000
- learning_rate : learning rate, default 0.000001
- save_every_epoch : saving interval for results, default 1
- save_start_epoch : start point for results saving, default 0
- model_name : model, default JERS(img_size, ext_stage , reg_stage, gamma, beta)
The results can be find after training.
- loss_log :
- model_name.txt : log file of the model
- model :
- model_name.pth : saved model
- sample_img :
- o : target images
- t : source images
- s_stage : extracted images by stage
- s_stage_mask : mask of extracted images by stage
- r_stage : warped (registered) images by stage
- segpred_am : predicted segmentation mask