-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathutils.py
1412 lines (1187 loc) · 56.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities and types for defining networks, these depend on PyTorch.
"""
from __future__ import annotations
import io
import re
import tempfile
import warnings
from collections import OrderedDict
from collections.abc import Callable, Mapping, Sequence
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Iterable
import numpy as np
import torch
import torch.nn as nn
from monai.apps.utils import get_logger
from monai.config import PathLike
from monai.utils.misc import ensure_tuple, save_obj, set_determinism
from monai.utils.module import look_up_option, optional_import
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
onnx, _ = optional_import("onnx")
onnxreference, _ = optional_import("onnx.reference")
onnxruntime, _ = optional_import("onnxruntime")
polygraphy, polygraphy_imported = optional_import("polygraphy")
torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
__all__ = [
"one_hot",
"predict_segmentation",
"normalize_transform",
"to_norm_affine",
"CastTempType",
"normal_init",
"icnr_init",
"pixelshuffle",
"pixelunshuffle",
"eval_mode",
"train_mode",
"get_state_dict",
"copy_model_state",
"save_state",
"convert_to_onnx",
"convert_to_torchscript",
"convert_to_trt",
"meshgrid_ij",
"meshgrid_xy",
"replace_modules",
"replace_modules_temp",
"look_up_named_module",
"set_named_module",
"has_nvfuser_instance_norm",
"get_profile_shapes",
]
logger = get_logger(module_name=__name__)
_has_nvfuser = None
def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None):
"""
Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.
"""
def scale_batch_size(input_shape: Sequence[int], scale_num: int):
scale_shape = [*input_shape]
scale_shape[0] = scale_num
return scale_shape
# Use the dynamic batchsize range to generate the min, opt and max model input shape
if dynamic_batchsize:
min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
else:
min_input_shape = opt_input_shape = max_input_shape = input_shape
return min_input_shape, opt_input_shape, max_input_shape
def has_nvfuser_instance_norm():
"""whether the current environment has InstanceNorm3dNVFuser
https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16
"""
global _has_nvfuser
if _has_nvfuser is not None:
return _has_nvfuser
_, _has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")
if not _has_nvfuser:
return False
try:
import importlib
importlib.import_module("instance_norm_nvfuser_cuda")
except ImportError:
_has_nvfuser = False
return _has_nvfuser
def look_up_named_module(name: str, mod, print_all_options=False):
"""
get the named module in `mod` by the attribute name,
for example ``look_up_named_module(net, "features.3.1.attn")``
Args:
name: a string representing the module attribute.
mod: a pytorch module to be searched (in ``mod.named_modules()``).
print_all_options: whether to print all named modules when `name` is not found in `mod`. Defaults to False.
Returns:
the corresponding pytorch module's subcomponent such as ``net.features[3][1].attn``
"""
name_str = look_up_option(
name, {n[0] for n in mod.named_modules()}, default=None, print_all_options=print_all_options
)
if name_str is None:
return None
if name_str == "":
return mod
for n in name_str.split("."):
if n.isdigit():
mod = mod[int(n)]
else:
n = look_up_option(n, {item[0] for item in mod.named_modules()}, default=None, print_all_options=False)
if n is None:
return None
mod = getattr(mod, n)
return mod
def set_named_module(mod, name: str, new_layer):
"""
look up `name` in `mod` and replace the layer with `new_layer`, return the updated `mod`.
Args:
mod: a pytorch module to be updated.
name: a string representing the target module attribute.
new_layer: a new module replacing the corresponding layer at ``mod.name``.
Returns:
an updated ``mod``
See also: :py:func:`monai.networks.utils.look_up_named_module`.
"""
mods_attr = name.rsplit(".", 1)
submods, attr = mods_attr if len(mods_attr) == 2 else ("", name)
if not attr:
return new_layer
_mod = look_up_named_module(submods, mod)
setattr(_mod, attr, new_layer)
return mod
def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
"""
For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th
dimension has the "one-hot" format, i.e., it has a total length of `num_classes`,
with a one and `num_class-1` zeros.
Note that this will include the background label, thus a binary mask should be treated as having two classes.
Args:
labels: input tensor of integers to be converted into the 'one-hot' format. Internally `labels` will be
converted into integers `labels.long()`.
num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to
`num_classes` from `1`.
dtype: the data type of the output one_hot label.
dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number.
Example:
For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]`
when `num_classes=N` number of classes and `dim=1`.
.. code-block:: python
from monai.networks.utils import one_hot
import torch
a = torch.randint(0, 2, size=(1, 2, 2, 2))
out = one_hot(a, num_classes=2, dim=0)
print(out.shape) # torch.Size([2, 2, 2, 2])
a = torch.randint(0, 2, size=(2, 1, 2, 2, 2))
out = one_hot(a, num_classes=2, dim=1)
print(out.shape) # torch.Size([2, 2, 2, 2, 2])
"""
# if `dim` is bigger, add singleton dim at the end
if labels.ndim < dim + 1:
shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape))
labels = torch.reshape(labels, shape)
sh = list(labels.shape)
if sh[dim] != 1:
raise AssertionError("labels should have a channel with length equal to one.")
sh[dim] = num_classes
o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
labels = o.scatter_(dim=dim, index=labels.long(), value=1)
return labels
def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False, threshold: float = 0.0) -> Any:
"""
Given the logits from a network, computing the segmentation by thresholding all values above 0
if multi-labels task, computing the `argmax` along the channel axis if multi-classes task,
logits has shape `BCHW[D]`.
Args:
logits: raw data of model output.
mutually_exclusive: if True, `logits` will be converted into a binary matrix using
a combination of argmax, which is suitable for multi-classes task. Defaults to False.
threshold: thresholding the prediction values if multi-labels task.
"""
if not mutually_exclusive:
return (logits >= threshold).int()
if logits.shape[1] == 1:
warnings.warn("single channel prediction, `mutually_exclusive=True` ignored, use threshold instead.")
return (logits >= threshold).int()
return logits.argmax(1, keepdim=True)
def normalize_transform(
shape,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
align_corners: bool = False,
zero_centered: bool = False,
) -> torch.Tensor:
"""
Compute an affine matrix according to the input shape.
The transform normalizes the homogeneous image coordinates to the
range of `[-1, 1]`. Currently the following source coordinates are supported:
- `align_corners=False`, `zero_centered=False`, normalizing from ``[-0.5, d-0.5]``.
- `align_corners=True`, `zero_centered=False`, normalizing from ``[0, d-1]``.
- `align_corners=False`, `zero_centered=True`, normalizing from ``[-(d-1)/2, (d-1)/2]``.
- `align_corners=True`, `zero_centered=True`, normalizing from ``[-d/2, d/2]``.
Args:
shape: input spatial shape, a sequence of integers.
device: device on which the returned affine will be allocated.
dtype: data type of the returned affine
align_corners: if True, consider -1 and 1 to refer to the centers of the
corner pixels rather than the image corners.
See also: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample
zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.
Setting this flag and `align_corners` will jointly specify the normalization source range.
"""
shape = convert_to_tensor(shape, torch.float64, device=device, wrap_sequence=True, track_meta=False)
norm = shape.clone().detach().to(dtype=torch.float64, device=device) # no in-place change
if align_corners:
norm[norm <= 1.0] = 2.0
norm = 2.0 / (norm if zero_centered else norm - 1.0)
norm = torch.diag(torch.cat((norm, torch.ones((1,), dtype=torch.float64, device=device))))
if not zero_centered: # else shift is 0
norm[:-1, -1] = -1.0
else:
norm[norm <= 0.0] = 2.0
norm = 2.0 / (norm - 1.0 if zero_centered else norm)
norm = torch.diag(torch.cat((norm, torch.ones((1,), dtype=torch.float64, device=device))))
if not zero_centered:
norm[:-1, -1] = 1.0 / shape - 1.0
norm = norm.unsqueeze(0).to(dtype=dtype)
norm.requires_grad = False
return norm # type: ignore
def to_norm_affine(
affine: torch.Tensor,
src_size: Sequence[int],
dst_size: Sequence[int],
align_corners: bool = False,
zero_centered: bool = False,
) -> torch.Tensor:
"""
Given ``affine`` defined for coordinates in the pixel space, compute the corresponding affine
for the normalized coordinates.
Args:
affine: Nxdxd batched square matrix
src_size: source image spatial shape
dst_size: target image spatial shape
align_corners: if True, consider -1 and 1 to refer to the centers of the
corner pixels rather than the image corners.
See also: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample
zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.
See also: :py:func:`monai.networks.utils.normalize_transform`.
Raises:
TypeError: When ``affine`` is not a ``torch.Tensor``.
ValueError: When ``affine`` is not Nxdxd.
ValueError: When ``src_size`` or ``dst_size`` dimensions differ from ``affine``.
"""
if not isinstance(affine, torch.Tensor):
raise TypeError(f"affine must be a torch.Tensor but is {type(affine).__name__}.")
if affine.ndimension() != 3 or affine.shape[1] != affine.shape[2]:
raise ValueError(f"affine must be Nxdxd, got {tuple(affine.shape)}.")
sr = affine.shape[1] - 1
if sr != len(src_size) or sr != len(dst_size):
raise ValueError(f"affine suggests {sr}D, got src={len(src_size)}D, dst={len(dst_size)}D.")
src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners, zero_centered)
dst_xform = normalize_transform(dst_size, "cpu", affine.dtype, align_corners, zero_centered)
return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0] # monai#5983
def normal_init(
m, std: float = 0.02, normal_func: Callable[[torch.Tensor, float, float], Any] = torch.nn.init.normal_
) -> None:
"""
Initialize the weight and bias tensors of `m' and its submodules to values from a normal distribution with a
stddev of `std'. Weight tensors of convolution and linear modules are initialized with a mean of 0, batch
norm modules with a mean of 1. The callable `normal_func', used to assign values, should have the same arguments
as its default normal_(). This can be used with `nn.Module.apply` to visit submodules of a network.
"""
cname = m.__class__.__name__
if getattr(m, "weight", None) is not None and (cname.find("Conv") != -1 or cname.find("Linear") != -1):
normal_func(m.weight.data, 0.0, std)
if getattr(m, "bias", None) is not None:
nn.init.constant_(m.bias.data, 0.0)
elif cname.find("BatchNorm") != -1:
normal_func(m.weight.data, 1.0, std)
nn.init.constant_(m.bias.data, 0)
def icnr_init(conv, upsample_factor, init=nn.init.kaiming_normal_):
"""
ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , "Checkerboard artifact free
sub-pixel convolution".
"""
out_channels, in_channels, *dims = conv.weight.shape
scale_factor = upsample_factor ** len(dims)
oc2 = int(out_channels / scale_factor)
kernel = torch.zeros([oc2, in_channels] + dims)
kernel = init(kernel)
kernel = kernel.transpose(0, 1)
kernel = kernel.reshape(oc2, in_channels, -1)
kernel = kernel.repeat(1, 1, scale_factor)
kernel = kernel.reshape([in_channels, out_channels] + dims)
kernel = kernel.transpose(0, 1)
conv.weight.data.copy_(kernel)
def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor:
"""
Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.
See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
Using a nEfficient Sub-Pixel Convolutional Neural Network."
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
Args:
x: Input tensor with shape BCHW[D]
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
scale_factor: factor to rescale the spatial dimensions by, must be >=1
Returns:
Reshuffled version of `x`.
Raises:
ValueError: When input channels of `x` are not divisible by (scale_factor ** spatial_dims)
"""
dim, factor = spatial_dims, scale_factor
input_size = list(x.size())
batch_size, channels = input_size[:2]
scale_divisor = factor**dim
if channels % scale_divisor != 0:
raise ValueError(
f"Number of input channels ({channels}) must be evenly "
f"divisible by scale_factor ** dimensions ({factor}**{dim}={scale_divisor})."
)
org_channels = int(channels // scale_divisor)
output_size = [batch_size, org_channels] + [d * factor for d in input_size[2:]]
indices = list(range(2, 2 + 2 * dim))
indices = indices[dim:] + indices[:dim]
permute_indices = [0, 1]
for idx in range(dim):
permute_indices.extend(indices[idx::dim])
x = x.reshape([batch_size, org_channels] + [factor] * dim + input_size[2:])
x = x.permute(permute_indices).reshape(output_size)
return x
def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor:
"""
Apply pixel unshuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.
Inverse operation of pixelshuffle.
See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
Using an Efficient Sub-Pixel Convolutional Neural Network."
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
Args:
x: Input tensor with shape BCHW[D]
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
scale_factor: factor to reduce the spatial dimensions by, must be >=1
Returns:
Unshuffled version of `x` with shape (B, C*(r**d), H/r, W/r) for 2D
or (B, C*(r**d), D/r, H/r, W/r) for 3D, where r is the scale_factor
and d is spatial_dims.
Raises:
ValueError: When spatial dimensions are not divisible by scale_factor
"""
dim, factor = spatial_dims, scale_factor
input_size = list(x.size())
batch_size, channels = input_size[:2]
scale_factor_mult = factor**dim
new_channels = channels * scale_factor_mult
if any(d % factor != 0 for d in input_size[2:]):
raise ValueError(
f"All spatial dimensions must be divisible by factor {factor}. " f", spatial shape is: {input_size[2:]}"
)
output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]]
reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], [])
permute_indices = [0, 1] + [(2 * i + 3) for i in range(spatial_dims)] + [(2 * i + 2) for i in range(spatial_dims)]
x = x.reshape(reshaped_size).permute(permute_indices)
x = x.reshape(output_size)
return x
@contextmanager
def eval_mode(*nets: nn.Module):
"""
Set network(s) to eval mode and then return to original state at the end.
Args:
nets: Input network(s)
Examples
.. code-block:: python
t=torch.rand(1,1,16,16)
p=torch.nn.Conv2d(1,1,3)
print(p.training) # True
with eval_mode(p):
print(p.training) # False
print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated
"""
# Get original state of network(s).
# Check the training attribute in case it's TensorRT based models which don't have this attribute.
training = [n for n in nets if hasattr(n, "training") and n.training]
try:
# set to eval mode
with torch.no_grad():
yield [n.eval() if hasattr(n, "eval") else n for n in nets]
finally:
# Return required networks to training
for n in training:
if hasattr(n, "train"):
n.train()
@contextmanager
def train_mode(*nets: nn.Module):
"""
Set network(s) to train mode and then return to original state at the end.
Args:
nets: Input network(s)
Examples
.. code-block:: python
t=torch.rand(1,1,16,16)
p=torch.nn.Conv2d(1,1,3)
p.eval()
print(p.training) # False
with train_mode(p):
print(p.training) # True
print(p(t).sum().backward()) # No exception
"""
# Get original state of network(s)
# Check the training attribute in case it's TensorRT based models which don't have this attribute.
eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)]
try:
# set to train mode
with torch.set_grad_enabled(True):
yield [n.train() if hasattr(n, "train") else n for n in nets]
finally:
# Return required networks to eval_list
for n in eval_list:
if hasattr(n, "eval"):
n.eval()
def get_state_dict(obj: torch.nn.Module | Mapping):
"""
Get the state dict of input object if has `state_dict`, otherwise, return object directly.
For data parallel model, automatically convert it to regular model first.
Args:
obj: input object to check and get the state_dict.
"""
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
return obj.state_dict() if hasattr(obj, "state_dict") else obj
def copy_model_state(
dst: torch.nn.Module | Mapping,
src: torch.nn.Module | Mapping,
dst_prefix="",
mapping=None,
exclude_vars=None,
inplace=True,
filter_func=None,
):
"""
Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten
by the ones from `src` whenever their keys match. The method provides additional `dst_prefix` for
the `dst` key when matching them. `mapping` can be a `{"src_key": "dst_key"}` dict, indicating
`dst[dst_prefix + dst_key] = src[src_key]`.
This function is mainly to return a model state dict
for loading the `src` model state into the `dst` model, `src` and `dst` can have different dict keys, but
their corresponding values normally have the same shape.
Args:
dst: a pytorch module or state dict to be updated.
src: a pytorch module or state dict used to get the values used for the update.
dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]`
will be assigned to the value of `src[src_key]`.
mapping: a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]`
to be assigned to the value of `src[src_key]`.
exclude_vars: a regular expression to match the `dst` variable names,
so that their values are not overwritten by `src`.
inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`.
This option is only available when `dst` is a `torch.nn.Module`.
filter_func: a filter function used to filter the weights to be loaded.
See 'filter_swinunetr' in "monai.networks.nets.swin_unetr.py".
Examples:
.. code-block:: python
from monai.networks.nets import BasicUNet
from monai.networks.utils import copy_model_state
model_a = BasicUNet(in_channels=1, out_channels=4)
model_b = BasicUNet(in_channels=1, out_channels=2)
model_a_b, changed, unchanged = copy_model_state(
model_a, model_b, exclude_vars="conv_0.conv_0", inplace=False)
# dst model updated: 76 of 82 variables.
model_a.load_state_dict(model_a_b)
# <All keys matched successfully>
Returns: an OrderedDict of the updated `dst` state, the changed, and unchanged keys.
"""
src_dict = get_state_dict(src)
dst_dict = OrderedDict(get_state_dict(dst))
to_skip = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)}
# update dst with items from src
all_keys, updated_keys = list(dst_dict), list()
for s, val in src_dict.items():
dst_key = f"{dst_prefix}{s}"
if dst_key in dst_dict and dst_key not in to_skip and dst_dict[dst_key].shape == val.shape:
dst_dict[dst_key] = val
updated_keys.append(dst_key)
for s in mapping if mapping else {}:
dst_key = f"{dst_prefix}{mapping[s]}"
if dst_key in dst_dict and dst_key not in to_skip:
if dst_dict[dst_key].shape != src_dict[s].shape:
warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.")
dst_dict[dst_key] = src_dict[s]
updated_keys.append(dst_key)
if filter_func is not None:
for key, value in src_dict.items():
new_pair = filter_func(key, value)
if new_pair is not None and new_pair[0] not in to_skip:
dst_dict[new_pair[0]] = new_pair[1]
updated_keys.append(new_pair[0])
updated_keys = sorted(set(updated_keys))
unchanged_keys = sorted(set(all_keys).difference(updated_keys))
logger.info(f"'dst' model updated: {len(updated_keys)} of {len(dst_dict)} variables.")
if inplace and isinstance(dst, torch.nn.Module):
if isinstance(dst, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
dst = dst.module
dst.load_state_dict(dst_dict) # type: ignore
return dst_dict, updated_keys, unchanged_keys
def save_state(src: torch.nn.Module | dict, path: PathLike, **kwargs):
"""
Save the state dict of input source data with PyTorch `save`.
It can save `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`.
And automatically convert the data parallel module to regular module.
For example::
save_state(net, path)
save_state(net.state_dict(), path)
save_state({"net": net, "opt": opt}, path)
net_dp = torch.nn.DataParallel(net)
save_state(net_dp, path)
Refer to: https://pytorch.org/ignite/v0.4.8/generated/ignite.handlers.DiskSaver.html.
Args:
src: input data to save, can be `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`.
path: target file path to save the input object.
kwargs: other args for the `save_obj` except for the `obj` and `path`.
default `func` is `torch.save()`, details of the args:
https://pytorch.org/docs/stable/generated/torch.save.html.
"""
ckpt: dict = {}
if isinstance(src, dict):
for k, v in src.items():
ckpt[k] = get_state_dict(v)
else:
ckpt = get_state_dict(src)
save_obj(obj=ckpt, path=path, **kwargs)
def convert_to_onnx(
model: nn.Module,
inputs: Sequence[Any],
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
opset_version: int | None = None,
dynamic_axes: Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None = None,
filename: Any | None = None,
verify: bool = False,
device: torch.device | None = None,
use_ort: bool = False,
ort_provider: Sequence[str] | None = None,
rtol: float = 1e-4,
atol: float = 0.0,
use_trace: bool = True,
do_constant_folding: bool = True,
constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
**kwargs,
):
"""
Utility to convert a model into ONNX model and optionally verify with ONNX or onnxruntime.
See also: https://pytorch.org/docs/stable/onnx.html for how to convert a PyTorch model to ONNX.
Args:
model: source PyTorch model to save.
inputs: input sample data used by pytorch.onnx.export. It is also used in ONNX model verification.
input_names: optional input names of the ONNX model.
output_names: optional output names of the ONNX model.
opset_version: version of the (ai.onnx) opset to target. Must be >= 7 and not exceed
the latest opset version supported by PyTorch, for more details:
https://github.com/onnx/onnx/blob/main/docs/Operators.md and
https://github.com/pytorch/pytorch/blob/master/torch/onnx/_constants.py
dynamic_axes: specifies axes of tensors as dynamic (i.e. known only at run-time). If set to None,
the exported model will have the shapes of all input and output tensors set to match given
ones, for more details: https://pytorch.org/docs/stable/onnx.html#torch.onnx.export.
filename: optional filename to save the ONNX model, if None, don't save the ONNX model.
verify: whether to verify the ONNX model with ONNX or onnxruntime.
device: target PyTorch device to verify the model, if None, use CUDA if available.
use_ort: whether to use onnxruntime to verify the model.
ort_provider": onnxruntime provider to use, default is ["CPUExecutionProvider"].
rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
use_trace: whether to use `torch.jit.trace` to export the torchscript model.
do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done.
constant_size_threshold: passed to polygrapy conatant forling, default = 16M
kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export()
else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
https://pytorch.org/docs/master/generated/torch.jit.script.html.
"""
model.eval()
with torch.no_grad():
torch_versioned_kwargs = {}
if use_trace:
# let torch.onnx.export to trace the model.
mode_to_export = model
torch_versioned_kwargs = kwargs
if "dynamo" in kwargs and kwargs["dynamo"] and verify:
torch_versioned_kwargs["verify"] = verify
verify = False
else:
mode_to_export = torch.jit.script(model, **kwargs)
if torch.is_tensor(inputs) or isinstance(inputs, dict):
onnx_inputs = (inputs,)
else:
onnx_inputs = tuple(inputs)
temp_file = None
if filename is None:
temp_file = tempfile.NamedTemporaryFile()
f = temp_file.name
else:
f = filename
print(f"torch_versioned_kwargs={torch_versioned_kwargs}")
torch.onnx.export(
mode_to_export,
onnx_inputs,
f=f,
input_names=input_names,
output_names=output_names or None,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
**torch_versioned_kwargs,
)
onnx_model = onnx.load(f)
if do_constant_folding and polygraphy_imported:
from polygraphy.backend.onnx.loader import fold_constants, save_onnx
onnx_model = fold_constants(onnx_model, size_threshold=constant_size_threshold)
save_onnx(onnx_model, f)
if verify:
if isinstance(inputs, dict):
inputs = list(inputs.values())
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs]
model = model.to(device)
with torch.no_grad():
set_determinism(seed=0)
torch_out = ensure_tuple(model(*inputs), True)
set_determinism(seed=0)
model_input_names = [i.name for i in onnx_model.graph.input]
input_dict = dict(zip(model_input_names, [i.cpu().numpy() for i in inputs]))
if use_ort:
ort_sess = onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=ort_provider if ort_provider else ["CPUExecutionProvider"]
)
onnx_out = ort_sess.run(None, input_dict)
else:
sess = onnxreference.ReferenceEvaluator(onnx_model)
onnx_out = sess.run(None, input_dict)
set_determinism(seed=None)
# compare onnx/ort and PyTorch results
for r1, r2 in zip(torch_out, onnx_out):
if isinstance(r1, torch.Tensor):
torch.testing.assert_close(r1.cpu(), convert_to_tensor(r2, dtype=r1.dtype), rtol=rtol, atol=atol) # type: ignore
return onnx_model
def convert_to_torchscript(
model: nn.Module,
filename_or_obj: Any | None = None,
extra_files: dict | None = None,
verify: bool = False,
inputs: Sequence[Any] | None = None,
device: torch.device | None = None,
rtol: float = 1e-4,
atol: float = 0.0,
use_trace: bool = False,
**kwargs,
):
"""
Utility to convert a model into TorchScript model and save to file,
with optional input / output data verification.
Args:
model: source PyTorch model to save.
filename_or_obj: if not None, specify a file-like object (has to implement write and flush)
or a string containing a file path name to save the TorchScript model.
extra_files: map from filename to contents which will be stored as part of the save model file.
for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html.
verify: whether to verify the input and output of TorchScript model.
if `filename_or_obj` is not None, load the saved TorchScript model and verify.
inputs: input test data to verify model, should be a sequence of data, every item maps to a argument
of `model()` function.
device: target device to verify the model, if None, use CUDA if available.
rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
use_trace: whether to use `torch.jit.trace` to export the TorchScript model.
kwargs: other arguments except `obj` for `torch.jit.script()` or `torch.jit.trace()` (if use_trace is True)
to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html.
"""
model.eval()
with torch.no_grad():
if use_trace:
if inputs is None:
raise ValueError("Missing input data for tracing convert.")
script_module = torch.jit.trace(model, example_inputs=inputs, **kwargs)
else:
script_module = torch.jit.script(model, **kwargs)
if filename_or_obj is not None:
torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files)
if verify:
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if inputs is None:
raise ValueError("Missing input data for verification.")
inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs]
ts_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else script_module
ts_model.eval().to(device)
model = model.to(device)
with torch.no_grad():
set_determinism(seed=0)
torch_out = ensure_tuple(model(*inputs))
set_determinism(seed=0)
torchscript_out = ensure_tuple(ts_model(*inputs))
set_determinism(seed=None)
# compare TorchScript and PyTorch results
for r1, r2 in zip(torch_out, torchscript_out):
if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore
return script_module
def _onnx_trt_compile(
onnx_model,
min_shape: Sequence[int],
opt_shape: Sequence[int],
max_shape: Sequence[int],
device: int,
precision: str,
input_names: Sequence[str] | None,
output_names: Sequence[str] | None,
):
"""
This function takes an ONNX model as input, exports it to a TensorRT engine, wraps the TensorRT engine
to a TensorRT engine-based TorchScript model and return the TorchScript model.
Args:
onnx_model: the source ONNX model to compile.
min_shape: the minimum input shape of the converted TensorRT model.
opt_shape: the optimization input shape of the model, on which the TensorRT optimizes.
max_shape: the maximum input shape of the converted TensorRT model.
device: the target GPU index to convert and verify the model.
precision: the weight precision of the converted TensorRT engine-based TorchScript model.
Should be 'fp32' or 'fp16'.
input_names: optional input names of the ONNX model. Should be a sequence like
`['input_0', 'input_1', ..., 'input_N']` where N equals to the number of the
model inputs.
output_names: optional output names of the ONNX model. Should be a sequence like
`['output_0', 'output_1', ..., 'output_N']` where N equals to the number of
the model outputs.
"""
trt, _ = optional_import("tensorrt", "8.5.3")
input_shapes = (min_shape, opt_shape, max_shape)
# default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function.
input_names = [] if not input_names else input_names
output_names = [] if not output_names else output_names
# set up the TensorRT builder
torch.cuda.set_device(device)
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
if input_names:
profile.set_shape(input_names[0], *input_shapes)
# parse the ONNX model
parser = trt.OnnxParser(network, logger)
success = parser.parse(onnx_model.SerializeToString())
if not success:
parser_error_message = ""
for idx in range(parser.num_errors):
parser_error_message += parser.get_error(idx).desc() + "\n"
raise Exception(f"TensorRT cannot parse the ONNX model, due to:\n{parser_error_message}")
# set up the conversion configuration
config = builder.create_builder_config()
config.add_optimization_profile(profile)
if precision == "fp16":
config.set_flag(trt.BuilderFlag.FP16)
serialized_engine = builder.build_serialized_network(network, config)
f = io.BytesIO()
f.write(serialized_engine)
# wrap the serialized TensorRT engine back to a TorchScript module.
trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
f.getvalue(),
device=torch_tensorrt.Device(f"cuda:{device}"),
input_binding_names=input_names,
output_binding_names=output_names,
)
return trt_model
def convert_to_trt(
model: nn.Module,
precision: str,
input_shape: Sequence[int],
dynamic_batchsize: Sequence[int] | None = None,
use_trace: bool = False,
filename_or_obj: Any | None = None,
verify: bool = False,
device: int | None = None,
use_onnx: bool | None = False,
onnx_input_names: Sequence[str] | None = ("input_0",),
onnx_output_names: Sequence[str] | None = ("output_0",),
rtol: float = 1e-2,
atol: float = 0.0,
**kwargs,
):
"""
Utility to export a model into a TensorRT engine-based TorchScript model with optional input / output data verification.
There are two ways to export a model:
1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->
TensorRT engine-based TorchScript.
When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT
may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through
the second way, some Python data structures like `dict` are not supported. And some TorchScript models are
not supported by the ONNX if exported through `torch.jit.script`.
Args:
model: a source PyTorch model to convert.
precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'.
input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or
[N, C, H, W, D].
dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be
converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. After converted, the batchsize of model
input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize that the
TensorRT tries to fit. The `OPT_BATCH` should be the most frequently used input batchsize in the application,
default to None.
use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to
a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True), default to False.
filename_or_obj: if not None, specify a file-like object (has to implement write and flush) or a string containing a
file path name to load the TensorRT engine based TorchScript model for verifying.
verify: whether to verify the input and output of the TensorRT engine based TorchScript model.
device: the target GPU index to convert and verify the model. If None, use #0 GPU.
use_onnx: whether to use the ONNX-TensorRT way to export the TensorRT engine-based TorchScript model.
onnx_input_names: optional input names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be
a sequence like `('input_0', 'input_1', ..., 'input_N')` where N equals to the number of the model inputs. If not
given, will use `('input_0',)`, which supposes the model only has one input.
onnx_output_names: optional output names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be
a sequence like `('output_0', 'output_1', ..., 'output_N')` where N equals to the number of the model outputs. If
not given, will use `('output_0',)`, which supposes the model only has one output.
rtol: the relative tolerance when comparing the outputs between the PyTorch model and TensorRT model.
atol: the absolute tolerance when comparing the outputs between the PyTorch model and TensorRT model.
kwargs: other arguments except `module`, `inputs`, `enabled_precisions` and `device` for `torch_tensorrt.compile()`
to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.
"""
if not torch.cuda.is_available():
raise Exception("Cannot find any GPU devices.")
if not input_shape:
raise ValueError("Missing the input shape for model convert.")
if not dynamic_batchsize:
warnings.warn(f"There is no dynamic batch range. The converted model only takes {input_shape} shape input.")
if (dynamic_batchsize is not None) and (len(dynamic_batchsize) != 3):
warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.")