diff --git a/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb b/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb index 9314bd35f1..46a7fdb299 100644 --- a/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb +++ b/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb @@ -15,20 +15,32 @@ "1. Define a new transform according to MONAI transform API.\n", "1. Load Nifti image with metadata, load a list of images and stack them.\n", "1. Randomly adjust intensity for data augmentation.\n", - "1. Cache IO and transforms to accelerate training and validation.\n", + "1. Cache IO and transforms, ThreadDataLoader, and AMP to accelerate training and validation.\n", "1. Swin UNETR model, DiceCE loss function, Mean Dice metric for multi-organ segmentation task.\n", "\n", - "For this tutorial, the dataset needs to be downloaded from: https://www.synapse.org/#!Synapse:syn3193805/wiki/217752. \n", + "For this tutorial, the dataset needs to be downloaded from: https://www.synapse.org/#!Synapse:syn3193805/wiki/217752. More details are provided in the \"Download dataset\" section below.\n", "\n", - "In addition, the json file for data splits needs to be downloaded from this [link](https://drive.google.com/file/d/1t4fIQQkONv7ArTSZe4Nucwkk1KfdUDvW/view?usp=sharing). Once downloaded, place the json file in the same folder as the dataset. \n", + "In addition, the json file for data splits needs to be downloaded from this [link](https://drive.google.com/file/d/1qcGh41p-rI3H_sQ0JwOAhNiQSXriQqGi/view?usp=sharing). Once downloaded, place the json file in the same folder as the dataset. \n", "\n", "For BTCV dataset, under Institutional Review Board (IRB) supervision, 50 abdomen CT scans of were randomly selected from a combination of an ongoing colorectal cancer chemotherapy trial, and a retrospective ventral hernia study. The 50 scans were captured during portal venous contrast phase with variable volume sizes (512 x 512 x 85 - 512 x 512 x 198) and field of views (approx. 280 x 280 x 280 mm3 - 500 x 500 x 650 mm3). The in-plane resolution varies from 0.54 x 0.54 mm2 to 0.98 x 0.98 mm2, while the slice thickness ranges from 2.5 mm to 5.0 mm. \n", "\n", - "Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kideny 4.Gallbladder 5.Esophagus 6. Liver 7. Stomach 8.Aorta 9. IVC 10. Portal and Splenic Veins 11. Pancreas 12 Right adrenal gland 13 Left adrenal gland.\n", - "\n", - "Modality: CT\n", - "Size: 30 3D volumes (24 Training + 6 Testing) \n", - "Challenge: BTCV MICCAI Challenge\n", + "- Target: 13 abdominal organs including \n", + " 1. Spleen \n", + " 2. Right Kidney \n", + " 3. Left Kideny \n", + " 4. Gallbladder \n", + " 5. Esophagus \n", + " 6. Liver \n", + " 7. Stomach \n", + " 8. Aorta \n", + " 9. IVC \n", + " 10. Portal and Splenic Veins \n", + " 11. Pancreas \n", + " 12. Right adrenal gland \n", + " 13. Left adrenal gland.\n", + "- Modality: CT\n", + "- Size: 30 3D volumes (24 Training + 6 Testing)\n", + "- Challenge: BTCV MICCAI Challenge\n", "\n", "The following figure shows image patches with the organ sub-regions that are annotated in the CT (top left) and the final labels for the whole dataset (right).\n", "\n", @@ -92,7 +104,7 @@ "metadata": {}, "outputs": [], "source": [ - "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel, tqdm]\"\n", + "!python -c \"import monai; import nibabel; import tqdm\" || pip install -q \"monai-weekly[nibabel, tqdm]\"\n", "!python -c \"import matplotlib\" || pip install -q matplotlib\n", "%matplotlib inline" ] @@ -108,7 +120,6 @@ "import tempfile\n", "\n", "import matplotlib.pyplot as plt\n", - "import numpy as np\n", "from tqdm import tqdm\n", "\n", "from monai.losses import DiceCELoss\n", @@ -125,6 +136,7 @@ " ScaleIntensityRanged,\n", " Spacingd,\n", " RandRotate90d,\n", + " EnsureTyped,\n", ")\n", "\n", "from monai.config import print_config\n", @@ -132,10 +144,11 @@ "from monai.networks.nets import SwinUNETR\n", "\n", "from monai.data import (\n", - " DataLoader,\n", + " ThreadDataLoader,\n", " CacheDataset,\n", " load_decathlon_datalist,\n", " decollate_batch,\n", + " set_track_meta,\n", ")\n", "\n", "\n", @@ -171,17 +184,32 @@ "metadata": {}, "source": [ "## Setup transforms for training and validation\n", - "To save on GPU memory utilization, the num_samples can be reduced to 2. " + "To save on GPU memory utilization, the num_samples can be reduced to 2. \n", + "\n", + "A note on design related to MetaTensors:\n", + "\n", + "- Summary: using `EnsureTyped(..., track_meta=False)` (caching) and `set_track_meta(False)` (during training) speeds up training significantly.\n", + "\n", + "- We are moving towards the use of MONAI's MetaTensor in place of numpy arrays or PyTorch tensors. MetaTensors have the benefit of carrying the metadata directly with the tensor, but in some use cases (like here with training, where training data are only used for computing loss and metadata is not useful), we can safely disregard the metadata to improve speed.\n", + "\n", + "- Hence, you will see `EnsureTyped` being used before the first random transform in the training transform chain, which caches the result of deterministic transforms on GPU as Tensors (rather than MetaTensors), with `track_meta = False`. \n", + "\n", + "- On the other hand, in the following demos we will display example validation images, which uses metadata, so we use `EnsureTyped` with `track_meta = True`. Since there are no random transforms during validation, tracking metadata for validation images causes virtually no slowdown (~0.5%).\n", + "\n", + "- In the next section, you will see `set_track_meta(False)`. This is a global API introduced in MONAI 0.9.1, and it makes sure that random transforms will also be performed using Tensors rather than MetaTensors. Used together with `track_meta=False` in `EnsureTyped`, it results in all transforms being performed on Tensors, which we have found to speed up training." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_samples = 4\n", "\n", + "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", "train_transforms = Compose(\n", " [\n", " LoadImaged(keys=[\"image\", \"label\"], ensure_channel_first=True),\n", @@ -200,6 +228,7 @@ " clip=True,\n", " ),\n", " CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n", + " EnsureTyped(keys=[\"image\", \"label\"], device=device, track_meta=False),\n", " RandCropByPosNegLabeld(\n", " keys=[\"image\", \"label\"],\n", " label_key=\"label\",\n", @@ -250,6 +279,7 @@ " keys=[\"image\"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True\n", " ),\n", " CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n", + " EnsureTyped(keys=[\"image\", \"label\"], device=device, track_meta=True),\n", " ]\n", ")" ] @@ -258,188 +288,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - " ## Download dataset and format in the folder.\n", - " 1. Download dataset from here: https://www.synapse.org/#!Synapse:syn3193805/wiki/89480\\n\n", - " 2. Put images in the ./data/imagesTr\n", - " 3. Put labels in the ./data/labelsTr\n", - " 4. make JSON file accordingly: ./data/dataset_0.json\n", - " Example of JSON file:\n", - " {\n", - " \"description\": \"btcv yucheng\",\n", - " \"labels\": {\n", - " \"0\": \"background\",\n", - " \"1\": \"spleen\",\n", - " \"2\": \"rkid\",\n", - " \"3\": \"lkid\",\n", - " \"4\": \"gall\",\n", - " \"5\": \"eso\",\n", - " \"6\": \"liver\",\n", - " \"7\": \"sto\",\n", - " \"8\": \"aorta\",\n", - " \"9\": \"IVC\",\n", - " \"10\": \"veins\",\n", - " \"11\": \"pancreas\",\n", - " \"12\": \"rad\",\n", - " \"13\": \"lad\"\n", - " },\n", - " \"licence\": \"yt\",\n", - " \"modality\": {\n", - " \"0\": \"CT\"\n", - " },\n", - " \"name\": \"btcv\",\n", - " \"numTest\": 20,\n", - " \"numTraining\": 80,\n", - " \"reference\": \"Vanderbilt University\",\n", - " \"release\": \"1.0 06/08/2015\",\n", - " \"tensorImageSize\": \"3D\",\n", - " \"test\": [\n", - " \"imagesTs/img0061.nii.gz\",\n", - " \"imagesTs/img0062.nii.gz\",\n", - " \"imagesTs/img0063.nii.gz\",\n", - " \"imagesTs/img0064.nii.gz\",\n", - " \"imagesTs/img0065.nii.gz\",\n", - " \"imagesTs/img0066.nii.gz\",\n", - " \"imagesTs/img0067.nii.gz\",\n", - " \"imagesTs/img0068.nii.gz\",\n", - " \"imagesTs/img0069.nii.gz\",\n", - " \"imagesTs/img0070.nii.gz\",\n", - " \"imagesTs/img0071.nii.gz\",\n", - " \"imagesTs/img0072.nii.gz\",\n", - " \"imagesTs/img0073.nii.gz\",\n", - " \"imagesTs/img0074.nii.gz\",\n", - " \"imagesTs/img0075.nii.gz\",\n", - " \"imagesTs/img0076.nii.gz\",\n", - " \"imagesTs/img0077.nii.gz\",\n", - " \"imagesTs/img0078.nii.gz\",\n", - " \"imagesTs/img0079.nii.gz\",\n", - " \"imagesTs/img0080.nii.gz\"\n", - " ],\n", - " \"training\": [\n", - " {\n", - " \"image\": \"imagesTr/img0001.nii.gz\",\n", - " \"label\": \"labelsTr/label0001.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0002.nii.gz\",\n", - " \"label\": \"labelsTr/label0002.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0003.nii.gz\",\n", - " \"label\": \"labelsTr/label0003.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0004.nii.gz\",\n", - " \"label\": \"labelsTr/label0004.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0005.nii.gz\",\n", - " \"label\": \"labelsTr/label0005.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0006.nii.gz\",\n", - " \"label\": \"labelsTr/label0006.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0007.nii.gz\",\n", - " \"label\": \"labelsTr/label0007.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0008.nii.gz\",\n", - " \"label\": \"labelsTr/label0008.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0009.nii.gz\",\n", - " \"label\": \"labelsTr/label0009.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0010.nii.gz\",\n", - " \"label\": \"labelsTr/label0010.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0021.nii.gz\",\n", - " \"label\": \"labelsTr/label0021.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0022.nii.gz\",\n", - " \"label\": \"labelsTr/label0022.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0023.nii.gz\",\n", - " \"label\": \"labelsTr/label0023.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0024.nii.gz\",\n", - " \"label\": \"labelsTr/label0024.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0025.nii.gz\",\n", - " \"label\": \"labelsTr/label0025.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0026.nii.gz\",\n", - " \"label\": \"labelsTr/label0026.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0027.nii.gz\",\n", - " \"label\": \"labelsTr/label0027.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0028.nii.gz\",\n", - " \"label\": \"labelsTr/label0028.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0029.nii.gz\",\n", - " \"label\": \"labelsTr/label0029.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0030.nii.gz\",\n", - " \"label\": \"labelsTr/label0030.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0031.nii.gz\",\n", - " \"label\": \"labelsTr/label0031.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0032.nii.gz\",\n", - " \"label\": \"labelsTr/label0032.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0033.nii.gz\",\n", - " \"label\": \"labelsTr/label0033.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0034.nii.gz\",\n", - " \"label\": \"labelsTr/label0034.nii.gz\"\n", - " }\n", - " ],\n", - " \"validation\": [\n", - " {\n", - " \"image\": \"imagesTr/img0035.nii.gz\",\n", - " \"label\": \"labelsTr/label0035.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0036.nii.gz\",\n", - " \"label\": \"labelsTr/label0036.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0037.nii.gz\",\n", - " \"label\": \"labelsTr/label0037.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0038.nii.gz\",\n", - " \"label\": \"labelsTr/label0038.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0039.nii.gz\",\n", - " \"label\": \"labelsTr/label0039.nii.gz\"\n", - " },\n", - " {\n", - " \"image\": \"imagesTr/img0040.nii.gz\",\n", - " \"label\": \"labelsTr/label0040.nii.gz\"\n", - " }\n", - " ]\n", - "}\n", - " " + "## Download dataset and format in the folder\n", + "1. Download dataset from here: https://www.synapse.org/#!Synapse:syn3193805/wiki/89480. After you open the link, navigate to the \"Files\" tab, then download Abdomen/RawData.zip.\n", + "\n", + " Note that you may need to register for an account on Synapse and consent to use agreements before being able to view/download this file. There are options to download directly from the browser or from the command line; please refer to Synapse API documentation for more info.\n", + "\n", + "\n", + "2. After downloading the zip file, unzip. Then put images from `RawData/Training/img` in `./data/imagesTr`, and put labels from `RawData/Training/label` in `./data/labelsTr`.\n", + "\n", + "\n", + "3. Make a JSON file to define train/val split and other relevant parameters. Place the JSON file at `./data/dataset_0.json`.\n", + "\n", + " You can download an example of the JSON file [here](https://drive.google.com/file/d/1qcGh41p-rI3H_sQ0JwOAhNiQSXriQqGi/view?usp=sharing), or, equivalently, use the following `wget` command. If you would like to use this directly, please move it into the `./data` folder." ] }, { @@ -448,7 +308,17 @@ "metadata": {}, "outputs": [], "source": [ - "data_dir = \"/data/\"\n", + "# uncomment this command to download the JSON file directly\n", + "# wget -O data/dataset_0.json 'https://drive.google.com/uc?export=download&id=1qcGh41p-rI3H_sQ0JwOAhNiQSXriQqGi'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"data/\"\n", "split_JSON = \"dataset_0.json\"\n", "\n", "datasets = data_dir + split_JSON\n", @@ -461,15 +331,17 @@ " cache_rate=1.0,\n", " num_workers=8,\n", ")\n", - "train_loader = DataLoader(\n", - " train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True\n", - ")\n", + "train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)\n", "val_ds = CacheDataset(\n", " data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4\n", ")\n", - "val_loader = DataLoader(\n", - " val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True\n", - ")" + "val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)\n", + "\n", + "# as explained in the \"Setup transforms\" section above, we want cached training images to not have metadata, and validations to have metadata\n", + "# the EnsureTyped transforms allow us to make this distinction\n", + "# on the other hand, set_track_meta is a global API; doing so here makes sure subsequent transforms (i.e., random transforms for training)\n", + "# will be carried out as Tensors, not MetaTensors\n", + "set_track_meta(False)" ] }, { @@ -536,12 +408,12 @@ "source": [ "### Create Swin UNETR model\n", "\n", - "In this scetion, we create Swin UNETR model for the 14-class multi-organ segmentation. We use a feature size of 48 which is compatible with self-supervised pre-trained weights. We also use gradient checkpointing (use_checkpoint) for more memory-efficient training. " + "In this section, we create a Swin UNETR model for the 14-class multi-organ segmentation. We use a feature size of 48, which is compatible with the self-supervised pre-trained weights. We also use gradient checkpointing (use_checkpoint) for more memory-efficient training. " ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -563,22 +435,24 @@ "source": [ "### Initialize Swin UNETR encoder from self-supervised pre-trained weights\n", "\n", - "In this section, we intialize the Swin UNETR encoder from weights downloaded from this [link](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/model_swinvit.pt). If training from scratch is desired, please skip this section." + "In this section, we intialize the Swin UNETR encoder from pre-trained weights. The weights can be downloaded using the wget command below, or by following [this link](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/model_swinvit.pt) to GitHub. If training from scratch is desired, please skip this section." ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using pretrained self-supervied Swin UNETR backbone weights !\n" - ] - } - ], + "outputs": [], + "source": [ + "# uncomment to download the pre-trained weights\n", + "# !wget https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/model_swinvit.pt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "weight = torch.load(\"./model_swinvit.pt\")\n", "model.load_from(weights=weight)\n", @@ -594,13 +468,14 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.backends.cudnn.benchmark = True\n", "loss_function = DiceCELoss(to_onehot_y=True, softmax=True)\n", - "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)" + "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)\n", + "scaler = torch.cuda.amp.GradScaler()" ] }, { @@ -623,7 +498,8 @@ " with torch.no_grad():\n", " for step, batch in enumerate(epoch_iterator_val):\n", " val_inputs, val_labels = (batch[\"image\"].cuda(), batch[\"label\"].cuda())\n", - " val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)\n", + " with torch.cuda.amp.autocast():\n", + " val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)\n", " val_labels_list = decollate_batch(val_labels)\n", " val_labels_convert = [\n", " post_label(val_label_tensor) for val_label_tensor in val_labels_list\n", @@ -651,11 +527,14 @@ " for step, batch in enumerate(epoch_iterator):\n", " step += 1\n", " x, y = (batch[\"image\"].cuda(), batch[\"label\"].cuda())\n", - " logit_map = model(x)\n", - " loss = loss_function(logit_map, y)\n", - " loss.backward()\n", + " with torch.cuda.amp.autocast():\n", + " logit_map = model(x)\n", + " loss = loss_function(logit_map, y)\n", + " scaler.scale(loss).backward()\n", " epoch_loss += loss.item()\n", - " optimizer.step()\n", + " scaler.unscale_(optimizer)\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", " optimizer.zero_grad()\n", " epoch_iterator.set_description(\n", " \"Training (%d / %d Steps) (loss=%2.5f)\"\n", @@ -689,9 +568,15 @@ " )\n", " )\n", " global_step += 1\n", - " return global_step, dice_val_best, global_step_best\n", - "\n", - "\n", + " return global_step, dice_val_best, global_step_best" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "max_iterations = 30000\n", "eval_num = 500\n", "post_label = AsDiscrete(to_onehot=14)\n", @@ -835,7 +720,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [