In [1]:
from unetr_pp.network_architecture.synapse.unetr_pp_synapse import UNETR_PP


model = UNETR_PP(in_channels=1,
                 out_channels=14,
                 img_size=[64, 128, 128],
                 feature_size=16,
                 num_heads=4,
                 depths=[3, 3, 3, 3],
                 dims=[32, 64, 128, 256],
                 do_ds=True,
                 )

In [2]:
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table

input = torch.randn((1, 1, 64, 128, 128))
flops = FlopCountAnalysis(model, input)
print(flop_count_table(flops))

| module                                  | #parameters or shape   | #flops     |
|:----------------------------------------|:-----------------------|:-----------|
| model                                   | 42.953M                | 47.936G    |
|  unetr_pp_encoder                       |  27.387M               |  12.703G   |
|   unetr_pp_encoder.downsample_layers    |   0.346M               |   0.158G   |
|    unetr_pp_encoder.downsample_layers.0 |    1.088K              |    38.797M |
|    unetr_pp_encoder.downsample_layers.1 |    16.512K             |    68.42M  |
|    unetr_pp_encoder.downsample_layers.2 |    65.792K             |    33.882M |
|    unetr_pp_encoder.downsample_layers.3 |    0.263M              |    16.859M |
|   unetr_pp_encoder.stages               |   27.041M              |   12.545G  |
|    unetr_pp_encoder.stages.0            |    9.623M              |    6.943G  |
|    unetr_pp_encoder.stages.1            |    2.312M              |    3.258G  |
|    unetr_pp_en

In [1]:
from unetr_pp.network_architecture.synapse.unetr_pp_synapse import UNETR_PP


model = UNETR_PP(in_channels=1,
                 out_channels=14,
                 img_size=[64, 128, 128],
                 feature_size=16,
                 num_heads=4,
                 depths=[3, 3, 3, 3],
                 dims=[32, 64, 128, 256],
                 do_ds=True,
                 )

import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table

input = torch.randn((1, 1, 64, 128, 128))
flops = FlopCountAnalysis(model, input)
print(flop_count_table(flops))

| module                                  | #parameters or shape   | #flops     |
|:----------------------------------------|:-----------------------|:-----------|
| model                                   | 19.983M                | 46.223G    |
|  unetr_pp_encoder                       |  15.779M               |  11.838G   |
|   unetr_pp_encoder.downsample_layers    |   0.346M               |   0.158G   |
|    unetr_pp_encoder.downsample_layers.0 |    1.088K              |    38.797M |
|    unetr_pp_encoder.downsample_layers.1 |    16.512K             |    68.42M  |
|    unetr_pp_encoder.downsample_layers.2 |    65.792K             |    33.882M |
|    unetr_pp_encoder.downsample_layers.3 |    0.263M              |    16.859M |
|   unetr_pp_encoder.stages               |   15.432M              |   11.68G   |
|    unetr_pp_encoder.stages.0            |    0.186M              |    6.37G   |
|    unetr_pp_encoder.stages.1            |    0.729M              |    3.052G  |
|    unetr_pp_en

In [2]:
model

UNETR_PP(
  (unetr_pp_encoder): UnetrPPEncoder(
    (downsample_layers): ModuleList(
      (0): Sequential(
        (0): Convolution(
          (conv): Conv3d(1, 32, kernel_size=(2, 4, 4), stride=(2, 4, 4), bias=False)
        )
        (1): GroupNorm(1, 32, eps=1e-05, affine=True)
      )
      (1): Sequential(
        (0): Convolution(
          (conv): Conv3d(32, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
        )
        (1): GroupNorm(32, 64, eps=1e-05, affine=True)
      )
      (2): Sequential(
        (0): Convolution(
          (conv): Conv3d(64, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
        )
        (1): GroupNorm(64, 128, eps=1e-05, affine=True)
      )
      (3): Sequential(
        (0): Convolution(
          (conv): Conv3d(128, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
        )
        (1): GroupNorm(128, 256, eps=1e-05, affine=True)
      )
    )
    (stages): ModuleList(
      (0): Sequential(
        (0): Transformer

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

class Efficient3DHydraAttention(nn.Module):
    """
    Simplified and robust Hydra Attention for 3D medical image segmentation
    """
    
    def __init__(
        self,
        embed_dim: int = 768,
        num_heads: int = 12,
        patch_size: Tuple[int, int, int] = (8, 8, 8),
        dropout: float = 0.1
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.patch_size = patch_size
        
        # Each hydra head gets equal number of attention heads
        heads_per_hydra = num_heads // 4
        
        # Hydra Heads - each specialized for different tasks
        self.local_head = LocalBoundaryHead(embed_dim, heads_per_hydra, dropout)
        self.regional_head = RegionalContextHead(embed_dim, heads_per_hydra, dropout)
        self.global_head = GlobalAnatomyHead(embed_dim, heads_per_hydra, dropout)
        self.cross_slice_head = CrossSliceHead(embed_dim, heads_per_hydra, dropout)
        
        # Adaptive routing - learns which tokens go to which head
        self.routing_gate = nn.Linear(embed_dim, 4)
        
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self, 
        x: torch.Tensor, 
        volume_shape: Tuple[int, int, int],
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            x: Input tensor [B, N, C] where N = D*H*W patches
            volume_shape: Original 3D volume dimensions (D, H, W)
            mask: Optional attention mask
        """
        B, N, C = x.shape
        
        # Adaptive routing - decide which tokens need which type of attention
        routing_scores = F.softmax(self.routing_gate(x), dim=-1)  # [B, N, 4]
        
        # Apply each hydra head with adaptive weighting
        local_out = self.local_head(x, volume_shape)
        regional_out = self.regional_head(x, volume_shape)
        global_out = self.global_head(x, volume_shape)
        cross_out = self.cross_slice_head(x, volume_shape)
        
        # Weight and combine outputs
        local_weight = routing_scores[:, :, 0:1]
        regional_weight = routing_scores[:, :, 1:2]
        global_weight = routing_scores[:, :, 2:3]
        cross_weight = routing_scores[:, :, 3:4]
        
        hydra_out = (local_out * local_weight + 
                    regional_out * regional_weight + 
                    global_out * global_weight + 
                    cross_out * cross_weight)
        
        # Final projection
        out = self.out_proj(hydra_out)
        return self.dropout(out)


class LocalBoundaryHead(nn.Module):
    """Focuses on fine-grained boundaries using windowed attention"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        # For simplicity, use local windowed attention with fixed window size
        window_size = min(64, N // 8)  # Adaptive window size
        
        if window_size >= N:
            # If sequence is short, use full attention
            return self._full_attention(x)
        else:
            # Use sliding window attention
            return self._windowed_attention(x, window_size)
    
    def _full_attention(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        
        return self.proj(self.dropout(out))
    
    def _windowed_attention(self, x: torch.Tensor, window_size: int) -> torch.Tensor:
        B, N, C = x.shape
        
        # Pad sequence to be divisible by window_size
        pad_len = (window_size - N % window_size) % window_size
        if pad_len > 0:
            x_padded = F.pad(x, (0, 0, 0, pad_len))
        else:
            x_padded = x
        
        N_padded = x_padded.shape[1]
        num_windows = N_padded // window_size
        
        # Reshape to windows
        x_windows = x_padded.view(B, num_windows, window_size, C)
        x_windows = x_windows.view(-1, window_size, C)  # [B*num_windows, window_size, C]
        
        # Apply attention within each window
        qkv = self.qkv(x_windows).reshape(-1, window_size, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(-1, window_size, C)
        
        # Reshape back
        out = out.view(B, num_windows, window_size, C)
        out = out.view(B, N_padded, C)
        
        # Remove padding
        if pad_len > 0:
            out = out[:, :N, :]
        
        return self.proj(self.dropout(out))


class RegionalContextHead(nn.Module):
    """Focuses on medium-range context using strided attention"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        # Strided sampling - take every 4th token for efficiency
        stride = 4
        indices = torch.arange(0, N, stride, device=x.device)
        x_strided = x[:, indices, :]  # [B, N//4, C]
        
        # Apply attention on strided tokens
        qkv = self.qkv(x_strided).reshape(B, -1, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out_strided = (attn @ v).transpose(1, 2).reshape(B, -1, C)
        
        # Interpolate back to full resolution
        out = torch.zeros_like(x)
        out[:, indices, :] = out_strided
        
        # Simple interpolation for missing positions
        for i in range(N):
            if i not in indices:
                # Find nearest strided positions
                left_idx = indices[indices <= i]
                right_idx = indices[indices > i]
                
                if len(left_idx) > 0 and len(right_idx) > 0:
                    left = left_idx[-1]
                    right = right_idx[0]
                    # Linear interpolation
                    alpha = (i - left) / (right - left)
                    out[:, i, :] = (1 - alpha) * out[:, left, :] + alpha * out[:, right, :]
                elif len(left_idx) > 0:
                    out[:, i, :] = out[:, left_idx[-1], :]
                elif len(right_idx) > 0:
                    out[:, i, :] = out[:, right_idx[0], :]
        
        return self.proj(self.dropout(out))


class GlobalAnatomyHead(nn.Module):
    """Captures global relationships using linear attention"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.to_q = nn.Linear(embed_dim, embed_dim)
        self.to_k = nn.Linear(embed_dim, embed_dim)
        self.to_v = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        q = self.to_q(x)  # [B, N, C]
        k = self.to_k(x)  # [B, N, C] 
        v = self.to_v(x)  # [B, N, C]
        
        # Linear attention: compute global context then apply to each query
        # Global context = sum(k_i * v_i) / sum(k_i)
        k_sum = k.sum(dim=1, keepdim=True)  # [B, 1, C]
        kv_sum = (k.unsqueeze(-1) * v.unsqueeze(-2)).sum(dim=1)  # [B, C, C]
        
        # Normalize to avoid division by zero
        k_norm = k_sum / (k_sum.norm(dim=-1, keepdim=True) + 1e-6)
        
        # Apply to each query position
        out = torch.zeros_like(q)
        for i in range(N):
            q_i = q[:, i:i+1, :]  # [B, 1, C]
            out[:, i:i+1, :] = q_i @ kv_sum / (q_i @ k_norm.transpose(-2, -1) + 1e-6)
        
        return self.proj(self.dropout(out))


class CrossSliceHead(nn.Module):
    """Ensures consistency across slice orientations"""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        # Sample a subset of tokens for cross-slice attention
        sample_rate = min(1.0, 256 / N)  # Sample up to 256 tokens
        num_samples = max(16, int(N * sample_rate))
        
        # Random sampling for diversity
        indices = torch.randperm(N, device=x.device)[:num_samples]
        indices = indices.sort()[0]  # Sort to maintain some spatial locality
        
        x_sampled = x[:, indices, :]  # [B, num_samples, C]
        
        # Apply attention on sampled tokens
        qkv = self.qkv(x_sampled).reshape(B, num_samples, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out_sampled = (attn @ v).transpose(1, 2).reshape(B, num_samples, C)
        
        # Create global context and broadcast to all positions
        global_context = out_sampled.mean(dim=1, keepdim=True)  # [B, 1, C]
        out = global_context.expand(B, N, C)
        
        return self.proj(self.dropout(out))


def compare_attention_efficiency():
    """Compare FLOP counts for different attention mechanisms"""
    
    # Typical 3D medical image parameters
    batch_size = 2
    volume_shape = (64, 64, 64)  # D, H, W
    patch_size = (8, 8, 8)
    embed_dim = 768
    
    # Calculate number of patches
    num_patches = (volume_shape[0] // patch_size[0]) * \
                  (volume_shape[1] // patch_size[1]) * \
                  (volume_shape[2] // patch_size[2])
    
    print(f"Volume shape: {volume_shape}")
    print(f"Patch size: {patch_size}")
    print(f"Number of patches: {num_patches}")
    
    # FLOP comparison
    print("\n=== FLOP Comparison ===")
    
    # Standard attention: O(N²)
    standard_flops = 2 * num_patches * num_patches * embed_dim
    print(f"Standard Attention FLOPs: {standard_flops:,} ({standard_flops/1e9:.2f}B)")
    
    # Hydra attention breakdown:
    # Local head: windowed attention (window size = 64)
    window_size = min(64, num_patches // 8)
    num_windows = num_patches // window_size
    local_flops = 2 * num_windows * (window_size * window_size) * embed_dim
    print(f"Local Head FLOPs: {local_flops:,} ({local_flops/1e6:.1f}M)")
    
    # Regional head: strided attention (stride=4)
    strided_patches = num_patches // 4
    regional_flops = 2 * strided_patches * strided_patches * embed_dim
    print(f"Regional Head FLOPs: {regional_flops:,} ({regional_flops/1e6:.1f}M)")
    
    # Global head: linear attention O(N)
    global_flops = 2 * num_patches * embed_dim  # Simplified linear attention
    print(f"Global Head FLOPs: {global_flops:,} ({global_flops/1e6:.1f}M)")
    
    # Cross-slice head: sparse sampling
    slice_samples = min(256, num_patches)
    cross_flops = 2 * slice_samples * slice_samples * embed_dim
    print(f"Cross-slice Head FLOPs: {cross_flops:,} ({cross_flops/1e6:.1f}M)")
    
    total_hydra_flops = local_flops + regional_flops + global_flops + cross_flops
    print(f"\nTotal Hydra FLOPs: {total_hydra_flops:,} ({total_hydra_flops/1e6:.1f}M)")
    print(f"Speedup: {standard_flops/total_hydra_flops:.1f}x")
    print(f"Memory reduction: ~{standard_flops/total_hydra_flops:.1f}x")


if __name__ == "__main__":
    # Test the implementation
    compare_attention_efficiency()
    
    # Create model
    model = Efficient3DHydraAttention(
        embed_dim=768,
        num_heads=12,
        patch_size=(8, 8, 8),
        dropout=0.1
    )
    
    # Test forward pass
    batch_size = 2
    volume_shape = (64, 64, 64)
    patch_size = (8, 8, 8)
    
    # Calculate correct number of patches
    num_patches = (volume_shape[0] // patch_size[0]) * \
                  (volume_shape[1] // patch_size[1]) * \
                  (volume_shape[2] // patch_size[2])
    print(f"\nTesting with:")
    print(f"Volume shape: {volume_shape}")
    print(f"Patch size: {patch_size}")
    print(f"Number of patches: {num_patches}")
    
    x = torch.randn(batch_size, num_patches, 768)
    
    with torch.no_grad():
        output = model(x, volume_shape)
        print(f"\nInput shape: {x.shape}")
        print(f"Output shape: {output.shape}")
        print("✓ Forward pass successful!")
        
        # Test that output has correct properties
        print(f"Output mean: {output.mean().item():.4f}")
        print(f"Output std: {output.std().item():.4f}")
        print(f"Output min: {output.min().item():.4f}")
        print(f"Output max: {output.max().item():.4f}")

Volume shape: (64, 64, 64)
Patch size: (8, 8, 8)
Number of patches: 512

=== FLOP Comparison ===
Standard Attention FLOPs: 402,653,184 (0.40B)
Local Head FLOPs: 50,331,648 (50.3M)
Regional Head FLOPs: 25,165,824 (25.2M)
Global Head FLOPs: 786,432 (0.8M)
Cross-slice Head FLOPs: 100,663,296 (100.7M)

Total Hydra FLOPs: 176,947,200 (176.9M)
Speedup: 2.3x
Memory reduction: ~2.3x

Testing with:
Volume shape: (64, 64, 64)
Patch size: (8, 8, 8)
Number of patches: 512

Input shape: torch.Size([2, 512, 768])
Output shape: torch.Size([2, 512, 768])
✓ Forward pass successful!
Output mean: 1.8712
Output std: 991.7621
Output min: -91737.8672
Output max: 91863.3047


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math
from typing import Tuple, Optional
import numpy as np

# EPA Implementation (from UNETR++)
class EPA(nn.Module):
    """
    Efficient Paired Attention Block, based on: "Shaker et al.,
    UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
    """
    def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,
                 channel_attn_drop=0.1, spatial_attn_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
        # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)
        self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)
        # E and F are projection matrices with shared weights used in spatial attention module to project
        # keys and values from HWD-dimension to P-dimension
        self.E = self.F = nn.Linear(input_size, proj_size)
        self.attn_drop = nn.Dropout(channel_attn_drop)
        self.attn_drop_2 = nn.Dropout(spatial_attn_drop)
        self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2))
        self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2))
        
    def forward(self, x):
        B, N, C = x.shape
        qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)
        qkvv = qkvv.permute(2, 0, 3, 1, 4)
        q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]
        q_shared = q_shared.transpose(-2, -1)
        k_shared = k_shared.transpose(-2, -1)
        v_CA = v_CA.transpose(-2, -1)
        v_SA = v_SA.transpose(-2, -1)
        k_shared_projected = self.E(k_shared)
        v_SA_projected = self.F(v_SA)
        q_shared = torch.nn.functional.normalize(q_shared, dim=-1)
        k_shared = torch.nn.functional.normalize(k_shared, dim=-1)
        attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature
        attn_CA = attn_CA.softmax(dim=-1)
        attn_CA = self.attn_drop(attn_CA)
        x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)
        attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2
        attn_SA = attn_SA.softmax(dim=-1)
        attn_SA = self.attn_drop_2(attn_SA)
        x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)
        # Concat fusion
        x_SA = self.out_proj(x_SA)
        x_CA = self.out_proj2(x_CA)
        x = torch.cat((x_SA, x_CA), dim=-1)
        return x
    
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'temperature', 'temperature2'}


# Simplified Hydra Attention (from previous implementation)
class Efficient3DHydraAttention(nn.Module):
    """Hydra Attention for 3D medical image segmentation"""
    
    def __init__(self, embed_dim: int = 768, num_heads: int = 12, dropout: float = 0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # Each hydra head gets equal number of attention heads
        heads_per_hydra = num_heads // 4
        
        # Hydra Heads - each specialized for different tasks
        self.local_head = LocalBoundaryHead(embed_dim, heads_per_hydra, dropout)
        self.regional_head = RegionalContextHead(embed_dim, heads_per_hydra, dropout)
        self.global_head = GlobalAnatomyHead(embed_dim, heads_per_hydra, dropout)
        self.cross_slice_head = CrossSliceHead(embed_dim, heads_per_hydra, dropout)
        
        # Adaptive routing - learns which tokens go to which head
        self.routing_gate = nn.Linear(embed_dim, 4)
        
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int] = (64, 64, 64)) -> torch.Tensor:
        B, N, C = x.shape
        
        # Adaptive routing
        routing_scores = F.softmax(self.routing_gate(x), dim=-1)
        
        # Apply each hydra head with adaptive weighting
        local_out = self.local_head(x, volume_shape)
        regional_out = self.regional_head(x, volume_shape)
        global_out = self.global_head(x, volume_shape)
        cross_out = self.cross_slice_head(x, volume_shape)
        
        # Weight and combine outputs
        local_weight = routing_scores[:, :, 0:1]
        regional_weight = routing_scores[:, :, 1:2]
        global_weight = routing_scores[:, :, 2:3]
        cross_weight = routing_scores[:, :, 3:4]
        
        hydra_out = (local_out * local_weight + 
                    regional_out * regional_weight + 
                    global_out * global_weight + 
                    cross_out * cross_weight)
        
        # Final projection
        out = self.out_proj(hydra_out)
        return self.dropout(out)


# Hydra Head implementations (simplified for benchmarking)
class LocalBoundaryHead(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        window_size = min(64, N // 8)
        
        if window_size >= N:
            return self._full_attention(x)
        else:
            return self._windowed_attention(x, window_size)
    
    def _full_attention(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        
        return self.proj(out)
    
    def _windowed_attention(self, x: torch.Tensor, window_size: int) -> torch.Tensor:
        B, N, C = x.shape
        
        # Simplified windowed attention
        pad_len = (window_size - N % window_size) % window_size
        if pad_len > 0:
            x_padded = F.pad(x, (0, 0, 0, pad_len))
        else:
            x_padded = x
        
        N_padded = x_padded.shape[1]
        num_windows = N_padded // window_size
        
        x_windows = x_padded.view(B, num_windows, window_size, C)
        x_windows = x_windows.view(-1, window_size, C)
        
        qkv = self.qkv(x_windows).reshape(-1, window_size, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(-1, window_size, C)
        
        out = out.view(B, N_padded, C)
        if pad_len > 0:
            out = out[:, :N, :]
        
        return self.proj(out)


class RegionalContextHead(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        # Strided sampling
        stride = 4
        indices = torch.arange(0, N, stride, device=x.device)
        x_strided = x[:, indices, :]
        
        qkv = self.qkv(x_strided).reshape(B, -1, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out_strided = (attn @ v).transpose(1, 2).reshape(B, -1, C)
        
        # Simple interpolation back
        out = torch.zeros_like(x)
        out[:, indices, :] = out_strided
        
        # Fill missing positions with nearest neighbor
        for i in range(N):
            if i not in indices:
                nearest = indices[torch.argmin(torch.abs(indices - i))]
                out[:, i, :] = out[:, nearest, :]
        
        return self.proj(out)


class GlobalAnatomyHead(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.to_q = nn.Linear(embed_dim, embed_dim)
        self.to_k = nn.Linear(embed_dim, embed_dim)
        self.to_v = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)
        
        # Linear attention approximation
        k_sum = k.sum(dim=1, keepdim=True)
        kv = (k.transpose(-2, -1) @ v) / N
        out = q @ kv
        
        return self.proj(out)


class CrossSliceHead(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float):
        super().__init__()
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int]) -> torch.Tensor:
        B, N, C = x.shape
        
        # Sample subset for efficiency
        sample_rate = min(1.0, 256 / N)
        num_samples = max(16, int(N * sample_rate))
        
        indices = torch.randperm(N, device=x.device)[:num_samples]
        x_sampled = x[:, indices, :]
        
        qkv = self.qkv(x_sampled).reshape(B, num_samples, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out_sampled = (attn @ v).transpose(1, 2).reshape(B, num_samples, C)
        
        # Broadcast global context
        global_context = out_sampled.mean(dim=1, keepdim=True)
        out = global_context.expand(B, N, C)
        
        return self.proj(out)


# Standard Multi-Head Attention for comparison
class StandardAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int] = (64, 64, 64)) -> torch.Tensor:
        out, _ = self.attention(x, x, x)
        return out


def count_flops(model, x, volume_shape=(64, 64, 64)):
    """Count FLOPs for attention mechanisms"""
    B, N, C = x.shape
    
    if isinstance(model, EPA):
        # EPA FLOP calculation
        # qkvv projection: 4 * N * C * C
        qkvv_flops = 4 * N * C * C
        
        # E and F projections: 2 * num_heads * C/num_heads * N * proj_size
        proj_size = getattr(model, 'E').out_features
        input_size = getattr(model, 'E').in_features
        ef_flops = 2 * model.num_heads * (C // model.num_heads) * input_size * proj_size
        
        # Channel attention: Q @ K^T = (C/num_heads)^2 * num_heads
        ca_attn_flops = model.num_heads * (C // model.num_heads) * (C // model.num_heads) * B
        # Channel attention: Attn @ V = (C/num_heads)^2 * num_heads  
        ca_out_flops = model.num_heads * (C // model.num_heads) * (C // model.num_heads) * B
        
        # Spatial attention: Q @ K_proj = N * proj_size * num_heads
        sa_attn_flops = model.num_heads * N * proj_size * B
        # Spatial attention: Attn @ V_proj = N * proj_size * num_heads
        sa_out_flops = model.num_heads * N * proj_size * B
        
        # Output projections: 2 * N * C * C/2
        out_proj_flops = 2 * N * C * (C // 2)
        
        total_flops = (qkvv_flops + ef_flops + ca_attn_flops + ca_out_flops + 
                      sa_attn_flops + sa_out_flops + out_proj_flops) * B
        
    elif isinstance(model, StandardAttention):
        # Standard attention: O(N^2 * C)
        # Q@K^T: N^2 * C, Attn@V: N^2 * C, plus linear projections: 3 * N * C^2
        total_flops = B * (2 * N * N * C + 4 * N * C * C)  # QKV + attention + output proj
        
    elif isinstance(model, Efficient3DHydraAttention):
        # Hydra attention approximation
        # Local head: windowed attention
        window_size = min(64, N // 8)
        num_windows = N // window_size if window_size < N else 1
        local_flops = num_windows * (2 * window_size * window_size * C + 3 * window_size * C * C)
        
        # Regional head: strided attention  
        strided_N = N // 4
        regional_flops = 2 * strided_N * strided_N * C + 3 * strided_N * C * C
        
        # Global head: linear attention O(N * C^2)
        global_flops = 3 * N * C * C
        
        # Cross-slice head: sparse sampling
        sample_N = min(256, N)
        cross_flops = 2 * sample_N * sample_N * C + 3 * sample_N * C * C
        
        # Routing and projections
        routing_flops = N * C * 4 + N * C * C
        
        total_flops = B * (local_flops + regional_flops + global_flops + cross_flops + routing_flops)
        
    return total_flops


def benchmark_models():
    """Comprehensive benchmark of different attention mechanisms"""
    
    print("="*80)
    print("COMPREHENSIVE ATTENTION BENCHMARK: HYDRA vs EPA vs STANDARD")
    print("="*80)
    
    # Test configurations for 3D medical imaging
    configs = [
        {"name": "Small Volume (32³)", "batch_size": 4, "volume_shape": (32, 32, 32), "patch_size": (4, 4, 4)},
        {"name": "Medium Volume (64³)", "batch_size": 2, "volume_shape": (64, 64, 64), "patch_size": (8, 8, 8)},
        {"name": "Large Volume (128³)", "batch_size": 1, "volume_shape": (128, 128, 128), "patch_size": (16, 16, 16)},
    ]
    
    embed_dim = 768
    num_heads = 12
    dropout = 0.1
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")
    print()
    
    results = {}
    
    for config in configs:
        print(f"\n{'='*50}")
        print(f"TESTING: {config['name']}")
        print(f"{'='*50}")
        
        batch_size = config["batch_size"]
        volume_shape = config["volume_shape"]
        patch_size = config["patch_size"]
        
        # Calculate number of patches
        num_patches = np.prod([v // p for v, p in zip(volume_shape, patch_size)])
        input_size = num_patches  # For EPA
        proj_size = min(256, num_patches // 4)  # Projection size for EPA
        
        print(f"Volume: {volume_shape}, Patches: {num_patches}, Batch: {batch_size}")
        print(f"Input tensor: [{batch_size}, {num_patches}, {embed_dim}]")
        print()
        
        # Create models
        models = {
            "Standard": StandardAttention(embed_dim, num_heads, dropout).to(device),
            "EPA": EPA(input_size, embed_dim, proj_size, num_heads).to(device),
            "Hydra": Efficient3DHydraAttention(embed_dim, num_heads, dropout).to(device)
        }
        
        # Create input
        x = torch.randn(batch_size, num_patches, embed_dim, device=device)
        
        # Benchmark each model
        config_results = {}
        
        for name, model in models.items():
            print(f"\n--- {name} Attention ---")
            
            # Parameter count
            param_count = sum(p.numel() for p in model.parameters())
            print(f"Parameters: {param_count:,} ({param_count/1e6:.2f}M)")
            
            # FLOP count
            flops = count_flops(model, x, volume_shape)
            print(f"FLOPs: {flops:,} ({flops/1e6:.1f}M)")
            
            # Memory usage
            model.train()
            torch.cuda.empty_cache() if device.type == "cuda" else None
            
            try:
                # Warmup
                with torch.no_grad():
                    if name == "Hydra":
                        _ = model(x, volume_shape)
                    else:
                        _ = model(x)
                
                # Memory measurement
                if device.type == "cuda":
                    torch.cuda.empty_cache()
                    torch.cuda.reset_peak_memory_stats()
                
                # Forward pass timing
                model.eval()
                times = []
                
                with torch.no_grad():
                    for _ in range(10):  # Multiple runs for stable timing
                        torch.cuda.synchronize() if device.type == "cuda" else None
                        start_time = time.time()
                        
                        if name == "Hydra":
                            output = model(x, volume_shape)
                        else:
                            output = model(x)
                        
                        torch.cuda.synchronize() if device.type == "cuda" else None
                        end_time = time.time()
                        
                        times.append(end_time - start_time)
                
                avg_time = np.mean(times[2:])  # Exclude first 2 for warmup
                std_time = np.std(times[2:])
                
                # Peak memory
                if device.type == "cuda":
                    peak_memory = torch.cuda.max_memory_allocated() / (1024**3)  # GB
                else:
                    peak_memory = 0
                
                print(f"Inference Time: {avg_time*1000:.2f} ± {std_time*1000:.2f} ms")
                print(f"Peak Memory: {peak_memory:.3f} GB")
                print(f"Output Shape: {output.shape}")
                
                # Store results
                config_results[name] = {
                    "params": param_count,
                    "flops": flops,
                    "time_ms": avg_time * 1000,
                    "time_std": std_time * 1000,
                    "memory_gb": peak_memory,
                    "throughput": batch_size / avg_time
                }
                
            except Exception as e:
                print(f"❌ Failed: {e}")
                config_results[name] = None
        
        results[config['name']] = config_results
        
        # Comparison summary for this configuration
        print(f"\n{'='*30} SUMMARY {'='*30}")
        valid_results = {k: v for k, v in config_results.items() if v is not None}
        
        if len(valid_results) > 1:
            baseline = "Standard"
            if baseline in valid_results:
                print(f"Speedup vs {baseline}:")
                for name, result in valid_results.items():
                    if name != baseline:
                        speedup = valid_results[baseline]["time_ms"] / result["time_ms"]
                        flop_reduction = valid_results[baseline]["flops"] / result["flops"]
                        param_ratio = result["params"] / valid_results[baseline]["params"]
                        print(f"  {name}: {speedup:.2f}x faster, {flop_reduction:.2f}x fewer FLOPs, {param_ratio:.2f}x params")
    
    # Overall summary
    print(f"\n{'='*80}")
    print("FINAL COMPARISON SUMMARY")
    print(f"{'='*80}")
    
    for config_name, config_results in results.items():
        print(f"\n{config_name}:")
        valid_results = {k: v for k, v in config_results.items() if v is not None}
        
        if "Standard" in valid_results and "Hydra" in valid_results:
            hydra = valid_results["Hydra"]
            standard = valid_results["Standard"]
            speedup = standard["time_ms"] / hydra["time_ms"]
            flop_reduction = standard["flops"] / hydra["flops"]
            print(f"  Hydra vs Standard: {speedup:.1f}x faster, {flop_reduction:.1f}x fewer FLOPs")
        
        if "EPA" in valid_results and "Hydra" in valid_results:
            hydra = valid_results["Hydra"]
            epa = valid_results["EPA"]
            speedup = epa["time_ms"] / hydra["time_ms"]
            flop_ratio = hydra["flops"] / epa["flops"]
            print(f"  Hydra vs EPA: {speedup:.1f}x faster" if speedup > 1 else f"  EPA vs Hydra: {1/speedup:.1f}x faster")
            print(f"    FLOP ratio: {flop_ratio:.1f}x")
    
    return results


if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Run comprehensive benchmark
    benchmark_results = benchmark_models()
    
    print(f"\n{'='*80}")
    print("BENCHMARK COMPLETED!")
    print(f"{'='*80}")

COMPREHENSIVE ATTENTION BENCHMARK: HYDRA vs EPA vs STANDARD
Device: cuda


TESTING: Small Volume (32³)
Volume: (32, 32, 32), Patches: 512, Batch: 4
Input tensor: [4, 512, 768]


--- Standard Attention ---
Parameters: 2,362,368 (2.36M)
FLOPs: 6,442,450,944 (6442.5M)
Inference Time: 1.95 ± 0.01 ms
Peak Memory: 0.156 GB
Output Shape: torch.Size([4, 512, 768])

--- EPA Attention ---
Parameters: 3,015,576 (3.02M)
FLOPs: 6,469,189,632 (6469.2M)
Inference Time: 2.21 ± 0.00 ms
Peak Memory: 0.147 GB
Output Shape: torch.Size([4, 512, 768])

--- Hydra Attention ---
Parameters: 10,043,140 (10.04M)
FLOPs: 11,884,560,384 (11884.6M)
Inference Time: 48.00 ± 0.15 ms
Peak Memory: 0.130 GB
Output Shape: torch.Size([4, 512, 768])

Speedup vs Standard:
  EPA: 0.88x faster, 1.00x fewer FLOPs, 1.28x params
  Hydra: 0.04x faster, 0.54x fewer FLOPs, 4.25x params

TESTING: Medium Volume (64³)
Volume: (64, 64, 64), Patches: 512, Batch: 2
Input tensor: [2, 512, 768]


--- Standard Attention ---
Parameters: 2,362,

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

class OptimizedHydraAttention(nn.Module):
    """
    Highly optimized Hydra Attention designed to outperform both Standard and EPA attention.
    Key optimizations:
    - Parallel processing of all heads
    - Fused operations to reduce memory overhead
    - Efficient sparse patterns
    - Minimal parameter overhead
    """
    
    def __init__(
        self,
        embed_dim: int = 768,
        num_heads: int = 12,
        dropout: float = 0.1,
        efficiency_mode: str = "balanced"  # "fast", "balanced", "memory"
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.efficiency_mode = efficiency_mode
        
        # Single unified QKV projection for all heads (more efficient than separate projections)
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        
        # Lightweight routing mechanism (much simpler than before)
        self.route_proj = nn.Linear(embed_dim, num_heads, bias=False)
        
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        
        # Precompute attention patterns for efficiency
        self._register_attention_patterns()
        
    def _register_attention_patterns(self):
        """Precompute sparse attention patterns to avoid runtime computation"""
        # These will be populated during first forward pass based on sequence length
        self.register_buffer('local_mask', None, persistent=False)
        self.register_buffer('strided_mask', None, persistent=False)
        self.register_buffer('global_indices', None, persistent=False)
        
    def _create_attention_masks(self, seq_len: int, device: torch.device):
        """Create efficient sparse attention masks"""
        if self.local_mask is not None and self.local_mask.size(0) == seq_len:
            return  # Already created for this sequence length
            
        # Local attention pattern (sliding window)
        window_size = min(64, seq_len // 4)
        local_mask = torch.zeros(seq_len, seq_len, device=device, dtype=torch.bool)
        for i in range(seq_len):
            start = max(0, i - window_size // 2)
            end = min(seq_len, i + window_size // 2 + 1)
            local_mask[i, start:end] = True
            
        # Strided attention pattern
        stride = max(1, seq_len // 128)  # Adaptive stride
        strided_mask = torch.zeros(seq_len, seq_len, device=device, dtype=torch.bool)
        for i in range(0, seq_len, stride):
            for j in range(0, seq_len, stride):
                strided_mask[i, j] = True
                
        # Global attention indices (sample key positions)
        global_indices = torch.linspace(0, seq_len-1, min(32, seq_len), device=device).long()
        
        # Register as buffers
        self.register_buffer('local_mask', local_mask, persistent=False)
        self.register_buffer('strided_mask', strided_mask, persistent=False)
        self.register_buffer('global_indices', global_indices, persistent=False)
    
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int] = (64, 64, 64)) -> torch.Tensor:
        B, N, C = x.shape
        
        # Create attention masks if needed
        self._create_attention_masks(N, x.device)
        
        # Unified QKV projection (more efficient than separate projections)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, num_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Lightweight routing - which heads use which patterns
        route_weights = torch.softmax(self.route_proj(x.mean(dim=1)), dim=-1)  # [B, num_heads]
        
        # Parallel attention computation with different patterns per head
        attn_out = torch.zeros_like(q)  # [B, num_heads, N, head_dim]
        
        # Distribute heads across different attention patterns
        heads_per_pattern = self.num_heads // 4
        
        # Pattern 1: Local attention (fine details)
        local_heads = slice(0, heads_per_pattern)
        attn_out[:, local_heads] = self._local_attention(
            q[:, local_heads], k[:, local_heads], v[:, local_heads]
        )
        
        # Pattern 2: Strided attention (medium range)  
        strided_heads = slice(heads_per_pattern, 2 * heads_per_pattern)
        attn_out[:, strided_heads] = self._strided_attention(
            q[:, strided_heads], k[:, strided_heads], v[:, strided_heads]
        )
        
        # Pattern 3: Global attention (long range) - Linear attention for efficiency
        global_heads = slice(2 * heads_per_pattern, 3 * heads_per_pattern)
        attn_out[:, global_heads] = self._global_linear_attention(
            q[:, global_heads], k[:, global_heads], v[:, global_heads]
        )
        
        # Pattern 4: Random sparse attention (cross-slice consistency)
        sparse_heads = slice(3 * heads_per_pattern, self.num_heads)
        attn_out[:, sparse_heads] = self._sparse_attention(
            q[:, sparse_heads], k[:, sparse_heads], v[:, sparse_heads]
        )
        
        # Apply routing weights
        route_weights = route_weights.unsqueeze(-1).unsqueeze(-1)  # [B, num_heads, 1, 1]
        attn_out = attn_out * route_weights
        
        # Combine heads and project
        out = attn_out.transpose(1, 2).reshape(B, N, C)
        out = self.out_proj(out)
        
        return self.dropout(out)
    
    def _local_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        """Efficient local attention using sparse mask"""
        B, H, N, D = q.shape
        
        # Compute attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B, H, N, N]
        
        # Apply local mask (only attend to nearby tokens)
        attn = attn.masked_fill(~self.local_mask.unsqueeze(0).unsqueeze(0), -float('inf'))
        attn = F.softmax(attn, dim=-1)
        
        # Apply attention
        out = attn @ v  # [B, H, N, D]
        return out
    
    def _strided_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        """Efficient strided attention"""
        B, H, N, D = q.shape
        
        # Use precomputed strided mask
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.masked_fill(~self.strided_mask.unsqueeze(0).unsqueeze(0), -float('inf'))
        attn = F.softmax(attn, dim=-1)
        
        out = attn @ v
        return out
    
    def _global_linear_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        """Linear attention for global context - O(N) complexity"""
        B, H, N, D = q.shape
        
        # Normalize queries and keys for stability
        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)
        
        # Linear attention: compute k^T @ v first (more efficient)
        kv = k.transpose(-2, -1) @ v  # [B, H, D, D]
        out = q @ kv  # [B, H, N, D]
        
        # Normalize by key sum to maintain attention property
        k_sum = k.sum(dim=-2, keepdim=True)  # [B, H, 1, D]
        normalizer = (q * k_sum).sum(dim=-1, keepdim=True) + 1e-6  # [B, H, N, 1]
        out = out / normalizer
        
        return out
    
    def _sparse_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        """Sparse random attention for cross-connections"""
        B, H, N, D = q.shape
        
        # Sample a subset of key positions for each query
        if self.efficiency_mode == "fast":
            # Very sparse for maximum speed
            sample_ratio = min(0.1, 32 / N)
        elif self.efficiency_mode == "balanced":
            sample_ratio = min(0.25, 64 / N)
        else:  # memory mode
            sample_ratio = min(0.5, 128 / N)
            
        num_samples = max(4, int(N * sample_ratio))
        
        # Use global indices for consistent sampling
        indices = self.global_indices[:num_samples]
        
        # Sample keys and values
        k_sampled = k[:, :, indices, :]  # [B, H, num_samples, D]
        v_sampled = v[:, :, indices, :]  # [B, H, num_samples, D]
        
        # Compute attention with sampled keys
        attn = (q @ k_sampled.transpose(-2, -1)) * self.scale  # [B, H, N, num_samples]
        attn = F.softmax(attn, dim=-1)
        
        # Apply to sampled values
        out = attn @ v_sampled  # [B, H, N, D]
        return out


class UltraFastHydra(nn.Module):
    """
    Ultra-optimized version focusing purely on speed
    Trades some accuracy for maximum performance
    """
    
    def __init__(self, embed_dim: int = 768, num_heads: int = 12, dropout: float = 0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # Single QKV projection with minimal overhead
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        
        # Fixed patterns - no learnable routing for maximum speed
        self.window_size = 32  # Fixed small window
        self.stride = 8        # Fixed stride
        self.global_tokens = 16  # Fixed number of global tokens
        
    def forward(self, x: torch.Tensor, volume_shape: Tuple[int, int, int] = (64, 64, 64)) -> torch.Tensor:
        B, N, C = x.shape
        
        # Single QKV computation
        qkv = self.qkv(x).view(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Split heads evenly across patterns
        heads_per_pattern = self.num_heads // 3
        
        out = torch.zeros_like(q)
        
        # Pattern 1: Local sliding window (fast)
        if heads_per_pattern > 0:
            local_slice = slice(0, heads_per_pattern)
            out[:, local_slice] = self._fast_local(q[:, local_slice], k[:, local_slice], v[:, local_slice], N)
        
        # Pattern 2: Strided attention (medium range)
        if heads_per_pattern > 0:
            strided_slice = slice(heads_per_pattern, 2 * heads_per_pattern)
            out[:, strided_slice] = self._fast_strided(q[:, strided_slice], k[:, strided_slice], v[:, strided_slice], N)
        
        # Pattern 3: Linear attention (global)
        remaining_heads = self.num_heads - 2 * heads_per_pattern
        if remaining_heads > 0:
            global_slice = slice(2 * heads_per_pattern, self.num_heads)
            out[:, global_slice] = self._fast_linear(q[:, global_slice], k[:, global_slice], v[:, global_slice])
        
        # Combine and project
        out = out.transpose(1, 2).reshape(B, N, C)
        return self.out_proj(out)
    
    def _fast_local(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, N: int) -> torch.Tensor:
        """Extremely fast local attention with fixed window"""
        B, H, _, D = q.shape
        
        # Pad sequence to be divisible by window_size
        pad_size = (self.window_size - N % self.window_size) % self.window_size
        if pad_size > 0:
            q = F.pad(q, (0, 0, 0, pad_size))
            k = F.pad(k, (0, 0, 0, pad_size))
            v = F.pad(v, (0, 0, 0, pad_size))
        
        N_padded = q.size(2)
        num_windows = N_padded // self.window_size
        
        # Reshape to windows
        q = q.view(B, H, num_windows, self.window_size, D)
        k = k.view(B, H, num_windows, self.window_size, D)
        v = v.view(B, H, num_windows, self.window_size, D)
        
        # Attention within each window
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = attn @ v
        
        # Reshape back
        out = out.view(B, H, N_padded, D)
        
        # Remove padding
        if pad_size > 0:
            out = out[:, :, :N, :]
            
        return out
    
    def _fast_strided(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, N: int) -> torch.Tensor:
        """Fast strided attention"""
        # Sample every stride-th token
        indices = torch.arange(0, N, self.stride, device=q.device)
        
        q_strided = q[:, :, indices, :]
        k_strided = k[:, :, indices, :]
        v_strided = v[:, :, indices, :]
        
        # Attention on strided tokens
        attn = (q_strided @ k_strided.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out_strided = attn @ v_strided
        
        # Interpolate back to full resolution
        out = torch.zeros_like(q)
        out[:, :, indices, :] = out_strided
        
        # Simple nearest neighbor for missing positions
        for i in range(N):
            if i not in indices:
                nearest_idx = indices[torch.argmin(torch.abs(indices - i))]
                out[:, :, i, :] = out[:, :, nearest_idx, :]
                
        return out
    
    def _fast_linear(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        """Ultra-fast linear attention"""
        # Normalize for stability
        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)
        
        # Linear attention
        kv = k.transpose(-2, -1) @ v
        out = q @ kv
        
        # Simple normalization
        out = out / math.sqrt(q.size(-1))
        
        return out


# Benchmark the optimized versions
def benchmark_optimized():
    """Test the optimized Hydra implementations"""
    
    print("="*80)
    print("OPTIMIZED HYDRA BENCHMARK")
    print("="*80)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 2
    seq_len = 512
    embed_dim = 768
    num_heads = 12
    
    x = torch.randn(batch_size, seq_len, embed_dim, device=device)
    
    models = {
        "Standard": nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).to(device),
        "Optimized Hydra": OptimizedHydraAttention(embed_dim, num_heads, efficiency_mode="balanced").to(device),
        "Ultra-Fast Hydra": UltraFastHydra(embed_dim, num_heads).to(device),
    }
    
    for name, model in models.items():
        print(f"\n--- {name} ---")
        
        # Parameter count
        params = sum(p.numel() for p in model.parameters())
        print(f"Parameters: {params:,} ({params/1e6:.2f}M)")
        
        # Timing
        model.eval()
        times = []
        
        with torch.no_grad():
            # Warmup
            for _ in range(5):
                if name == "Standard":
                    _ = model(x, x, x)[0]
                else:
                    _ = model(x)
            
            # Actual timing
            for _ in range(20):
                torch.cuda.synchronize() if device.type == "cuda" else None
                start = time.time()
                
                if name == "Standard":
                    output = model(x, x, x)[0]
                else:
                    output = model(x)
                    
                torch.cuda.synchronize() if device.type == "cuda" else None
                times.append(time.time() - start)
        
        avg_time = np.mean(times[5:]) * 1000  # Convert to ms, skip warmup
        std_time = np.std(times[5:]) * 1000
        
        print(f"Time: {avg_time:.2f} ± {std_time:.2f} ms")
        print(f"Output shape: {output.shape}")
        
        # Memory
        if device.type == "cuda":
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            
            with torch.no_grad():
                if name == "Standard":
                    _ = model(x, x, x)[0]
                else:
                    _ = model(x)
                    
            peak_mem = torch.cuda.max_memory_allocated() / (1024**3)
            print(f"Peak Memory: {peak_mem:.3f} GB")


if __name__ == "__main__":
    import time
    import numpy as np
    benchmark_optimized()

OPTIMIZED HYDRA BENCHMARK

--- Standard ---
Parameters: 2,362,368 (2.36M)
Time: 1.22 ± 0.26 ms
Output shape: torch.Size([2, 512, 768])
Peak Memory: 0.079 GB

--- Optimized Hydra ---
Parameters: 2,368,512 (2.37M)
Time: 1.34 ± 0.11 ms
Output shape: torch.Size([2, 512, 768])
Peak Memory: 0.065 GB

--- Ultra-Fast Hydra ---
Parameters: 2,359,296 (2.36M)
Time: 51.31 ± 1.56 ms
Output shape: torch.Size([2, 512, 768])
Peak Memory: 0.057 GB


In [None]:
class ECD(nn.Module):
    """
    Efficient Channel-Depth attention block - Optimized for reduced FLOPs
    """
    def __init__(self, depth_size: int, hidden_size, proj_size, num_heads=4, qkv_bias=False,
                 channel_attn_drop=0.1, spatial_attn_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size//num_heads
        
        
        # Original layers
        self.qkv_c = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
        self.qkv_d = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
        self.attn_drop = nn.Dropout(channel_attn_drop)
        self.attn_drop_2 = nn.Dropout(spatial_attn_drop)
        self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2))
        self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2))
    
    
    
    def forward(self, x):
        B, C, D, H, W = x.shape
        N = H * W
        
        x_reshaped = x.permute(0, 2, 3, 4, 1).view(B, D, N, C)
        
        qkv_c = self.qkv_c(x_reshaped).reshape(B, D, N, 3, C).permute(3, 0, 1, 2, 4)
        q_d = self.qkv_d(x_reshaped).reshape(B, D, N, C)
        q_c, k_c, v_c = qkv_c[0], qkv_c[1], qkv_c[2]
        k_shared, v_shared = k_c, v_c
        
        
        # Channel attention 
        q_c = q_c.reshape(B*D, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
        k_c = k_c.reshape(B*D, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
        v_c = v_c.reshape(B*D, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
        
        q_c = q_c.transpose(-2, -1)
        k_c = k_c.transpose(-2, -1)
        v_c = v_c.transpose(-2, -1)
        q_c = F.normalize(q_c, dim=-1)
        k_c = F.normalize(k_c, dim=-1)
        
        attn_CA = (q_c @ k_c.transpose(-2, -1)) / (N ** 0.5)
        
        attn_CA = attn_CA.softmax(dim=-1)
        attn_CA = self.attn_drop(attn_CA)
        x_CA = (attn_CA @ v_c).reshape(B, D, C, N).permute(0, 1, 3, 2)
        x_CA = self.out_proj(x_CA)
        
        # Depth attention
        q_d = q_d.permute(0, 3, 2, 1).reshape(B*C, N, D) 
        k_d = k_shared.permute(0, 3, 2, 1).reshape(B*C, N, D) 
        v_d = v_shared.permute(0, 3, 2, 1).reshape(B*C, N, D) 
        
        q_d = q_d.transpose(-2, -1)
        k_d = k_d.transpose(-2, -1)
        v_d = v_d.transpose(-2, -1)
        q_d = F.normalize(q_d, dim=-1)
        k_d = F.normalize(k_d, dim=-1)
        
        attn_D = (q_d @ k_d.transpose(-2, -1)) / (N ** 0.5)
        
        
        attn_D = attn_D.softmax(dim=-1)
        attn_D = self.attn_drop_2(attn_D)
        x_D = (attn_D @ v_d)
        x_D = x_D.reshape(B, C, D, N).permute(0, 2, 3, 1)
        x_D = self.out_proj2(x_D)
        
        x = torch.cat((x_CA, x_D), dim=-1).permute(0, 3, 1, 2).reshape(B, C, D, H, W)
        
        return x