# üß† EGM-Net: Energy-Gated Gabor Mamba Network

**Medical Image Segmentation with Implicit Neural Representations**

## 1Ô∏è‚É£ Setup

In [1]:
# Install dependencies
%pip install -q torch torchvision numpy matplotlib tqdm gdown nibabel scikit-image monai

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.7/2.7 MB[0m [31m71.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25h

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os, glob, json

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"üñ•Ô∏è Device: {device}")
if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

üñ•Ô∏è Device: cuda
   GPU: Tesla T4


---
## 2Ô∏è‚É£ Model Architecture

All model code is defined inline below.

### 2.1 Mamba Block

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional

class DepthwiseSeparableConv2d(nn.Module):

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size=kernel_size,
            padding=kernel_size // 2, groups=in_channels, bias=False
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class DirectionalScanner(nn.Module):

    def __init__(self, channels: int, scan_dim: int = 64):
        
        super().__init__()
        self.channels = channels
        self.scan_dim = scan_dim
        
        # Learnable projection to scan_dim for each direction
        self.proj_in = nn.Linear(channels, scan_dim)
        
        # GRU cell for sequential state processing (simulates SSM)
        self.gru_cell = nn.GRUCell(scan_dim, scan_dim)
        
        # Project back to original channels
        self.proj_out = nn.Linear(scan_dim, channels)
        
    def _scan_direction(self, x: torch.Tensor, direction: str) -> torch.Tensor:
        
        B, C, H, W = x.shape
        
        # Prepare sequence based on direction
        if direction == "right":
            # Scan left-to-right: (B, H*W, C) after reshape
            x = x.permute(0, 2, 3, 1).reshape(B * H, W, C)  # (B*H, W, C)
        elif direction == "down":
            # Scan top-to-bottom
            x = x.permute(0, 3, 2, 1).reshape(B * W, H, C)  # (B*W, H, C)
        elif direction == "left":
            # Scan right-to-left (reverse)
            x = x.permute(0, 2, 3, 1).flip(1).reshape(B * H, W, C)  # (B*H, W, C)
        elif direction == "up":
            # Scan bottom-to-top (reverse)
            x = x.permute(0, 3, 2, 1).flip(1).reshape(B * W, H, C)  # (B*W, H, C)
        else:
            raise ValueError(f"Unknown direction: {direction}")
        
        # Project to scan dimension
        x = self.proj_in(x)  # (*, W/H, scan_dim)
        
        # Apply GRU cell sequentially (simulates SSM forward pass)
        outputs = []
        h = torch.zeros(x.shape[0], self.scan_dim, device=x.device, dtype=x.dtype)
        
        for t in range(x.shape[1]):
            h = self.gru_cell(x[:, t], h)  # GRU step
            outputs.append(h)
        
        x = torch.stack(outputs, dim=1)  # (*, W/H, scan_dim)
        
        # Project back to original channels
        x = self.proj_out(x)  # (*, W/H, C)
        
        # Reshape back to (B, C, H, W)
        if direction == "right":
            x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
        elif direction == "down":
            x = x.reshape(B, W, H, C).permute(0, 3, 2, 1)
        elif direction == "left":
            x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).flip(-1)
        elif direction == "up":
            x = x.reshape(B, W, H, C).permute(0, 3, 2, 1).flip(-2)
        
        return x
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        # Scan in all 4 directions
        scan_right = self._scan_direction(x, "right")
        scan_down = self._scan_direction(x, "down")
        scan_left = self._scan_direction(x, "left")
        scan_up = self._scan_direction(x, "up")
        
        # Aggregate by averaging
        output = (scan_right + scan_down + scan_left + scan_up) / 4.0
        
        return output

class VSSBlock(nn.Module):

    def __init__(self, channels: int, hidden_dim: Optional[int] = None, 
                 scan_dim: int = 64, expansion_ratio: float = 2.0):
        
        super().__init__()
        self.channels = channels
        hidden_dim = hidden_dim or int(channels * expansion_ratio)
        
        # Preprocessing: expand channels
        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
        self.conv_expand = nn.Conv2d(channels, hidden_dim, kernel_size=1, bias=True)
        
        # Directional scanning
        self.scanner = DirectionalScanner(hidden_dim, scan_dim=scan_dim)
        
        # Postprocessing: contract channels back
        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=hidden_dim, eps=1e-6)
        self.conv_contract = nn.Conv2d(hidden_dim, channels, kernel_size=1, bias=True)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        residual = x
        
        # Preprocessing
        x = self.norm1(x)
        x = self.conv_expand(x)
        x = F.gelu(x)
        
        # Directional scanning (core SSM-like operation)
        x = self.scanner(x)
        
        # Postprocessing
        x = self.norm2(x)
        x = self.conv_contract(x)
        
        # Residual connection
        output = x + residual
        
        return output

class MambaBlockStack(nn.Module):

    def __init__(self, channels: int, depth: int = 2, **kwargs):
        
        super().__init__()
        self.blocks = nn.ModuleList([
            VSSBlock(channels, **kwargs) for _ in range(depth)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for block in self.blocks:
            x = block(x)
        return x

if __name__ == "__main__":
    # Test VSSBlock
    batch_size, channels, height, width = 2, 64, 64, 64
    x = torch.randn(batch_size, channels, height, width)
    
    vss_block = VSSBlock(channels, scan_dim=32)
    output = vss_block(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Module parameters: {sum(p.numel() for p in vss_block.parameters())}")


Input shape: torch.Size([2, 64, 64, 64])
Output shape: torch.Size([2, 64, 64, 64])
Module parameters: 31648


### 2.2 Spectral Layers

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

class SpectralGating(nn.Module):

    def __init__(self, channels: int, height: int, width: int, 
                 threshold: float = 0.1, complex_init: str = "kaiming"):
        
        super().__init__()
        self.channels = channels
        self.height = height
        self.width = width
        self.threshold = threshold
        
        # Create learnable complex weights for frequency domain
        # Shape: (channels, height, width//2 + 1) for rfft2
        self.register_buffer(
            "freq_shape",
            torch.tensor([channels, height, width // 2 + 1], dtype=torch.long)
        )
        
        # Real and Imaginary parts of complex weights
        self.weight_real = nn.Parameter(
            torch.zeros(channels, height, width // 2 + 1)
        )
        self.weight_imag = nn.Parameter(
            torch.zeros(channels, height, width // 2 + 1)
        )
        
        self._init_weights(complex_init)
        
    def _init_weights(self, strategy: str = "kaiming"):
        
        if strategy == "identity":
            # Initialize close to identity (magnitude ~1, phase ~0)
            nn.init.ones_(self.weight_real)
            nn.init.zeros_(self.weight_imag)
        elif strategy == "kaiming":
            # Kaiming initialization adapted for complex numbers
            fan_in = self.height * (self.width // 2 + 1)
            std = (2.0 / fan_in) ** 0.5
            nn.init.normal_(self.weight_real, 0, std)
            nn.init.normal_(self.weight_imag, 0, std)
        else:
            raise ValueError(f"Unknown init strategy: {strategy}")
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        B, C, H, W = x.shape
        
        # Apply FFT to convert to frequency domain
        # rfft2 returns complex tensor
        x_freq = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
        
        # Create complex weight matrix: weight_real + 1j * weight_imag
        # Reshape to (1, C, H, W//2+1) for broadcasting
        complex_weight = (
            self.weight_real.unsqueeze(0) + 
            1j * self.weight_imag.unsqueeze(0)
        )
        
        # Apply channel-wise multiplication in frequency domain
        # Shape: (B, C, H, W//2+1) * (1, C, H, W//2+1) -> (B, C, H, W//2+1)
        x_filtered = x_freq * complex_weight
        
        # Optional: Hard thresholding to remove low-amplitude noise
        if self.threshold > 0:
            magnitude = torch.abs(x_filtered)
            mask = magnitude > self.threshold
            x_filtered = x_filtered * mask.float()
        
        # Apply inverse FFT to return to spatial domain
        output = torch.fft.irfft2(x_filtered, s=(H, W), dim=(-2, -1), norm="ortho")
        
        return output

class FrequencyLoss(nn.Module):

    def __init__(self, weight: float = 0.1):
        
        super().__init__()
        self.weight = weight
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        
        # Apply FFT
        pred_freq = torch.fft.rfft2(pred, dim=(-2, -1), norm="ortho")
        target_freq = torch.fft.rfft2(target, dim=(-2, -1), norm="ortho")
        
        # Compute L2 distance in frequency domain
        # Using both magnitude and phase information
        loss_real = F.mse_loss(pred_freq.real, target_freq.real)
        loss_imag = F.mse_loss(pred_freq.imag, target_freq.imag)
        
        return loss_real + loss_imag

if __name__ == "__main__":
    # Test SpectralGating
    batch_size, channels, height, width = 2, 64, 64, 64
    x = torch.randn(batch_size, channels, height, width)
    
    spec_gate = SpectralGating(channels, height, width)
    output = spec_gate(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Module parameters: {sum(p.numel() for p in spec_gate.parameters())}")


Input shape: torch.Size([2, 64, 64, 64])
Output shape: torch.Size([2, 64, 64, 64])
Module parameters: 270336


### 2.3 Monogenic Signal Processing

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

class RieszTransform(nn.Module):

    def __init__(self, epsilon: float = 1e-8):
        
        super().__init__()
        self.epsilon = epsilon
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        B, C, H, W = x.shape
        
        # Create frequency grid
        freq_y = torch.fft.fftfreq(H, device=x.device, dtype=x.dtype)
        freq_x = torch.fft.fftfreq(W, device=x.device, dtype=x.dtype)
        freq_y, freq_x = torch.meshgrid(freq_y, freq_x, indexing='ij')
        
        # Compute radial frequency (avoid division by zero)
        radius = torch.sqrt(freq_x**2 + freq_y**2 + self.epsilon)
        
        # Riesz kernels in frequency domain
        # H1 = -j * u / |w|, H2 = -j * v / |w|
        kernel_x = freq_x / radius
        kernel_y = freq_y / radius
        
        # Set DC component to zero
        kernel_x[0, 0] = 0
        kernel_y[0, 0] = 0
        
        # Apply FFT to input
        x_fft = torch.fft.fft2(x)
        
        # Apply Riesz kernels (multiplication by -j in frequency = Hilbert-like)
        # -j * X = real(X) * (-j) + imag(X) * (-j) * j = imag(X) - j*real(X)
        riesz_x_fft = -1j * x_fft * kernel_x.unsqueeze(0).unsqueeze(0)
        riesz_y_fft = -1j * x_fft * kernel_y.unsqueeze(0).unsqueeze(0)
        
        # Inverse FFT
        riesz_x = torch.fft.ifft2(riesz_x_fft).real
        riesz_y = torch.fft.ifft2(riesz_y_fft).real
        
        return riesz_x, riesz_y

class LogGaborFilter(nn.Module):

    def __init__(self, num_scales: int = 4, num_orientations: int = 6,
                 min_wavelength: float = 3.0, mult: float = 2.1,
                 sigma_on_f: float = 0.55):
        
        super().__init__()
        self.num_scales = num_scales
        self.num_orientations = num_orientations
        self.min_wavelength = min_wavelength
        self.mult = mult
        self.sigma_on_f = sigma_on_f
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        B, C, H, W = x.shape
        device = x.device
        dtype = x.dtype
        
        # Create frequency grid
        freq_y = torch.fft.fftfreq(H, device=device, dtype=dtype)
        freq_x = torch.fft.fftfreq(W, device=device, dtype=dtype)
        freq_y, freq_x = torch.meshgrid(freq_y, freq_x, indexing='ij')
        
        # Polar coordinates
        radius = torch.sqrt(freq_x**2 + freq_y**2)
        radius[0, 0] = 1  # Avoid log(0)
        theta = torch.atan2(freq_y, freq_x)
        
        # FFT of input
        x_fft = torch.fft.fft2(x)
        
        outputs = []
        
        for scale in range(self.num_scales):
            wavelength = self.min_wavelength * (self.mult ** scale)
            fo = 1.0 / wavelength  # Center frequency
            
            # Log-Gabor radial component
            log_gabor_radial = torch.exp(
                -(torch.log(radius / fo) ** 2) / (2 * math.log(self.sigma_on_f) ** 2)
            )
            log_gabor_radial[0, 0] = 0  # Zero DC
            
            for orient in range(self.num_orientations):
                angle = orient * math.pi / self.num_orientations
                
                # Angular component
                ds = torch.sin(theta - angle)
                dc = torch.cos(theta - angle)
                dtheta = torch.abs(torch.atan2(ds, dc))
                
                # Angular spread
                angular_spread = torch.exp(
                    -(dtheta ** 2) / (2 * (math.pi / self.num_orientations) ** 2)
                )
                
                # Combined filter
                log_gabor = log_gabor_radial * angular_spread
                
                # Apply filter
                filtered = torch.fft.ifft2(x_fft * log_gabor.unsqueeze(0).unsqueeze(0))
                outputs.append(filtered.abs())
        
        return torch.cat(outputs, dim=1)

class MonogenicSignal(nn.Module):

    def __init__(self, epsilon: float = 1e-8):
        
        super().__init__()
        self.riesz = RieszTransform(epsilon=epsilon)
        self.epsilon = epsilon
    
    def forward(self, x: torch.Tensor) -> dict:
        
        # Get Riesz components
        riesz_x, riesz_y = self.riesz(x)
        
        # Compute amplitude (local energy)
        # A = sqrt(f^2 + h1^2 + h2^2)
        amplitude = torch.sqrt(x**2 + riesz_x**2 + riesz_y**2 + self.epsilon)
        
        # Compute orientation
        # theta = atan2(h2, h1)
        orientation = torch.atan2(riesz_y, riesz_x + self.epsilon)
        
        # Compute phase
        # phi = atan2(sqrt(h1^2 + h2^2), f)
        riesz_magnitude = torch.sqrt(riesz_x**2 + riesz_y**2 + self.epsilon)
        phase = torch.atan2(riesz_magnitude, x + self.epsilon)
        
        return {
            'amplitude': amplitude,
            'phase': phase,
            'orientation': orientation,
            'riesz_x': riesz_x,
            'riesz_y': riesz_y
        }

class EnergyMap(nn.Module):

    def __init__(self, normalize: bool = True, smoothing_sigma: float = 1.0):
        
        super().__init__()
        self.monogenic = MonogenicSignal()
        self.normalize = normalize
        self.smoothing_sigma = smoothing_sigma
        
        # Create Gaussian smoothing kernel
        if smoothing_sigma > 0:
            kernel_size = int(6 * smoothing_sigma) | 1  # Ensure odd
            self.register_buffer('smooth_kernel', self._create_gaussian_kernel(
                kernel_size, smoothing_sigma
            ))
        else:
            self.smooth_kernel = None
    
    def _create_gaussian_kernel(self, kernel_size: int, sigma: float) -> torch.Tensor:
        
        x = torch.arange(kernel_size) - kernel_size // 2
        x = x.float()
        gaussian_1d = torch.exp(-x**2 / (2 * sigma**2))
        gaussian_2d = gaussian_1d.unsqueeze(0) * gaussian_1d.unsqueeze(1)
        gaussian_2d = gaussian_2d / gaussian_2d.sum()
        return gaussian_2d.unsqueeze(0).unsqueeze(0)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        
        # Convert to grayscale if needed
        if x.shape[1] > 1:
            x = x.mean(dim=1, keepdim=True)
        
        # Get monogenic decomposition
        mono_out = self.monogenic(x)
        
        # Energy is the amplitude
        energy = mono_out['amplitude']
        
        # Optional smoothing
        if self.smooth_kernel is not None:
            pad = self.smooth_kernel.shape[-1] // 2
            energy = F.conv2d(energy, self.smooth_kernel, padding=pad)
        
        # Normalize to [0, 1]
        if self.normalize:
            B = energy.shape[0]
            energy_flat = energy.view(B, -1)
            energy_min = energy_flat.min(dim=1, keepdim=True)[0].view(B, 1, 1, 1)
            energy_max = energy_flat.max(dim=1, keepdim=True)[0].view(B, 1, 1, 1)
            energy = (energy - energy_min) / (energy_max - energy_min + 1e-8)
        
        return energy, mono_out

class BoundaryDetector(nn.Module):

    def __init__(self, num_scales: int = 4, num_orientations: int = 6,
                 noise_threshold: float = 0.1):
        
        super().__init__()
        self.log_gabor = LogGaborFilter(num_scales, num_orientations)
        self.num_scales = num_scales
        self.num_orientations = num_orientations
        self.noise_threshold = noise_threshold
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        # Get multi-scale responses
        responses = self.log_gabor(x)  # (B, S*O, H, W)
        
        # Sum across orientations to get edge strength per scale
        B, _, H, W = responses.shape
        responses = responses.view(B, self.num_scales, self.num_orientations, H, W)
        
        # Max across orientations (strongest edge direction)
        edge_strength = responses.max(dim=2)[0]  # (B, S, H, W)
        
        # Sum across scales
        edge_strength = edge_strength.sum(dim=1, keepdim=True)  # (B, 1, H, W)
        
        # Normalize and threshold
        edge_max = edge_strength.view(B, -1).max(dim=1)[0].view(B, 1, 1, 1)
        edge_strength = edge_strength / (edge_max + 1e-8)
        edge_strength = torch.clamp(edge_strength - self.noise_threshold, min=0)
        edge_strength = edge_strength / (1 - self.noise_threshold + 1e-8)
        
        return edge_strength

if __name__ == "__main__":
    # Test Monogenic Signal processing
    print("Testing Monogenic Signal Processing...")
    
    # Create test image with edges
    H, W = 128, 128
    x = torch.zeros(1, 1, H, W)
    x[:, :, 32:96, 32:96] = 1.0  # Square
    
    # Add some noise
    x = x + 0.1 * torch.randn_like(x)
    
    # Test Energy Map
    energy_extractor = EnergyMap(normalize=True)
    energy, mono = energy_extractor(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Energy map shape: {energy.shape}")
    print(f"Energy range: [{energy.min():.3f}, {energy.max():.3f}]")
    print(f"Monogenic components: {list(mono.keys())}")
    
    # Test Boundary Detector
    boundary_detector = BoundaryDetector()
    boundaries = boundary_detector(x)
    
    print(f"Boundary map shape: {boundaries.shape}")
    print(f"Boundary range: [{boundaries.min():.3f}, {boundaries.max():.3f}]")
    
    print("\n‚úì All tests passed!")


Testing Monogenic Signal Processing...
Input shape: torch.Size([1, 1, 128, 128])
Energy map shape: torch.Size([1, 1, 128, 128])
Energy range: [0.000, 1.000]
Monogenic components: ['amplitude', 'phase', 'orientation', 'riesz_x', 'riesz_y']
Boundary map shape: torch.Size([1, 1, 128, 128])
Boundary range: [0.000, 1.000]

‚úì All tests passed!


### 2.4 Gabor Implicit Layers

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List

class GaborBasis(nn.Module):

    def __init__(self, input_dim: int = 2, num_frequencies: int = 64,
                 sigma_range: Tuple[float, float] = (0.1, 2.0),
                 freq_range: Tuple[float, float] = (1.0, 10.0),
                 learnable: bool = True):
        
        super().__init__()
        self.input_dim = input_dim
        self.num_frequencies = num_frequencies
        self.output_dim = num_frequencies * 2  # sin and cos components
        
        # Initialize frequencies uniformly in log space
        log_freqs = torch.linspace(
            math.log(freq_range[0]), 
            math.log(freq_range[1]), 
            num_frequencies
        )
        freqs = torch.exp(log_freqs)
        
        # Initialize sigmas (Gaussian envelope widths)
        sigmas = torch.linspace(sigma_range[0], sigma_range[1], num_frequencies)
        
        # Random orientations for 2D
        orientations = torch.rand(num_frequencies) * 2 * math.pi
        
        # Random phases
        phases = torch.rand(num_frequencies) * 2 * math.pi
        
        # Create direction vectors from orientations
        directions = torch.stack([
            torch.cos(orientations),
            torch.sin(orientations)
        ], dim=-1)  # (num_freq, 2)
        
        if learnable:
            self.freqs = nn.Parameter(freqs)
            self.sigmas = nn.Parameter(sigmas)
            self.directions = nn.Parameter(directions)
            self.phases = nn.Parameter(phases)
        else:
            self.register_buffer('freqs', freqs)
            self.register_buffer('sigmas', sigmas)
            self.register_buffer('directions', directions)
            self.register_buffer('phases', phases)
    
    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        
        # Normalize directions
        directions = F.normalize(self.directions, dim=-1)  # (num_freq, 2)
        
        # Project coordinates onto directions
        # coords: (..., 2), directions: (num_freq, 2)
        proj = torch.matmul(coords, directions.T)  # (..., num_freq)
        
        # Compute Gaussian envelope
        # exp(-proj¬≤ / (2œÉ¬≤))
        sigmas = torch.abs(self.sigmas) + 0.01  # Ensure positive
        gaussian = torch.exp(-proj**2 / (2 * sigmas**2 + 1e-8))
        
        # Compute oscillatory component
        # cos(2œÄf¬∑proj + œÜ), sin(2œÄf¬∑proj + œÜ)
        freqs = torch.abs(self.freqs) + 0.1  # Ensure positive
        arg = 2 * math.pi * freqs * proj + self.phases
        
        cos_comp = gaussian * torch.cos(arg)
        sin_comp = gaussian * torch.sin(arg)
        
        # Concatenate sin and cos
        gabor_features = torch.cat([cos_comp, sin_comp], dim=-1)
        
        return gabor_features

class FourierFeatures(nn.Module):

    def __init__(self, input_dim: int = 2, num_frequencies: int = 64,
                 scale: float = 10.0, learnable: bool = False):
        
        super().__init__()
        self.input_dim = input_dim
        self.num_frequencies = num_frequencies
        self.output_dim = num_frequencies * 2
        
        # Random frequency matrix
        B = torch.randn(input_dim, num_frequencies) * scale
        
        if learnable:
            self.B = nn.Parameter(B)
        else:
            self.register_buffer('B', B)
    
    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        
        # Project: coords @ B
        proj = 2 * math.pi * torch.matmul(coords, self.B)  # (..., num_freq)
        
        # Sin and cos
        return torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)

class SIRENLayer(nn.Module):

    def __init__(self, in_features: int, out_features: int, 
                 omega_0: float = 30.0, is_first: bool = False):
        
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        
        self.linear = nn.Linear(in_features, out_features)
        self._init_weights()
    
    def _init_weights(self):
        
        with torch.no_grad():
            if self.is_first:
                # First layer: uniform in [-1/n, 1/n]
                self.linear.weight.uniform_(-1 / self.in_features, 
                                            1 / self.in_features)
            else:
                # Other layers: uniform in [-sqrt(6/n)/œâ‚ÇÄ, sqrt(6/n)/œâ‚ÇÄ]
                bound = math.sqrt(6 / self.in_features) / self.omega_0
                self.linear.weight.uniform_(-bound, bound)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sin(self.omega_0 * self.linear(x))

class GaborNet(nn.Module):

    def __init__(self, coord_dim: int = 2, feature_dim: int = 256,
                 hidden_dim: int = 256, output_dim: int = 1,
                 num_layers: int = 4, num_frequencies: int = 64,
                 use_gabor: bool = True, omega_0: float = 30.0):
        
        super().__init__()
        
        # Coordinate encoding
        if use_gabor:
            self.coord_encoder = GaborBasis(
                input_dim=coord_dim,
                num_frequencies=num_frequencies,
                learnable=True
            )
        else:
            self.coord_encoder = FourierFeatures(
                input_dim=coord_dim,
                num_frequencies=num_frequencies,
                learnable=False
            )
        
        coord_encoded_dim = self.coord_encoder.output_dim
        
        # Input dimension: encoded coords + features
        input_dim = coord_encoded_dim + feature_dim
        
        # Build SIREN network
        layers = []
        
        # First layer
        layers.append(SIRENLayer(input_dim, hidden_dim, omega_0, is_first=True))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(SIRENLayer(hidden_dim, hidden_dim, omega_0, is_first=False))
        
        # Final layer (linear, no sine activation)
        layers.append(nn.Linear(hidden_dim, output_dim))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, coords: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
        
        # Encode coordinates
        coord_encoded = self.coord_encoder(coords)  # (B, N, coord_encoded_dim)
        
        # Concatenate with features
        x = torch.cat([coord_encoded, features], dim=-1)  # (B, N, input_dim)
        
        # Pass through network
        output = self.network(x)  # (B, N, output_dim)
        
        return output

class ImplicitSegmentationHead(nn.Module):

    def __init__(self, feature_channels: int = 64, num_classes: int = 2,
                 hidden_dim: int = 256, num_layers: int = 4,
                 num_frequencies: int = 64, use_gabor: bool = True):
        
        super().__init__()
        
        self.feature_channels = feature_channels
        self.num_classes = num_classes
        
        # Feature projector (reduce channel dimension)
        self.feature_proj = nn.Sequential(
            nn.Conv2d(feature_channels, hidden_dim, kernel_size=1),
            nn.GroupNorm(32, hidden_dim),
            nn.GELU()
        )
        
        # Implicit decoder
        self.implicit_decoder = GaborNet(
            coord_dim=2,
            feature_dim=hidden_dim,
            hidden_dim=hidden_dim,
            output_dim=num_classes,
            num_layers=num_layers,
            num_frequencies=num_frequencies,
            use_gabor=use_gabor
        )
    
    def sample_features(self, feature_map: torch.Tensor, 
                       coords: torch.Tensor) -> torch.Tensor:
        
        B, C, H, W = feature_map.shape
        N = coords.shape[1]
        
        # Reshape coords for grid_sample: (B, N, 1, 2) -> (B, 1, N, 2)
        # grid_sample expects (B, H, W, 2) where last dim is (x, y)
        grid = coords.view(B, 1, N, 2)
        
        # Sample using bilinear interpolation
        # feature_map: (B, C, H, W), grid: (B, 1, N, 2)
        # output: (B, C, 1, N)
        sampled = F.grid_sample(
            feature_map, grid,
            mode='bilinear',
            padding_mode='border',
            align_corners=True
        )
        
        # Reshape: (B, C, 1, N) -> (B, N, C)
        sampled = sampled.squeeze(2).permute(0, 2, 1)
        
        return sampled
    
    def forward(self, feature_map: torch.Tensor, 
                coords: Optional[torch.Tensor] = None,
                output_size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
        
        B, C, H_feat, W_feat = feature_map.shape
        device = feature_map.device
        
        # Project features
        feature_map = self.feature_proj(feature_map)  # (B, hidden_dim, H, W)
        
        # Generate coordinates if not provided
        if coords is None:
            if output_size is None:
                output_size = (H_feat * 4, W_feat * 4)
            
            H_out, W_out = output_size
            
            # Create normalized coordinate grid [-1, 1]
            y = torch.linspace(-1, 1, H_out, device=device)
            x = torch.linspace(-1, 1, W_out, device=device)
            yy, xx = torch.meshgrid(y, x, indexing='ij')
            coords = torch.stack([xx, yy], dim=-1)  # (H_out, W_out, 2)
            coords = coords.view(1, -1, 2).expand(B, -1, -1)  # (B, H*W, 2)
            
            reshape_output = True
        else:
            reshape_output = False
            H_out, W_out = None, None
        
        # Sample features at coordinates
        features = self.sample_features(feature_map, coords)  # (B, N, hidden_dim)
        
        # Implicit decoding
        logits = self.implicit_decoder(coords, features)  # (B, N, num_classes)
        
        # Reshape to image if using grid
        if reshape_output:
            logits = logits.view(B, H_out, W_out, self.num_classes)
            logits = logits.permute(0, 3, 1, 2)  # (B, C, H, W)
        
        return logits

if __name__ == "__main__":
    print("Testing Gabor Implicit Modules...")
    
    # Test Gabor Basis
    print("\n[1] Testing GaborBasis...")
    gabor = GaborBasis(input_dim=2, num_frequencies=32)
    coords = torch.randn(4, 100, 2)  # (B, N, 2)
    encoded = gabor(coords)
    print(f"Input coords: {coords.shape}")
    print(f"Gabor encoded: {encoded.shape}")
    
    # Test GaborNet
    print("\n[2] Testing GaborNet...")
    net = GaborNet(coord_dim=2, feature_dim=64, hidden_dim=128, 
                   output_dim=3, num_layers=3, num_frequencies=32)
    features = torch.randn(4, 100, 64)
    output = net(coords, features)
    print(f"GaborNet output: {output.shape}")
    
    # Test ImplicitSegmentationHead
    print("\n[3] Testing ImplicitSegmentationHead...")
    head = ImplicitSegmentationHead(
        feature_channels=64, num_classes=3,
        hidden_dim=128, num_layers=3, num_frequencies=32
    )
    feature_map = torch.randn(2, 64, 32, 32)
    
    # Test with automatic grid
    seg_output = head(feature_map, output_size=(128, 128))
    print(f"Feature map: {feature_map.shape}")
    print(f"Segmentation output (grid): {seg_output.shape}")
    
    # Test with custom coordinates
    custom_coords = torch.rand(2, 500, 2) * 2 - 1  # Random points in [-1, 1]
    seg_points = head(feature_map, coords=custom_coords)
    print(f"Segmentation output (points): {seg_points.shape}")
    
    print("\n‚úì All tests passed!")


Testing Gabor Implicit Modules...

[1] Testing GaborBasis...
Input coords: torch.Size([4, 100, 2])
Gabor encoded: torch.Size([4, 100, 64])

[2] Testing GaborNet...
GaborNet output: torch.Size([4, 100, 3])

[3] Testing ImplicitSegmentationHead...
Feature map: torch.Size([2, 64, 32, 32])
Segmentation output (grid): torch.Size([2, 3, 128, 128])
Segmentation output (points): torch.Size([2, 500, 3])

‚úì All tests passed!


### 2.5 Spectral Mamba

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List

class SpectralVSSBlock(nn.Module):

    def __init__(self, channels: int, height: int, width: int,
                 depth: int = 2, expansion_ratio: float = 2.0,
                 threshold: float = 0.1):
        
        super().__init__()
        self.channels = channels
        self.height = height
        self.width = width
        
        # Branch A: Spatial path (VSS Blocks)
        self.vss_blocks = MambaBlockStack(
            channels, depth=depth, 
            expansion_ratio=expansion_ratio, 
            scan_dim=min(64, channels)
        )
        
        # Branch B: Spectral path (FFT-based filtering)
        self.spectral_gate = SpectralGating(
            channels, height, width, 
            threshold=threshold, 
            complex_init="kaiming"
        )
        
        # Fusion layer (learnable weighting)
        self.fusion_weight = nn.Parameter(torch.tensor(0.5))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        # Branch A: Spatial context (VSS)
        spatial_out = self.vss_blocks(x)
        
        # Branch B: Frequency filtering (Spectral)
        spectral_out = self.spectral_gate(x)
        
        # Learnable fusion with sigmoid weight
        weight = torch.sigmoid(self.fusion_weight)
        output = weight * spatial_out + (1 - weight) * spectral_out
        
        return output

class PatchEmbedding(nn.Module):

    def __init__(self, in_channels: int = 3, out_channels: int = 64, 
                 patch_size: int = 4):
        
        super().__init__()
        self.patch_size = patch_size
        self.conv = nn.Conv2d(
            in_channels, out_channels,
            kernel_size=patch_size, stride=patch_size, bias=True
        )
        self.norm = nn.LayerNorm(out_channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        x = self.conv(x)
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x = self.norm(x)
        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x

class PatchMerging(nn.Module):

    def __init__(self, channels: int, out_channels: Optional[int] = None):
        
        super().__init__()
        out_channels = out_channels or channels * 2
        self.conv = nn.Conv2d(channels, out_channels, kernel_size=2, 
                              stride=2, bias=True)
        self.norm = nn.LayerNorm(out_channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x = self.norm(x)
        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x

class PatchExpanding(nn.Module):

    def __init__(self, channels: int, out_channels: Optional[int] = None):
        
        super().__init__()
        out_channels = out_channels or channels // 2
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', 
                                    align_corners=True)
        self.conv = nn.Conv2d(channels, out_channels, kernel_size=1, bias=True)
        self.norm = nn.LayerNorm(out_channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.upsample(x)
        x = self.conv(x)
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x = self.norm(x)
        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x

class SpectralVMUNet(nn.Module):

    def __init__(self, in_channels: int = 1, out_channels: int = 3,
                 img_size: int = 256, base_channels: int = 64,
                 num_stages: int = 4, depth: int = 2):
        
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.img_size = img_size
        self.base_channels = base_channels
        self.num_stages = num_stages
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(in_channels, base_channels, patch_size=4)
        initial_size = img_size // 4
        
        # Encoder
        self.encoder_blocks = nn.ModuleList()
        self.downsample_layers = nn.ModuleList()
        
        for i in range(num_stages):
            in_ch = base_channels * (2 ** i)
            out_ch = in_ch
            h = w = initial_size // (2 ** i)
            
            # SpectralVSSBlock
            block = SpectralVSSBlock(
                in_ch, h, w, depth=depth, expansion_ratio=2.0, threshold=0.1
            )
            self.encoder_blocks.append(block)
            
            # Downsampling (except after last encoder block)
            if i < num_stages - 1:
                down = PatchMerging(in_ch, in_ch * 2)
                self.downsample_layers.append(down)
        
        # Bottleneck - uses the last encoder's output channels
        # After num_stages-1 downsamplings, channels = base_channels * 2^(num_stages-1)
        bottleneck_ch = base_channels * (2 ** (num_stages - 1))
        bottleneck_h = bottleneck_w = initial_size // (2 ** (num_stages - 1))
        self.bottleneck = SpectralVSSBlock(
            bottleneck_ch, bottleneck_h, bottleneck_w,
            depth=depth + 1, expansion_ratio=2.0, threshold=0.1
        )
        
        # Decoder
        # We have num_stages - 1 decoder stages (matching skip connections)
        # Each decoder stage: upsample -> concat with skip -> fusion -> SpectralVSSBlock
        self.decoder_blocks = nn.ModuleList()
        self.upsample_layers = nn.ModuleList()
        
        num_decoder_stages = num_stages - 1
        
        for i in range(num_decoder_stages):
            # Going from deepest to shallowest
            # i=0: from bottleneck (8x8, 512ch) -> upsample to (16x16, 256ch)
            # i=1: from 16x16, 256ch -> upsample to (32x32, 128ch)
            # i=2: from 32x32, 128ch -> upsample to (64x64, 64ch)
            
            # Input channels: for i=0, it's bottleneck_ch; else from previous decoder output
            if i == 0:
                in_ch = bottleneck_ch  # 512 for default
            else:
                in_ch = base_channels * (2 ** (num_stages - 1 - i))
            
            # Output channels after upsampling
            out_ch = base_channels * (2 ** (num_stages - 2 - i))
            
            # Upsampling layer
            up = PatchExpanding(in_ch, out_ch)
            self.upsample_layers.append(up)
            
            # Skip connection comes from encoder at level (num_decoder_stages - 1 - i)
            # which has same spatial size after upsampling
            skip_ch = out_ch  # Skip has same channels as upsampled output
            
            # Spatial size at this level
            h = w = initial_size // (2 ** (num_stages - 2 - i))
            
            # Fusion: concatenate upsampled + skip, then reduce channels
            fused_ch = out_ch + skip_ch  # After concatenation
            fusion = nn.Sequential(
                nn.Conv2d(fused_ch, out_ch, kernel_size=1, bias=True),
                nn.GroupNorm(num_groups=min(32, out_ch), num_channels=out_ch, eps=1e-6)
            )
            self.decoder_blocks.append(fusion)
            
            # SpectralVSSBlock after fusion
            vss = SpectralVSSBlock(
                out_ch, h, w, depth=depth, expansion_ratio=2.0, threshold=0.1
            )
            self.decoder_blocks.append(vss)
        
        # Segmentation head
        self.seg_head = nn.Sequential(
            nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, 
                      padding=1, bias=True),
            nn.GroupNorm(num_groups=32, num_channels=base_channels // 2, eps=1e-6),
            nn.GELU(),
            nn.Conv2d(base_channels // 2, out_channels, kernel_size=1, bias=True)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        # Patch embedding
        x = self.patch_embed(x)
        
        # Encoder path with skip connections storage
        # Skip connections are saved BEFORE downsampling
        skips = []
        for i in range(self.num_stages):
            x = self.encoder_blocks[i](x)
            # Save skip connection before downsampling
            if i < self.num_stages - 1:
                skips.append(x)
                x = self.downsample_layers[i](x)
        
        # The last encoder output goes to bottleneck (no skip for this level)
        # skips now contains: [stage0_out, stage1_out, stage2_out] for 4 stages
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder path with skip connections
        # Decoder stages: num_stages - 1 (since last encoder has no skip)
        num_decoder_stages = self.num_stages - 1
        
        for i in range(num_decoder_stages):
            # Upsample
            x = self.upsample_layers[i](x)
            
            # Concatenate skip connection (in reverse order)
            # For i=0: skip from encoder stage num_stages-2 (last skip)
            # For i=1: skip from encoder stage num_stages-3
            skip_idx = num_decoder_stages - 1 - i
            skip = skips[skip_idx]
            x = torch.cat([x, skip], dim=1)
            
            # Fusion and processing
            x = self.decoder_blocks[2 * i](x)  # Fusion conv
            x = self.decoder_blocks[2 * i + 1](x)  # SpectralVSSBlock
        
        # Segmentation head
        output = self.seg_head(x)
        
        # Upsample to original resolution (since patch embedding uses stride 4)
        output = F.interpolate(output, size=(self.img_size, self.img_size),
                               mode='bilinear', align_corners=True)
        
        return output

if __name__ == "__main__":
    # Test the full architecture
    batch_size = 2
    in_channels = 1
    out_channels = 3  # Binary segmentation + background
    img_size = 256
    
    model = SpectralVMUNet(
        in_channels=in_channels,
        out_channels=out_channels,
        img_size=img_size,
        base_channels=64,
        num_stages=4,
        depth=2
    )
    
    # Create dummy input
    x = torch.randn(batch_size, in_channels, img_size, img_size)
    
    # Forward pass
    output = model(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")


Input shape: torch.Size([2, 1, 256, 256])
Output shape: torch.Size([2, 3, 256, 256])
Total parameters: 10,312,523
Trainable parameters: 10,312,523


### 2.6 EGM-Net

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
import math

class PatchEmbedding(nn.Module):

    def __init__(self, in_channels: int = 1, embed_dim: int = 64, patch_size: int = 4):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, 
                              stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)  # (B, C, H, W)
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x = self.norm(x)
        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x

class DownsampleBlock(nn.Module):

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.norm = nn.GroupNorm(32, out_channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.norm(self.conv(x))

class UpsampleBlock(nn.Module):

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.norm = nn.GroupNorm(min(32, out_channels), out_channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.norm(self.conv(self.up(x)))

class MambaEncoderStage(nn.Module):

    def __init__(self, channels: int, depth: int = 2, spatial_size: int = 64):
        super().__init__()
        self.blocks = nn.ModuleList([
            VSSBlock(channels, scan_dim=min(64, channels))
            for _ in range(depth)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for block in self.blocks:
            x = block(x)
        return x

class CoarseBranch(nn.Module):

    def __init__(self, in_channels: int, num_classes: int, num_stages: int = 3):
        super().__init__()
        
        self.upsample_layers = nn.ModuleList()
        self.conv_layers = nn.ModuleList()
        
        channels = in_channels
        for i in range(num_stages):
            out_ch = max(channels // 2, 64)
            self.upsample_layers.append(UpsampleBlock(channels, out_ch))
            self.conv_layers.append(nn.Sequential(
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
                nn.GroupNorm(min(32, out_ch), out_ch),
                nn.GELU()
            ))
            channels = out_ch
        
        self.head = nn.Conv2d(channels, num_classes, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for up, conv in zip(self.upsample_layers, self.conv_layers):
            x = up(x)
            x = conv(x)
        return self.head(x)

class EnergyGatedFusion(nn.Module):

    def __init__(self, temperature: float = 1.0):
        
        super().__init__()
        self.temperature = temperature
        self.gate_scale = nn.Parameter(torch.ones(1))
        self.gate_bias = nn.Parameter(torch.zeros(1))
    
    def forward(self, coarse: torch.Tensor, fine: torch.Tensor, 
                energy: torch.Tensor) -> torch.Tensor:
        
        # Resize energy to match prediction size
        if energy.shape[-2:] != coarse.shape[-2:]:
            energy = F.interpolate(energy, size=coarse.shape[-2:], 
                                   mode='bilinear', align_corners=True)
        
        # Apply learnable scaling and temperature
        gate = torch.sigmoid((energy * self.gate_scale + self.gate_bias) / self.temperature)
        
        # Blend: high energy ‚Üí use fine, low energy ‚Üí use coarse
        output = coarse + gate * (fine - coarse)
        
        return output

class FineBranch(nn.Module):

    def __init__(self, feature_channels: int, num_classes: int,
                 hidden_dim: int = 256, num_layers: int = 4,
                 num_frequencies: int = 64):
        super().__init__()
        
        self.implicit_head = ImplicitSegmentationHead(
            feature_channels=feature_channels,
            num_classes=num_classes,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_frequencies=num_frequencies,
            use_gabor=True  # Use Gabor instead of Fourier
        )
    
    def forward(self, features: torch.Tensor, 
                coords: Optional[torch.Tensor] = None,
                output_size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
        
        return self.implicit_head(features, coords, output_size)

class EGMNet(nn.Module):

    def __init__(self, in_channels: int = 1, num_classes: int = 2,
                 img_size: int = 256, base_channels: int = 64,
                 num_stages: int = 4, encoder_depth: int = 2,
                 implicit_hidden: int = 256, implicit_layers: int = 4,
                 num_frequencies: int = 64):
        
        super().__init__()
        
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.img_size = img_size
        self.num_stages = num_stages
        
        # 1. Monogenic Energy Extractor (fixed, non-trainable)
        self.energy_extractor = EnergyMap(normalize=True, smoothing_sigma=1.0)
        
        # 2. Patch Embedding
        self.patch_embed = PatchEmbedding(in_channels, base_channels, patch_size=4)
        feat_size = img_size // 4
        
        # 3. Mamba Encoder
        self.encoder_stages = nn.ModuleList()
        self.downsample_layers = nn.ModuleList()
        
        channels = base_channels
        for i in range(num_stages):
            self.encoder_stages.append(
                MambaEncoderStage(channels, depth=encoder_depth, spatial_size=feat_size)
            )
            if i < num_stages - 1:
                self.downsample_layers.append(
                    DownsampleBlock(channels, channels * 2)
                )
                channels *= 2
                feat_size //= 2
        
        # Store final encoder channels
        self.encoder_channels = channels
        
        # 4. Bottleneck
        self.bottleneck = MambaEncoderStage(
            channels, depth=encoder_depth + 1, spatial_size=feat_size
        )
        
        # 5. Coarse Branch (standard decoder)
        self.coarse_branch = CoarseBranch(
            in_channels=channels,
            num_classes=num_classes,
            num_stages=num_stages - 1
        )
        
        # 6. Fine Branch (Gabor implicit decoder)
        self.fine_branch = FineBranch(
            feature_channels=channels,
            num_classes=num_classes,
            hidden_dim=implicit_hidden,
            num_layers=implicit_layers,
            num_frequencies=num_frequencies
        )
        
        # 7. Energy-Gated Fusion
        self.fusion = EnergyGatedFusion(temperature=1.0)
    
    def forward(self, x: torch.Tensor, 
                output_size: Optional[Tuple[int, int]] = None) -> Dict[str, torch.Tensor]:
        
        B, C, H, W = x.shape
        
        if output_size is None:
            output_size = (H, W)
        
        # 1. Extract energy map (detached, no gradients for physics module)
        with torch.no_grad():
            # Convert to grayscale if needed
            x_gray = x.mean(dim=1, keepdim=True) if C > 1 else x
            energy, mono_out = self.energy_extractor(x_gray)
        
        # 2. Patch embedding
        features = self.patch_embed(x)
        
        # 3. Encoder (Mamba stages)
        encoder_features = []
        for i, stage in enumerate(self.encoder_stages):
            features = stage(features)
            encoder_features.append(features)
            if i < len(self.downsample_layers):
                features = self.downsample_layers[i](features)
        
        # 4. Bottleneck
        features = self.bottleneck(features)
        
        # 5. Coarse branch
        coarse = self.coarse_branch(features)
        coarse = F.interpolate(coarse, size=output_size, 
                               mode='bilinear', align_corners=True)
        
        # 6. Fine branch (implicit decoder)
        fine = self.fine_branch(features, output_size=output_size)
        
        # 7. Energy-gated fusion
        output = self.fusion(coarse, fine, energy)
        
        return {
            'output': output,
            'coarse': coarse,
            'fine': fine,
            'energy': energy
        }
    
    def inference(self, x: torch.Tensor, 
                  output_size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
        
        return self.forward(x, output_size)['output']
    
    def query_points(self, x: torch.Tensor, 
                     coords: torch.Tensor) -> torch.Tensor:
        
        B, C, H, W = x.shape
        
        # Encode image
        features = self.patch_embed(x)
        for i, stage in enumerate(self.encoder_stages):
            features = stage(features)
            if i < len(self.downsample_layers):
                features = self.downsample_layers[i](features)
        features = self.bottleneck(features)
        
        # Query fine branch at coordinates
        fine_points = self.fine_branch.implicit_head(features, coords=coords)
        
        return fine_points

class EGMNetLite(nn.Module):

    def __init__(self, in_channels: int = 1, num_classes: int = 2,
                 img_size: int = 256):
        super().__init__()
        
        self.model = EGMNet(
            in_channels=in_channels,
            num_classes=num_classes,
            img_size=img_size,
            base_channels=32,  # Reduced from 64
            num_stages=3,      # Reduced from 4
            encoder_depth=1,   # Reduced from 2
            implicit_hidden=128,  # Reduced from 256
            implicit_layers=3,    # Reduced from 4
            num_frequencies=32    # Reduced from 64
        )
    
    def forward(self, x, output_size=None):
        return self.model(x, output_size)
    
    def inference(self, x, output_size=None):
        return self.model.inference(x, output_size)

if __name__ == "__main__":
    print("=" * 60)
    print("Testing EGM-Net (Energy-Gated Gabor Mamba Network)")
    print("=" * 60)
    
    # Test full model
    print("\n[1] Testing EGM-Net Full...")
    model = EGMNet(
        in_channels=1,
        num_classes=3,
        img_size=256,
        base_channels=64,
        num_stages=4,
        encoder_depth=2
    )
    
    x = torch.randn(2, 1, 256, 256)
    outputs = model(x)
    
    print(f"Input: {x.shape}")
    print(f"Output: {outputs['output'].shape}")
    print(f"Coarse: {outputs['coarse'].shape}")
    print(f"Fine: {outputs['fine'].shape}")
    print(f"Energy: {outputs['energy'].shape}")
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Test point query (resolution-free inference)
    print("\n[2] Testing Point Query (Resolution-Free)...")
    coords = torch.rand(2, 1000, 2) * 2 - 1  # Random points in [-1, 1]
    point_output = model.query_points(x, coords)
    print(f"Query coords: {coords.shape}")
    print(f"Point output: {point_output.shape}")
    
    # Test lite model
    print("\n[3] Testing EGM-Net Lite...")
    lite_model = EGMNetLite(in_channels=1, num_classes=3, img_size=256)
    lite_outputs = lite_model(x)
    
    lite_params = sum(p.numel() for p in lite_model.parameters())
    print(f"Lite model parameters: {lite_params:,}")
    print(f"Lite output: {lite_outputs['output'].shape}")
    
    print("\n" + "=" * 60)
    print("‚úì All tests passed!")
    print("=" * 60)


Testing EGM-Net (Energy-Gated Gabor Mamba Network)

[1] Testing EGM-Net Full...
Input: torch.Size([2, 1, 256, 256])
Output: torch.Size([2, 3, 256, 256])
Coarse: torch.Size([2, 3, 256, 256])
Fine: torch.Size([2, 3, 256, 256])
Energy: torch.Size([2, 1, 256, 256])

Total parameters: 9,133,192
Trainable parameters: 9,133,192

[2] Testing Point Query (Resolution-Free)...
Query coords: torch.Size([2, 1000, 2])
Point output: torch.Size([2, 1000, 3])

[3] Testing EGM-Net Lite...
Lite model parameters: 635,272
Lite output: torch.Size([2, 3, 256, 256])

‚úì All tests passed!


### 2.7 Loss Functions

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

class DiceLoss(nn.Module):

    def __init__(self, smooth: float = 1e-5, reduction: str = "mean"):
        
        super().__init__()
        self.smooth = smooth
        self.reduction = reduction
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        
        # Convert logits to probabilities
        pred = torch.softmax(pred, dim=1)
        
        # Ensure target has same shape as pred for multi-class
        if target.ndim == 3:  # (B, H, W) -> convert to one-hot
            target = F.one_hot(target.long(), num_classes=pred.shape[1])
            target = target.permute(0, 3, 1, 2).float()
        
        # Flatten spatial dimensions
        pred = pred.view(pred.shape[0], pred.shape[1], -1)
        target = target.view(target.shape[0], target.shape[1], -1)
        
        # Compute Dice score
        intersection = torch.sum(pred * target, dim=2)
        union = torch.sum(pred, dim=2) + torch.sum(target, dim=2)
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        
        # Return loss (1 - Dice)
        loss = 1.0 - dice
        
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss

class FocalLoss(nn.Module):

    def __init__(self, alpha: float = 0.25, gamma: float = 2.0,
                 reduction: str = "mean"):
        
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        
        # Get class probabilities
        p = torch.softmax(pred, dim=1)
        
        # Get class log probabilities
        ce = F.cross_entropy(pred, target.long(), reduction='none')
        
        # Get probability of true class
        p_t = torch.gather(p, 1, target.long().unsqueeze(1)).squeeze(1)
        
        # Compute focal loss
        focal_weight = (1.0 - p_t) ** self.gamma
        focal_loss = focal_weight * ce
        
        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        else:
            return focal_loss

class FrequencyLoss(nn.Module):

    def __init__(self, weight: float = 0.1):
        
        super().__init__()
        self.weight = weight
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        
        # Ensure both have batch and channel dimensions
        if pred.ndim == 3:
            pred = pred.unsqueeze(1)
        if target.ndim == 3:
            target = target.unsqueeze(1)
        
        # Flatten to single channel for FFT comparison
        if pred.shape[1] > 1:
            # For multi-channel, convert to grayscale by averaging
            pred = pred.mean(dim=1, keepdim=True)
        if target.shape[1] > 1:
            target = target.mean(dim=1, keepdim=True)
        
        # Apply FFT to convert to frequency domain
        pred_freq = torch.fft.rfft2(pred, dim=(-2, -1), norm="ortho")
        target_freq = torch.fft.rfft2(target, dim=(-2, -1), norm="ortho")
        
        # Compute L2 distance in frequency domain
        # Consider both magnitude and phase information
        loss_real = F.mse_loss(pred_freq.real, target_freq.real, reduction='mean')
        loss_imag = F.mse_loss(pred_freq.imag, target_freq.imag, reduction='mean')
        
        return loss_real + loss_imag

class SpectralDualLoss(nn.Module):

    def __init__(self, spatial_weight: float = 1.0, freq_weight: float = 0.1,
                 use_dice: bool = True, use_focal: bool = True):
        
        super().__init__()
        self.spatial_weight = spatial_weight
        self.freq_weight = freq_weight
        self.use_dice = use_dice
        self.use_focal = use_focal
        
        # Spatial losses
        if use_dice:
            self.dice_loss = DiceLoss(smooth=1e-5)
        
        if use_focal:
            self.focal_loss = FocalLoss(alpha=0.25, gamma=2.0)
        else:
            self.ce_loss = nn.CrossEntropyLoss()
        
        # Frequency loss
        self.freq_loss = FrequencyLoss(weight=freq_weight)
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor,
                return_components: bool = False) -> torch.Tensor:
        
        # Ensure target is on same device as pred
        target = target.to(pred.device)
        
        # Spatial losses
        spatial_loss = 0.0
        losses_dict = {}
        
        if self.use_dice:
            dice = self.dice_loss(pred, target)
            spatial_loss = spatial_loss + dice
            losses_dict['dice'] = dice.item()
        
        if self.use_focal:
            focal = self.focal_loss(pred, target)
            spatial_loss = spatial_loss + focal
            losses_dict['focal'] = focal.item()
        else:
            ce = self.ce_loss(pred, target)
            spatial_loss = spatial_loss + ce
            losses_dict['ce'] = ce.item()
        
        # Frequency loss
        # For frequency loss, we need to extract the predicted class (argmax) and compare
        pred_probs = torch.softmax(pred, dim=1)
        pred_class = torch.argmax(pred_probs, dim=1)  # (B, H, W)
        
        freq = self.freq_loss(pred_class.float(), target.float())
        losses_dict['freq'] = freq.item()
        
        # Weighted combination
        total_loss = (self.spatial_weight * spatial_loss + 
                     self.freq_weight * freq)
        losses_dict['total'] = total_loss.item()
        
        if return_components:
            return total_loss, losses_dict
        else:
            return total_loss

class BoundaryAwareLoss(nn.Module):

    def __init__(self, kernel_size: int = 3, weight: float = 1.0):
        
        super().__init__()
        self.kernel_size = kernel_size
        self.weight = weight
    
    def _compute_boundaries(self, mask: torch.Tensor) -> torch.Tensor:
        
        # Convert to float
        mask = mask.float().unsqueeze(1)  # (B, 1, H, W)
        
        # Compute gradients using Sobel-like operation
        kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
                                dtype=mask.dtype, device=mask.device)
        kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
                                dtype=mask.dtype, device=mask.device)
        
        kernel_x = kernel_x.view(1, 1, 3, 3)
        kernel_y = kernel_y.view(1, 1, 3, 3)
        
        grad_x = F.conv2d(mask, kernel_x, padding=1)
        grad_y = F.conv2d(mask, kernel_y, padding=1)
        
        # Compute magnitude of gradient
        grad_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8)
        
        # Threshold to get boundary pixels
        boundary = (grad_magnitude > 0).float().squeeze(1)
        
        return boundary
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        
        # Get predicted class
        pred_probs = torch.softmax(pred, dim=1)
        pred_class = torch.argmax(pred_probs, dim=1)  # (B, H, W)
        
        # Compute boundary maps
        pred_boundary = self._compute_boundaries(pred_class)
        target_boundary = self._compute_boundaries(target)
        
        # Compute cross-entropy loss weighted by boundary
        ce_loss = F.cross_entropy(pred, target.long(), reduction='none')
        
        # Apply boundary weight (higher loss for boundary pixels)
        boundary_weight = (pred_boundary + target_boundary).clamp(0, 1)
        boundary_weight = 1.0 + boundary_weight  # Weight between 1 and 2
        
        weighted_loss = ce_loss * boundary_weight
        
        return weighted_loss.mean()

if __name__ == "__main__":
    # Test losses
    batch_size, num_classes, height, width = 2, 3, 64, 64
    
    # Create dummy predictions and targets
    pred = torch.randn(batch_size, num_classes, height, width)
    target = torch.randint(0, num_classes, (batch_size, height, width))
    
    # Test SpectralDualLoss
    loss_fn = SpectralDualLoss(spatial_weight=1.0, freq_weight=0.1)
    loss, components = loss_fn(pred, target, return_components=True)
    
    print(f"Total Loss: {loss.item():.4f}")
    for name, value in components.items():
        print(f"  {name}: {value:.4f}")
    
    # Test BoundaryAwareLoss
    boundary_loss_fn = BoundaryAwareLoss()
    boundary_loss = boundary_loss_fn(pred, target)
    print(f"\nBoundary Loss: {boundary_loss.item():.4f}")


Total Loss: 1.7057
  dice: 0.6659
  focal: 0.9038
  freq: 1.3602
  total: 1.7057

Boundary Loss: 2.7721


### 2.8 Metrics

In [10]:
import torch
import torch.nn.functional as F
import numpy as np
from monai.metrics import compute_hausdorff_distance

def count_parameters(model):
    
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class SegmentationMetrics:
    def __init__(self, num_classes, device):
        self.num_classes = num_classes
        self.device = device
        self.reset()
        
    def reset(self):
        self.batches = 0
        self.total_correct_pixels = 0
        self.total_pixels = 0
        
        # Aggregated stats for Precision/Recall/F1 (Global)
        self.tp = torch.zeros(self.num_classes, device=self.device)
        self.fp = torch.zeros(self.num_classes, device=self.device)
        self.fn = torch.zeros(self.num_classes, device=self.device)
        
        # Accumulators for averaging Batch-wise metrics
        self.dice_sum = torch.zeros(self.num_classes, device=self.device)
        self.iou_sum = torch.zeros(self.num_classes, device=self.device)
        self.hd95_sum = torch.zeros(self.num_classes, device=self.device)
        
        # Track valid batches for HD95 (it can be NaN if class is missing)
        self.hd95_counts = torch.zeros(self.num_classes, device=self.device)

    def update(self, preds, targets):
        
        self.batches += 1
        
        # Accuracy
        self.total_correct_pixels += (preds == targets).sum().item()
        self.total_pixels += targets.numel()
        
        # Create one-hot for HD95 and Dice
        # preds_oh: (B, C, H, W)
        preds_oh = F.one_hot(preds, num_classes=self.num_classes).permute(0, 3, 1, 2).float()
        targets_oh = F.one_hot(targets, num_classes=self.num_classes).permute(0, 3, 1, 2).float()
        
        # Helper for Dice/IoU/TP/FP/FN
        for c in range(self.num_classes):
            p_flat = preds_oh[:, c].reshape(-1)
            t_flat = targets_oh[:, c].reshape(-1)
            
            intersection = (p_flat * t_flat).sum()
            union = p_flat.sum() + t_flat.sum()
            
            # Global TP/FP/FN accumulation
            self.tp[c] += intersection
            self.fp[c] += (p_flat.sum() - intersection)
            self.fn[c] += (t_flat.sum() - intersection)
            
            # Batch-wise Dice/IoU accumulation
            dice = (2. * intersection + 1e-6) / (union + 1e-6)
            iou = (intersection + 1e-6) / (union - intersection + 1e-6)
            
            self.dice_sum[c] += dice
            self.iou_sum[c] += iou
            
        # HD95 Compliance (MONAI)
        # compute_hausdorff_distance expects (B, C, spatial...)
        # include_background=True usually, but we iterate.
        # We can compute all classes at once.
        try:
            import warnings
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                
                # percentile=95
                hd95_batch = compute_hausdorff_distance(
                    y_pred=preds_oh, 
                    y=targets_oh, 
                    include_background=True, 
                    percentile=95.0,
                    spacing=None  # Pixel space
                )
            # hd95_batch is (B, C)
            
            for c in range(self.num_classes):
                # Filter NaNs/Infs (happens if class missing in both pred and target, or just one)
                # MONAI returns NaN if one is empty. We mostly care if target exists.
                # Common practice: if target is empty, skip. If target exists but pred empty, HD is high (inf).
                # MONAI behavior: Nan if both empty. Inf if one empty.
                
                valid_vals = hd95_batch[:, c]
                valid_mask = ~torch.isnan(valid_vals) & ~torch.isinf(valid_vals)
                
                if valid_mask.any():
                    self.hd95_sum[c] += valid_vals[valid_mask].sum()
                    self.hd95_counts[c] += valid_mask.sum()
                    
        except Exception as e:
            # Fallback or strict error? 
            # Often happens if shapes are weird or empty batch.
            pass

    def compute(self):
        
        metrics = {}
        
        # Global Accuracy
        metrics['accuracy'] = self.total_correct_pixels / max(self.total_pixels, 1)
        
        # Per-class metrics
        dice_scores = []
        iou_scores = []
        hd95_scores = []
        precision_scores = []
        recall_scores = []
        f1_scores = []
        
        for c in range(self.num_classes):
            # Batch-averaged Dice/IoU
            dice_scores.append((self.dice_sum[c] / max(self.batches, 1)).item())
            iou_scores.append((self.iou_sum[c] / max(self.batches, 1)).item())
            
            # Batch-averaged HD95
            if self.hd95_counts[c] > 0:
                hd95_scores.append((self.hd95_sum[c] / self.hd95_counts[c]).item())
            else:
                hd95_scores.append(float('nan')) # Or 0.0 or inf
            
            # Global-based Precision/Recall/F1
            p = (self.tp[c] / (self.tp[c] + self.fp[c] + 1e-6)).item()
            r = (self.tp[c] / (self.tp[c] + self.fn[c] + 1e-6)).item()
            f1 = 2 * p * r / (p + r + 1e-6) if (p + r) > 0 else 0.0
            
            precision_scores.append(p)
            recall_scores.append(r)
            f1_scores.append(f1)
            
        metrics['dice_scores'] = dice_scores
        metrics['iou'] = iou_scores
        metrics['hd95'] = hd95_scores
        metrics['precision'] = precision_scores
        metrics['recall'] = recall_scores
        metrics['f1_score'] = f1_scores
        
        return metrics




In [11]:
print('‚úÖ All modules loaded!')

‚úÖ All modules loaded!


---
## 3Ô∏è‚É£ Dataset

Download and preprocess ACDC cardiac MRI dataset.

### 3.1 Configuration

In [12]:
DATASET = 'ACDC'
DRIVE_FOLDER_ID = '1EelzBVjIoDQ4uzt0_2JzmF_PuUHsD93e'
RAW_DATA_DIR = f'./data/{DATASET}'
PREPROCESSED_DIR = f'./preprocessed_data/{DATASET}'
IMG_SIZE = 224
NUM_CLASSES = 4
CLASS_NAMES = ['Background', 'RV', 'Myocardium', 'LV']

os.makedirs(RAW_DATA_DIR, exist_ok=True)
print(f"üìä Dataset: {DATASET}")
print(f"   Classes: {CLASS_NAMES}")

üìä Dataset: ACDC
   Classes: ['Background', 'RV', 'Myocardium', 'LV']


### 3.2 Download from Kaggle

In [13]:
import kagglehub
acdc_path = kagglehub.dataset_download('samdazel/automated-cardiac-diagnosis-challenge-miccai17')

print('Data source import complete.')
print(f'Data path: {acdc_path}')
# Xem c·∫•u tr√∫c th∆∞ m·ª•c
import os
for item in os.listdir(acdc_path):
    print(f'  {item}')


RAW_DATA_DIR = os.path.join(acdc_path, 'database')
print(f'RAW_DATA_DIR: {RAW_DATA_DIR}')
# Xem c·∫•u tr√∫c
for item in os.listdir(RAW_DATA_DIR):
    print(f'  {item}')

Downloading from https://www.kaggle.com/api/v1/datasets/download/samdazel/automated-cardiac-diagnosis-challenge-miccai17?dataset_version_number=2...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2.11G/2.11G [01:39<00:00, 22.7MB/s]

Extracting files...





Data source import complete.
Data path: /root/.cache/kagglehub/datasets/samdazel/automated-cardiac-diagnosis-challenge-miccai17/versions/2
  database
RAW_DATA_DIR: /root/.cache/kagglehub/datasets/samdazel/automated-cardiac-diagnosis-challenge-miccai17/versions/2/database
  training
  testing
  MANDATORY_CITATION.md


### 3.3 Preprocess (NIfTI ‚Üí NumPy)

In [14]:
import os
import sys
import configparser
import numpy as np
import nibabel as nib
from tqdm import tqdm
import json
from skimage.transform import resize

# Add project root to path

def normalize_intensity(image):
    
    # Clip outliers
    p05 = np.percentile(image, 0.5)
    p995 = np.percentile(image, 99.5)
    image = np.clip(image, p05, p995)
    
    # Z-score
    mean = np.mean(image)
    std = np.std(image)
    if std > 0:
        return (image - mean) / std
    return image

def preprocess_single_patient_acdc(patient_path, output_dir, target_size=(224, 224)):
    
    patient_folder = os.path.basename(patient_path)
    info_cfg_path = os.path.join(patient_path, 'Info.cfg')
    
    # T·∫°o folder output cho slice
    img_save_dir = os.path.join(output_dir, 'images')
    mask_save_dir = os.path.join(output_dir, 'masks')
    os.makedirs(img_save_dir, exist_ok=True)
    os.makedirs(mask_save_dir, exist_ok=True)
    
    # ƒê·ªçc config ƒë·ªÉ bi·∫øt frame n√†o l√† ED, frame n√†o l√† ES
    if not os.path.exists(info_cfg_path):
        return 0
    
    try:
        parser = configparser.ConfigParser()
        with open(info_cfg_path, 'r') as f:
            config_string = '[DEFAULT]\n' + f.read()
        parser.read_string(config_string)
        ed_frame = int(parser['DEFAULT']['ED'])
        es_frame = int(parser['DEFAULT']['ES'])
    except Exception as e:
        print(f"  Error reading Info.cfg for {patient_folder}: {e}")
        return 0
    
    slices_saved = 0
    
    for frame_num, frame_name in [(ed_frame, 'ED'), (es_frame, 'ES')]:
        img_filename = f'{patient_folder}_frame{frame_num:02d}.nii.gz'
        mask_filename = f'{patient_folder}_frame{frame_num:02d}_gt.nii.gz'
        
        # T√¨m file (support c·∫£ .nii v√† .nii.gz)
        img_path = None
        mask_path = None
        for suffix in ['.gz', '']:
            test_img = os.path.join(patient_path, img_filename.replace('.gz', '') if suffix == '' else img_filename)
            test_mask = os.path.join(patient_path, mask_filename.replace('.gz', '') if suffix == '' else mask_filename)
            if os.path.exists(test_img):
                img_path = test_img
                mask_path = test_mask
                break
        
        if img_path is None or not os.path.exists(img_path):
            continue
            
        try:
            # Load NIfTI
            img_nii = nib.load(img_path)
            img_data = img_nii.get_fdata() # (H, W, D)
            
            mask_data = None
            if os.path.exists(mask_path):
                mask_data = nib.load(mask_path).get_fdata()
            else:
                continue # B·ªè qua n·∫øu kh√¥ng c√≥ mask
            
            # 1. Normalize Intensity TR∆Ø·ªöC khi resize (t√≠nh tr√™n to√†n volume 3D)
            img_data = normalize_intensity(img_data)
            
            num_slices = img_data.shape[2]
            
            # 2. X·ª≠ l√Ω t·ª´ng slice v√† l∆∞u ngay l·∫≠p t·ª©c
            for i in range(num_slices):
                slice_img = img_data[:, :, i]
                slice_mask = mask_data[:, :, i]
                
                # B·ªè qua slice ƒëen thui (kh√¥ng c√≥ th√¥ng tin) ƒë·ªÉ tr√°nh nhi·ªÖu training
                if np.sum(slice_img) == 0:
                    continue
                
                # Resize (L∆∞u √Ω: resize c·ªßa skimage range input=output, ƒë√£ normalize th√¨ v·∫´n gi·ªØ range)
                slice_img_resized = resize(
                    slice_img, target_size, order=1, preserve_range=True, anti_aliasing=True, mode='reflect'
                ).astype(np.float32)
                
                slice_mask_resized = resize(
                    slice_mask, target_size, order=0, preserve_range=True, anti_aliasing=False, mode='reflect'
                ).astype(np.uint8) # Mask ph·∫£i l√† int
                
                # T·∫°o t√™n file: patient001_ED_slice005.npy
                file_id = f"{patient_folder}_{frame_name}_slice{i:03d}"
                
                np.save(os.path.join(img_save_dir, f"{file_id}.npy"), slice_img_resized)
                np.save(os.path.join(mask_save_dir, f"{file_id}.npy"), slice_mask_resized)
                
                slices_saved += 1
                
        except Exception as e:
            print(f"  Error processing {patient_folder} frame {frame_num}: {e}")
            continue
            
    return slices_saved

def preprocess_acdc_dataset(input_dir, output_dir, target_size=(224, 224)):
    
    # L·∫•y danh s√°ch patient
    patient_folders = sorted([
        os.path.join(input_dir, d) 
        for d in os.listdir(input_dir) 
        if os.path.isdir(os.path.join(input_dir, d)) and d.startswith('patient')
    ])
    
    print(f"Found {len(patient_folders)} patients. Outputting 2D slices to {output_dir}")
    
    total_slices = 0
    
    for patient_path in tqdm(patient_folders, desc="Processing ACDC"):
        slices = preprocess_single_patient_acdc(patient_path, output_dir, target_size)
        total_slices += slices
        
    print(f"\nCompleted! Saved {total_slices} slices total.")
    print(f"Images: {os.path.join(output_dir, 'images')}")
    print(f"Masks:  {os.path.join(output_dir, 'masks')}")


training_dir = os.path.join(acdc_path, 'database', 'training')
output_dir = './preprocessed_data/ACDC'
print(f"üìÇ Input: {training_dir}")
print(f"üìÇ Output: {output_dir}")
preprocess_acdc_dataset(training_dir, output_dir, target_size=(224, 224))

üìÇ Input: /root/.cache/kagglehub/datasets/samdazel/automated-cardiac-diagnosis-challenge-miccai17/versions/2/database/training
üìÇ Output: ./preprocessed_data/ACDC
Found 100 patients. Outputting 2D slices to ./preprocessed_data/ACDC


Processing ACDC: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:16<00:00,  6.16it/s]


Completed! Saved 1902 slices total.
Images: ./preprocessed_data/ACDC/images
Masks:  ./preprocessed_data/ACDC/masks





### 3.4 Create DataLoaders

In [15]:
from torch.utils.data import Dataset, DataLoader
import glob

class ACDCDataset2D(Dataset):
    """Fast memmap-based dataset for preprocessed 2D slices."""
    
    def __init__(self, data_dir, split='train'):
        img_dir = os.path.join(data_dir, 'images')
        all_files = sorted(glob.glob(os.path.join(img_dir, '*.npy')))
        
        np.random.seed(42)
        indices = np.random.permutation(len(all_files))
        split_idx = int(0.8 * len(all_files))
        
        if split == 'train':
            self.files = [all_files[i] for i in indices[:split_idx]]
        else:
            self.files = [all_files[i] for i in indices[split_idx:]]
        
        self.mask_dir = os.path.join(data_dir, 'masks')
        print(f"   {split.upper()}: {len(self.files)} 2D slices")
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx]
        mask_path = os.path.join(self.mask_dir, os.path.basename(img_path))
        
        # memmap for fast disk access
        img = np.load(img_path, mmap_mode='r').copy()
        seg = np.load(mask_path, mmap_mode='r').copy()
        
        # Already 2D (H, W) -> add channel dim -> (1, H, W)
        return torch.from_numpy(img).unsqueeze(0).float(), torch.from_numpy(seg).long()

BATCH_SIZE = 4
train_ds = ACDCDataset2D(PREPROCESSED_DIR, 'train')
val_ds = ACDCDataset2D(PREPROCESSED_DIR, 'val')
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
print(f"‚úÖ Memmap DataLoader ready")

   TRAIN: 1521 2D slices
   VAL: 381 2D slices
‚úÖ Memmap DataLoader ready




---
## 4Ô∏è‚É£ Training

### 4.1 Create Model

In [16]:
model = EGMNet(in_channels=1, num_classes=NUM_CLASSES, img_size=IMG_SIZE).to(device)
print(f"‚úÖ EGM-Net: {sum(p.numel() for p in model.parameters()):,} parameters")

‚úÖ EGM-Net: 9,133,514 parameters


### 4.2 Loss & Optimizer

In [17]:
criterion = SpectralDualLoss(spatial_weight=1.0, freq_weight=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

### 4.3 Training Loop

In [None]:
# === TRAINING ===
from torch.cuda.amp import autocast, GradScaler

NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
EARLY_STOP_PATIENCE = 20
NUM_CLASSES = 4
CLASS_NAMES = {0: 'BG', 1: 'RV', 2: 'MYO', 3: 'LV'}

print(f"Device: {device}")
model = EGMNet(in_channels=1, num_classes=NUM_CLASSES, img_size=IMG_SIZE).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

criterion = SpectralDualLoss(spatial_weight=1.0, freq_weight=0.1).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10)
scaler = GradScaler()

best_dice = 0.0
epochs_no_improve = 0

for epoch in range(NUM_EPOCHS):
    torch.cuda.empty_cache()
    
    model.train()
    train_loss = 0.0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        
        with autocast():
            outputs = model(images)
            pred = outputs['output'] if isinstance(outputs, dict) else outputs
            loss = criterion(pred, masks)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    
    torch.cuda.empty_cache()
    
    model.eval()
    val_dice_sum = 0
    val_count = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            with autocast():
                outputs = model(images)
            pred = outputs['output'] if isinstance(outputs, dict) else outputs
            pred_cls = pred.argmax(dim=1)
            for c in range(1, NUM_CLASSES):
                inter = ((pred_cls == c) & (masks == c)).sum()
                union = (pred_cls == c).sum() + (masks == c).sum()
                val_dice_sum += (2 * inter / (union + 1e-5)).item()
            val_count += 1
    
    avg_dice = val_dice_sum / (val_count * (NUM_CLASSES - 1))
    scheduler.step(avg_dice)
    
    print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, Dice={avg_dice:.4f}")
    
    if avg_dice > best_dice:
        best_dice = avg_dice
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"   ‚úÖ Best model saved!")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
    
    if epochs_no_improve >= EARLY_STOP_PATIENCE:
        print(f"Early stopping!")
        break

print(f"Training complete! Best Dice: {best_dice:.4f}")

Device: cuda


  scaler = GradScaler()


Parameters: 9,133,514


  with autocast():
Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:58<00:00,  1.06it/s]
  with autocast():


Epoch 1: Loss=0.6124, Dice=0.5176
   ‚úÖ Best model saved!


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:55<00:00,  1.07it/s]


Epoch 2: Loss=0.4977, Dice=0.6052
   ‚úÖ Best model saved!


Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:58<00:00,  1.06it/s]


Epoch 3: Loss=0.4342, Dice=0.6449
   ‚úÖ Best model saved!


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:56<00:00,  1.07it/s]


Epoch 4: Loss=0.3814, Dice=0.6960
   ‚úÖ Best model saved!


Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:54<00:00,  1.07it/s]


Epoch 5: Loss=0.3510, Dice=0.6771


Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:53<00:00,  1.08it/s]


Epoch 6: Loss=0.3233, Dice=0.7195
   ‚úÖ Best model saved!


Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:53<00:00,  1.08it/s]


Epoch 7: Loss=0.3036, Dice=0.7400
   ‚úÖ Best model saved!


Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:55<00:00,  1.07it/s]


Epoch 8: Loss=0.2881, Dice=0.7424
   ‚úÖ Best model saved!


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:54<00:00,  1.07it/s]


Epoch 9: Loss=0.2767, Dice=0.7636
   ‚úÖ Best model saved!


Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:52<00:00,  1.08it/s]


Epoch 10: Loss=0.2654, Dice=0.7768
   ‚úÖ Best model saved!


Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:51<00:00,  1.08it/s]


Epoch 11: Loss=0.2558, Dice=0.7697


Epoch 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:53<00:00,  1.08it/s]


Epoch 12: Loss=0.2499, Dice=0.7764


Epoch 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:52<00:00,  1.08it/s]


Epoch 13: Loss=0.2401, Dice=0.7813
   ‚úÖ Best model saved!


Epoch 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 381/381 [05:53<00:00,  1.08it/s]


Epoch 14: Loss=0.2369, Dice=0.7853
   ‚úÖ Best model saved!


Epoch 15:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 276/381 [04:16<01:41,  1.04it/s]

In [None]:
# === EVALUATE FUNCTION ===
def evaluate_metrics(model, dataloader, device, num_classes):
    """Evaluate model on a dataloader and return metrics."""
    model.eval()
    metrics_tracker = SegmentationMetrics(num_classes=num_classes, device=device)
    
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            pred = outputs['output'] if isinstance(outputs, dict) else outputs
            pred_cls = pred.argmax(dim=1)
            metrics_tracker.update(pred_cls, masks)
    
    return metrics_tracker.compute()

# === FINAL TEST EVALUATION ===
def final_evaluation(model, test_loader, device, num_classes, class_names):
    """Run final evaluation on test set."""
    print("\n" + "="*60)
    print("üìä FINAL TEST EVALUATION")
    print("="*60)
    
    # Load best model
    checkpoint = torch.load('best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']}")
    
    metrics = evaluate_metrics(model, test_loader, device, num_classes)
    
    print("\n--- Per-Class Results ---")
    for c in range(num_classes):
        print(f"   {class_names[c]:<5}: Dice={metrics['dice_scores'][c]:.4f}, "
              f"IoU={metrics['iou'][c]:.4f}, HD95={metrics['hd95'][c]:.2f}")
    
    avg_dice = np.mean(metrics['dice_scores'][1:])
    avg_iou = np.mean(metrics['iou'][1:])
    avg_hd95 = np.nanmean(metrics['hd95'][1:])
    
    print("\n--- Summary (Foreground Only) ---")
    print(f"   Avg Dice: {avg_dice:.4f}")
    print(f"   Avg IoU:  {avg_iou:.4f}")
    print(f"   Avg HD95: {avg_hd95:.2f}")
    print(f"   Accuracy: {metrics['accuracy']:.4f}")
    
    return metrics

In [None]:
# Final evaluation on test set
final_metrics = final_evaluation(model, val_loader, device, NUM_CLASSES, CLASS_NAMES)

---
## 5Ô∏è‚É£ Visualization

In [None]:
model.eval()
images, masks = next(iter(val_loader))
images, masks = images.to(device), masks.to(device)

with torch.no_grad():
    outputs = model(images)
    pred = outputs['output'] if isinstance(outputs, dict) else outputs
    pred_cls = pred.argmax(dim=1)

fig, axes = plt.subplots(3, 4, figsize=(16, 12))
for i in range(4):
    axes[0, i].imshow(images[i, 0].cpu(), cmap='gray')
    axes[0, i].set_title('Input')
    axes[1, i].imshow(masks[i].cpu(), cmap='jet', vmin=0, vmax=NUM_CLASSES-1)
    axes[1, i].set_title('Ground Truth')
    axes[2, i].imshow(pred_cls[i].cpu(), cmap='jet', vmin=0, vmax=NUM_CLASSES-1)
    axes[2, i].set_title('Prediction')
for ax in axes.flatten():
    ax.axis('off')
plt.tight_layout()
plt.show()