From 709a8493142a461613e1fe551db3e85398239c08 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 7 Apr 2022 20:34:46 +0100 Subject: [PATCH 1/3] extend mlp Signed-off-by: Wenqi Li --- monai/networks/blocks/mlp.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index a1728365cf..afb67fad51 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -9,8 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Tuple, Union + import torch.nn as nn +from monai.networks.layers import get_act_layer +from monai.utils import look_up_option + +SUPPORTED_DROPOUT_MODE = {"vit", "swin"} + class MLPBlock(nn.Module): """ @@ -18,12 +25,21 @@ class MLPBlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> None: + def __init__( + self, + hidden_size: int, + mlp_dim: int, + dropout_rate: float = 0.0, + act: Optional[Union[Tuple, str]] = "GELU", + dropout_mode="vit", + ) -> None: """ Args: hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. dropout_rate: faction of the input units to drop. + act: activation type and arguments. Defaults to GELU. + dropout_mode: dropout mode, can be "vit" or "swin". """ @@ -34,9 +50,15 @@ def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> self.linear1 = nn.Linear(hidden_size, mlp_dim) self.linear2 = nn.Linear(mlp_dim, hidden_size) - self.fn = nn.GELU() + self.fn = get_act_layer(act) self.drop1 = nn.Dropout(dropout_rate) - self.drop2 = nn.Dropout(dropout_rate) + dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) + if dropout_opt == "vit": + self.drop2 = nn.Dropout(dropout_rate) + elif dropout_opt == "swin": + self.drop2 = self.drop1 + else: + raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}") def forward(self, x): x = self.fn(self.linear1(x)) From 55703fe0183c0e766763223405e1f030e71d8dfe Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 7 Apr 2022 20:49:20 +0100 Subject: [PATCH 2/3] 0 mlp_dim Signed-off-by: Wenqi Li --- monai/networks/blocks/mlp.py | 8 ++++---- tests/test_mlp.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index afb67fad51..423d20baaa 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Tuple, Union import torch.nn as nn @@ -30,13 +30,13 @@ def __init__( hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0, - act: Optional[Union[Tuple, str]] = "GELU", + act: Union[Tuple, str] = "GELU", dropout_mode="vit", ) -> None: """ Args: hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. + mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. act: activation type and arguments. Defaults to GELU. dropout_mode: dropout mode, can be "vit" or "swin". @@ -47,7 +47,7 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") - + mlp_dim = mlp_dim or hidden_size self.linear1 = nn.Linear(hidden_size, mlp_dim) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 6fec5b6854..737762cfb1 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -21,7 +21,7 @@ TEST_CASE_MLP = [] for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [128, 256, 512, 768]: - for mlp_dim in [512, 1028, 2048, 3072]: + for mlp_dim in [0, 1028, 2048, 3072]: test_case = [ {"hidden_size": hidden_size, "mlp_dim": mlp_dim, "dropout_rate": dropout_rate}, From 92c49040bfd96ca8bade1a3d4771feca60c8bc20 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 8 Apr 2022 17:11:48 +0100 Subject: [PATCH 3/3] update based on comments Signed-off-by: Wenqi Li --- monai/networks/blocks/mlp.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index 423d20baaa..0feeb044f3 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -40,6 +40,11 @@ def __init__( dropout_rate: faction of the input units to drop. act: activation type and arguments. Defaults to GELU. dropout_mode: dropout mode, can be "vit" or "swin". + "vit" mode uses two dropout instances as implemented in + https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 + "swin" corresponds to one instance as implemented in + https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 + """