diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 063a1fded1..4c7263c6d5 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -80,7 +80,7 @@ def __init__( if self.pos_embed == "perceptron" and m % p != 0: raise ValueError("patch_size should be divisible by img_size for perceptron.") self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) - self.patch_dim = in_channels * np.prod(patch_size) + self.patch_dim = int(in_channels * np.prod(patch_size)) self.patch_embeddings: nn.Module if self.pos_embed == "conv": diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index cf837c5a6f..db92111d14 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -14,7 +14,7 @@ from monai.utils import optional_import -einops, _ = optional_import("einops") +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") class SABlock(nn.Module): @@ -43,17 +43,20 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 def forward(self, x): - q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - x = einops.rearrange(x, "b h l d -> b l (h d)") + x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 62e92603ab..a5f7963eca 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -25,6 +25,8 @@ class ViT(nn.Module): """ Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + ViT supports Torchscript but only works for Pytorch after 1.8. """ def __init__( @@ -99,7 +101,7 @@ def __init__( def forward(self, x): x = self.patch_embedding(x) - if self.classification: + if hasattr(self, "cls_token"): cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) hidden_states_out = [] @@ -107,6 +109,6 @@ def forward(self, x): x = blk(x) hidden_states_out.append(x) x = self.norm(x) - if self.classification: + if hasattr(self, "classification_head"): x = self.classification_head(x[:, 0]) return x, hidden_states_out diff --git a/tests/test_vit.py b/tests/test_vit.py index 870e4010ec..d5ae209e50 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -16,6 +16,7 @@ from monai.networks import eval_mode from monai.networks.nets.vit import ViT +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASE_Vit = [] for dropout_rate in [0.6]: @@ -27,7 +28,7 @@ for mlp_dim in [3072]: for num_layers in [4]: for num_classes in [8]: - for pos_embed in ["conv"]: + for pos_embed in ["conv", "perceptron"]: for classification in [False, True]: for nd in (2, 3): test_case = [ @@ -133,6 +134,17 @@ def test_ill_arg(self): dropout_rate=0.3, ) + @parameterized.expand(TEST_CASE_Vit) + @SkipIfBeforePyTorchVersion((1, 9)) + def test_script(self, input_param, input_shape, _): + net = ViT(**(input_param)) + net.eval() + with torch.no_grad(): + torch.jit.script(net) + + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + if __name__ == "__main__": unittest.main()