Skip to content

Latest commit

 

History

History
265 lines (210 loc) · 12.9 KB

README.md

File metadata and controls

265 lines (210 loc) · 12.9 KB

Models Genesis - Official PyTorch Implementation

We provide the official PyTorch implementation of training Models Genesis as well as the usage of the pre-trained Models Genesis in the following paper:

Models Genesis: Generic Autodidactic Models for 3D Medical Image Analysis
Zongwei Zhou1, Vatsal Sodha1, Md Mahfuzur Rahman Siddiquee1,
Ruibin Feng1, Nima Tajbakhsh1, Michael B. Gotway2, and Jianming Liang1
1 Arizona State University, 2 Mayo Clinic
International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI), 2019
Young Scientist Award
paper | code | slides | poster | talk (YouTube, YouKu) | blog

Models Genesis
Zongwei Zhou1, Vatsal Sodha1, Jiaxuan Pang1, Michael B. Gotway2, and Jianming Liang1
1 Arizona State University, 2 Mayo Clinic
Medical Image Analysis (MedIA)
MedIA Best Paper Award
paper | code | slides | graphical abstract

Dependencies

  • Linux
  • Python 2.7+
  • PyTorch 1.3.1

Usage of the pre-trained Models Genesis

1. Clone the repository

$ git clone https://github.com/MrGiovanni/ModelsGenesis.git
$ cd ModelsGenesis/
$ pip install -r requirements.txt

2. Download the pre-trained Models Genesis

Weight Download Description
Genesis_Chest_CT.h5 link pre-trained U-Net weights in keras
Genesis_Chest_CT.pt link pre-trained U-Net weights in pytorch
genesis_nnunet_luna16_006.model link pre-trained nnU-Net weights in pytorch

Download the pre-trained Genesis Chest CT and save into ./pretrained_weights/Genesis_Chest_CT.pt directory.

3. Fine-tune Models Genesis on your own target task

Models Genesis learn a general-purpose image representation that can be leveraged for a wide range of target tasks. Specifically, Models Genesis can be utilized to initialize the encoder for the target classification tasks and to initialize the encoder-decoder for the target segmentation tasks.

As for the target classification tasks, the 3D deep model can be initialized with the pre-trained encoder using the following example:

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import unet3d

# prepare your own data
train_loader = DataLoader(Your Dataset, batch_size=config.batch_size, shuffle=True)

# prepare the 3D model
class TargetNet(nn.Module):
    def __init__(self, base_model,n_class=1):
        super(TargetNet, self).__init__()

        self.base_model = base_model
        self.dense_1 = nn.Linear(512, 1024, bias=True)
        self.dense_2 = nn.Linear(1024, n_class, bias=True)

    def forward(self, x):
        self.base_model(x)
        self.base_out = self.base_model.out512
        # This global average polling is for shape (N,C,H,W) not for (N, H, W, C)
        # where N = batch_size, C = channels, H = height, and W = Width
        self.out_glb_avg_pool = F.avg_pool3d(self.base_out, kernel_size=self.base_out.size()[2:]).view(self.base_out.size()[0],-1)
        self.linear_out = self.dense_1(self.out_glb_avg_pool)
        final_out = self.dense_2( F.relu(self.linear_out))
        return final_out
        
base_model = unet3d.UNet3D()

#Load pre-trained weights
weight_dir = 'pretrained_weights/Genesis_Chest_CT.pt'
checkpoint = torch.load(weight_dir)
state_dict = checkpoint['state_dict']
unParalled_state_dict = {}
for key in state_dict.keys():
    unParalled_state_dict[key.replace("module.", "")] = state_dict[key]
base_model.load_state_dict(unParalled_state_dict)
target_model = TargetNet(base_model)
target_model.to(device)
target_model = nn.DataParallel(target_model, device_ids = [i for i in range(torch.cuda.device_count())])
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(target_model.parameters(), config.lr, momentum=0.9, weight_decay=0.0, nesterov=False)

# train the model

for epoch in range(intial_epoch, config.nb_epoch):
    scheduler.step(epoch)
    target_model.train()
    for batch_ndx, (x,y) in enumerate(train_loader):
        x, y = x.float().to(device), y.float().to(device)
        pred = F.sigmoid(target_model(x))
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

As for the target segmentation tasks, the 3D deep model can be initialized with the pre-trained encoder-decoder using the following example:

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import unet3d

#Declare the Dice Loss
def torch_dice_coef_loss(y_true,y_pred, smooth=1.):
    y_true_f = torch.flatten(y_true)
    y_pred_f = torch.flatten(y_pred)
    intersection = torch.sum(y_true_f * y_pred_f)
    return 1. - ((2. * intersection + smooth) / (torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth))

# prepare your own data
train_loader = DataLoader(Your Dataset, batch_size=config.batch_size, shuffle=True)

# prepare the 3D model

model = unet3d.UNet3D()

#Load pre-trained weights
weight_dir = 'pretrained_weights/Genesis_Chest_CT.pt'
checkpoint = torch.load(weight_dir)
state_dict = checkpoint['state_dict']
unParalled_state_dict = {}
for key in state_dict.keys():
    unParalled_state_dict[key.replace("module.", "")] = state_dict[key]
model.load_state_dict(unParalled_state_dict)

model.to(device)
model = nn.DataParallel(model, device_ids = [i for i in range(torch.cuda.device_count())])
criterion = torch_dice_coef_loss
optimizer = torch.optim.SGD(model.parameters(), config.lr, momentum=0.9, weight_decay=0.0, nesterov=False)

# train the model

for epoch in range(intial_epoch, config.nb_epoch):
    scheduler.step(epoch)
    model.train()
    for batch_ndx, (x,y) in enumerate(train_loader):
        x, y = x.float().to(device), y.float().to(device)
        pred = model(x)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Prepare your own data: If the image modality in your target task is CT, we suggest that all the intensity values be clipped on the min (-1000) and max (+1000) interesting Hounsfield Unit range and then scale between 0 and 1. If the image modality is MRI, we suggest that all the intensity values be clipped on min (0) and max (+4000) interesting range and then scale between 0 and 1. For any other modalities, you may want to first clip on the meaningful intensity range and then scale between 0 and 1.

We adopt input cubes shaped in (N, 1, 64, 64, 32) during model pre-training, where N denotes the number of training samples。 When fine-tuning the pre-trained Models Genesis, any arbitrary input size is acceptable as long as it is divisible by 16 (=2^4) due to four down-sampling layers in V-Net. That said, to segment larger objects, such as liver, kidney, or big nodule :-( you may want to try

input_channels, input_rows, input_cols, input_deps = 1, 128, 128, 64
input_channels, input_rows, input_cols, input_deps = 1, 160, 160, 96

or even larger input size as you wish.

Learn Models Genesis from your own unlabeled data

1. Clone the repository

$ git clone https://github.com/MrGiovanni/ModelsGenesis.git
$ cd ModelsGenesis/
$ pip install -r requirements.txt

2. Create the data generator (LUNA-2016 for example)

For your convenience, we have provided our own extracted 3D cubes from LUNA16.

Download from Google Drive or Baidu Wangpan <- code: m8g4. Each sub-folder is named as 'bat_N_s_64x64x32', where N denotes the number of cubes extracted from each patient. You may select the scale of training samples accordingly based on your resources in hand: larger N demands longer learning time and more powerful GPUs/CPUs, while may (or may not) result in a more generic visual representation. We have adopted N=32 in our MICCAI paper.

  • The processed cubes directory structure
generated_cubes/
    |--  bat_32_s_64x64x32_0.npy: cubes extracted from subset0 in luna16
    |--  bat_32_s_64x64x32_1.npy: cubes extracted from subset1 in luna16
    |--  bat_32_s_64x64x32_2.npy: cubes extracted from subset2 in luna16
    |--  bat_32_s_64x64x32_3.npy: cubes extracted from subset3 in luna16
    |--  bat_32_s_64x64x32_4.npy: cubes extracted from subset4 in luna16
    |--  bat_32_s_64x64x32_5.npy: cubes extracted from subset5 in luna16
    |--  bat_32_s_64x64x32_6.npy: cubes extracted from subset6 in luna16
    |--  bat_32_s_64x64x32_7.npy: cubes extracted from subset7 in luna16
    |--  bat_32_s_64x64x32_8.npy: cubes extracted from subset8 in luna16
    |--  bat_32_s_64x64x32_9.npy: cubes extracted from subset9 in luna16

You can also extract 3D cubes by your own following two steps below:

Step 1: Download LUNA-2016 dataset from the challenge website (https://luna16.grand-challenge.org/download/) and save to ./datasets/luna16 directory.

Step 2: Extract 3D cubes from the patient data by running the script below. The extracted 3D cubes will be saved into ./generated_cubes directory.

for subset in `seq 0 9`
do
python -W ignore infinite_generator_3D.py \
--fold $subset \
--scale 32 \
--data datasets/luna16 \
--save generated_cubes
done

3. Pre-train Models Genesis (LUNA-2016 for example)

python -W ignore pytorch/Genesis_Chest_CT.py

Your pre-trained Models Genesis will be saved at ./pytorch/pretrained_weights/Vnet-genesis_chest_ct.pt.

Citation

If you use this code or use our pre-trained weights for your research, please cite our paper:

@InProceedings{zhou2019models,
  author="Zhou, Zongwei and Sodha, Vatsal and Rahman Siddiquee, Md Mahfuzur and Feng, Ruibin and Tajbakhsh, Nima and Gotway, Michael B. and Liang, Jianming",
  title="Models Genesis: Generic Autodidactic Models for 3D Medical Image Analysis",
  booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2019",
  year="2019",
  publisher="Springer International Publishing",
  address="Cham",
  pages="384--393",
  isbn="978-3-030-32251-9",
  url="https://link.springer.com/chapter/10.1007/978-3-030-32251-9_42"
}

@article{zhou2021models,
  title="Models Genesis",
  author="Zhou, Zongwei and Sodha, Vatsal and Pang, Jiaxuan and Gotway, Michael B and Liang, Jianming",
  journal="Medical Image Analysis",
  volume = "67",
  pages = "101840",
  year = "2021",
  issn = "1361-8415",
  doi = "https://doi.org/10.1016/j.media.2020.101840",
  url = "http://www.sciencedirect.com/science/article/pii/S1361841520302048",
}

@phdthesis{zhou2021towards,
  title={Towards Annotation-Efficient Deep Learning for Computer-Aided Diagnosis},
  author={Zhou, Zongwei},
  year={2021},
  school={Arizona State University}
}

Acknowledgement

We thank Jiaxuan Pang and Vatsal Sodha for their implementation of Models Genesis in PyTorch. We build 3D U-Net architecture by referring to the released code at mattmacy/vnet.pytorch. This is a patent-pending technology.