Skip to content

Commit b22a2af

Browse files
authored
Add AMP inference logic in Brats tutorial (#335)
* [DLMED] add AMP inference Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] fix format issue Signed-off-by: Nic Ma <nma@nvidia.com>
1 parent 42856da commit b22a2af

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

3d_segmentation/brats_segmentation_3d.ipynb

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@
435435
"source": [
436436
"max_epochs = 300\n",
437437
"val_interval = 1\n",
438+
"VAL_AMP = True\n",
438439
"\n",
439440
"# standard PyTorch program style: create SegResNet, DiceLoss and Adam optimizer\n",
440441
"device = torch.device(\"cuda:0\")\n",
@@ -457,9 +458,28 @@
457458
" [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]\n",
458459
")\n",
459460
"\n",
461+
"\n",
462+
"# define inference method\n",
463+
"def inference(input):\n",
464+
"\n",
465+
" def _compute(input):\n",
466+
" return sliding_window_inference(\n",
467+
" inputs=input,\n",
468+
" roi_size=(240, 240, 160),\n",
469+
" sw_batch_size=1,\n",
470+
" predictor=model,\n",
471+
" overlap=0.5,\n",
472+
" )\n",
473+
"\n",
474+
" if VAL_AMP:\n",
475+
" with torch.cuda.amp.autocast():\n",
476+
" return _compute(input)\n",
477+
" else:\n",
478+
" return _compute(input)\n",
479+
"\n",
480+
"\n",
460481
"# use amp to accelerate training\n",
461482
"scaler = torch.cuda.amp.GradScaler()\n",
462-
"VAL_AMP = True\n",
463483
"# enable cuDNN benchmark\n",
464484
"torch.backends.cudnn.benchmark = True"
465485
]
@@ -531,17 +551,7 @@
531551
" val_data[\"image\"].to(device),\n",
532552
" val_data[\"label\"].to(device),\n",
533553
" )\n",
534-
"\n",
535-
" def _compute_sliding_window(input):\n",
536-
" return sliding_window_inference(\n",
537-
" inputs=input, roi_size=(240, 240, 160), sw_batch_size=1, predictor=model, overlap=0.5\n",
538-
" )\n",
539-
"\n",
540-
" if VAL_AMP:\n",
541-
" with torch.cuda.amp.autocast():\n",
542-
" val_outputs = _compute_sliding_window(val_inputs)\n",
543-
" else:\n",
544-
" val_outputs = _compute_sliding_window(val_inputs)\n",
554+
" val_outputs = inference(val_inputs)\n",
545555
" val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]\n",
546556
" dice_metric(y_pred=val_outputs, y=val_labels)\n",
547557
" dice_metric_batch(y_pred=val_outputs, y=val_labels)\n",
@@ -732,9 +742,7 @@
732742
" val_input = val_ds[6][\"image\"].unsqueeze(0).to(device)\n",
733743
" roi_size = (128, 128, 64)\n",
734744
" sw_batch_size = 4\n",
735-
" val_output = sliding_window_inference(\n",
736-
" inputs=val_input, roi_size=(240, 240, 160), sw_batch_size=1, predictor=model, overlap=0.5\n",
737-
" )\n",
745+
" val_output = inference(val_input)\n",
738746
" val_output = post_trans(val_output[0])\n",
739747
" plt.figure(\"image\", (24, 6))\n",
740748
" for i in range(4):\n",
@@ -835,9 +843,7 @@
835843
"with torch.no_grad():\n",
836844
" for val_data in val_org_loader:\n",
837845
" val_inputs = val_data[\"image\"].to(device)\n",
838-
" val_data[\"pred\"] = sliding_window_inference(\n",
839-
" inputs=val_inputs, roi_size=(240, 240, 160), sw_batch_size=1, predictor=model, overlap=0.5\n",
840-
" )\n",
846+
" val_data[\"pred\"] = inference(val_inputs)\n",
841847
" val_data = [post_transforms(i) for i in decollate_batch(val_data)]\n",
842848
" val_outputs, val_labels = from_engine([\"pred\", \"label\"])(val_data)\n",
843849
" dice_metric(y_pred=val_outputs, y=val_labels)\n",
@@ -879,7 +885,7 @@
879885
],
880886
"metadata": {
881887
"kernelspec": {
882-
"display_name": "Python 3",
888+
"display_name": "Python 3 (ipykernel)",
883889
"language": "python",
884890
"name": "python3"
885891
},

0 commit comments

Comments
 (0)