Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, Union

import torch.nn as nn

from monai.networks.layers import get_act_layer
from monai.utils import look_up_option

SUPPORTED_DROPOUT_MODE = {"vit", "swin"}


class MLPBlock(nn.Module):
"""
A multi-layer perceptron block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> None:
def __init__(
self,
hidden_size: int,
mlp_dim: int,
dropout_rate: float = 0.0,
act: Union[Tuple, str] = "GELU",
dropout_mode="vit",
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
dropout_rate: faction of the input units to drop.
act: activation type and arguments. Defaults to GELU.
dropout_mode: dropout mode, can be "vit" or "swin".
"vit" mode uses two dropout instances as implemented in
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
"swin" corresponds to one instance as implemented in
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23


"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")

mlp_dim = mlp_dim or hidden_size
self.linear1 = nn.Linear(hidden_size, mlp_dim)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = nn.GELU()
self.fn = get_act_layer(act)
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = nn.Dropout(dropout_rate)
dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
if dropout_opt == "vit":
self.drop2 = nn.Dropout(dropout_rate)
elif dropout_opt == "swin":
self.drop2 = self.drop1
else:
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")

def forward(self, x):
x = self.fn(self.linear1(x))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TEST_CASE_MLP = []
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [128, 256, 512, 768]:
for mlp_dim in [512, 1028, 2048, 3072]:
for mlp_dim in [0, 1028, 2048, 3072]:

test_case = [
{"hidden_size": hidden_size, "mlp_dim": mlp_dim, "dropout_rate": dropout_rate},
Expand Down