## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install -q torch torchvision numpy matplotlib tqdm

In [None]:
# Clone repository
!git clone https://github.com/QuocKhanhLuong/FourierNetwork.git
%cd FourierNetwork

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"üñ•Ô∏è Using device: {device}")
if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Import Models

In [None]:
# Import our models
from monogenic import EnergyMap, MonogenicSignal, BoundaryDetector
from gabor_implicit import GaborBasis, GaborNet, ImplicitSegmentationHead
from egm_net import EGMNet, EGMNetLite
from spectral_mamba import SpectralVMUNet

print("‚úÖ All modules imported successfully!")

## üè• REAL DATA TRAINING (Synapse Multi-Organ Dataset)

This section downloads and trains on the **Synapse Multi-Organ CT Dataset** - a standard benchmark for medical image segmentation.

**Dataset Info:**
- 30 abdominal CT scans with 3779 axial contrast-enhanced clinical CT images
- 8 abdominal organs: Aorta, Gallbladder, Spleen, Left Kidney, Right Kidney, Liver, Pancreas, Stomach
- Training: 18 cases, Testing: 12 cases

In [None]:
# Download Synapse dataset from Google Drive or alternative sources
import os
import zipfile
from google.colab import drive

# Option 1: Mount Google Drive (if you have the dataset there)
# drive.mount('/content/drive')
# DATA_PATH = '/content/drive/MyDrive/Synapse'

# Option 2: Download from public source (preprocessed npz format)
# Note: You may need to request access from https://www.synapse.org/
# Here we provide a preprocessed version structure

# Create data directory
os.makedirs('data/Synapse', exist_ok=True)

print("""
üì• DATASET DOWNLOAD OPTIONS:

Option 1: Download preprocessed Synapse dataset
-------------------------------------------------
!gdown --id <your_google_drive_file_id> -O data/synapse.zip
!unzip -q data/synapse.zip -d data/

Option 2: Use your own dataset
-------------------------------
Upload your data to: data/Synapse/
Structure:
  data/Synapse/
    ‚îú‚îÄ‚îÄ train_npz/
    ‚îÇ   ‚îú‚îÄ‚îÄ case0001_slice001.npz
    ‚îÇ   ‚îú‚îÄ‚îÄ case0001_slice002.npz
    ‚îÇ   ‚îî‚îÄ‚îÄ ...
    ‚îî‚îÄ‚îÄ test_vol_h5/
        ‚îú‚îÄ‚îÄ case0001.npy.h5
        ‚îî‚îÄ‚îÄ ...

Option 3: Use ISIC Skin Lesion dataset (easier to download)
------------------------------------------------------------
We'll download ISIC 2018 dataset for demonstration.
""")

USE_ISIC = True  # Set to False if you have Synapse dataset

In [None]:
# Download ISIC 2018 Skin Lesion Dataset (easy to access)
if USE_ISIC:
    print("üì• Downloading ISIC 2018 Dataset...")
    
    # Create directories
    os.makedirs('data/ISIC2018/images', exist_ok=True)
    os.makedirs('data/ISIC2018/masks', exist_ok=True)
    
    # Download training images and masks
    !pip install -q gdown
    
    # ISIC 2018 Training Data (subset for demo - 500 images)
    # Full dataset: https://challenge.isic-archive.com/data/
    !gdown --fuzzy "https://drive.google.com/file/d/1E2xHt5jqXLxWCjWwIZ9lNj9lE4kZ8vIz/view?usp=sharing" -O data/ISIC2018_subset.zip 2>/dev/null || echo "Using backup download..."
    
    # If gdown fails, use wget from alternative source
    if not os.path.exists('data/ISIC2018_subset.zip'):
        print("Downloading from alternative source...")
        # Create synthetic data if download fails (for demo purposes)
        print("‚ö†Ô∏è Could not download dataset. Creating synthetic medical data for demo...")
        CREATE_SYNTHETIC = True
    else:
        !unzip -q data/ISIC2018_subset.zip -d data/ISIC2018/
        CREATE_SYNTHETIC = False
else:
    CREATE_SYNTHETIC = False

print("‚úÖ Data setup complete!")

In [None]:
# Create synthetic medical imaging dataset (fallback if download fails)
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import glob

class SyntheticMedicalDataset(Dataset):
    """Generate synthetic medical-like images with organs for demo."""
    def __init__(self, num_samples=500, img_size=256, num_classes=2, split='train'):
        self.num_samples = num_samples
        self.img_size = img_size
        self.num_classes = num_classes
        self.split = split
        np.random.seed(42 if split == 'train' else 123)
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        size = self.img_size
        img = np.zeros((size, size), dtype=np.float32)
        mask = np.zeros((size, size), dtype=np.int64)
        
        # Background texture
        img += 0.1 * np.random.randn(size, size)
        
        # Add random elliptical organs
        num_organs = np.random.randint(1, 4)
        y, x = np.meshgrid(np.arange(size), np.arange(size), indexing='ij')
        
        for i in range(num_organs):
            # Random ellipse parameters
            cx = np.random.randint(size//4, 3*size//4)
            cy = np.random.randint(size//4, 3*size//4)
            rx = np.random.randint(size//8, size//3)
            ry = np.random.randint(size//8, size//3)
            angle = np.random.rand() * np.pi
            
            # Rotated ellipse
            cos_a, sin_a = np.cos(angle), np.sin(angle)
            x_rot = cos_a * (x - cx) + sin_a * (y - cy)
            y_rot = -sin_a * (x - cx) + cos_a * (y - cy)
            ellipse_mask = (x_rot/rx)**2 + (y_rot/ry)**2 < 1
            
            intensity = 0.3 + 0.5 * np.random.rand()
            img[ellipse_mask] = intensity
            mask[ellipse_mask] = 1  # Foreground class
        
        # Add noise
        img = img + 0.05 * np.random.randn(size, size)
        img = np.clip(img, 0, 1)
        
        # Convert to tensors
        img = torch.from_numpy(img).unsqueeze(0).float()
        mask = torch.from_numpy(mask).long()
        
        # Data augmentation for training
        if self.split == 'train' and np.random.rand() > 0.5:
            img = torch.flip(img, dims=[2])
            mask = torch.flip(mask, dims=[1])
        
        return img, mask


class ISICDataset(Dataset):
    """ISIC Skin Lesion Dataset."""
    def __init__(self, data_dir, img_size=256, split='train'):
        self.img_size = img_size
        self.split = split
        
        img_dir = os.path.join(data_dir, 'images')
        mask_dir = os.path.join(data_dir, 'masks')
        
        self.images = sorted(glob.glob(os.path.join(img_dir, '*.jpg')))
        self.masks = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
        
        # Split 80/20
        split_idx = int(0.8 * len(self.images))
        if split == 'train':
            self.images = self.images[:split_idx]
            self.masks = self.masks[:split_idx]
        else:
            self.images = self.images[split_idx:]
            self.masks = self.masks[split_idx:]
            
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Load image
        img = Image.open(self.images[idx]).convert('L')  # Grayscale
        img = img.resize((self.img_size, self.img_size), Image.BILINEAR)
        img = np.array(img, dtype=np.float32) / 255.0
        
        # Load mask
        mask = Image.open(self.masks[idx]).convert('L')
        mask = mask.resize((self.img_size, self.img_size), Image.NEAREST)
        mask = (np.array(mask) > 127).astype(np.int64)
        
        img = torch.from_numpy(img).unsqueeze(0).float()
        mask = torch.from_numpy(mask).long()
        
        return img, mask


class SynapseDataset(Dataset):
    """Synapse Multi-Organ Dataset (NPZ format)."""
    def __init__(self, data_dir, img_size=256, split='train'):
        self.img_size = img_size
        self.split = split
        
        if split == 'train':
            self.data_files = sorted(glob.glob(os.path.join(data_dir, 'train_npz', '*.npz')))
        else:
            self.data_files = sorted(glob.glob(os.path.join(data_dir, 'test_npz', '*.npz')))
            
    def __len__(self):
        return len(self.data_files)
    
    def __getitem__(self, idx):
        data = np.load(self.data_files[idx])
        img = data['image']
        mask = data['label']
        
        # Resize if needed
        if img.shape != (self.img_size, self.img_size):
            img = np.array(Image.fromarray(img).resize((self.img_size, self.img_size), Image.BILINEAR))
            mask = np.array(Image.fromarray(mask.astype(np.uint8)).resize((self.img_size, self.img_size), Image.NEAREST))
        
        img = torch.from_numpy(img).unsqueeze(0).float()
        mask = torch.from_numpy(mask).long()
        
        return img, mask


# Create datasets
IMG_SIZE = 256
NUM_CLASSES = 2  # Background + Foreground (or 9 for Synapse)

if 'CREATE_SYNTHETIC' in dir() and CREATE_SYNTHETIC:
    print("üìä Creating synthetic medical dataset...")
    train_dataset = SyntheticMedicalDataset(num_samples=500, img_size=IMG_SIZE, split='train')
    val_dataset = SyntheticMedicalDataset(num_samples=100, img_size=IMG_SIZE, split='val')
elif os.path.exists('data/ISIC2018/images') and len(glob.glob('data/ISIC2018/images/*.jpg')) > 0:
    print("üìä Loading ISIC 2018 dataset...")
    train_dataset = ISICDataset('data/ISIC2018', img_size=IMG_SIZE, split='train')
    val_dataset = ISICDataset('data/ISIC2018', img_size=IMG_SIZE, split='val')
elif os.path.exists('data/Synapse/train_npz'):
    print("üìä Loading Synapse dataset...")
    train_dataset = SynapseDataset('data/Synapse', img_size=IMG_SIZE, split='train')
    val_dataset = SynapseDataset('data/Synapse', img_size=IMG_SIZE, split='val')
    NUM_CLASSES = 9  # 8 organs + background
else:
    print("üìä Creating synthetic medical dataset (fallback)...")
    train_dataset = SyntheticMedicalDataset(num_samples=500, img_size=IMG_SIZE, split='train')
    val_dataset = SyntheticMedicalDataset(num_samples=100, img_size=IMG_SIZE, split='val')

# Create dataloaders
BATCH_SIZE = 4
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"‚úÖ Dataset loaded!")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Validation samples: {len(val_dataset)}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Number of classes: {NUM_CLASSES}")

In [None]:
# Visualize some training samples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    img, mask = train_dataset[i * 50]
    
    axes[0, i].imshow(img[0], cmap='gray')
    axes[0, i].set_title(f'Image {i+1}')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(mask, cmap='viridis')
    axes[1, i].set_title(f'Mask {i+1}')
    axes[1, i].axis('off')

plt.suptitle('Training Samples', fontsize=14)
plt.tight_layout()
plt.show()

## üéØ Training Configuration

In [None]:
# Training configuration
config = {
    # Model
    'model': 'egm_net',           # 'egm_net', 'egm_net_lite', 'spectral_vmamba'
    'in_channels': 1,
    'num_classes': NUM_CLASSES,
    'img_size': IMG_SIZE,
    'base_channels': 64,
    
    # Training
    'num_epochs': 100,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'batch_size': BATCH_SIZE,
    
    # Loss weights
    'dice_weight': 1.0,
    'ce_weight': 1.0,
    'boundary_weight': 0.5,
    
    # Implicit representation
    'num_points': 2048,           # Points sampled for implicit loss
    'boundary_ratio': 0.5,        # Ratio of points near boundaries
    
    # Checkpointing
    'save_every': 10,
    'checkpoint_dir': './checkpoints',
    
    # Early stopping
    'patience': 20,
}

# Create checkpoint directory
os.makedirs(config['checkpoint_dir'], exist_ok=True)

print("üìã Training Configuration:")
for k, v in config.items():
    print(f"   {k}: {v}")

In [None]:
# Create model
print(f"Creating {config['model']} model...")

if config['model'] == 'egm_net':
    model = EGMNet(
        in_channels=config['in_channels'],
        num_classes=config['num_classes'],
        img_size=config['img_size'],
        base_channels=config['base_channels'],
        num_stages=4,
        encoder_depth=2
    )
elif config['model'] == 'egm_net_lite':
    model = EGMNetLite(
        in_channels=config['in_channels'],
        num_classes=config['num_classes'],
        img_size=config['img_size']
    )
else:  # spectral_vmamba
    model = SpectralVMUNet(
        in_channels=config['in_channels'],
        out_channels=config['num_classes'],
        img_size=config['img_size'],
        base_channels=config['base_channels'],
        num_stages=4
    )

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úÖ Model created!")
print(f"   Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
print(f"   Trainable parameters: {trainable_params:,}")

## üìà Loss Functions & Metrics

In [None]:
# Loss functions
class DiceLoss(nn.Module):
    """Dice loss for segmentation."""
    def __init__(self, smooth=1e-5):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        # pred: (B, C, H, W) logits
        # target: (B, H, W) class indices
        pred = F.softmax(pred, dim=1)
        num_classes = pred.shape[1]
        
        target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()


class BoundaryLoss(nn.Module):
    """Boundary-aware loss using Sobel edge detection."""
    def __init__(self):
        super().__init__()
        # Sobel kernels
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
        self.register_buffer('sobel_x', sobel_x.view(1, 1, 3, 3))
        self.register_buffer('sobel_y', sobel_y.view(1, 1, 3, 3))
        
    def get_boundaries(self, mask):
        # mask: (B, H, W)
        mask = mask.float().unsqueeze(1)
        edge_x = F.conv2d(mask, self.sobel_x, padding=1)
        edge_y = F.conv2d(mask, self.sobel_y, padding=1)
        edges = torch.sqrt(edge_x**2 + edge_y**2)
        return (edges > 0.5).float()
    
    def forward(self, pred, target):
        # Get boundary regions
        boundaries = self.get_boundaries(target)
        
        # Weight loss by boundary
        pred_probs = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, pred.shape[1]).permute(0, 3, 1, 2).float()
        
        # BCE at boundaries
        boundary_loss = F.binary_cross_entropy(
            pred_probs * boundaries, 
            target_one_hot * boundaries,
            reduction='sum'
        ) / (boundaries.sum() + 1e-6)
        
        return boundary_loss


class CombinedLoss(nn.Module):
    """Combined loss for EGM-Net training."""
    def __init__(self, dice_weight=1.0, ce_weight=1.0, boundary_weight=0.5, num_classes=2):
        super().__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.boundary_weight = boundary_weight
        
        self.dice_loss = DiceLoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.boundary_loss = BoundaryLoss()
        
    def forward(self, outputs, targets):
        # For EGM-Net, outputs is a dict
        if isinstance(outputs, dict):
            pred = outputs['output']
            coarse = outputs.get('coarse')
            fine = outputs.get('fine')
            
            # Main loss
            loss = self.dice_weight * self.dice_loss(pred, targets)
            loss += self.ce_weight * self.ce_loss(pred, targets)
            loss += self.boundary_weight * self.boundary_loss(pred, targets)
            
            # Auxiliary losses (coarse and fine branches)
            if coarse is not None:
                loss += 0.3 * self.ce_loss(coarse, targets)
            if fine is not None:
                loss += 0.3 * self.ce_loss(fine, targets)
                
            return loss
        else:
            # Standard output (SpectralVMUNet)
            loss = self.dice_weight * self.dice_loss(outputs, targets)
            loss += self.ce_weight * self.ce_loss(outputs, targets)
            loss += self.boundary_weight * self.boundary_loss(outputs, targets)
            return loss


# Metrics
def compute_dice(pred, target, num_classes):
    """Compute per-class Dice scores."""
    dice_scores = []
    pred_classes = torch.argmax(pred, dim=1)
    
    for c in range(num_classes):
        pred_c = (pred_classes == c).float()
        target_c = (target == c).float()
        
        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum()
        
        if union > 0:
            dice = (2.0 * intersection) / union
        else:
            dice = torch.tensor(1.0)  # Both empty = perfect
            
        dice_scores.append(dice.item())
    
    return dice_scores


def compute_iou(pred, target, num_classes):
    """Compute per-class IoU scores."""
    iou_scores = []
    pred_classes = torch.argmax(pred, dim=1)
    
    for c in range(num_classes):
        pred_c = (pred_classes == c).float()
        target_c = (target == c).float()
        
        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum() - intersection
        
        if union > 0:
            iou = intersection / union
        else:
            iou = torch.tensor(1.0)
            
        iou_scores.append(iou.item())
    
    return iou_scores


print("‚úÖ Loss functions and metrics defined!")

## üöÄ Training Loop

In [None]:
# Full training function
def train_model(model, train_loader, val_loader, config, device):
    """Complete training loop with validation and checkpointing."""
    
    # Loss and optimizer
    criterion = CombinedLoss(
        dice_weight=config['dice_weight'],
        ce_weight=config['ce_weight'],
        boundary_weight=config['boundary_weight'],
        num_classes=config['num_classes']
    ).to(device)
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # Learning rate scheduler (cosine annealing)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['num_epochs'],
        eta_min=1e-6
    )
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_dice': [],
        'val_iou': [],
        'learning_rate': []
    }
    
    best_val_dice = 0.0
    patience_counter = 0
    
    print("\n" + "="*60)
    print("üöÄ Starting Training")
    print("="*60)
    
    for epoch in range(config['num_epochs']):
        # =============== Training ===============
        model.train()
        train_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]")
        
        for batch_idx, (images, masks) in enumerate(train_pbar):
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            
            # Compute loss
            loss = criterion(outputs, masks)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)
        
        # =============== Validation ===============
        model.eval()
        val_loss = 0.0
        all_dice = []
        all_iou = []
        
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Val]"):
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                
                # Get prediction tensor
                pred = outputs['output'] if isinstance(outputs, dict) else outputs
                
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                # Compute metrics
                dice_scores = compute_dice(pred, masks, config['num_classes'])
                iou_scores = compute_iou(pred, masks, config['num_classes'])
                
                all_dice.append(np.mean(dice_scores[1:]))  # Exclude background
                all_iou.append(np.mean(iou_scores[1:]))
        
        avg_val_loss = val_loss / len(val_loader)
        avg_val_dice = np.mean(all_dice)
        avg_val_iou = np.mean(all_iou)
        
        history['val_loss'].append(avg_val_loss)
        history['val_dice'].append(avg_val_dice)
        history['val_iou'].append(avg_val_iou)
        history['learning_rate'].append(optimizer.param_groups[0]['lr'])
        
        # Update learning rate
        scheduler.step()
        
        # Print epoch summary
        print(f"\nüìä Epoch {epoch+1}/{config['num_epochs']}")
        print(f"   Train Loss: {avg_train_loss:.4f}")
        print(f"   Val Loss:   {avg_val_loss:.4f}")
        print(f"   Val Dice:   {avg_val_dice:.4f}")
        print(f"   Val IoU:    {avg_val_iou:.4f}")
        print(f"   LR:         {optimizer.param_groups[0]['lr']:.6f}")
        
        # =============== Checkpointing ===============
        # Save best model
        if avg_val_dice > best_val_dice:
            best_val_dice = avg_val_dice
            patience_counter = 0
            
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_dice': best_val_dice,
                'config': config,
                'history': history
            }
            torch.save(checkpoint, os.path.join(config['checkpoint_dir'], 'best_model.pth'))
            print(f"   ‚úÖ New best model saved! (Dice: {best_val_dice:.4f})")
        else:
            patience_counter += 1
        
        # Save periodic checkpoints
        if (epoch + 1) % config['save_every'] == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history
            }
            torch.save(checkpoint, os.path.join(config['checkpoint_dir'], f'checkpoint_epoch_{epoch+1}.pth'))
        
        # Early stopping
        if patience_counter >= config['patience']:
            print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch+1} epochs")
            break
    
    print("\n" + "="*60)
    print(f"üéâ Training completed!")
    print(f"   Best Val Dice: {best_val_dice:.4f}")
    print("="*60)
    
    return history


print("‚úÖ Training function defined!")

In [None]:
# üöÄ START TRAINING
history = train_model(model, train_loader, val_loader, config, device)

## üìä Training Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Dice score
axes[0, 1].plot(history['val_dice'], label='Val Dice', linewidth=2, color='green')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Dice Score')
axes[0, 1].set_title('Validation Dice Score')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].axhline(y=max(history['val_dice']), color='r', linestyle='--', alpha=0.5, label=f"Best: {max(history['val_dice']):.4f}")

# IoU score
axes[1, 0].plot(history['val_iou'], label='Val IoU', linewidth=2, color='orange')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('IoU Score')
axes[1, 0].set_title('Validation IoU Score')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].plot(history['learning_rate'], label='Learning Rate', linewidth=2, color='purple')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')

plt.suptitle('Training Progress', fontsize=14)
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüìà Training Summary:")
print(f"   Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"   Final Val Loss:   {history['val_loss'][-1]:.4f}")
print(f"   Best Val Dice:    {max(history['val_dice']):.4f}")
print(f"   Best Val IoU:     {max(history['val_iou']):.4f}")

## üîç Inference & Visualization

In [None]:
# Load best model
checkpoint = torch.load(os.path.join(config['checkpoint_dir'], 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
print(f"‚úÖ Loaded best model from epoch {checkpoint['epoch']} (Dice: {checkpoint['best_val_dice']:.4f})")

In [None]:
# Visualize predictions on validation set
model.eval()

fig, axes = plt.subplots(4, 4, figsize=(16, 16))

with torch.no_grad():
    for i in range(4):
        idx = i * (len(val_dataset) // 4)
        img, mask = val_dataset[idx]
        img = img.unsqueeze(0).to(device)
        
        # Get prediction
        outputs = model(img)
        pred = outputs['output'] if isinstance(outputs, dict) else outputs
        pred_mask = torch.argmax(pred, dim=1)[0].cpu()
        
        # Get energy map if available
        energy = outputs.get('energy', None)
        
        # Plot
        axes[i, 0].imshow(img[0, 0].cpu(), cmap='gray')
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask, cmap='viridis')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_mask, cmap='viridis')
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
        
        # Overlay
        overlay = img[0, 0].cpu().numpy()
        overlay = np.stack([overlay, overlay, overlay], axis=-1)
        pred_np = pred_mask.numpy()
        mask_np = mask.numpy()
        
        # Red for prediction, blue for ground truth
        overlay[..., 0] = np.where(pred_np > 0, 1.0, overlay[..., 0])
        overlay[..., 2] = np.where(mask_np > 0, 1.0, overlay[..., 2])
        overlay = np.clip(overlay, 0, 1)
        
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Overlay (Red=Pred, Blue=GT)')
        axes[i, 3].axis('off')

plt.suptitle('Validation Predictions', fontsize=14)
plt.tight_layout()
plt.savefig('predictions.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Visualize EGM-Net branches (Coarse vs Fine)
if config['model'] in ['egm_net', 'egm_net_lite']:
    fig, axes = plt.subplots(3, 5, figsize=(20, 12))
    
    with torch.no_grad():
        for i in range(3):
            idx = i * (len(val_dataset) // 3)
            img, mask = val_dataset[idx]
            img = img.unsqueeze(0).to(device)
            
            outputs = model(img)
            
            # Extract all outputs
            final_pred = torch.argmax(outputs['output'], dim=1)[0].cpu()
            coarse_pred = torch.argmax(outputs['coarse'], dim=1)[0].cpu()
            fine_pred = torch.argmax(outputs['fine'], dim=1)[0].cpu()
            energy = outputs['energy'][0, 0].cpu()
            
            # Plot
            axes[i, 0].imshow(img[0, 0].cpu(), cmap='gray')
            axes[i, 0].set_title('Input')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(energy, cmap='hot')
            axes[i, 1].set_title('Energy Map')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(coarse_pred, cmap='viridis')
            axes[i, 2].set_title('Coarse Branch')
            axes[i, 2].axis('off')
            
            axes[i, 3].imshow(fine_pred, cmap='viridis')
            axes[i, 3].set_title('Fine Branch')
            axes[i, 3].axis('off')
            
            axes[i, 4].imshow(final_pred, cmap='viridis')
            axes[i, 4].set_title('Final (Fused)')
            axes[i, 4].axis('off')
    
    plt.suptitle('EGM-Net Branch Analysis\\n(Energy-Gated Fusion of Coarse + Fine)', fontsize=14)
    plt.tight_layout()
    plt.savefig('branch_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()

## üéØ Resolution-Free Inference Demo

In [None]:
# Demonstrate resolution-free inference (unique to EGM-Net)
if config['model'] in ['egm_net', 'egm_net_lite']:
    print("üî¨ Resolution-Free Inference Demo")
    print("   EGM-Net can render at ANY resolution without retraining!")
    
    resolutions = [64, 128, 256, 512]
    
    # Get a sample image
    img, mask = val_dataset[0]
    img = img.unsqueeze(0).to(device)
    
    fig, axes = plt.subplots(2, len(resolutions), figsize=(16, 8))
    
    with torch.no_grad():
        for i, res in enumerate(resolutions):
            # Render at different resolutions
            outputs = model(img, output_size=(res, res))
            pred = torch.argmax(outputs['output'], dim=1)[0].cpu()
            
            # Also show input at same res for comparison
            input_resized = F.interpolate(img, size=(res, res), mode='bilinear', align_corners=False)
            
            axes[0, i].imshow(input_resized[0, 0].cpu(), cmap='gray')
            axes[0, i].set_title(f'Input {res}√ó{res}')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(pred, cmap='viridis')
            axes[1, i].set_title(f'Prediction {res}√ó{res}')
            axes[1, i].axis('off')
    
    plt.suptitle('Resolution-Free Rendering\\n(Same model weights, different output resolutions)', fontsize=14)
    plt.tight_layout()
    plt.savefig('resolution_free.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\\n‚úÖ Same model can output at 64√ó64 to 512√ó512 (or higher)!")
    print("   This is impossible with standard CNN decoders.")

## üíæ Save & Export Model

In [None]:
# Save final model to Google Drive
from google.colab import drive

try:
    drive.mount('/content/drive')
    
    # Save to Drive
    save_path = '/content/drive/MyDrive/EGM_Net_Models'
    os.makedirs(save_path, exist_ok=True)
    
    # Save checkpoint
    final_checkpoint = {
        'model_state_dict': model.state_dict(),
        'config': config,
        'history': history,
        'best_val_dice': max(history['val_dice'])
    }
    torch.save(final_checkpoint, os.path.join(save_path, 'egm_net_trained.pth'))
    
    # Also copy training curves
    import shutil
    shutil.copy('training_curves.png', save_path)
    shutil.copy('predictions.png', save_path)
    
    print(f"‚úÖ Model saved to Google Drive: {save_path}")
    print(f"   - egm_net_trained.pth")
    print(f"   - training_curves.png")
    print(f"   - predictions.png")
    
except Exception as e:
    print(f"‚ö†Ô∏è Could not save to Google Drive: {e}")
    print("   Model is saved locally in ./checkpoints/")

## üìä Final Evaluation Metrics

In [None]:
# Final evaluation on full validation set
model.eval()

all_dice_scores = []
all_iou_scores = []
all_hd95_scores = []  # Hausdorff Distance 95

# Helper function for Hausdorff distance
def compute_hausdorff_95(pred, target):
    """Compute 95th percentile Hausdorff distance."""
    from scipy.ndimage import distance_transform_edt
    
    pred_np = pred.numpy().astype(bool)
    target_np = target.numpy().astype(bool)
    
    if pred_np.sum() == 0 or target_np.sum() == 0:
        return 0.0
    
    # Distance transforms
    pred_dist = distance_transform_edt(~pred_np)
    target_dist = distance_transform_edt(~target_np)
    
    # Get surface points
    pred_surface = pred_np & (distance_transform_edt(pred_np) <= 1)
    target_surface = target_np & (distance_transform_edt(target_np) <= 1)
    
    # Distances from pred surface to target, and vice versa
    d_pred_to_target = target_dist[pred_surface]
    d_target_to_pred = pred_dist[target_surface]
    
    if len(d_pred_to_target) == 0 or len(d_target_to_pred) == 0:
        return 0.0
    
    # 95th percentile
    hd95 = max(np.percentile(d_pred_to_target, 95), np.percentile(d_target_to_pred, 95))
    return hd95

print("üîç Running final evaluation...")

with torch.no_grad():
    for images, masks in tqdm(val_loader, desc="Evaluating"):
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model(images)
        pred = outputs['output'] if isinstance(outputs, dict) else outputs
        pred_masks = torch.argmax(pred, dim=1)
        
        for b in range(images.shape[0]):
            # Per-sample metrics
            dice = compute_dice(pred[b:b+1], masks[b:b+1], config['num_classes'])
            iou = compute_iou(pred[b:b+1], masks[b:b+1], config['num_classes'])
            
            all_dice_scores.append(np.mean(dice[1:]))  # Exclude background
            all_iou_scores.append(np.mean(iou[1:]))
            
            # Hausdorff distance (for foreground)
            try:
                hd95 = compute_hausdorff_95(
                    (pred_masks[b] > 0).cpu(),
                    (masks[b] > 0).cpu()
                )
                all_hd95_scores.append(hd95)
            except:
                pass

# Print results
print("\n" + "="*60)
print("üìä FINAL EVALUATION RESULTS")
print("="*60)
print(f"\n{'Metric':<20} {'Mean':<12} {'Std':<12}")
print("-"*44)
print(f"{'Dice Score':<20} {np.mean(all_dice_scores):.4f}       {np.std(all_dice_scores):.4f}")
print(f"{'IoU Score':<20} {np.mean(all_iou_scores):.4f}       {np.std(all_iou_scores):.4f}")
if all_hd95_scores:
    print(f"{'HD95 (mm)':<20} {np.mean(all_hd95_scores):.2f}         {np.std(all_hd95_scores):.2f}")
print("-"*44)
print(f"\nTotal validation samples: {len(all_dice_scores)}")
print("="*60)

## 3. Test Monogenic Signal Processing

In [None]:
# Create a test image with edges
def create_test_image(size=256):
    """Create synthetic medical-like image with organs."""
    img = torch.zeros(1, 1, size, size)
    
    # Add circular "organ"
    y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
    center1 = (size // 2, size // 2)
    radius1 = size // 4
    mask1 = ((x - center1[0])**2 + (y - center1[1])**2) < radius1**2
    img[0, 0, mask1] = 0.7
    
    # Add smaller "tumor"
    center2 = (size // 2 + 30, size // 2 - 20)
    radius2 = size // 10
    mask2 = ((x - center2[0])**2 + (y - center2[1])**2) < radius2**2
    img[0, 0, mask2] = 1.0
    
    # Add noise
    img = img + 0.05 * torch.randn_like(img)
    
    return img, mask1.float(), mask2.float()

# Create test image
test_img, organ_mask, tumor_mask = create_test_image(256)
print(f"Test image shape: {test_img.shape}")

In [None]:
# Test Monogenic Energy Extraction
energy_extractor = EnergyMap(normalize=True, smoothing_sigma=1.0)
energy, mono_out = energy_extractor(test_img)

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

axes[0, 0].imshow(test_img[0, 0], cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

axes[0, 1].imshow(energy[0, 0].detach(), cmap='hot')
axes[0, 1].set_title('Energy Map (Edges)')
axes[0, 1].axis('off')

axes[0, 2].imshow(mono_out['phase'][0, 0].detach(), cmap='twilight')
axes[0, 2].set_title('Phase')
axes[0, 2].axis('off')

axes[1, 0].imshow(mono_out['orientation'][0, 0].detach(), cmap='hsv')
axes[1, 0].set_title('Orientation')
axes[1, 0].axis('off')

axes[1, 1].imshow(mono_out['riesz_x'][0, 0].detach(), cmap='RdBu')
axes[1, 1].set_title('Riesz X Component')
axes[1, 1].axis('off')

axes[1, 2].imshow(mono_out['riesz_y'][0, 0].detach(), cmap='RdBu')
axes[1, 2].set_title('Riesz Y Component')
axes[1, 2].axis('off')

plt.suptitle('Monogenic Signal Decomposition', fontsize=14)
plt.tight_layout()
plt.show()

print("\n‚úÖ Monogenic processing works correctly!")

## 4. Test Gabor Basis vs Fourier Features

In [None]:
from gabor_implicit import GaborBasis, FourierFeatures

# Create coordinate grid
size = 128
y = torch.linspace(-1, 1, size)
x = torch.linspace(-1, 1, size)
yy, xx = torch.meshgrid(y, x, indexing='ij')
coords = torch.stack([xx, yy], dim=-1).view(1, -1, 2)  # (1, size*size, 2)

# Compare Gabor vs Fourier
gabor = GaborBasis(input_dim=2, num_frequencies=32)
fourier = FourierFeatures(input_dim=2, num_frequencies=32, scale=10.0)

gabor_features = gabor(coords)
fourier_features = fourier(coords)

print(f"Gabor features shape: {gabor_features.shape}")
print(f"Fourier features shape: {fourier_features.shape}")

In [None]:
# Visualize first few basis functions
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    # Gabor
    gabor_vis = gabor_features[0, :, i].view(size, size).detach().numpy()
    axes[0, i].imshow(gabor_vis, cmap='RdBu', vmin=-1, vmax=1)
    axes[0, i].set_title(f'Gabor Basis {i+1}')
    axes[0, i].axis('off')
    
    # Fourier
    fourier_vis = fourier_features[0, :, i].view(size, size).detach().numpy()
    axes[1, i].imshow(fourier_vis, cmap='RdBu', vmin=-1, vmax=1)
    axes[1, i].set_title(f'Fourier Basis {i+1}')
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Gabor\n(Localized)', fontsize=12)
axes[1, 0].set_ylabel('Fourier\n(Global)', fontsize=12)

plt.suptitle('Gabor vs Fourier Basis Functions\n(Gabor is localized ‚Üí No Gibbs ringing)', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Create and Analyze Models

In [None]:
# Create EGM-Net models
print("Creating models...")

# Full model
egm_net = EGMNet(
    in_channels=1,
    num_classes=3,
    img_size=256,
    base_channels=64,
    num_stages=4,
    encoder_depth=2
).to(device)

# Lite model
egm_lite = EGMNetLite(
    in_channels=1,
    num_classes=3,
    img_size=256
).to(device)

# Spectral Mamba (comparison)
spec_mamba = SpectralVMUNet(
    in_channels=1,
    out_channels=3,
    img_size=256,
    base_channels=64,
    num_stages=4
).to(device)

print("\nüìä Model Comparison:")
print("-" * 50)
models = {
    'EGM-Net Full': egm_net,
    'EGM-Net Lite': egm_lite,
    'SpectralVMUNet': spec_mamba
}

for name, model in models.items():
    params = sum(p.numel() for p in model.parameters())
    print(f"{name:20s}: {params:,} parameters ({params/1e6:.2f}M)")

## 6. Test Forward Pass

In [None]:
# Test forward pass
test_input = torch.randn(2, 1, 256, 256).to(device)

print("Testing forward pass...")
print(f"Input shape: {test_input.shape}")

with torch.no_grad():
    # EGM-Net
    egm_out = egm_net(test_input)
    print(f"\nüîπ EGM-Net Output:")
    for k, v in egm_out.items():
        print(f"   {k}: {v.shape}")
    
    # SpectralVMUNet
    spec_out = spec_mamba(test_input)
    print(f"\nüîπ SpectralVMUNet Output: {spec_out.shape}")

print("\n‚úÖ Forward pass successful!")

## 7. Test Resolution-Free Inference (Unique to EGM-Net)

In [None]:
# EGM-Net can query at arbitrary coordinates!
print("Testing Resolution-Free Inference...")

# Create query points (random locations)
num_points = 10000
random_coords = torch.rand(1, num_points, 2).to(device) * 2 - 1  # [-1, 1]

with torch.no_grad():
    # Query at random points
    point_output = egm_net.query_points(test_input[:1], random_coords)
    
print(f"Query coordinates: {random_coords.shape}")
print(f"Point outputs: {point_output.shape}")
print("\n‚úÖ Resolution-free inference works!")
print("   ‚Üí You can zoom into boundaries at ANY resolution!")

In [None]:
# Demonstrate resolution-free: render at different resolutions
resolutions = [64, 128, 256, 512]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

with torch.no_grad():
    for idx, res in enumerate(resolutions):
        # Render at this resolution
        output = egm_net(test_input[:1], output_size=(res, res))
        pred = torch.argmax(output['output'], dim=1)[0].cpu().numpy()
        
        axes[idx].imshow(pred, cmap='viridis')
        axes[idx].set_title(f'{res}√ó{res}')
        axes[idx].axis('off')

plt.suptitle('Resolution-Free Rendering (Same model, different output sizes)', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Visualize Energy-Gated Fusion

In [None]:
# Visualize the dual-branch architecture
with torch.no_grad():
    outputs = egm_net(test_input[:1])

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Input
axes[0, 0].imshow(test_input[0, 0].cpu(), cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

# Energy Map
axes[0, 1].imshow(outputs['energy'][0, 0].cpu(), cmap='hot')
axes[0, 1].set_title('Energy Map (Edge Detection)')
axes[0, 1].axis('off')

# Coarse Branch
coarse_pred = torch.argmax(outputs['coarse'], dim=1)[0].cpu()
axes[0, 2].imshow(coarse_pred, cmap='viridis')
axes[0, 2].set_title('Coarse Branch (Smooth)')
axes[0, 2].axis('off')

# Fine Branch
fine_pred = torch.argmax(outputs['fine'], dim=1)[0].cpu()
axes[1, 0].imshow(fine_pred, cmap='viridis')
axes[1, 0].set_title('Fine Branch (Sharp)')
axes[1, 0].axis('off')

# Final Output
final_pred = torch.argmax(outputs['output'], dim=1)[0].cpu()
axes[1, 1].imshow(final_pred, cmap='viridis')
axes[1, 1].set_title('Final Output (Fused)')
axes[1, 1].axis('off')

# Difference
diff = (fine_pred != coarse_pred).float()
axes[1, 2].imshow(diff, cmap='Reds')
axes[1, 2].set_title('Difference (Fine vs Coarse)')
axes[1, 2].axis('off')

plt.suptitle('EGM-Net Dual-Branch Architecture', fontsize=14)
plt.tight_layout()
plt.show()

## 9. Quick Training Demo

In [None]:
from train_egm import EGMNetTrainer, create_dummy_dataset
from torch.utils.data import DataLoader

# Create small dummy dataset
print("Creating dummy dataset...")
dataset = create_dummy_dataset(num_samples=16, img_size=256, num_classes=3)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Training config
config = {
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'num_epochs': 2,
    'num_points': 1024,
    'boundary_ratio': 0.5,
    'checkpoint_dir': './checkpoints_demo'
}

# Use lite model for faster training
model = EGMNetLite(in_channels=1, num_classes=3, img_size=256)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
# Train for a few epochs
print("\nStarting training demo...")
trainer = EGMNetTrainer(model, config, device=device)
trainer.train(train_loader, num_epochs=2)

print("\n‚úÖ Training demo completed!")

## 10. Inference Speed Benchmark

In [None]:
import time

def benchmark_model(model, input_tensor, num_runs=50, warmup=10):
    """Benchmark inference speed."""
    model.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
    
    if device == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            start = time.time()
            _ = model(input_tensor)
            if device == 'cuda':
                torch.cuda.synchronize()
            times.append(time.time() - start)
    
    return np.mean(times) * 1000, np.std(times) * 1000  # ms

# Benchmark
print("Benchmarking inference speed...")
print("-" * 60)

test_input = torch.randn(1, 1, 256, 256).to(device)

for name, model in [('EGM-Net Full', egm_net), ('EGM-Net Lite', egm_lite)]:
    mean_time, std_time = benchmark_model(model, test_input)
    fps = 1000 / mean_time
    print(f"{name:20s}: {mean_time:.2f} ¬± {std_time:.2f} ms ({fps:.1f} FPS)")

print("\n‚úÖ Benchmark completed!")

## 11. Summary

In [None]:
print("""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                    EGM-NET ARCHITECTURE SUMMARY                       ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                       ‚ïë
‚ïë  üî¨ KEY INNOVATIONS:                                                  ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  1. MONOGENIC ENERGY GATING                                          ‚ïë
‚ïë     ‚Ä¢ Physics-based edge detection (Riesz Transform)                 ‚ïë
‚ïë     ‚Ä¢ Automatically focuses on boundary regions                      ‚ïë
‚ïë     ‚Ä¢ Suppresses artifacts in flat regions                           ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  2. GABOR BASIS (vs Fourier)                                         ‚ïë
‚ïë     ‚Ä¢ Localized oscillations (Gaussian √ó sin)                        ‚ïë
‚ïë     ‚Ä¢ NO Gibbs ringing artifacts                                     ‚ïë
‚ïë     ‚Ä¢ Sharp edges remain clean                                       ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  3. DUAL-PATH ARCHITECTURE                                           ‚ïë
‚ïë     ‚Ä¢ Coarse Branch: Smooth body regions (Conv decoder)              ‚ïë
‚ïë     ‚Ä¢ Fine Branch: Sharp boundaries (Gabor Implicit)                 ‚ïë
‚ïë     ‚Ä¢ Energy-gated fusion: Best of both worlds                       ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  4. RESOLUTION-FREE INFERENCE                                        ‚ïë
‚ïë     ‚Ä¢ Query at ANY coordinate ‚Üí Infinite zoom                        ‚ïë
‚ïë     ‚Ä¢ No retraining needed for different resolutions                 ‚ïë
‚ïë     ‚Ä¢ Perfect for high-resolution medical imaging                    ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  5. MAMBA ENCODER                                                    ‚ïë
‚ïë     ‚Ä¢ O(N) complexity (vs O(N¬≤) for Transformers)                    ‚ïë
‚ïë     ‚Ä¢ Global context awareness                                       ‚ïë
‚ïë     ‚Ä¢ Efficient for large images                                     ‚ïë
‚ïë                                                                       ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                       ‚ïë
‚ïë  üìä MODEL SIZES:                                                      ‚ïë
‚ïë     ‚Ä¢ EGM-Net Full:  ~9.13M parameters                               ‚ïë
‚ïë     ‚Ä¢ EGM-Net Lite:  ~635K parameters                                ‚ïë
‚ïë     ‚Ä¢ SpectralVMUNet: ~10.31M parameters                             ‚ïë
‚ïë                                                                       ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
""")

---

## üìö Next Steps

1. **Train on real data**: Replace dummy dataset with medical imaging dataset (e.g., Synapse, ACDC)
2. **Tune hyperparameters**: Adjust `num_frequencies`, `boundary_ratio`, learning rate
3. **Evaluate metrics**: Dice score, IoU, Hausdorff distance
4. **Ablation study**: Compare Gabor vs Fourier, with/without energy gating

---

**Repository**: https://github.com/QuocKhanhLuong/FourierNetwork