Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX exporter #1826

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
12 changes: 8 additions & 4 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ def __init__(self,
self.device = device
self.perform_everything_on_gpu = perform_everything_on_gpu

def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
checkpoint_name: str = 'checkpoint_final.pth'):
def initialize_from_trained_model_folder(
self,
model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
checkpoint_name: str = 'checkpoint_final.pth',
disable_compilation: bool = False,
):
"""
This is used when making predictions with a trained model
"""
Expand Down Expand Up @@ -109,7 +113,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.label_manager = plans_manager.get_label_manager(dataset_json)
if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
and not isinstance(self.network, OptimizedModule):
and not isinstance(self.network, OptimizedModule) and not disable_compilation:
print('compiling network')
self.network = torch.compile(self.network)

Expand Down
205 changes: 174 additions & 31 deletions nnunetv2/model_sharing/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,204 @@
from pathlib import Path

from nnunetv2.model_sharing.model_download import download_and_install_from_url
from nnunetv2.model_sharing.model_export import export_pretrained_model
from nnunetv2.model_sharing.model_import import install_model_from_zip_file
from nnunetv2.model_sharing.onnx_export import export_onnx_model


def print_license_warning():
print('')
print('######################################################')
print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!')
print('######################################################')
print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some "
"allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use "
"nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!")
print('######################################################')
print('')
print("")
print("######################################################")
print("!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!")
print("######################################################")
print(
"Using the pretrained model weights is subject to the license of the dataset they were trained on. Some "
"allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use "
"nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!"
)
print("######################################################")
print("")


def download_by_url():
import argparse

parser = argparse.ArgumentParser(
description="Use this to download pretrained models. This script is intended to download models via url only. "
"CAREFUL: This script will overwrite "
"existing models (if they share the same trainer class and plans as "
"the pretrained model.")
parser.add_argument("url", type=str, help='URL of the pretrained model')
"CAREFUL: This script will overwrite "
"existing models (if they share the same trainer class and plans as "
"the pretrained model."
)
parser.add_argument("url", type=str, help="URL of the pretrained model")
args = parser.parse_args()
url = args.url
download_and_install_from_url(url)


def install_from_zip_entry_point():
import argparse

parser = argparse.ArgumentParser(
description="Use this to install a zip file containing a pretrained model.")
parser.add_argument("zip", type=str, help='zip file')
description="Use this to install a zip file containing a pretrained model."
)
parser.add_argument("zip", type=str, help="zip file")
args = parser.parse_args()
zip = args.zip
install_model_from_zip_file(zip)


def export_pretrained_model_entry():
import argparse

parser = argparse.ArgumentParser(
description="Use this to export a trained model as a zip file."
)
parser.add_argument("-d", type=str, required=True, help="Dataset name or id")
parser.add_argument("-o", type=str, required=True, help="Output file name")
parser.add_argument(
"-c",
nargs="+",
type=str,
required=False,
default=("3d_lowres", "3d_fullres", "2d", "3d_cascade_fullres"),
help="List of configuration names",
)
parser.add_argument(
"-tr", required=False, type=str, default="nnUNetTrainer", help="Trainer class"
)
parser.add_argument(
"-p", required=False, type=str, default="nnUNetPlans", help="plans identifier"
)
parser.add_argument(
"-f",
required=False,
nargs="+",
type=str,
default=(0, 1, 2, 3, 4),
help="list of fold ids",
)
parser.add_argument(
"-chk",
required=False,
nargs="+",
type=str,
default=("checkpoint_final.pth",),
help="List of checkpoint names to export. Default: checkpoint_final.pth",
)
parser.add_argument(
"--not_strict",
action="store_false",
default=False,
required=False,
help="Set this to allow missing folds and/or configurations",
)
parser.add_argument(
"--exp_cv_preds",
action="store_true",
required=False,
help="Set this to export the cross-validation predictions as well",
)
args = parser.parse_args()

export_pretrained_model(
dataset_name_or_id=args.d,
output_file=args.o,
configurations=args.c,
trainer=args.tr,
plans_identifier=args.p,
folds=args.f,
strict=not args.not_strict,
save_checkpoints=args.chk,
export_crossval_predictions=args.exp_cv_preds,
)


def export_pretrained_model_onnx_entry():
import argparse

parser = argparse.ArgumentParser(
description="Use this to export a trained model as a zip file.")
parser.add_argument('-d', type=str, required=True, help='Dataset name or id')
parser.add_argument('-o', type=str, required=True, help='Output file name')
parser.add_argument('-c', nargs='+', type=str, required=False,
default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'),
help="List of configuration names")
parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class')
parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier')
parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids')
parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ),
help='Lis tof checkpoint names to export. Default: checkpoint_final.pth')
parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations')
parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well')
description="Use this to export a trained model to ONNX format."
"You are responsible for creating the ONNX pipeline yourself."
)
parser.add_argument("-d", type=str, required=True, help="Dataset name or id")
parser.add_argument("-o", type=Path, required=True, help="Output directory")
parser.add_argument(
"-c",
nargs="+",
type=str,
required=False,
default=("3d_lowres", "3d_fullres", "2d", "3d_cascade_fullres"),
help="List of configuration names",
)
parser.add_argument(
"-tr", required=False, type=str, default="nnUNetTrainer", help="Trainer class"
)
parser.add_argument(
"-p", required=False, type=str, default="nnUNetPlans", help="plans identifier"
)
parser.add_argument(
"-f", required=False, nargs="+", type=str, default=None, help="list of fold ids"
)
parser.add_argument(
"-b",
required=False,
type=int,
default=0,
help="Batch size. Set to 0 for dynamic axes. Default: 0",
)
parser.add_argument(
"-chk",
required=False,
nargs="+",
type=str,
default=("checkpoint_final.pth",),
help="List of checkpoint names to export. Default: checkpoint_final.pth",
)
parser.add_argument(
"--not_strict",
action="store_false",
default=False,
required=False,
help="Set this to allow missing folds and/or configurations",
)
parser.add_argument(
"-v",
action="store_false",
default=False,
required=False,
help="Set this to get verbose output",
)
args = parser.parse_args()

export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr,
plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk,
export_crossval_predictions=args.exp_cv_preds)
print("######################################################")
print("!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!")
print("######################################################")
print(
"Exported models are provided as-is, without any\n"
"guarantees, warranties and/or support from MIC-DKFZ,\n"
"any associated persons and/or other entities.\n"
)
print(
"You will bear sole responsibility for the proper\n"
"use of the exported models.\n"
)
print(
"You are responsible for creating and validating\n"
"the ONNX pipeline yourself. To this end we provide\n"
"the .onnx file, and a config.json containing any\n"
"details you might need."
)
print("######################################################\n")

export_onnx_model(
dataset_name_or_id=args.d,
output_dir=args.o,
configurations=args.c,
batch_size=args.b,
trainer=args.tr,
plans_identifier=args.p,
folds=args.f,
strict=not args.not_strict,
save_checkpoints=args.chk,
verbose=args.v,
)
Loading