diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index b93d81bdef..f597221c18 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -78,15 +78,16 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=False, - use_flash_attention=use_flash_attention, - ) + if with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, + ) def forward( self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None diff --git a/tests/networks/blocks/test_transformerblock.py b/tests/networks/blocks/test_transformerblock.py index b977a38e73..ea5706f1a4 100644 --- a/tests/networks/blocks/test_transformerblock.py +++ b/tests/networks/blocks/test_transformerblock.py @@ -53,6 +53,26 @@ def test_ill_arg(self): with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + @skipUnless(has_einops, "Requires einops") + def test_no_cross_attention_params_when_disabled(self): + """When with_cross_attention=False, no cross-attention parameters should be registered.""" + block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False) + param_names = {n for n, _ in block.named_parameters()} + self.assertFalse( + any("cross_attn" in n or "norm_cross_attn" in n for n in param_names), + f"Unexpected cross-attention parameters found: {[n for n in param_names if 'cross' in n]}", + ) + + @skipUnless(has_einops, "Requires einops") + def test_cross_attention_params_when_enabled(self): + """When with_cross_attention=True, cross-attention parameters should be registered.""" + block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True) + param_names = {n for n, _ in block.named_parameters()} + self.assertTrue( + any("cross_attn" in n for n in param_names), + "Expected cross-attention parameters not found when with_cross_attention=True", + ) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format