Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,20 @@ def __init__(
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False.
sequence_length (int | None, optional): sequence length required for causal masking. Defaults to None.
with_cross_attention (bool, optional): whether to include cross-attention layers that attend to an
external context tensor. When False, norm_cross_attn and cross_attn are not instantiated.
Defaults to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.

Raises:
ValueError: if dropout_rate is not in [0, 1].
ValueError: if hidden_size is not divisible by num_heads.

"""

super().__init__()
Expand Down Expand Up @@ -78,15 +87,16 @@ def __init__(
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention

self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
if with_cross_attention:
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)

def forward(
self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None
Expand Down
28 changes: 28 additions & 0 deletions tests/networks/blocks/test_transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_params_not_registered_when_disabled(self):
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False)
param_names = [name for name, _ in block.named_parameters()]
self.assertFalse(any("cross_attn" in n for n in param_names))
self.assertFalse(any("norm_cross_attn" in n for n in param_names))
self.assertFalse(hasattr(block, "cross_attn"))
self.assertFalse(hasattr(block, "norm_cross_attn"))

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_params_registered_when_enabled(self):
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True)
self.assertTrue(hasattr(block, "cross_attn"))
self.assertTrue(hasattr(block, "norm_cross_attn"))
param_names = [name for name, _ in block.named_parameters()]
self.assertTrue(any("cross_attn" in n for n in param_names))
self.assertTrue(any("norm_cross_attn" in n for n in param_names))

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_forward_with_context(self):
hidden_size = 128
block = TransformerBlock(hidden_size=hidden_size, mlp_dim=256, num_heads=4, with_cross_attention=True)
x = torch.randn(2, 16, hidden_size)
context = torch.randn(2, 8, hidden_size)
with eval_mode(block):
out = block(x, context=context)
self.assertEqual(out.shape, x.shape)

@skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
Expand Down
Loading