Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3432 make vit support torchscript #3782

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
if self.pos_embed == "perceptron" and m % p != 0:
raise ValueError("patch_size should be divisible by img_size for perceptron.")
self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
self.patch_dim = in_channels * np.prod(patch_size)
self.patch_dim = int(in_channels * np.prod(patch_size))

self.patch_embeddings: nn.Module
if self.pos_embed == "conv":
Expand Down
9 changes: 6 additions & 3 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from monai.utils import optional_import

einops, _ = optional_import("einops")
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")


class SABlock(nn.Module):
Expand Down Expand Up @@ -43,17 +43,20 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0)
self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5

def forward(self, x):
q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads)
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = einops.rearrange(x, "b h l d -> b l (h d)")
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
return x
4 changes: 2 additions & 2 deletions monai/networks/nets/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def __init__(

def forward(self, x):
x = self.patch_embedding(x)
if self.classification:
if hasattr(self, "cls_token"):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
hidden_states_out = []
for blk in self.blocks:
x = blk(x)
hidden_states_out.append(x)
x = self.norm(x)
if self.classification:
if hasattr(self, "classification_head"):
x = self.classification_head(x[:, 0])
return x, hidden_states_out
13 changes: 12 additions & 1 deletion tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks import eval_mode
from monai.networks.nets.vit import ViT
from tests.utils import test_script_save

TEST_CASE_Vit = []
for dropout_rate in [0.6]:
Expand All @@ -27,7 +28,7 @@
for mlp_dim in [3072]:
for num_layers in [4]:
for num_classes in [8]:
for pos_embed in ["conv"]:
for pos_embed in ["conv", "perceptron"]:
for classification in [False, True]:
for nd in (2, 3):
test_case = [
Expand Down Expand Up @@ -133,6 +134,16 @@ def test_ill_arg(self):
dropout_rate=0.3,
)

@parameterized.expand(TEST_CASE_Vit)
def test_script(self, input_param, input_shape, _):
net = ViT(**(input_param))
net.eval()
with torch.no_grad():
torch.jit.script(net)

test_data = torch.randn(input_shape)
test_script_save(net, test_data)


if __name__ == "__main__":
unittest.main()