-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathtrt_compiler.py
675 lines (596 loc) · 26.9 KB
/
trt_compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
import os
import tempfile
import threading
from collections import OrderedDict
from pathlib import Path
from types import MethodType
from typing import Any, Dict, List, Tuple, Union
import torch
from monai.apps.utils import get_logger
from monai.networks.utils import add_casts_around_norms, convert_to_onnx, get_profile_shapes
from monai.utils.module import optional_import
polygraphy, polygraphy_imported = optional_import("polygraphy")
if polygraphy_imported:
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import (
CreateConfig,
Profile,
engine_bytes_from_network,
engine_from_bytes,
network_from_onnx_path,
)
trt, trt_imported = optional_import("tensorrt")
torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
cudart, _ = optional_import("cuda.cudart")
lock_sm = threading.Lock()
# Map of TRT dtype -> Torch dtype
def trt_to_torch_dtype_dict():
return {
trt.int32: torch.int32,
trt.float32: torch.float32,
trt.float16: torch.float16,
trt.bfloat16: torch.float16,
trt.int64: torch.int64,
trt.int8: torch.int8,
trt.bool: torch.bool,
}
def get_dynamic_axes(profiles):
"""
This method calculates dynamic_axes to use in onnx.export().
Args:
profiles: [[min,opt,max],...] list of profile dimensions
"""
dynamic_axes: dict[str, list[int]] = {}
if not profiles:
return dynamic_axes
for profile in profiles:
for key in profile:
axes = []
vals = profile[key]
for i in range(len(vals[0])):
if vals[0][i] != vals[2][i]:
axes.append(i)
if len(axes) > 0:
dynamic_axes[key] = axes
return dynamic_axes
def cuassert(cuda_ret):
"""
Error reporting method for CUDA calls.
Args:
cuda_ret: CUDA return code.
"""
err = cuda_ret[0]
if err != 0:
raise RuntimeError(f"CUDA ERROR: {err}")
if len(cuda_ret) > 1:
return cuda_ret[1]
return None
class ShapeError(Exception):
"""
Exception class to report errors from setting TRT plan input shapes
"""
pass
class TRTEngine:
"""
An auxiliary class to implement running of TRT optimized engines
"""
def __init__(self, plan_path, logger=None):
"""
Loads serialized engine, creates execution context and activates it
Args:
plan_path: path to serialized TRT engine.
logger: optional logger object
"""
self.plan_path = plan_path
self.logger = logger or get_logger("monai.networks.trt_compiler")
self.logger.info(f"Loading TensorRT engine: {self.plan_path}")
self.engine = engine_from_bytes(bytes_from_path(self.plan_path))
self.tensors = OrderedDict()
self.cuda_graph_instance = None # cuda graph
self.context = self.engine.create_execution_context()
self.input_names = []
self.output_names = []
self.dtypes = []
self.cur_profile = 0
self.input_table = {}
dtype_dict = trt_to_torch_dtype_dict()
for idx in range(self.engine.num_io_tensors):
binding = self.engine[idx]
if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
self.input_names.append(binding)
elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT:
self.output_names.append(binding)
dtype = dtype_dict[self.engine.get_tensor_dtype(binding)]
self.dtypes.append(dtype)
self.logger.info(
f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}"
)
def allocate_buffers(self, device):
"""
Allocates outputs to run TRT engine
Args:
device: GPU device to allocate memory on
"""
ctx = self.context
for i, binding in enumerate(self.output_names):
shape = list(ctx.get_tensor_shape(binding))
if binding not in self.tensors or list(self.tensors[binding].shape) != shape:
t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous()
self.tensors[binding] = t
ctx.set_tensor_address(binding, t.data_ptr())
def set_inputs(self, feed_dict, stream):
"""
Sets input bindings for TRT engine according to feed_dict
Args:
feed_dict: a dictionary [str->Tensor]
stream: CUDA stream to use
"""
e = self.engine
ctx = self.context
last_profile = self.cur_profile
def try_set_inputs():
for binding in self.input_names:
t = feed_dict.get(self.input_table[binding], None)
if t is not None:
t = t.contiguous()
shape = t.shape
ctx.set_input_shape(binding, shape)
ctx.set_tensor_address(binding, t.data_ptr())
while True:
try:
try_set_inputs()
break
except ShapeError:
next_profile = (self.cur_profile + 1) % e.num_optimization_profiles
if next_profile == last_profile:
raise
self.cur_profile = next_profile
ctx.set_optimization_profile_async(self.cur_profile, stream)
except Exception:
raise
left = ctx.infer_shapes()
assert len(left) == 0
def infer(self, stream, use_cuda_graph=False):
"""
Runs TRT engine.
Args:
stream: CUDA stream to run on
use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls.
"""
if use_cuda_graph:
if self.cuda_graph_instance is not None:
cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
cuassert(cudart.cudaStreamSynchronize(stream))
else:
# do inference before CUDA graph capture
noerror = self.context.execute_async_v3(stream)
if not noerror:
raise ValueError("ERROR: inference failed.")
# capture cuda graph
cuassert(
cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal)
)
self.context.execute_async_v3(stream)
graph = cuassert(cudart.cudaStreamEndCapture(stream))
self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0))
self.logger.info("CUDA Graph captured!")
else:
noerror = self.context.execute_async_v3(stream)
cuassert(cudart.cudaStreamSynchronize(stream))
if not noerror:
raise ValueError("ERROR: inference failed.")
return self.tensors
def make_tensor(d):
return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda()
def unroll_input(input_names, input_example):
# Simulate list/tuple unrolling during ONNX export
unrolled_input = {}
for name in input_names:
val = input_example[name]
if val is not None:
if isinstance(val, list) or isinstance(val, tuple):
for i in range(len(val)):
unrolled_input[f"{name}_{i}"] = make_tensor(val[i])
else:
unrolled_input[name] = make_tensor(val)
return unrolled_input
def parse_groups(
ret: List[torch.Tensor], output_lists: List[List[int]]
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]:
"""
Implements parsing of 'output_lists' arg of trt_compile().
Args:
ret: plain list of Tensors
output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list
of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
Format: [[group_n] | [], ...]
[] or group_n == 0 : next output from ret is a scalar
group_n > 0 : next output from ret is a list of group_n length
group_n == -1: next output is a dynamic list. This entry can be at any
position in output_lists, but can appear only once.
Returns:
Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists
"""
groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple()
cur = 0
for l in range(len(output_lists)):
gl = output_lists[l]
assert len(gl) == 0 or len(gl) == 1
if len(gl) == 0 or gl[0] == 0:
groups = (*groups, ret[cur])
cur = cur + 1
elif gl[0] > 0:
groups = (*groups, ret[cur : cur + gl[0]])
cur = cur + gl[0]
elif gl[0] == -1:
rev_groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple()
rcur = len(ret)
for rl in range(len(output_lists) - 1, l, -1):
rgl = output_lists[rl]
assert len(rgl) == 0 or len(rgl) == 1
if len(rgl) == 0 or rgl[0] == 0:
rcur = rcur - 1
rev_groups = (*rev_groups, ret[rcur])
elif rgl[0] > 0:
rcur = rcur - rgl[0]
rev_groups = (*rev_groups, ret[rcur : rcur + rgl[0]])
else:
raise ValueError("Two -1 lists in output")
groups = (*groups, ret[cur:rcur], *rev_groups[::-1])
break
return groups
class TrtCompiler:
"""
This class implements:
- TRT lazy persistent export
- Running TRT with optional fallback to Torch
(for TRT engines with limited profiles)
"""
def __init__(
self,
model,
plan_path,
precision="fp16",
method="onnx",
input_names=None,
output_names=None,
output_lists=None,
export_args=None,
build_args=None,
input_profiles=None,
dynamic_batchsize=None,
use_cuda_graph=False,
timestamp=None,
fallback=False,
forward_override=None,
logger=None,
):
"""
Initialization method:
Tries to load persistent serialized TRT engine
Saves its arguments for lazy TRT build on first forward() call
Args:
model: Model to "wrap".
plan_path : Path where to save persistent serialized TRT engine.
precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'.
method: One of 'onnx'|'torch_trt'.
Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option.
'torch_trt' may not work for some nets. Also AMP must be turned off for it to work.
input_names: Optional list of input names. If None, will be read from the function signature.
output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.
output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list
of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details.
build_args: Optional args to pass to TRT builder. See polygraphy.Config for details.
input_profiles: Optional list of profiles for TRT builder and ONNX export.
Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}.
dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be
converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH].
[note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine.
use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls!
timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes).
fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile).
"""
method_vals = ["onnx", "torch_trt"]
if method not in method_vals:
raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.")
precision_vals = ["fp32", "tf32", "fp16", "bf16"]
if precision not in precision_vals:
raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.")
self.plan_path = plan_path
self.precision = precision
self.method = method
self.return_dict = output_names is not None
self.output_names = output_names or []
self.output_lists = output_lists or []
self.profiles = input_profiles or []
self.dynamic_batchsize = dynamic_batchsize
self.export_args = export_args or {}
self.build_args = build_args or {}
self.engine: TRTEngine | None = None
self.use_cuda_graph = use_cuda_graph
self.fallback = fallback
self.disabled = False
self.logger = logger or get_logger("monai.networks.trt_compiler")
self.argspec = inspect.getfullargspec(model.forward)
# Normally we read input_names from forward() but can be overridden
if input_names is None:
input_names = self.argspec.args[1:]
self.defaults = {}
if self.argspec.defaults is not None:
for i in range(len(self.argspec.defaults)):
d = self.argspec.defaults[-i - 1]
if d is not None:
d = make_tensor(d)
self.defaults[self.argspec.args[-i - 1]] = d
self.input_names = input_names
self.old_forward = model.forward
# Force engine rebuild if older than the timestamp
if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp:
os.remove(self.plan_path)
def _inputs_to_dict(self, input_example):
trt_inputs = {}
for i, inp in enumerate(input_example):
input_name = self.input_names[i]
trt_inputs[input_name] = inp
return trt_inputs
def _load_engine(self):
"""
Loads TRT plan from disk and activates its execution context.
"""
try:
self.engine = TRTEngine(self.plan_path, self.logger)
# Make sure we have names correct
input_table = {}
for name in self.engine.input_names:
if name.startswith("__") and name not in self.input_names:
orig_name = name[2:]
else:
orig_name = name
input_table[name] = orig_name
self.engine.input_table = input_table
self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}")
except Exception as e:
self.logger.info(f"Exception while loading the engine:\n{e}")
def forward(self, model, argv, kwargs):
"""
Main forward method:
Builds TRT engine if not available yet.
Tries to run TRT engine
If exception thrown and self.callback==True: falls back to original Pytorch
Args: Passing through whatever args wrapped module's forward() has
Returns: Passing through wrapped module's forward() return value(s)
"""
args = self.defaults
args.update(kwargs)
if len(argv) > 0:
args.update(self._inputs_to_dict(argv))
if self.engine is None and not self.disabled:
# Restore original forward for export
new_forward = model.forward
model.forward = self.old_forward
try:
self._load_engine()
if self.engine is None:
build_args = args.copy()
with torch.no_grad():
self._build_and_save(model, build_args)
# This will reassign input_names from the engine
self._load_engine()
assert self.engine is not None
except Exception as e:
if self.fallback:
self.logger.info(f"Failed to build engine: {e}")
self.disabled = True
else:
raise e
if not self.disabled and not self.fallback:
# Delete all parameters
for param in model.parameters():
del param
# Call empty_cache to release GPU memory
torch.cuda.empty_cache()
# restore TRT hook
model.forward = new_forward
# Run the engine
try:
if self.engine is not None:
# forward_trt is not thread safe as we do not use per-thread execution contexts
with lock_sm:
device = torch.cuda.current_device()
stream = torch.cuda.Stream(device=device)
self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream)
self.engine.allocate_buffers(device=device)
# Need this to synchronize with Torch stream
stream.wait_stream(torch.cuda.current_stream())
ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph)
# if output_names is not None, return dictionary
if not self.return_dict:
ret = list(ret.values())
if self.output_lists:
ret = parse_groups(ret, self.output_lists)
elif len(ret) == 1:
ret = ret[0]
return ret
except Exception as e:
if self.fallback:
self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...")
else:
raise e
return self.old_forward(*argv, **kwargs)
def _onnx_to_trt(self, onnx_path):
"""
Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path
"""
profiles = []
for profile in self.profiles:
p = Profile()
for id, val in profile.items():
p.add(id, min=val[0], opt=val[1], max=val[2])
profiles.append(p)
build_args = self.build_args.copy()
build_args["tf32"] = self.precision != "fp32"
if self.precision == "fp16":
build_args["fp16"] = True
elif self.precision == "bf16":
build_args["bf16"] = True
self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}")
network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args))
def _build_and_save(self, model, input_example):
"""
If TRT engine is not ready, exports model to ONNX,
builds TRT engine and saves serialized TRT engine to the disk.
Args:
input_example: passed to onnx.export()
"""
if self.engine is not None:
return
export_args = self.export_args
engine_bytes = None
add_casts_around_norms(model)
if self.method == "torch_trt":
enabled_precisions = [torch.float32]
if self.precision == "fp16":
enabled_precisions.append(torch.float16)
elif self.precision == "bf16":
enabled_precisions.append(torch.bfloat16)
inputs = list(input_example.values())
def get_torch_trt_input(input_shape, dynamic_batchsize):
min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
return torch_tensorrt.Input(
min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape
)
tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs]
engine_bytes = torch_tensorrt.convert_method_to_trt_engine(
model, "forward", arg_inputs=tt_inputs, enabled_precisions=enabled_precisions, **export_args
)
else:
dbs = self.dynamic_batchsize
if dbs:
if len(self.profiles) > 0:
raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!")
if len(dbs) != 3:
raise ValueError("dynamic_batchsize has to have len ==3 ")
profile = {}
for id, val in input_example.items():
def add_profile(id, val):
sh = val.shape
if len(sh) > 0:
sh = sh[1:]
profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]
if isinstance(val, list) or isinstance(val, tuple):
for i in range(len(val)):
add_profile(f"{id}_{i}", val[i])
elif isinstance(val, torch.Tensor):
add_profile(id, val)
self.profiles = [profile]
self.dynamic_axes = get_dynamic_axes(self.profiles)
if len(self.dynamic_axes) > 0:
export_args.update({"dynamic_axes": self.dynamic_axes})
# Use temporary directory for easy cleanup in case of external weights
with tempfile.TemporaryDirectory() as tmpdir:
unrolled_input = unroll_input(self.input_names, input_example)
onnx_path = str(Path(tmpdir) / "model.onnx")
self.logger.info(
f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\n"
+ f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}"
)
convert_to_onnx(
model,
input_example,
filename=onnx_path,
input_names=list(unrolled_input.keys()),
output_names=self.output_names,
**export_args,
)
self.logger.info("Export to ONNX successful.")
engine_bytes = self._onnx_to_trt(onnx_path)
if engine_bytes:
open(self.plan_path, "wb").write(engine_bytes)
def trt_forward(self, *argv, **kwargs):
"""
Patch function to replace original model's forward() with.
Redirects to TrtCompiler.forward()
"""
return self._trt_compiler.forward(self, argv, kwargs)
def trt_compile(
model: torch.nn.Module,
base_path: str,
args: Dict[str, Any] | None = None,
submodule: Union[str, List[str]] | None = None,
logger: Any | None = None,
) -> torch.nn.Module:
"""
Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x.
NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
Review the TensorRT Support Matrix for which GPUs are supported.
Args:
model: module to patch with TrtCompiler object.
base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.
dirname(base_path) must exist, base_path does not have to.
If base_path does point to existing file (e.g. associated checkpoint),
that file becomes a dependency - its mtime is added to args["timestamp"].
args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details.
submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder']
If None, TrtCompiler patch is applied to the whole model.
Otherwise, submodule (or list of) is being patched.
logger: Optional logger for diagnostics.
Returns:
Always returns same model passed in as argument. This is for ease of use in configs.
"""
default_args: Dict[str, Any] = {
"method": "onnx",
"precision": "fp16",
"build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"},
}
default_args.update(args or {})
args = default_args
if trt_imported and polygraphy_imported and torch.cuda.is_available():
# if "path" filename point to existing file (e.g. checkpoint)
# it's also treated as dependency
if os.path.exists(base_path):
timestamp = int(os.path.getmtime(base_path))
if "timestamp" in args:
timestamp = max(int(args["timestamp"]), timestamp)
args["timestamp"] = timestamp
def wrap(model, path):
if not hasattr(model, "_trt_compiler"):
model.orig_forward = model.forward
wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args)
model._trt_compiler = wrapper
model.forward = MethodType(trt_forward, model)
def find_sub(parent, submodule):
idx = submodule.find(".")
# if there is "." in name, call recursively
if idx != -1:
parent_name = submodule[:idx]
parent = getattr(parent, parent_name)
submodule = submodule[idx + 1 :]
return find_sub(parent, submodule)
return parent, submodule
if submodule is not None:
if isinstance(submodule, str):
submodule = [submodule]
for s in submodule:
parent, sub = find_sub(model, s)
wrap(getattr(parent, sub), base_path + "." + s)
else:
wrap(model, base_path)
else:
logger = logger or get_logger("monai.networks.trt_compiler")
logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.")
return model