|
435 | 435 | "source": [ |
436 | 436 | "max_epochs = 300\n", |
437 | 437 | "val_interval = 1\n", |
| 438 | + "VAL_AMP = True\n", |
438 | 439 | "\n", |
439 | 440 | "# standard PyTorch program style: create SegResNet, DiceLoss and Adam optimizer\n", |
440 | 441 | "device = torch.device(\"cuda:0\")\n", |
|
457 | 458 | " [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]\n", |
458 | 459 | ")\n", |
459 | 460 | "\n", |
| 461 | + "\n", |
| 462 | + "# define inference method\n", |
| 463 | + "def inference(input):\n", |
| 464 | + "\n", |
| 465 | + " def _compute(input):\n", |
| 466 | + " return sliding_window_inference(\n", |
| 467 | + " inputs=input,\n", |
| 468 | + " roi_size=(240, 240, 160),\n", |
| 469 | + " sw_batch_size=1,\n", |
| 470 | + " predictor=model,\n", |
| 471 | + " overlap=0.5,\n", |
| 472 | + " )\n", |
| 473 | + "\n", |
| 474 | + " if VAL_AMP:\n", |
| 475 | + " with torch.cuda.amp.autocast():\n", |
| 476 | + " return _compute(input)\n", |
| 477 | + " else:\n", |
| 478 | + " return _compute(input)\n", |
| 479 | + "\n", |
| 480 | + "\n", |
460 | 481 | "# use amp to accelerate training\n", |
461 | 482 | "scaler = torch.cuda.amp.GradScaler()\n", |
462 | | - "VAL_AMP = True\n", |
463 | 483 | "# enable cuDNN benchmark\n", |
464 | 484 | "torch.backends.cudnn.benchmark = True" |
465 | 485 | ] |
|
531 | 551 | " val_data[\"image\"].to(device),\n", |
532 | 552 | " val_data[\"label\"].to(device),\n", |
533 | 553 | " )\n", |
534 | | - "\n", |
535 | | - " def _compute_sliding_window(input):\n", |
536 | | - " return sliding_window_inference(\n", |
537 | | - " inputs=input, roi_size=(240, 240, 160), sw_batch_size=1, predictor=model, overlap=0.5\n", |
538 | | - " )\n", |
539 | | - "\n", |
540 | | - " if VAL_AMP:\n", |
541 | | - " with torch.cuda.amp.autocast():\n", |
542 | | - " val_outputs = _compute_sliding_window(val_inputs)\n", |
543 | | - " else:\n", |
544 | | - " val_outputs = _compute_sliding_window(val_inputs)\n", |
| 554 | + " val_outputs = inference(val_inputs)\n", |
545 | 555 | " val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]\n", |
546 | 556 | " dice_metric(y_pred=val_outputs, y=val_labels)\n", |
547 | 557 | " dice_metric_batch(y_pred=val_outputs, y=val_labels)\n", |
|
732 | 742 | " val_input = val_ds[6][\"image\"].unsqueeze(0).to(device)\n", |
733 | 743 | " roi_size = (128, 128, 64)\n", |
734 | 744 | " sw_batch_size = 4\n", |
735 | | - " val_output = sliding_window_inference(\n", |
736 | | - " inputs=val_input, roi_size=(240, 240, 160), sw_batch_size=1, predictor=model, overlap=0.5\n", |
737 | | - " )\n", |
| 745 | + " val_output = inference(val_input)\n", |
738 | 746 | " val_output = post_trans(val_output[0])\n", |
739 | 747 | " plt.figure(\"image\", (24, 6))\n", |
740 | 748 | " for i in range(4):\n", |
|
835 | 843 | "with torch.no_grad():\n", |
836 | 844 | " for val_data in val_org_loader:\n", |
837 | 845 | " val_inputs = val_data[\"image\"].to(device)\n", |
838 | | - " val_data[\"pred\"] = sliding_window_inference(\n", |
839 | | - " inputs=val_inputs, roi_size=(240, 240, 160), sw_batch_size=1, predictor=model, overlap=0.5\n", |
840 | | - " )\n", |
| 846 | + " val_data[\"pred\"] = inference(val_inputs)\n", |
841 | 847 | " val_data = [post_transforms(i) for i in decollate_batch(val_data)]\n", |
842 | 848 | " val_outputs, val_labels = from_engine([\"pred\", \"label\"])(val_data)\n", |
843 | 849 | " dice_metric(y_pred=val_outputs, y=val_labels)\n", |
|
879 | 885 | ], |
880 | 886 | "metadata": { |
881 | 887 | "kernelspec": { |
882 | | - "display_name": "Python 3", |
| 888 | + "display_name": "Python 3 (ipykernel)", |
883 | 889 | "language": "python", |
884 | 890 | "name": "python3" |
885 | 891 | }, |
|
0 commit comments