In [None]:
# # HANS-Net Components Implementation

# This notebook implements the core building blocks for the Hyperbolic Attention Network for Segmentation (HANS-Net):
# - **WaveletDecomposition**: 2D DWT for multi-frequency feature extraction
# - **SynapticPlasticity & PlasticConvBlock**: Adaptive convolutions with Hebbian-inspired learning
# - **HyperbolicConvBlock**: Convolutions in Poincaré ball geometry
# - **TemporalAttention**: Cross-frame attention for temporal context
# - **INRBranch**: Implicit Neural Representation for boundary refinement

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Wavelet Decomposition

Applies 2D Discrete Wavelet Transform using Haar wavelets to decompose input into four subbands:
- **LL**: Low-Low (approximation coefficients)
- **LH**: Low-High (horizontal details)
- **HL**: High-Low (vertical details)  
- **HH**: High-High (diagonal details)

In [None]:
class WaveletDecomposition(nn.Module):
    """
    2D Discrete Wavelet Transform using Haar wavelets.
    
    Decomposes each input channel into 4 subbands (LL, LH, HL, HH),
    effectively converting [B, C, H, W] -> [B, C*4, H/2, W/2].
    This captures multi-frequency information useful for segmentation.
    """
    
    def __init__(self):
        super().__init__()
        # Haar wavelet filters (normalized)
        # Low-pass filter: [1, 1] / sqrt(2)
        # High-pass filter: [1, -1] / sqrt(2)
        ll = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32) / 2.0   # LL: avg
        lh = torch.tensor([[1, 1], [-1, -1]], dtype=torch.float32) / 2.0  # LH: horizontal
        hl = torch.tensor([[1, -1], [1, -1]], dtype=torch.float32) / 2.0  # HL: vertical
        hh = torch.tensor([[1, -1], [-1, 1]], dtype=torch.float32) / 2.0  # HH: diagonal
        
        # Stack filters: [4, 1, 2, 2]
        filters = torch.stack([ll, lh, hl, hh], dim=0).unsqueeze(1)
        self.register_buffer('filters', filters)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor [B, C, H, W]
        Returns:
            Wavelet coefficients [B, C*4, H//2, W//2]
        """
        B, C, H, W = x.shape
        
        # Apply filters to each channel independently
        # Reshape to process all channels: [B*C, 1, H, W]
        x_reshape = x.view(B * C, 1, H, W)
        
        # Convolve with 4 wavelet filters, stride=2 for downsampling
        # Output: [B*C, 4, H//2, W//2]
        coeffs = F.conv2d(x_reshape, self.filters, stride=2, padding=0)
        
        # Reshape back: [B, C*4, H//2, W//2]
        _, _, H_out, W_out = coeffs.shape
        coeffs = coeffs.view(B, C * 4, H_out, W_out)
        
        return coeffs


class WaveletReconstruction(nn.Module):
    """
    Inverse 2D Discrete Wavelet Transform (optional, for reconstruction).
    Converts [B, C*4, H, W] -> [B, C, H*2, W*2]
    """
    
    def __init__(self):
        super().__init__()
        # Inverse Haar filters (transposed)
        ll = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32) / 2.0
        lh = torch.tensor([[1, 1], [-1, -1]], dtype=torch.float32) / 2.0
        hl = torch.tensor([[1, -1], [1, -1]], dtype=torch.float32) / 2.0
        hh = torch.tensor([[1, -1], [-1, 1]], dtype=torch.float32) / 2.0
        
        filters = torch.stack([ll, lh, hl, hh], dim=0).unsqueeze(0)  # [1, 4, 2, 2]
        self.register_buffer('filters', filters)
    
    def forward(self, coeffs: torch.Tensor) -> torch.Tensor:
        """
        Args:
            coeffs: Wavelet coefficients [B, C*4, H, W] where C*4 is divisible by 4
        Returns:
            Reconstructed tensor [B, C, H*2, W*2]
        """
        B, C4, H, W = coeffs.shape
        C = C4 // 4
        
        # Reshape: [B*C, 4, H, W]
        coeffs = coeffs.view(B * C, 4, H, W)
        
        # Transposed conv for upsampling
        x = F.conv_transpose2d(coeffs, self.filters, stride=2, padding=0)
        
        # Reshape: [B, C, H*2, W*2]
        x = x.view(B, C, H * 2, W * 2)
        
        return x

## 2. Synaptic Plasticity & PlasticConvBlock

Implements Hebbian-inspired adaptive weight modulation:
- **SynapticPlasticity**: Learns per-channel gain and modulation based on input statistics
- **PlasticConvBlock**: Convolution with plasticity-enhanced weights

In [None]:
class SynapticPlasticity(nn.Module):
    """
    Learnable weight modulation inspired by Hebbian learning.
    
    Implements activity-dependent gain modulation where the effective
    weights are scaled based on learned per-channel parameters and
    input statistics. This allows the network to adaptively strengthen
    or weaken connections during training.
    """
    
    def __init__(self, channels: int):
        super().__init__()
        # Learnable per-channel gain (multiplicative)
        self.gain = nn.Parameter(torch.ones(channels))
        # Learnable threshold for activation-dependent modulation
        self.threshold = nn.Parameter(torch.zeros(channels))
        # Plasticity rate (how strongly input affects modulation)
        self.plasticity_rate = nn.Parameter(torch.ones(1) * 0.1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute plasticity-modulated features.
        
        Args:
            x: Input features [B, C, H, W]
        Returns:
            Modulated features [B, C, H, W]
        """
        B, C, H, W = x.shape
        
        # Compute per-channel activation statistics
        # Mean activation per channel: [B, C]
        channel_mean = x.mean(dim=(2, 3))
        
        # Hebbian modulation: strengthen active channels
        # modulation = sigmoid(plasticity_rate * (activation - threshold))
        modulation = torch.sigmoid(
            self.plasticity_rate * (channel_mean - self.threshold.view(1, -1))
        )  # [B, C]
        
        # Apply gain and modulation
        # Effective scale = gain * (1 + modulation)
        effective_scale = self.gain.view(1, -1, 1, 1) * (1 + modulation.view(B, C, 1, 1))
        
        return x * effective_scale


class PlasticConvBlock(nn.Module):
    """
    Convolutional block with synaptic plasticity, batch normalization, and GELU activation.
    
    This is the fundamental building block for encoder/decoder stages.
    The plasticity mechanism allows adaptive feature modulation during training.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        use_plasticity: bool = True,
        use_residual: bool = False,
        dropout_p: float = 0.0  # GPT-4 EDIT
    ):
        super().__init__()
        self.use_plasticity = use_plasticity
        self.use_residual = use_residual and (in_channels == out_channels) and (stride == 1)
        
        padding = kernel_size // 2
        
        # Main convolution path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Synaptic plasticity (applied after second conv)
        if use_plasticity:
            self.plasticity = SynapticPlasticity(out_channels)
        
        # Activation
        self.act = nn.GELU()
        self.dropout = nn.Dropout2d(dropout_p) if dropout_p > 0.0 else nn.Identity()  # GPT-4 EDIT

        
        # Residual projection if dimensions change
        if use_residual and (in_channels != out_channels or stride != 1):
            self.residual_proj = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
        else:
            self.residual_proj = None
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor [B, C_in, H, W]
        Returns:
            Output tensor [B, C_out, H', W']
        """
        identity = x
        
        # First conv block
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act(out)
        
        # Second conv block
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Apply synaptic plasticity
        if self.use_plasticity:
            out = self.plasticity(out)
        
        # Residual connection
        if self.use_residual:
            if self.residual_proj is not None:
                identity = self.residual_proj(identity)
            out = out + identity
        
        out = self.dropout(out)         # GPT-4 EDIT
        out = self.act(out)
        
        return out

## 3. Hyperbolic Convolution Block

Operations in the Poincaré ball model of hyperbolic space:
- **exp_map_zero**: Maps Euclidean vectors to the Poincaré ball (from origin)
- **log_map_zero**: Maps Poincaré ball vectors back to Euclidean (to origin)
- **HyperbolicConvBlock**: Performs convolution in hyperbolic space, better capturing hierarchical structures

In [None]:
# ============================================================================
# Hyperbolic Geometry Helpers (Poincaré Ball Model)
# ============================================================================

def exp_map_zero(v: torch.Tensor, c: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    """
    Exponential map from the origin in the Poincaré ball.
    
    Maps a Euclidean vector v to a point on the Poincaré ball with curvature -c.
    Formula: exp_0(v) = tanh(sqrt(c) * ||v||) * v / (sqrt(c) * ||v||)
    
    Args:
        v: Euclidean vector [*, D] (tangent vector at origin)
        c: Curvature parameter (positive scalar, represents -1/c^2 curvature)
        eps: Small constant for numerical stability
    Returns:
        Point on Poincaré ball [*, D]
    """
    sqrt_c = torch.sqrt(c)
    v_norm = v.norm(dim=-1, keepdim=True).clamp(min=eps)
    
    # Clamp the argument to tanh to avoid saturation
    tanh_arg = (sqrt_c * v_norm).clamp(max=15.0)  # tanh(15) ≈ 1.0
    
    # tanh(sqrt(c) * ||v||) * v / (sqrt(c) * ||v||)
    return torch.tanh(tanh_arg) * v / (sqrt_c * v_norm + eps)


def log_map_zero(y: torch.Tensor, c: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    """
    Logarithmic map to the origin in the Poincaré ball.
    
    Maps a point y on the Poincaré ball back to Euclidean tangent space at origin.
    Formula: log_0(y) = arctanh(sqrt(c) * ||y||) * y / (sqrt(c) * ||y||)
    
    Args:
        y: Point on Poincaré ball [*, D]
        c: Curvature parameter (positive scalar)
        eps: Small constant for numerical stability
    Returns:
        Euclidean vector [*, D]
    """
    sqrt_c = torch.sqrt(c)
    y_norm = y.norm(dim=-1, keepdim=True).clamp(min=eps)
    
    # Clamp to ensure arctanh input is strictly in (-1, 1)
    # This is CRITICAL - arctanh(1) = inf, arctanh(>1) = NaN
    y_norm_scaled = (sqrt_c * y_norm).clamp(min=eps, max=1.0 - eps)
    
    # arctanh(sqrt(c) * ||y||) * y / (sqrt(c) * ||y||)
    return torch.arctanh(y_norm_scaled) * y / (sqrt_c * y_norm + eps)


def project_to_ball(x: torch.Tensor, c: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    """
    Project points to be inside the Poincaré ball.
    
    Ensures ||x|| < 1/sqrt(c) by scaling down points that are outside.
    
    Args:
        x: Points [*, D]
        c: Curvature parameter
        eps: Small margin from boundary
    Returns:
        Projected points inside the ball [*, D]
    """
    max_norm = (1.0 - eps) / torch.sqrt(c)
    x_norm = x.norm(dim=-1, keepdim=True).clamp(min=eps)
    
    # Scale down if outside ball
    scale = torch.clamp(max_norm / x_norm, max=1.0)
    return x * scale


def mobius_add(x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    """
    Möbius addition in the Poincaré ball.
    
    This is the hyperbolic analog of vector addition.
    
    Args:
        x, y: Points on Poincaré ball [*, D]
        c: Curvature parameter
        eps: Numerical stability constant
    Returns:
        Result of Möbius addition [*, D]
    """
    x_sq = (x * x).sum(dim=-1, keepdim=True)
    y_sq = (y * y).sum(dim=-1, keepdim=True)
    xy = (x * y).sum(dim=-1, keepdim=True)
    
    num = (1 + 2 * c * xy + c * y_sq) * x + (1 - c * x_sq) * y
    denom = (1 + 2 * c * xy + c * c * x_sq * y_sq).clamp(min=eps)
    
    result = num / denom
    
    # Project result back into ball
    return project_to_ball(result, c, eps)


class HyperbolicConvBlock(nn.Module):
    """
    Convolution in hyperbolic (Poincaré ball) space.
    
    The forward pass:
    1. Normalize input features to bounded range
    2. Map to Poincaré ball via exponential map
    3. Perform convolution in tangent space
    4. Map result back and normalize
    
    Hyperbolic geometry is well-suited for capturing hierarchical structures
    in medical images (organ → region → tissue → boundary).
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        curvature: float = 1.0,
        learnable_curvature: bool = True
    ):
        super().__init__()
        
        # Curvature parameter (positive value, represents |c| in Poincaré ball)
        if learnable_curvature:
            # Initialize with log so softplus gives ~curvature
            init_val = math.log(math.exp(curvature) - 1)  # inverse of softplus
            self.curvature = nn.Parameter(torch.tensor(init_val))
        else:
            self.register_buffer('curvature', torch.tensor(curvature))
        
        # Input normalization to control feature magnitudes
        self.input_norm = nn.GroupNorm(min(8, in_channels), in_channels)
        
        # Convolution in tangent space
        padding = kernel_size // 2
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.bias = nn.Parameter(torch.zeros(out_channels))
        
        # Output normalization (more stable in hyperbolic setting)
        self.output_norm = nn.GroupNorm(min(8, out_channels), out_channels)
        self.act = nn.GELU()
        
        # Scale factor to keep features bounded before hyperbolic mapping
        self.scale = nn.Parameter(torch.ones(1) * 0.1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor [B, C_in, H, W]
        Returns:
            Output tensor [B, C_out, H, W]
        """
        B, C, H, W = x.shape
        
        # Ensure curvature is positive and bounded
        c = F.softplus(self.curvature).clamp(min=0.1, max=10.0)
        
        # Normalize input to control magnitudes
        x = self.input_norm(x)
        
        # Scale features to be small (important for hyperbolic stability)
        x_scaled = x * torch.abs(self.scale)
        
        # Reshape for hyperbolic operations: [B, H, W, C]
        x_bhwc = x_scaled.permute(0, 2, 3, 1).contiguous()
        
        # Map to Poincaré ball
        x_hyp = exp_map_zero(x_bhwc, c)
        
        # Project to ensure we're inside the ball
        x_hyp = project_to_ball(x_hyp, c)
        
        # Map back to tangent space for convolution
        x_tangent = log_map_zero(x_hyp, c)
        
        # Back to conv format: [B, C, H, W]
        x_tangent = x_tangent.permute(0, 3, 1, 2).contiguous()
        
        # Apply convolution in tangent space
        out = self.conv(x_tangent)
        
        # Add bias
        out = out + self.bias.view(1, -1, 1, 1)
        
        # Normalize and activate (stay in Euclidean space for stability)
        out = self.output_norm(out)
        out = self.act(out)
        
        return out

## 4. Temporal Attention

Cross-frame attention mechanism that aggregates information from T=3 consecutive slices, focusing on the center frame. Uses the center slice as query and all slices as keys/values.

In [None]:
class TemporalAttention(nn.Module):
    """
    Cross-frame attention for aggregating temporal context from T consecutive slices.
    
    Uses the center frame as the query and all frames as keys/values.
    This allows the model to leverage adjacent slice information while
    ultimately producing a segmentation for the center slice.
    
    Input: [B, T, C, H, W] - T frames of features
    Output: [B, C, H, W] - Temporally-attended center frame features
    """
    
    def __init__(self, embed_dim: int, num_heads: int = 4, dropout: float = 0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        # Query projection (for center frame only)
        self.q_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=False)
        
        # Key and Value projections (for all frames)
        self.k_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=False)
        self.v_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=False)
        
        # Output projection
        self.out_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=False)
        
        # Learnable temporal position embeddings
        self.temporal_pos = nn.Parameter(torch.randn(1, 3, embed_dim, 1, 1) * 0.02)  # For T=3
        
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor [B, T, C, H, W]
        Returns:
            Output tensor [B, C, H, W] (center frame with temporal context)
        """
        B, T, C, H, W = x.shape
        
        # Add temporal position embeddings
        x = x + self.temporal_pos[:, :T]
        
        # Extract center frame for query
        center_idx = T // 2  # For T=3, this is index 1
        q_input = x[:, center_idx]  # [B, C, H, W]
        
        # Compute query from center frame
        q = self.q_proj(q_input)  # [B, C, H, W]
        
        # Compute keys and values from all frames
        # Reshape for batch processing: [B*T, C, H, W]
        x_flat = x.view(B * T, C, H, W)
        k = self.k_proj(x_flat)  # [B*T, C, H, W]
        v = self.v_proj(x_flat)  # [B*T, C, H, W]
        
        # Reshape for attention computation
        # Q: [B, num_heads, head_dim, H*W]
        q = q.view(B, self.num_heads, self.head_dim, H * W)
        
        # K, V: [B, T, num_heads, head_dim, H*W]
        k = k.view(B, T, self.num_heads, self.head_dim, H * W)
        v = v.view(B, T, self.num_heads, self.head_dim, H * W)
        
        # Temporal attention: each spatial position in center attends to same position across time
        # Attention: [B, num_heads, H*W, T]
        attn_logits = torch.einsum('bhdn,bthdm->bhnt', q, k) * self.scale
        
        # For each spatial position, only attend to same position across time
        # Simplification: aggregate across time dimension
        attn_logits = attn_logits.diagonal(dim1=-2, dim2=-1)  # [B, num_heads, min(H*W, T)]
        
        # Recompute: simpler version - spatial attention pooled across time
        # Q: [B, num_heads, H*W, head_dim] K: [B, T, num_heads, H*W, head_dim]
        q = q.permute(0, 1, 3, 2)  # [B, num_heads, H*W, head_dim]
        k = k.permute(0, 2, 3, 4, 1)  # [B, num_heads, head_dim, H*W, T]
        v = v.permute(0, 2, 3, 4, 1)  # [B, num_heads, head_dim, H*W, T]
        
        # Per-position temporal attention: for each spatial position, attend across T frames
        # [B, num_heads, H*W, T]
        attn = torch.einsum('bnsd,bndst->bnst', q, k) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        # Weighted sum: [B, num_heads, H*W, head_dim]
        out = torch.einsum('bnst,bndst->bnsd', attn, v)
        
        # Reshape back: [B, C, H, W]
        out = out.permute(0, 1, 3, 2).contiguous()  # [B, num_heads, head_dim, H*W]
        out = out.view(B, C, H, W)
        
        # Output projection
        out = self.out_proj(out)
        
        # Residual connection with center frame
        out = out + q_input
        
        # Layer norm (reshape for LN)
        out = out.permute(0, 2, 3, 1)  # [B, H, W, C]
        out = self.norm(out)
        out = out.permute(0, 3, 1, 2)  # [B, C, H, W]
        
        return out

## 5. INR Branch (Implicit Neural Representation)

Continuous coordinate-based refinement for precise boundary delineation:
- **PositionalEncoding**: Fourier feature encoding of 2D coordinates
- **INRBranch**: MLP that predicts per-pixel refinement logits from coordinates + features

In [None]:
class PositionalEncoding(nn.Module):
    """
    Fourier feature positional encoding for 2D coordinates.
    
    Maps coordinates from [-1, 1] to a higher-dimensional space using
    sinusoidal functions at multiple frequencies. This allows the MLP
    to learn high-frequency details (important for sharp boundaries).
    
    Based on: "Fourier Features Let Networks Learn High Frequency Functions"
    """
    
    def __init__(self, num_frequencies: int = 10, include_input: bool = True):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.include_input = include_input
        
        # Frequency bands: 2^0, 2^1, ..., 2^(L-1)
        freq_bands = 2.0 ** torch.linspace(0, num_frequencies - 1, num_frequencies)
        self.register_buffer('freq_bands', freq_bands)
        
        # Output dimension: 2 (xy) * num_freq * 2 (sin/cos) + optional 2 (raw coords)
        self.out_dim = 2 * num_frequencies * 2
        if include_input:
            self.out_dim += 2
    
    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        """
        Args:
            coords: Normalized coordinates [*, 2] in range [-1, 1]
        Returns:
            Encoded coordinates [*, out_dim]
        """
        # coords: [..., 2]
        # Scale by frequency bands: [..., 2, num_freq]
        scaled = coords.unsqueeze(-1) * self.freq_bands * math.pi
        
        # Apply sin and cos: [..., 2, num_freq, 2]
        encoded = torch.stack([torch.sin(scaled), torch.cos(scaled)], dim=-1)
        
        # Flatten: [..., 2 * num_freq * 2]
        encoded = encoded.view(*coords.shape[:-1], -1)
        
        # Optionally include raw coordinates
        if self.include_input:
            encoded = torch.cat([coords, encoded], dim=-1)
        
        return encoded


def make_coord_grid(H: int, W: int, device: torch.device = None) -> torch.Tensor:
    """
    Create a grid of normalized 2D coordinates in [-1, 1].
    
    Args:
        H, W: Grid dimensions
        device: Target device
    Returns:
        Coordinate grid [H, W, 2]
    """
    # Create normalized coordinates
    y = torch.linspace(-1, 1, H, device=device)
    x = torch.linspace(-1, 1, W, device=device)
    
    # Create meshgrid
    yy, xx = torch.meshgrid(y, x, indexing='ij')
    
    # Stack: [H, W, 2]
    coords = torch.stack([xx, yy], dim=-1)
    
    return coords


class INRBranch(nn.Module):
    """
    Implicit Neural Representation branch for boundary refinement.
    
    Combines high-resolution coordinate encoding with image features
    to produce continuous, resolution-independent refinement signals.
    This helps recover fine boundary details lost in downsampling.
    
    Input: Features [B, C, H, W] + implicit coordinate grid
    Output: Refinement logits [B, 1, H, W]
    """
    
    def __init__(
        self,
        feature_dim: int,
        hidden_dim: int = 256,
        num_frequencies: int = 10,
        num_layers: int = 3
    ):
        super().__init__()
        
        # Positional encoding for coordinates
        self.pos_encoder = PositionalEncoding(num_frequencies, include_input=True)
        coord_dim = self.pos_encoder.out_dim
        
        # Feature projection (reduce channel dimension)
        self.feature_proj = nn.Conv2d(feature_dim, hidden_dim // 2, 1)
        
        # MLP: coord_encoding + projected_features -> refinement logit
        input_dim = coord_dim + hidden_dim // 2
        
        layers = []
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            out_dim = hidden_dim if i < num_layers - 1 else 1
            
            layers.append(nn.Linear(in_dim, out_dim))
            if i < num_layers - 1:
                layers.append(nn.GELU())
        
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, features: torch.Tensor, coords: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            features: Image features [B, C, H, W]
            coords: Optional pre-computed coordinates [H, W, 2]. 
                    If None, generates grid automatically.
        Returns:
            Refinement logits [B, 1, H, W]
        """
        B, C, H, W = features.shape
        device = features.device
        
        # Generate coordinate grid if not provided
        if coords is None:
            coords = make_coord_grid(H, W, device)  # [H, W, 2]
        
        # Encode coordinates
        coord_enc = self.pos_encoder(coords)  # [H, W, coord_dim]
        
        # Expand for batch: [B, H, W, coord_dim]
        coord_enc = coord_enc.unsqueeze(0).expand(B, -1, -1, -1)
        
        # Project features and reshape: [B, H, W, hidden_dim//2]
        feat_proj = self.feature_proj(features)  # [B, hidden_dim//2, H, W]
        feat_proj = feat_proj.permute(0, 2, 3, 1)  # [B, H, W, hidden_dim//2]
        
        # Concatenate coordinates and features
        combined = torch.cat([coord_enc, feat_proj], dim=-1)  # [B, H, W, input_dim]
        
        # MLP forward
        out = self.mlp(combined)  # [B, H, W, 1]
        
        # Reshape to standard format: [B, 1, H, W]
        out = out.permute(0, 3, 1, 2)
        
        return out

## 6. Validation Tests

Instantiate each module with dummy inputs to verify shapes and forward passes work correctly.

In [None]:
def test_all_modules():
    """
    Test all implemented modules with dummy data to verify shapes.
    """
    print("=" * 60)
    print("HANS-Net Module Validation Tests")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}\n")
    
    # Test parameters
    B, T, C, H, W = 2, 3, 1, 128, 128  # Batch, Time, Channels, Height, Width
    base_ch = 32
    
    # -------------------------------------------------------------------------
    # 1. Test WaveletDecomposition
    # -------------------------------------------------------------------------
    print("-" * 40)
    print("1. WaveletDecomposition")
    print("-" * 40)
    
    wavelet = WaveletDecomposition().to(device)
    x_in = torch.randn(B, C, H, W, device=device)
    x_wav = wavelet(x_in)
    
    print(f"   Input shape:  {list(x_in.shape)}")
    print(f"   Output shape: {list(x_wav.shape)}")
    print(f"   Expected:     [B={B}, C*4={C*4}, H/2={H//2}, W/2={W//2}]")
    assert x_wav.shape == (B, C * 4, H // 2, W // 2), "Shape mismatch!"
    print("   ✓ PASSED\n")
    
    # -------------------------------------------------------------------------
    # 2. Test SynapticPlasticity
    # -------------------------------------------------------------------------
    print("-" * 40)
    print("2. SynapticPlasticity")
    print("-" * 40)
    
    plasticity = SynapticPlasticity(channels=base_ch).to(device)
    x_in = torch.randn(B, base_ch, H, W, device=device)
    x_out = plasticity(x_in)
    
    print(f"   Input shape:  {list(x_in.shape)}")
    print(f"   Output shape: {list(x_out.shape)}")
    print(f"   Expected:     [B={B}, C={base_ch}, H={H}, W={W}]")
    assert x_out.shape == x_in.shape, "Shape mismatch!"
    print("   ✓ PASSED\n")
    
    # -------------------------------------------------------------------------
    # 3. Test PlasticConvBlock
    # -------------------------------------------------------------------------
    print("-" * 40)
    print("3. PlasticConvBlock")
    print("-" * 40)
    
    plastic_conv = PlasticConvBlock(in_channels=base_ch, out_channels=base_ch * 2).to(device)
    x_in = torch.randn(B, base_ch, H, W, device=device)
    x_out = plastic_conv(x_in)
    
    print(f"   Input shape:  {list(x_in.shape)}")
    print(f"   Output shape: {list(x_out.shape)}")
    print(f"   Expected:     [B={B}, C={base_ch*2}, H={H}, W={W}]")
    assert x_out.shape == (B, base_ch * 2, H, W), "Shape mismatch!"
    print("   ✓ PASSED\n")
    
    # -------------------------------------------------------------------------
    # 4. Test Hyperbolic Helpers
    # -------------------------------------------------------------------------
    print("-" * 40)
    print("4. Hyperbolic Maps (exp_map_zero, log_map_zero)")
    print("-" * 40)
    
    c = torch.tensor(1.0, device=device)
    v = torch.randn(B, H, W, base_ch, device=device) * 0.1  # Small vectors
    
    # Exp map
    v_hyp = exp_map_zero(v, c)
    print(f"   exp_map_zero input:  {list(v.shape)}")
    print(f"   exp_map_zero output: {list(v_hyp.shape)}")
    
    # Verify points are inside Poincaré ball (norm < 1)
    norms = v_hyp.norm(dim=-1)
    print(f"   Max norm (should be < 1): {norms.max().item():.4f}")
    assert norms.max() < 1.0, "Points outside Poincaré ball!"
    
    # Log map (inverse)
    v_back = log_map_zero(v_hyp, c)
    print(f"   log_map_zero output: {list(v_back.shape)}")
    
    # Check round-trip reconstruction
    reconstruction_error = (v - v_back).abs().max().item()
    print(f"   Round-trip error: {reconstruction_error:.6f}")
    assert reconstruction_error < 1e-4, "Round-trip reconstruction failed!"
    print("   ✓ PASSED\n")
    
    # -------------------------------------------------------------------------
    # 5. Test HyperbolicConvBlock
    # -------------------------------------------------------------------------
    print("-" * 40)
    print("5. HyperbolicConvBlock")
    print("-" * 40)
    
    hyp_conv = HyperbolicConvBlock(in_channels=base_ch, out_channels=base_ch * 2).to(device)
    x_in = torch.randn(B, base_ch, H // 4, W // 4, device=device) * 0.1
    x_out = hyp_conv(x_in)
    
    print(f"   Input shape:  {list(x_in.shape)}")
    print(f"   Output shape: {list(x_out.shape)}")
    print(f"   Expected:     [B={B}, C={base_ch*2}, H={H//4}, W={W//4}]")
    assert x_out.shape == (B, base_ch * 2, H // 4, W // 4), "Shape mismatch!"
    print(f"   Learned curvature: {F.softplus(hyp_conv.curvature).item():.4f}")
    print("   ✓ PASSED\n")
    
    # -------------------------------------------------------------------------
    # 6. Test TemporalAttention
    # -------------------------------------------------------------------------
    print("-" * 40)
    print("6. TemporalAttention")
    print("-" * 40)
    
    temp_attn = TemporalAttention(embed_dim=base_ch * 2, num_heads=4).to(device)
    x_in = torch.randn(B, T, base_ch * 2, H // 2, W // 2, device=device)
    x_out = temp_attn(x_in)
    
    print(f"   Input shape:  {list(x_in.shape)} (B, T, C, H, W)")
    print(f"   Output shape: {list(x_out.shape)} (B, C, H, W)")
    print(f"   Expected:     [B={B}, C={base_ch*2}, H={H//2}, W={W//2}]")
    assert x_out.shape == (B, base_ch * 2, H // 2, W // 2), "Shape mismatch!"
    print("   ✓ PASSED\n")
    
    # -------------------------------------------------------------------------
    # 7. Test INRBranch
    # -------------------------------------------------------------------------
    print("-" * 40)
    print("7. INRBranch")
    print("-" * 40)
    
    inr = INRBranch(feature_dim=base_ch, hidden_dim=128, num_frequencies=10).to(device)
    x_in = torch.randn(B, base_ch, H, W, device=device)
    x_out = inr(x_in)
    
    print(f"   Input shape:  {list(x_in.shape)}")
    print(f"   Output shape: {list(x_out.shape)}")
    print(f"   Expected:     [B={B}, 1, H={H}, W={W}]")
    assert x_out.shape == (B, 1, H, W), "Shape mismatch!"
    print("   ✓ PASSED\n")
    
    # -------------------------------------------------------------------------
    # Summary
    # -------------------------------------------------------------------------
    print("=" * 60)
    print("ALL TESTS PASSED! ✓")
    print("=" * 60)
    
    # Print parameter counts
    print("\nParameter counts:")
    modules = {
        'WaveletDecomposition': wavelet,
        'SynapticPlasticity': plasticity,
        'PlasticConvBlock': plastic_conv,
        'HyperbolicConvBlock': hyp_conv,
        'TemporalAttention': temp_attn,
        'INRBranch': inr
    }
    
    total = 0
    for name, module in modules.items():
        params = sum(p.numel() for p in module.parameters())
        total += params
        print(f"   {name}: {params:,} parameters")
    print(f"   {'Total':}: {total:,} parameters")


# Run the tests
test_all_modules()

## 7. HANSNet - Complete U-Net Architecture

The full HANS-Net model combining all components:
- **Encoder**: Wavelet decomposition + PlasticConvBlocks with downsampling
- **Temporal Attention**: Fuses T=3 slices at mid-level, outputs center slice features
- **Bottleneck**: HyperbolicConvBlock for hierarchical feature learning
- **Decoder**: Upsampling + skip connections from center slice encoder features
- **INR Refinement**: Boundary refinement via implicit neural representation

In [None]:
class HANSNet(nn.Module):
    """
    Hyperbolic Attention Network for Segmentation (HANS-Net).
    
    A U-Net style architecture for liver/tumor CT segmentation that combines:
    - Wavelet decomposition for multi-frequency feature extraction
    - Synaptic plasticity for adaptive feature learning
    - Temporal attention to fuse information from adjacent CT slices
    - Hyperbolic convolutions in the bottleneck for hierarchical representations
    - INR branch for boundary refinement
    
    Input:  [B, T=3, 1, H, W] - 3 consecutive CT slices
    Output: [B, 1, H, W]      - Segmentation logits for center slice
    """
    
    def __init__(self, base_channels: int = 32, num_classes: int = 1):
        super().__init__()
        
        self.base_channels = base_channels
        self.num_classes = num_classes
        
        # Channel progression: base -> 2x -> 4x -> 8x
        c1 = base_channels       # 32
        c2 = base_channels * 2   # 64
        c3 = base_channels * 4   # 128
        c4 = base_channels * 8   # 256
        
        # =====================================================================
        # ENCODER (processes all T slices)
        # =====================================================================
        
        # Level 1: Wavelet decomposition + initial conv
        # Input: [B*T, 1, H, W] -> Wavelet: [B*T, 4, H/2, W/2] -> Conv: [B*T, c1, H/2, W/2]
        self.wavelet = WaveletDecomposition()
        self.enc1 = PlasticConvBlock(4, c1, use_plasticity=True)  # 4 wavelet subbands -> c1
        
        # Level 2: Downsample + conv
        # [B*T, c1, H/2, W/2] -> [B*T, c2, H/4, W/4]
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = PlasticConvBlock(c1, c2, use_plasticity=True)
        
        # Level 3: Downsample + conv (temporal attention applied here)
        # [B*T, c2, H/4, W/4] -> [B*T, c3, H/8, W/8]
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = PlasticConvBlock(c2, c3, use_plasticity=True)
        
        # =====================================================================
        # TEMPORAL ATTENTION (fuses T slices -> center slice only)
        # =====================================================================
        
        # Input: [B, T, c3, H/8, W/8] -> Output: [B, c3, H/8, W/8]
        self.temporal_attn = TemporalAttention(embed_dim=c3, num_heads=4)
        
        # =====================================================================
        # BOTTLENECK (hyperbolic convolution)
        # =====================================================================
        
        # Downsample and apply hyperbolic conv
        # [B, c3, H/8, W/8] -> [B, c4, H/16, W/16]
        self.pool3 = nn.MaxPool2d(2)
        self.bottleneck = HyperbolicConvBlock(c3, c4, curvature=1.0, learnable_curvature=True)
        
        # =====================================================================
        # DECODER (U-Net style with skip connections from center slice)
        # =====================================================================
        
        # Decoder Level 3: Upsample + concat skip + conv
        # [B, c4, H/16, W/16] -> [B, c3, H/8, W/8]
        self.up3 = nn.ConvTranspose2d(c4, c3, kernel_size=2, stride=2)
        self.dec3 = PlasticConvBlock(c3 + c3, c3, use_plasticity=True, dropout_p=0.3)  # Skip from enc3 # GPT-4 EDIT
        
        # Decoder Level 2: Upsample + concat skip + conv
        # [B, c3, H/8, W/8] -> [B, c2, H/4, W/4]
        self.up2 = nn.ConvTranspose2d(c3, c2, kernel_size=2, stride=2)
        self.dec2 = PlasticConvBlock(c2 + c2, c2, use_plasticity=True, dropout_p=0.3)  # Skip from enc2 # GPT-4 EDIT
        
        # Decoder Level 1: Upsample + concat skip + conv
        # [B, c2, H/4, W/4] -> [B, c1, H/2, W/2]
        self.up1 = nn.ConvTranspose2d(c2, c1, kernel_size=2, stride=2)
        self.dec1 = PlasticConvBlock(c1 + c1, c1, use_plasticity=True, dropout_p=0.3)  # Skip from enc1 # GPT-4 EDIT
        
        # Final upsample to original resolution
        # [B, c1, H/2, W/2] -> [B, c1, H, W]
        self.final_up = nn.ConvTranspose2d(c1, c1, kernel_size=2, stride=2)
        self.final_conv = PlasticConvBlock(c1, c1, use_plasticity=True, dropout_p=0.3)  # GPT-4 EDIT
        
        # =====================================================================
        # OUTPUT HEADS
        # =====================================================================
        
        # Coarse segmentation head
        self.seg_head = nn.Conv2d(c1, num_classes, kernel_size=1)
        
        # INR refinement branch
        self.inr_branch = INRBranch(
            feature_dim=c1,
            hidden_dim=128,
            num_frequencies=10,
            num_layers=3
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of HANS-Net.
        
        Args:
            x: Input tensor [B, T=3, 1, H, W] - 3 consecutive CT slices
        Returns:
            Segmentation logits [B, 1, H, W] for center slice
        """
        B, T, C, H, W = x.shape
        assert T == 3, f"Expected T=3 slices, got T={T}"
        assert C == 1, f"Expected C=1 channel, got C={C}"
        
        # =====================================================================
        # ENCODER - Process all T slices together
        # =====================================================================
        
        # Merge temporal and batch dimensions: [B*T, 1, H, W]
        x_flat = x.view(B * T, C, H, W)
        
        # Level 1: Wavelet + PlasticConv
        # [B*T, 1, H, W] -> [B*T, 4, H/2, W/2] -> [B*T, c1, H/2, W/2]
        e1 = self.wavelet(x_flat)
        e1 = self.enc1(e1)
        
        # Level 2: Pool + PlasticConv
        # [B*T, c1, H/2, W/2] -> [B*T, c2, H/4, W/4]
        e2 = self.pool1(e1)
        e2 = self.enc2(e2)
        
        # Level 3: Pool + PlasticConv
        # [B*T, c2, H/4, W/4] -> [B*T, c3, H/8, W/8]
        e3 = self.pool2(e2)
        e3 = self.enc3(e3)
        
        # =====================================================================
        # TEMPORAL ATTENTION - Fuse slices, extract center
        # =====================================================================
        
        # Reshape for temporal attention: [B, T, c3, H/8, W/8]
        _, c3_ch, h3, w3 = e3.shape
        e3_temporal = e3.view(B, T, c3_ch, h3, w3)
        
        # Apply temporal attention -> [B, c3, H/8, W/8]
        f_center = self.temporal_attn(e3_temporal)
        
        # =====================================================================
        # BOTTLENECK - Hyperbolic convolution
        # =====================================================================
        
        # Downsample and apply hyperbolic conv
        # [B, c3, H/8, W/8] -> [B, c4, H/16, W/16]
        bottleneck = self.pool3(f_center)
        bottleneck = self.bottleneck(bottleneck)
        
        # =====================================================================
        # EXTRACT CENTER SLICE FEATURES FOR SKIP CONNECTIONS
        # =====================================================================
        
        center_idx = T // 2  # Index 1 for T=3
        
        # e1 skip: [B*T, c1, H/2, W/2] -> [B, c1, H/2, W/2]
        _, c1_ch, h1, w1 = e1.shape
        e1_center = e1.view(B, T, c1_ch, h1, w1)[:, center_idx]
        
        # e2 skip: [B*T, c2, H/4, W/4] -> [B, c2, H/4, W/4]
        _, c2_ch, h2, w2 = e2.shape
        e2_center = e2.view(B, T, c2_ch, h2, w2)[:, center_idx]
        
        # e3 is already fused via temporal attention, but we can use f_center
        # For skip connection at level 3, use f_center (already [B, c3, H/8, W/8])
        e3_center = f_center
        
        # =====================================================================
        # DECODER - U-Net style upsampling with skip connections
        # =====================================================================
        
        # Decoder Level 3: Upsample + skip from e3_center
        # [B, c4, H/16, W/16] -> [B, c3, H/8, W/8]
        d3 = self.up3(bottleneck)
        d3 = torch.cat([d3, e3_center], dim=1)  # [B, c3+c3, H/8, W/8]
        d3 = self.dec3(d3)  # [B, c3, H/8, W/8]
        
        # Decoder Level 2: Upsample + skip from e2_center
        # [B, c3, H/8, W/8] -> [B, c2, H/4, W/4]
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2_center], dim=1)  # [B, c2+c2, H/4, W/4]
        d2 = self.dec2(d2)  # [B, c2, H/4, W/4]
        
        # Decoder Level 1: Upsample + skip from e1_center
        # [B, c2, H/4, W/4] -> [B, c1, H/2, W/2]
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1_center], dim=1)  # [B, c1+c1, H/2, W/2]
        d1 = self.dec1(d1)  # [B, c1, H/2, W/2]
        
        # Final upsample to original resolution
        # [B, c1, H/2, W/2] -> [B, c1, H, W]
        dec_out = self.final_up(d1)
        dec_out = self.final_conv(dec_out)
        
        # =====================================================================
        # OUTPUT - Coarse logits + INR refinement
        # =====================================================================
        
        # Coarse segmentation logits
        coarse_logits = self.seg_head(dec_out)  # [B, 1, H, W]
        
        # INR boundary refinement
        refine_logits = self.inr_branch(dec_out)  # [B, 1, H, W]
        
        # Combine coarse and refined predictions
        final_logits = coarse_logits + refine_logits
        
        return final_logits


# =============================================================================
# Test HANSNet
# =============================================================================

def test_hansnet():
    """Test the complete HANS-Net model."""
    print("=" * 60)
    print("HANS-Net Complete Model Test")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}\n")
    
    # Instantiate model
    model = HANSNet(base_channels=32, num_classes=1).to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print()
    
    # Create dummy input: [B=2, T=3, C=1, H=128, W=128]
    x_dummy = torch.randn(2, 3, 1, 128, 128, device=device)
    print(f"Input shape:  {list(x_dummy.shape)} [B, T, C, H, W]")
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        output = model(x_dummy)
    
    print(f"Output shape: {list(output.shape)} [B, num_classes, H, W]")
    print(f"Expected:     [2, 1, 128, 128]")
    
    assert output.shape == (2, 1, 128, 128), f"Shape mismatch! Got {output.shape}"
    print("\n✓ HANS-Net forward pass successful!")
    
    # Test with different spatial sizes
    print("\n" + "-" * 40)
    print("Testing with different spatial sizes...")
    print("-" * 40)
    
    for size in [64, 128, 256]:
        x_test = torch.randn(1, 3, 1, size, size, device=device)
        with torch.no_grad():
            out_test = model(x_test)
        print(f"   Input [{1}, 3, 1, {size}, {size}] -> Output {list(out_test.shape)} ✓")
    
    print("\n" + "=" * 60)
    print("ALL TESTS PASSED!")
    print("=" * 60)
    
    return model


# Run the test
hansnet_model = test_hansnet()

### **DropOut Check**

In [None]:
model = HANSNet().cpu()
model.train()  # VERY IMPORTANT — dropout active

x = torch.randn(1, 3, 1, 128, 128)

preds = []
for _ in range(4):
    out = model(x)
    preds.append(out.detach().numpy())

print("Outputs equal?", 
      (preds[0] == preds[1]).all(),
      (preds[1] == preds[2]).all())

In [None]:
model = HANSNet().cpu()
x = torch.randn(1, 3, 1, 128, 128)

from torch.nn.functional import sigmoid

model.train()  # activate dropout

samples = []
for _ in range(8):
    samples.append(sigmoid(model(x)))

samples = torch.stack(samples)  # [8,1,128,128]

mean_map = samples.mean(dim=0)
var_map  = samples.var(dim=0)

print("Mean:", mean_map.shape)
print("Variance:", var_map.shape)
print("Mean-variance sum:", var_map.sum().item())

### **Sanity Check (For NaNs)**

In [None]:
model = HANSNet().cpu()
model.train()

x = torch.randn(1, 3, 1, 128, 128)

logits = model(x)
print("Any NaNs in logits?", torch.isnan(logits).any().item())
print("Any Infs in logits?", torch.isinf(logits).any().item())


In [None]:
model = HANSNet().cpu()
model.train()  # activate dropout

# Simulate normalized CT-like input (0–1)
x = torch.rand(1, 3, 1, 128, 128)

samples = []
for _ in range(8):
    samples.append(torch.sigmoid(model(x)))

samples = torch.stack(samples, dim=0)  # [8,1,1,128,128]
mean_map = samples.mean(dim=0)
var_map  = samples.var(dim=0)

print("Mean:", mean_map.shape)
print("Variance:", var_map.shape)
print("Any NaNs in mean?", torch.isnan(mean_map).any().item())
print("Any NaNs in var?", torch.isnan(var_map).any().item())
print("Mean-variance sum:", var_map.sum().item())


In [None]:
# =============================================================================
# Comprehensive NaN/Stability Test After Fix
# =============================================================================

def test_stability():
    """Test model stability after fixing hyperbolic operations."""
    print("=" * 60)
    print("Stability Test - Checking for NaN/Inf values")
    print("=" * 60)
    
    device = torch.device('cpu')
    
    # Test 1: HyperbolicConvBlock in isolation
    print("\n1. Testing HyperbolicConvBlock isolation...")
    hyp = HyperbolicConvBlock(128, 256).to(device)
    x_hyp = torch.randn(2, 128, 8, 8, device=device)
    
    out_hyp = hyp(x_hyp)
    print(f"   Input range: [{x_hyp.min():.3f}, {x_hyp.max():.3f}]")
    print(f"   Output range: [{out_hyp.min():.3f}, {out_hyp.max():.3f}]")
    print(f"   NaN in output: {torch.isnan(out_hyp).any().item()}")
    print(f"   Inf in output: {torch.isinf(out_hyp).any().item()}")
    assert not torch.isnan(out_hyp).any(), "NaN in HyperbolicConvBlock!"
    print("   ✓ PASSED")
    
    # Test 2: Full HANSNet forward pass
    print("\n2. Testing HANSNet single forward pass...")
    model = HANSNet(base_channels=32).to(device)
    model.eval()
    
    x = torch.randn(1, 3, 1, 128, 128, device=device)
    with torch.no_grad():
        out = model(x)
    
    print(f"   Input range: [{x.min():.3f}, {x.max():.3f}]")
    print(f"   Output range: [{out.min():.3f}, {out.max():.3f}]")
    print(f"   NaN in output: {torch.isnan(out).any().item()}")
    print(f"   Inf in output: {torch.isinf(out).any().item()}")
    assert not torch.isnan(out).any(), "NaN in HANSNet output!"
    print("   ✓ PASSED")
    
    # Test 3: Multiple forward passes with dropout (train mode)
    print("\n3. Testing HANSNet with dropout (train mode)...")
    model.train()
    
    nan_found = False
    for i in range(10):
        x = torch.randn(1, 3, 1, 128, 128, device=device)
        out = model(x)
        if torch.isnan(out).any():
            print(f"   NaN found at iteration {i+1}!")
            nan_found = True
            break
    
    if not nan_found:
        print("   No NaN in 10 forward passes")
        print("   ✓ PASSED")
    
    # Test 4: MC Dropout uncertainty estimation
    print("\n4. Testing MC Dropout uncertainty estimation...")
    model.train()
    
    x = torch.rand(1, 3, 1, 128, 128, device=device)  # Normalized input [0, 1]
    samples = []
    
    for _ in range(8):
        with torch.no_grad():
            out = torch.sigmoid(model(x))
        samples.append(out)
        
    samples = torch.stack(samples, dim=0)
    mean_map = samples.mean(dim=0)
    var_map = samples.var(dim=0)
    
    print(f"   Mean range: [{mean_map.min():.4f}, {mean_map.max():.4f}]")
    print(f"   Variance range: [{var_map.min():.6f}, {var_map.max():.6f}]")
    print(f"   NaN in mean: {torch.isnan(mean_map).any().item()}")
    print(f"   NaN in variance: {torch.isnan(var_map).any().item()}")
    print(f"   Variance sum: {var_map.sum().item():.4f}")
    
    assert not torch.isnan(mean_map).any(), "NaN in mean!"
    assert not torch.isnan(var_map).any(), "NaN in variance!"
    print("   ✓ PASSED")
    
    # Test 5: Gradient flow check
    print("\n5. Testing gradient flow...")
    model.train()
    x = torch.randn(1, 3, 1, 64, 64, device=device, requires_grad=False)
    target = torch.randint(0, 2, (1, 1, 64, 64), device=device).float()
    
    out = model(x)
    loss = F.binary_cross_entropy_with_logits(out, target)
    loss.backward()
    
    grad_ok = True
    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print(f"   NaN gradient in {name}!")
                grad_ok = False
                break
    
    if grad_ok:
        print("   No NaN gradients found")
        print("   ✓ PASSED")
    
    print("\n" + "=" * 60)
    print("ALL STABILITY TESTS PASSED!")
    print("=" * 60)

test_stability()

## 8. Loss Functions & MC-Dropout Inference Utilities

Helper functions for training and uncertainty-aware inference:
- **dice_loss**: Soft Dice loss for segmentation training
- **mc_predict**: Monte Carlo Dropout for uncertainty estimation
- **visualize_uncertainty**: Simple visualization of predictions and uncertainty maps

In [None]:
def dice_loss(probs: torch.Tensor, targets: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    Compute soft Dice loss for binary segmentation.
    
    Dice coefficient measures the overlap between predicted probabilities
    and ground truth masks. Dice loss = 1 - Dice coefficient.
    
    Args:
        probs: Model output after sigmoid, shape [B, 1, H, W], values in [0, 1]
        targets: Ground-truth binary mask, shape [B, 1, H, W], values in {0, 1}
        eps: Small constant for numerical stability
    
    Returns:
        Scalar tensor: mean Dice loss over the batch
    
    Formula:
        dice_coeff = (2 * intersection + eps) / (sum_probs + sum_targets + eps)
        dice_loss = 1 - dice_coeff
    """
    # Ensure same shape
    assert probs.shape == targets.shape, f"Shape mismatch: {probs.shape} vs {targets.shape}"
    
    B = probs.shape[0]
    
    # Flatten spatial dimensions: [B, 1, H, W] -> [B, H*W]
    probs_flat = probs.view(B, -1)
    targets_flat = targets.view(B, -1)
    
    # Compute intersection and sums per sample
    intersection = (probs_flat * targets_flat).sum(dim=1)  # [B]
    sum_probs = probs_flat.sum(dim=1)                       # [B]
    sum_targets = targets_flat.sum(dim=1)                   # [B]
    
    # Dice coefficient per sample
    dice_coeff = (2.0 * intersection + eps) / (sum_probs + sum_targets + eps)  # [B]
    
    # Dice loss = 1 - dice_coeff, averaged over batch
    loss = 1.0 - dice_coeff.mean()
    
    return loss


def combined_loss(logits: torch.Tensor, targets: torch.Tensor, 
                  bce_weight: float = 0.5, dice_weight: float = 0.5) -> torch.Tensor:
    """
    Combined BCE + Dice loss for segmentation.
    
    Args:
        logits: Raw model output (before sigmoid), shape [B, 1, H, W]
        targets: Ground-truth binary mask, shape [B, 1, H, W]
        bce_weight: Weight for BCE loss
        dice_weight: Weight for Dice loss
    
    Returns:
        Scalar tensor: weighted sum of BCE and Dice losses
    """
    # BCE with logits (numerically stable)
    bce = F.binary_cross_entropy_with_logits(logits, targets)
    
    # Dice loss (needs probabilities)
    probs = torch.sigmoid(logits)
    dice = dice_loss(probs, targets)
    
    return bce_weight * bce + dice_weight * dice

In [None]:
def mc_predict(model: nn.Module, x: torch.Tensor, n_samples: int = 10) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Monte Carlo Dropout prediction for uncertainty estimation.
    
    Performs multiple stochastic forward passes with dropout enabled
    to estimate prediction uncertainty (epistemic uncertainty).
    
    Args:
        model: HANSNet instance (or any model with Dropout layers)
        x: Input tensor [B, T=3, 1, H, W] (3 consecutive CT slices)
        n_samples: Number of stochastic forward passes (default: 10)
    
    Returns:
        mean_probs: Mean predicted probabilities [B, 1, H, W]
        var_probs: Variance of predictions (uncertainty) [B, 1, H, W]
    
    Usage:
        >>> model = HANSNet(base_channels=32)
        >>> x = torch.rand(1, 3, 1, 128, 128)
        >>> mean_probs, var_probs = mc_predict(model, x, n_samples=10)
        >>> # High variance regions indicate model uncertainty
    
    Note:
        - Model is set to train() mode to activate Dropout layers
        - Gradients are disabled during sampling for efficiency
        - Higher n_samples gives more accurate uncertainty estimates but takes longer
    """
    # Store original training state
    was_training = model.training
    
    # Enable dropout by setting model to train mode
    model.train()
    
    # Collect samples
    samples = []
    
    with torch.no_grad():
        for _ in range(n_samples):
            # Forward pass
            logits = model(x)
            
            # Convert to probabilities
            probs = torch.sigmoid(logits)
            
            samples.append(probs)
    
    # Stack samples: [n_samples, B, 1, H, W]
    samples = torch.stack(samples, dim=0)
    
    # Compute statistics
    mean_probs = samples.mean(dim=0)  # [B, 1, H, W]
    var_probs = samples.var(dim=0)    # [B, 1, H, W]
    
    # Restore original training state
    if not was_training:
        model.eval()
    
    return mean_probs, var_probs


def mc_predict_with_samples(model: nn.Module, x: torch.Tensor, n_samples: int = 10) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Extended MC Dropout that also returns all individual samples.
    
    Useful for more detailed uncertainty analysis or custom aggregation.
    
    Args:
        model: HANSNet instance
        x: Input tensor [B, T=3, 1, H, W]
        n_samples: Number of stochastic forward passes
    
    Returns:
        mean_probs: Mean predicted probabilities [B, 1, H, W]
        var_probs: Variance of predictions [B, 1, H, W]
        all_samples: All prediction samples [n_samples, B, 1, H, W]
    """
    was_training = model.training
    model.train()
    
    samples = []
    
    with torch.no_grad():
        for _ in range(n_samples):
            logits = model(x)
            probs = torch.sigmoid(logits)
            samples.append(probs)
    
    samples = torch.stack(samples, dim=0)
    mean_probs = samples.mean(dim=0)
    var_probs = samples.var(dim=0)
    
    if not was_training:
        model.eval()
    
    return mean_probs, var_probs, samples

In [None]:
def visualize_uncertainty(ct_slice: torch.Tensor, 
                          mean_probs: torch.Tensor, 
                          var_probs: torch.Tensor,
                          threshold: float = 0.5,
                          figsize: Tuple[int, int] = (15, 4)) -> None:
    """
    Visualize CT slice, predicted segmentation, and uncertainty map.
    
    Creates a 1x4 subplot showing:
    - Original CT slice (grayscale)
    - Mean prediction probability (segmentation)
    - Binary prediction (thresholded)
    - Variance map (uncertainty heatmap)
    
    Args:
        ct_slice: CT image, shape [H, W] or [1, H, W] or [1, 1, H, W]
        mean_probs: Mean predicted probabilities [1, H, W] or [1, 1, H, W]
        var_probs: Variance (uncertainty) [1, H, W] or [1, 1, H, W]
        threshold: Threshold for binary prediction (default: 0.5)
        figsize: Figure size (width, height)
    
    Note:
        - All tensors are moved to CPU and converted to numpy
        - Requires matplotlib (imported inside function)
    """
    import matplotlib.pyplot as plt
    
    # Convert tensors to numpy and squeeze extra dimensions
    def to_numpy(t):
        if isinstance(t, torch.Tensor):
            t = t.detach().cpu()
        return t.squeeze().numpy()
    
    ct_np = to_numpy(ct_slice)
    mean_np = to_numpy(mean_probs)
    var_np = to_numpy(var_probs)
    
    # Create binary prediction
    binary_pred = (mean_np > threshold).astype(float)
    
    # Create figure
    fig, axes = plt.subplots(1, 4, figsize=figsize)
    
    # CT slice
    axes[0].imshow(ct_np, cmap='gray')
    axes[0].set_title('CT Slice (Input)')
    axes[0].axis('off')
    
    # Mean prediction (probability)
    im1 = axes[1].imshow(mean_np, cmap='hot', vmin=0, vmax=1)
    axes[1].set_title('Mean Probability')
    axes[1].axis('off')
    plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Binary prediction
    axes[2].imshow(binary_pred, cmap='gray', vmin=0, vmax=1)
    axes[2].set_title(f'Binary Pred (τ={threshold})')
    axes[2].axis('off')
    
    # Variance (uncertainty)
    im3 = axes[3].imshow(var_np, cmap='viridis')
    axes[3].set_title('Uncertainty (Variance)')
    axes[3].axis('off')
    plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()


def visualize_mc_samples(ct_slice: torch.Tensor,
                         samples: torch.Tensor,
                         n_show: int = 4,
                         figsize: Tuple[int, int] = (16, 4)) -> None:
    """
    Visualize individual MC Dropout samples alongside the CT slice.
    
    Args:
        ct_slice: CT image, shape [H, W] or [1, H, W]
        samples: MC samples [n_samples, 1, 1, H, W] or [n_samples, 1, H, W]
        n_show: Number of samples to display (default: 4)
        figsize: Figure size
    """
    import matplotlib.pyplot as plt
    
    def to_numpy(t):
        if isinstance(t, torch.Tensor):
            t = t.detach().cpu()
        return t.squeeze().numpy()
    
    ct_np = to_numpy(ct_slice)
    n_samples = min(n_show, samples.shape[0])
    
    fig, axes = plt.subplots(1, n_samples + 1, figsize=figsize)
    
    # CT slice
    axes[0].imshow(ct_np, cmap='gray')
    axes[0].set_title('CT Slice')
    axes[0].axis('off')
    
    # Individual samples
    for i in range(n_samples):
        sample_np = to_numpy(samples[i])
        axes[i + 1].imshow(sample_np, cmap='hot', vmin=0, vmax=1)
        axes[i + 1].set_title(f'Sample {i + 1}')
        axes[i + 1].axis('off')
    
    plt.tight_layout()
    plt.show()

### MC-Dropout Demo & Validation

Test the MC-Dropout inference pipeline with a dummy input to verify:
- Correct output shapes
- No NaN/Inf values
- Variance is non-zero (dropout is working)

In [None]:
# =============================================================================
# MC-Dropout Demo & Validation
# =============================================================================

def test_mc_dropout():
    """Test MC-Dropout inference pipeline."""
    print("=" * 60)
    print("MC-Dropout Inference Demo")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}\n")
    
    # Create model
    model = HANSNet(base_channels=32).to(device)
    print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
    
    # Create dummy CT-like input (normalized to [0, 1])
    x = torch.rand(1, 3, 1, 128, 128, device=device)
    print(f"\nInput shape: {list(x.shape)} [B, T, C, H, W]")
    print(f"Input range: [{x.min():.3f}, {x.max():.3f}]")
    
    # Run MC-Dropout prediction
    print("\n" + "-" * 40)
    print("Running MC-Dropout with n_samples=8...")
    print("-" * 40)
    
    n_samples = 8
    mean_probs, var_probs = mc_predict(model, x, n_samples=n_samples)
    
    # Check shapes
    print(f"\nMean probs shape: {list(mean_probs.shape)}")
    print(f"Var probs shape:  {list(var_probs.shape)}")
    
    expected_shape = (1, 1, 128, 128)
    assert mean_probs.shape == expected_shape, f"Mean shape mismatch: {mean_probs.shape}"
    assert var_probs.shape == expected_shape, f"Var shape mismatch: {var_probs.shape}"
    print("✓ Shapes correct")
    
    # Check for NaN/Inf
    print(f"\nNaN in mean_probs: {torch.isnan(mean_probs).any().item()}")
    print(f"NaN in var_probs:  {torch.isnan(var_probs).any().item()}")
    print(f"Inf in mean_probs: {torch.isinf(mean_probs).any().item()}")
    print(f"Inf in var_probs:  {torch.isinf(var_probs).any().item()}")
    
    assert not torch.isnan(mean_probs).any(), "NaN in mean_probs!"
    assert not torch.isnan(var_probs).any(), "NaN in var_probs!"
    assert not torch.isinf(mean_probs).any(), "Inf in mean_probs!"
    assert not torch.isinf(var_probs).any(), "Inf in var_probs!"
    print("✓ No NaN/Inf values")
    
    # Check value ranges
    print(f"\nMean probs range: [{mean_probs.min():.4f}, {mean_probs.max():.4f}]")
    print(f"Var probs range:  [{var_probs.min():.6f}, {var_probs.max():.6f}]")
    print(f"Var probs sum:    {var_probs.sum().item():.4f}")
    
    # Variance should be > 0 (dropout is introducing randomness)
    if var_probs.sum() > 0:
        print("✓ Variance > 0 (dropout is working)")
    else:
        print("⚠ Warning: Variance is 0 - dropout might not be active")
    
    # Test dice_loss
    print("\n" + "-" * 40)
    print("Testing dice_loss...")
    print("-" * 40)
    
    # Create dummy target
    target = (torch.rand(1, 1, 128, 128, device=device) > 0.7).float()
    
    loss = dice_loss(mean_probs, target)
    print(f"Dice loss: {loss.item():.4f}")
    assert not torch.isnan(loss), "NaN in dice_loss!"
    assert 0 <= loss.item() <= 1, f"Dice loss out of range: {loss.item()}"
    print("✓ Dice loss computed correctly")
    
    # Test combined_loss
    print("\n" + "-" * 40)
    print("Testing combined_loss...")
    print("-" * 40)
    
    # Need logits for combined loss
    model.eval()
    with torch.no_grad():
        logits = model(x)
    
    comb_loss = combined_loss(logits, target)
    print(f"Combined loss (BCE + Dice): {comb_loss.item():.4f}")
    assert not torch.isnan(comb_loss), "NaN in combined_loss!"
    print("✓ Combined loss computed correctly")
    
    print("\n" + "=" * 60)
    print("ALL MC-DROPOUT TESTS PASSED!")
    print("=" * 60)
    
    return model, x, mean_probs, var_probs


# Run the demo
mc_model, mc_input, mc_mean, mc_var = test_mc_dropout()