In [1]:
from unetr_pp.network_architecture.synapse.unetr_pp_synapse import UNETR_PP


model = UNETR_PP(in_channels=1,
                 out_channels=14,
                 img_size=[64, 128, 128],
                 feature_size=16,
                 num_heads=4,
                 depths=[3, 3, 3, 3],
                 dims=[32, 64, 128, 256],
                 do_ds=True,
                 )

In [2]:
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table

input = torch.randn((1, 1, 64, 128, 128))
flops = FlopCountAnalysis(model, input)
print(flop_count_table(flops))

| module                                  | #parameters or shape   | #flops     |
|:----------------------------------------|:-----------------------|:-----------|
| model                                   | 42.953M                | 47.936G    |
|  unetr_pp_encoder                       |  27.387M               |  12.703G   |
|   unetr_pp_encoder.downsample_layers    |   0.346M               |   0.158G   |
|    unetr_pp_encoder.downsample_layers.0 |    1.088K              |    38.797M |
|    unetr_pp_encoder.downsample_layers.1 |    16.512K             |    68.42M  |
|    unetr_pp_encoder.downsample_layers.2 |    65.792K             |    33.882M |
|    unetr_pp_encoder.downsample_layers.3 |    0.263M              |    16.859M |
|   unetr_pp_encoder.stages               |   27.041M              |   12.545G  |
|    unetr_pp_encoder.stages.0            |    9.623M              |    6.943G  |
|    unetr_pp_encoder.stages.1            |    2.312M              |    3.258G  |
|    unetr_pp_en

In [1]:
from unetr_pp.network_architecture.synapse.unetr_pp_synapse import UNETR_PP


model = UNETR_PP(in_channels=1,
                 out_channels=14,
                 img_size=[64, 128, 128],
                 feature_size=16,
                 num_heads=4,
                 depths=[3, 3, 3, 3],
                 dims=[32, 64, 128, 256],
                 do_ds=True,
                 )

import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table

input = torch.randn((1, 1, 64, 128, 128))
flops = FlopCountAnalysis(model, input)
print(flop_count_table(flops))

| module                                  | #parameters or shape   | #flops     |
|:----------------------------------------|:-----------------------|:-----------|
| model                                   | 19.983M                | 46.223G    |
|  unetr_pp_encoder                       |  15.779M               |  11.838G   |
|   unetr_pp_encoder.downsample_layers    |   0.346M               |   0.158G   |
|    unetr_pp_encoder.downsample_layers.0 |    1.088K              |    38.797M |
|    unetr_pp_encoder.downsample_layers.1 |    16.512K             |    68.42M  |
|    unetr_pp_encoder.downsample_layers.2 |    65.792K             |    33.882M |
|    unetr_pp_encoder.downsample_layers.3 |    0.263M              |    16.859M |
|   unetr_pp_encoder.stages               |   15.432M              |   11.68G   |
|    unetr_pp_encoder.stages.0            |    0.186M              |    6.37G   |
|    unetr_pp_encoder.stages.1            |    0.729M              |    3.052G  |
|    unetr_pp_en

In [2]:
model

UNETR_PP(
  (unetr_pp_encoder): UnetrPPEncoder(
    (downsample_layers): ModuleList(
      (0): Sequential(
        (0): Convolution(
          (conv): Conv3d(1, 32, kernel_size=(2, 4, 4), stride=(2, 4, 4), bias=False)
        )
        (1): GroupNorm(1, 32, eps=1e-05, affine=True)
      )
      (1): Sequential(
        (0): Convolution(
          (conv): Conv3d(32, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
        )
        (1): GroupNorm(32, 64, eps=1e-05, affine=True)
      )
      (2): Sequential(
        (0): Convolution(
          (conv): Conv3d(64, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
        )
        (1): GroupNorm(64, 128, eps=1e-05, affine=True)
      )
      (3): Sequential(
        (0): Convolution(
          (conv): Conv3d(128, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
        )
        (1): GroupNorm(128, 256, eps=1e-05, affine=True)
      )
    )
    (stages): ModuleList(
      (0): Sequential(
        (0): Transformer

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

class Efficient3DHydraAttention(nn.Module):
    """
    Memory-efficient Hydra Attention for 3D medical image segmentation
    Combines multiple specialized attention heads with different computational patterns
    """
    
    def __init__(
        self,
        embed_dim: int = 768,
        num_heads: int = 12,
        patch_size: Tuple[int, int, int] = (8, 8, 8),
        window_sizes: Tuple[int, int, int] = (32, 64, 128),
        dropout: float = 0.1,
        use_flash_attention: bool = True
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.patch_size = patch_size
        self.window_sizes = window_sizes
        self.use_flash_attention = use_flash_attention
        
        # Hydra Heads - each specialized for different tasks
        self.local_head = LocalBoundaryHead(embed_dim, num_heads // 4, dropout)
        self.regional_head = RegionalContextHead(embed_dim, num_heads // 4, dropout)
        self.global_head = GlobalAnatomyHead(embed_dim, num_heads // 4, dropout)
        self.cross_slice_head = CrossSliceHead(embed_dim, num_heads // 4, dropout)
        
        # Adaptive routing - learns which tokens go to which head
        self.routing_gate = nn.Linear(embed_dim, 4)
        
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self, 
        x: torch.Tensor, 
        volume_shape: Tuple[int, int, int],
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            x: Input tensor [B, N, C] where N = D*H*W patches
            volume_shape: Original 3D volume dimensions (D, H, W)
            mask: Optional attention mask
        """
        B, N, C = x.shape
        
        # Adaptive routing - decide which tokens need which type of attention
        routing_scores = F.softmax(self.routing_gate(x), dim=-1)  # [B, N, 4]
        
        # Split routing scores for each head
        local_weight = routing_scores[:, :, 0:1]      # [B, N, 1]
        regional_weight = routing_scores[:, :, 1:2]   # [B, N, 1]
        global_weight = routing_scores[:, :, 2:3]     # [B, N, 1]
        cross_weight = routing_scores[:, :, 3:4]      # [B, N, 1]
        
        # Apply each hydra head with adaptive weighting
        local_out = self.local_head(x, volume_shape) * local_weight
        regional_out = self.regional_head(x, volume_shape) * regional_weight
        global_out = self.global_head(x, volume_shape) * global_weight
        cross_out = self.cross_slice_head(x, volume_shape) * cross_weight
        
        # Combine outputs
        hydra_out = local_out + regional_out + global_out + cross_out
        
        # Final projection
        out = self.out_proj(hydra_out)
        return self.dropout(out)


class LocalBoundaryHead(nn.Module):
    """Focuses on fine-grained boundaries and edges"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // (num_heads * 4)  # Smaller head for efficiency
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, self.head_dim * num_heads * 3)
        self.window_size = 8  # Small local window
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        D, H, W = volume_shape
        
        # Calculate patch dimensions
        patch_d, patch_h, patch_w = D // self.patch_size[0], H // self.patch_size[1], W // self.patch_size[2]
        
        # Verify dimensions match
        expected_patches = patch_d * patch_h * patch_w
        if N != expected_patches:
            raise ValueError(f"Expected {expected_patches} patches, got {N}")
        
        # Reshape to 3D grid of patches
        x_3d = x.view(B, patch_d, patch_h, patch_w, C)
        
        # Apply windowed attention - process small 2x2x2 windows of patches
        return self._windowed_attention_3d(x_3d, 2).view(B, N, -1)  # Use 2x2x2 windows of patches
    
    def _windowed_attention_3d(self, x: torch.Tensor, window_size: int) -> torch.Tensor:
        B, patch_D, patch_H, patch_W, C = x.shape
        
        # Pad if necessary to make divisible by window_size
        pad_d = (window_size - patch_D % window_size) % window_size
        pad_h = (window_size - patch_H % window_size) % window_size  
        pad_w = (window_size - patch_W % window_size) % window_size
        
        if pad_d > 0 or pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_d))
        
        patch_D, patch_H, patch_W = x.shape[1:4]
        
        # Create windows of patches
        x = x.view(B, patch_D//window_size, window_size, patch_H//window_size, window_size, 
                  patch_W//window_size, window_size, C)
        x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
        x = x.view(-1, window_size**3, C)
        
        # Apply attention within each window
        qkv = self.qkv(x).reshape(-1, window_size**3, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Efficient attention computation
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(-1, window_size**3, self.num_heads * self.head_dim)
        
        # Reshape back
        num_windows = (patch_D//window_size) * (patch_H//window_size) * (patch_W//window_size)
        out = out.view(B, num_windows, window_size**3, -1)
        out = out.view(B, patch_D//window_size, patch_H//window_size, patch_W//window_size,
                      window_size, window_size, window_size, -1)
        out = out.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
        out = out.view(B, patch_D, patch_H, patch_W, -1)
        
        # Remove padding
        original_D = patch_D - pad_d
        original_H = patch_H - pad_h  
        original_W = patch_W - pad_w
        if pad_d > 0 or pad_h > 0 or pad_w > 0:
            out = out[:, :original_D, :original_H, :original_W, :]
            
        return out


class RegionalContextHead(nn.Module):
    """Focuses on organ-level structures with medium-range attention"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // (num_heads * 4)
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, self.head_dim * num_heads * 3)
        self.stride = 2  # Strided attention for efficiency
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        # Downsample for efficiency (every 2nd token in each dimension)
        indices = self._get_strided_indices(volume_shape, self.stride)
        x_strided = x[:, indices, :]  # [B, N//8, C]
        
        # Apply standard attention on downsampled tokens
        qkv = self.qkv(x_strided).reshape(B, -1, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out_strided = (attn @ v).transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
        
        # Upsample back to original size
        out = self._upsample_to_original(out_strided, indices, N)
        return out
    
    def _get_strided_indices(self, volume_shape: Tuple[int, int, int], stride: int) -> torch.Tensor:
        D, H, W = volume_shape
        # Convert to patch grid dimensions
        patch_d, patch_h, patch_w = D // 8, H // 8, W // 8  # Assuming patch_size = (8,8,8)
        
        indices = []
        for d in range(0, patch_d, stride):
            for h in range(0, patch_h, stride):
                for w in range(0, patch_w, stride):
                    indices.append(d * patch_h * patch_w + h * patch_w + w)
        return torch.tensor(indices, dtype=torch.long)
    
    def _upsample_to_original(self, x_strided: torch.Tensor, indices: torch.Tensor, original_N: int) -> torch.Tensor:
        B, _, C = x_strided.shape
        device = x_strided.device
        
        # Create output tensor
        out = torch.zeros(B, original_N, C, device=device)
        
        # Fill strided positions
        out[:, indices, :] = x_strided
        
        # Simple interpolation for missing positions (could be more sophisticated)
        mask = torch.zeros(original_N, dtype=torch.bool, device=device)
        mask[indices] = True
        
        # Linear interpolation for empty positions
        for i in range(original_N):
            if not mask[i]:
                # Find nearest filled positions
                left_idx = indices[indices < i].max() if len(indices[indices < i]) > 0 else indices[0]
                right_idx = indices[indices > i].min() if len(indices[indices > i]) > 0 else indices[-1]
                
                # Simple average (could use distance weighting)
                out[:, i, :] = (out[:, left_idx, :] + out[:, right_idx, :]) / 2
        
        return out


class GlobalAnatomyHead(nn.Module):
    """Captures global anatomical relationships with linear attention"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // (num_heads * 4)
        
        # Linear attention for O(N) complexity
        self.to_q = nn.Linear(embed_dim, self.head_dim * num_heads)
        self.to_k = nn.Linear(embed_dim, self.head_dim * num_heads)
        self.to_v = nn.Linear(embed_dim, self.head_dim * num_heads)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        q = self.to_q(x).view(B, N, self.num_heads, self.head_dim)
        k = self.to_k(x).view(B, N, self.num_heads, self.head_dim)
        v = self.to_v(x).view(B, N, self.num_heads, self.head_dim)
        
        # Linear attention: O(N) instead of O(N²)
        # Compute k^T @ v first, then q @ (k^T @ v)
        kv = k.transpose(-2, -1) @ v  # [B, num_heads, head_dim, head_dim]
        out = q @ kv  # [B, N, num_heads, head_dim]
        
        # Normalize
        k_sum = k.sum(dim=1, keepdim=True)  # [B, 1, num_heads, head_dim]
        normalizer = q @ k_sum.transpose(-2, -1)  # [B, N, num_heads, 1]
        out = out / (normalizer + 1e-6)
        
        return out.reshape(B, N, -1)


class CrossSliceHead(nn.Module):
    """Ensures consistency across different slice orientations"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // (num_heads * 4)
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, self.head_dim * num_heads * 3)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        D, H, W = volume_shape
        
        # Convert to patch grid dimensions
        patch_d, patch_h, patch_w = D // 8, H // 8, W // 8  # Assuming patch_size = (8,8,8)
        
        # Reshape to 3D patch grid
        x_3d = x.view(B, patch_d, patch_h, patch_w, C)
        
        # Sample representative slices from each orientation
        # Axial (xy planes) - sample every few slices in depth
        axial_step = max(1, patch_d // 4)
        axial_slices = x_3d[:, ::axial_step, :, :, :].reshape(B, -1, C)
        
        # Sagittal (yz planes) - sample every few slices in height
        sagittal_step = max(1, patch_h // 4)
        sagittal_slices = x_3d[:, :, ::sagittal_step, :, :].reshape(B, -1, C)
        
        # Coronal (xz planes) - sample every few slices in width
        coronal_step = max(1, patch_w // 4)
        coronal_slices = x_3d[:, :, :, ::coronal_step, :].reshape(B, -1, C)
        
        # Combine all slice samples
        slice_tokens = torch.cat([axial_slices, sagittal_slices, coronal_slices], dim=1)
        
        # Apply attention across slice samples
        qkv = self.qkv(slice_tokens).reshape(B, -1, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        slice_out = (attn @ v).transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
        
        # Broadcast back to all positions (simplified - could be more sophisticated)
        global_context = slice_out.mean(dim=1, keepdim=True)  # [B, 1, C]
        out = global_context.expand(B, N, -1)
        
        return out


# Example usage and efficiency comparison
def compare_attention_efficiency():
    """Compare FLOP counts for different attention mechanisms"""
    
    # Typical 3D medical image parameters
    batch_size = 2
    volume_shape = (64, 64, 64)  # D, H, W
    patch_size = (8, 8, 8)
    embed_dim = 768
    
    # Calculate number of patches
    num_patches = (volume_shape[0] // patch_size[0]) * \
                  (volume_shape[1] // patch_size[1]) * \
                  (volume_shape[2] // patch_size[2])
    
    print(f"Volume shape: {volume_shape}")
    print(f"Number of patches: {num_patches}")
    print(f"Sequence length: {num_patches}")
    
    # FLOP comparison
    print("\n=== FLOP Comparison ===")
    
    # Standard attention: O(N²)
    standard_flops = 2 * num_patches * num_patches * embed_dim
    print(f"Standard Attention FLOPs: {standard_flops:,} ({standard_flops/1e9:.2f}B)")
    
    # Hydra attention breakdown:
    # Local head: windowed attention
    window_size = 8
    num_windows = num_patches // (window_size**3)
    local_flops = 2 * num_windows * (window_size**3) * (window_size**3) * (embed_dim//4)
    print(f"Local Head FLOPs: {local_flops:,} ({local_flops/1e6:.1f}M)")
    
    # Regional head: strided attention  
    strided_patches = num_patches // 8  # stride=2 in each dim
    regional_flops = 2 * strided_patches * strided_patches * (embed_dim//4)
    print(f"Regional Head FLOPs: {regional_flops:,} ({regional_flops/1e6:.1f}M)")
    
    # Global head: linear attention
    global_flops = 2 * num_patches * embed_dim * (embed_dim//4)  # O(N*d²)
    print(f"Global Head FLOPs: {global_flops:,} ({global_flops/1e6:.1f}M)")
    
    # Cross-slice head: sparse sampling
    slice_samples = 64  # Representative slices
    cross_flops = 2 * slice_samples * slice_samples * (embed_dim//4)
    print(f"Cross-slice Head FLOPs: {cross_flops:,} ({cross_flops/1e3:.1f}K)")
    
    total_hydra_flops = local_flops + regional_flops + global_flops + cross_flops
    print(f"\nTotal Hydra FLOPs: {total_hydra_flops:,} ({total_hydra_flops/1e6:.1f}M)")
    print(f"Speedup: {standard_flops/total_hydra_flops:.1f}x")
    print(f"Memory reduction: ~{standard_flops/total_hydra_flops:.1f}x")


if __name__ == "__main__":
    # Test the implementation
    compare_attention_efficiency()
    
    # Create model
    model = Efficient3DHydraAttention(
        embed_dim=768,
        num_heads=12,
        patch_size=(8, 8, 8),
        dropout=0.1
    )
    
    # Test forward pass
    batch_size = 2
    volume_shape = (64, 64, 64)
    patch_size = (8, 8, 8)
    
    # Calculate correct number of patches
    num_patches = (volume_shape[0] // patch_size[0]) * \
                  (volume_shape[1] // patch_size[1]) * \
                  (volume_shape[2] // patch_size[2])
    print(f"Calculated num_patches: {num_patches}")
    
    x = torch.randn(batch_size, num_patches, 768)
    
    with torch.no_grad():
        output = model(x, volume_shape)
        print(f"\nInput shape: {x.shape}")
        print(f"Output shape: {output.shape}")
        print("✓ Forward pass successful!")

Volume shape: (64, 64, 64)
Number of patches: 512
Sequence length: 512

=== FLOP Comparison ===
Standard Attention FLOPs: 402,653,184 (0.40B)
Local Head FLOPs: 100,663,296 (100.7M)
Regional Head FLOPs: 1,572,864 (1.6M)
Global Head FLOPs: 150,994,944 (151.0M)
Cross-slice Head FLOPs: 1,572,864 (1572.9K)

Total Hydra FLOPs: 254,803,968 (254.8M)
Speedup: 1.6x
Memory reduction: ~1.6x
Calculated num_patches: 512


AttributeError: 'LocalBoundaryHead' object has no attribute 'patch_size'