Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ class SABlock(nn.Module):
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) -> None:
def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: bias term for the qkv linear layer.

"""

Expand All @@ -42,7 +43,7 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0)

self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
Expand Down
7 changes: 5 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ class TransformerBlock(nn.Module):
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0) -> None:
def __init__(
self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: apply bias term for the qkv linear layer

"""

Expand All @@ -41,7 +44,7 @@ def __init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate:

self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
self.norm1 = nn.LayerNorm(hidden_size)
self.attn = SABlock(hidden_size, num_heads, dropout_rate)
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
self.norm2 = nn.LayerNorm(hidden_size)

def forward(self, x):
Expand Down
3 changes: 3 additions & 0 deletions monai/networks/nets/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
res_block: bool = True,
dropout_rate: float = 0.0,
spatial_dims: int = 3,
qkv_bias: bool = False,
) -> None:
"""
Args:
Expand All @@ -56,6 +57,7 @@ def __init__(
res_block: bool argument to determine if residual block is used.
dropout_rate: faction of the input units to drop.
spatial_dims: number of spatial dims.
qkv_bias: apply the bias term for the qkv linear layer in self attention block

Examples::

Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(
classification=self.classification,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
qkv_bias=qkv_bias,
)
self.encoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
Expand Down
4 changes: 3 additions & 1 deletion monai/networks/nets/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
dropout_rate: float = 0.0,
spatial_dims: int = 3,
post_activation="Tanh",
qkv_bias: bool = False,
) -> None:
"""
Args:
Expand All @@ -61,6 +62,7 @@ def __init__(
spatial_dims: number of spatial dimensions.
post_activation: add a final acivation function to the classification head when `classification` is True.
Default to "Tanh" for `nn.Tanh()`. Set to other values to remove this function.
qkv_bias: apply bias to the qkv linear layer in self attention block

Examples::

Expand Down Expand Up @@ -95,7 +97,7 @@ def __init__(
spatial_dims=spatial_dims,
)
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias) for i in range(num_layers)]
)
self.norm = nn.LayerNorm(hidden_size)
if self.classification:
Expand Down