## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install -q torch torchvision numpy matplotlib tqdm

In [None]:
# Clone repository
!git clone https://github.com/QuocKhanhLuong/FourierNetwork.git
%cd FourierNetwork

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"üñ•Ô∏è Using device: {device}")
if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Import Models

In [None]:
# Import our models
from monogenic import EnergyMap, MonogenicSignal, BoundaryDetector
from gabor_implicit import GaborBasis, GaborNet, ImplicitSegmentationHead
from egm_net import EGMNet, EGMNetLite
from spectral_mamba import SpectralVMUNet

print("‚úÖ All modules imported successfully!")

## 3. Test Monogenic Signal Processing

In [None]:
# Create a test image with edges
def create_test_image(size=256):
    """Create synthetic medical-like image with organs."""
    img = torch.zeros(1, 1, size, size)
    
    # Add circular "organ"
    y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
    center1 = (size // 2, size // 2)
    radius1 = size // 4
    mask1 = ((x - center1[0])**2 + (y - center1[1])**2) < radius1**2
    img[0, 0, mask1] = 0.7
    
    # Add smaller "tumor"
    center2 = (size // 2 + 30, size // 2 - 20)
    radius2 = size // 10
    mask2 = ((x - center2[0])**2 + (y - center2[1])**2) < radius2**2
    img[0, 0, mask2] = 1.0
    
    # Add noise
    img = img + 0.05 * torch.randn_like(img)
    
    return img, mask1.float(), mask2.float()

# Create test image
test_img, organ_mask, tumor_mask = create_test_image(256)
print(f"Test image shape: {test_img.shape}")

In [None]:
# Test Monogenic Energy Extraction
energy_extractor = EnergyMap(normalize=True, smoothing_sigma=1.0)
energy, mono_out = energy_extractor(test_img)

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

axes[0, 0].imshow(test_img[0, 0], cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

axes[0, 1].imshow(energy[0, 0].detach(), cmap='hot')
axes[0, 1].set_title('Energy Map (Edges)')
axes[0, 1].axis('off')

axes[0, 2].imshow(mono_out['phase'][0, 0].detach(), cmap='twilight')
axes[0, 2].set_title('Phase')
axes[0, 2].axis('off')

axes[1, 0].imshow(mono_out['orientation'][0, 0].detach(), cmap='hsv')
axes[1, 0].set_title('Orientation')
axes[1, 0].axis('off')

axes[1, 1].imshow(mono_out['riesz_x'][0, 0].detach(), cmap='RdBu')
axes[1, 1].set_title('Riesz X Component')
axes[1, 1].axis('off')

axes[1, 2].imshow(mono_out['riesz_y'][0, 0].detach(), cmap='RdBu')
axes[1, 2].set_title('Riesz Y Component')
axes[1, 2].axis('off')

plt.suptitle('Monogenic Signal Decomposition', fontsize=14)
plt.tight_layout()
plt.show()

print("\n‚úÖ Monogenic processing works correctly!")

## 4. Test Gabor Basis vs Fourier Features

In [None]:
from gabor_implicit import GaborBasis, FourierFeatures

# Create coordinate grid
size = 128
y = torch.linspace(-1, 1, size)
x = torch.linspace(-1, 1, size)
yy, xx = torch.meshgrid(y, x, indexing='ij')
coords = torch.stack([xx, yy], dim=-1).view(1, -1, 2)  # (1, size*size, 2)

# Compare Gabor vs Fourier
gabor = GaborBasis(input_dim=2, num_frequencies=32)
fourier = FourierFeatures(input_dim=2, num_frequencies=32, scale=10.0)

gabor_features = gabor(coords)
fourier_features = fourier(coords)

print(f"Gabor features shape: {gabor_features.shape}")
print(f"Fourier features shape: {fourier_features.shape}")

In [None]:
# Visualize first few basis functions
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    # Gabor
    gabor_vis = gabor_features[0, :, i].view(size, size).detach().numpy()
    axes[0, i].imshow(gabor_vis, cmap='RdBu', vmin=-1, vmax=1)
    axes[0, i].set_title(f'Gabor Basis {i+1}')
    axes[0, i].axis('off')
    
    # Fourier
    fourier_vis = fourier_features[0, :, i].view(size, size).detach().numpy()
    axes[1, i].imshow(fourier_vis, cmap='RdBu', vmin=-1, vmax=1)
    axes[1, i].set_title(f'Fourier Basis {i+1}')
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Gabor\n(Localized)', fontsize=12)
axes[1, 0].set_ylabel('Fourier\n(Global)', fontsize=12)

plt.suptitle('Gabor vs Fourier Basis Functions\n(Gabor is localized ‚Üí No Gibbs ringing)', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Create and Analyze Models

In [None]:
# Create EGM-Net models
print("Creating models...")

# Full model
egm_net = EGMNet(
    in_channels=1,
    num_classes=3,
    img_size=256,
    base_channels=64,
    num_stages=4,
    encoder_depth=2
).to(device)

# Lite model
egm_lite = EGMNetLite(
    in_channels=1,
    num_classes=3,
    img_size=256
).to(device)

# Spectral Mamba (comparison)
spec_mamba = SpectralVMUNet(
    in_channels=1,
    out_channels=3,
    img_size=256,
    base_channels=64,
    num_stages=4
).to(device)

print("\nüìä Model Comparison:")
print("-" * 50)
models = {
    'EGM-Net Full': egm_net,
    'EGM-Net Lite': egm_lite,
    'SpectralVMUNet': spec_mamba
}

for name, model in models.items():
    params = sum(p.numel() for p in model.parameters())
    print(f"{name:20s}: {params:,} parameters ({params/1e6:.2f}M)")

## 6. Test Forward Pass

In [None]:
# Test forward pass
test_input = torch.randn(2, 1, 256, 256).to(device)

print("Testing forward pass...")
print(f"Input shape: {test_input.shape}")

with torch.no_grad():
    # EGM-Net
    egm_out = egm_net(test_input)
    print(f"\nüîπ EGM-Net Output:")
    for k, v in egm_out.items():
        print(f"   {k}: {v.shape}")
    
    # SpectralVMUNet
    spec_out = spec_mamba(test_input)
    print(f"\nüîπ SpectralVMUNet Output: {spec_out.shape}")

print("\n‚úÖ Forward pass successful!")

## 7. Test Resolution-Free Inference (Unique to EGM-Net)

In [None]:
# EGM-Net can query at arbitrary coordinates!
print("Testing Resolution-Free Inference...")

# Create query points (random locations)
num_points = 10000
random_coords = torch.rand(1, num_points, 2).to(device) * 2 - 1  # [-1, 1]

with torch.no_grad():
    # Query at random points
    point_output = egm_net.query_points(test_input[:1], random_coords)
    
print(f"Query coordinates: {random_coords.shape}")
print(f"Point outputs: {point_output.shape}")
print("\n‚úÖ Resolution-free inference works!")
print("   ‚Üí You can zoom into boundaries at ANY resolution!")

In [None]:
# Demonstrate resolution-free: render at different resolutions
resolutions = [64, 128, 256, 512]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

with torch.no_grad():
    for idx, res in enumerate(resolutions):
        # Render at this resolution
        output = egm_net(test_input[:1], output_size=(res, res))
        pred = torch.argmax(output['output'], dim=1)[0].cpu().numpy()
        
        axes[idx].imshow(pred, cmap='viridis')
        axes[idx].set_title(f'{res}√ó{res}')
        axes[idx].axis('off')

plt.suptitle('Resolution-Free Rendering (Same model, different output sizes)', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Visualize Energy-Gated Fusion

In [None]:
# Visualize the dual-branch architecture
with torch.no_grad():
    outputs = egm_net(test_input[:1])

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Input
axes[0, 0].imshow(test_input[0, 0].cpu(), cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

# Energy Map
axes[0, 1].imshow(outputs['energy'][0, 0].cpu(), cmap='hot')
axes[0, 1].set_title('Energy Map (Edge Detection)')
axes[0, 1].axis('off')

# Coarse Branch
coarse_pred = torch.argmax(outputs['coarse'], dim=1)[0].cpu()
axes[0, 2].imshow(coarse_pred, cmap='viridis')
axes[0, 2].set_title('Coarse Branch (Smooth)')
axes[0, 2].axis('off')

# Fine Branch
fine_pred = torch.argmax(outputs['fine'], dim=1)[0].cpu()
axes[1, 0].imshow(fine_pred, cmap='viridis')
axes[1, 0].set_title('Fine Branch (Sharp)')
axes[1, 0].axis('off')

# Final Output
final_pred = torch.argmax(outputs['output'], dim=1)[0].cpu()
axes[1, 1].imshow(final_pred, cmap='viridis')
axes[1, 1].set_title('Final Output (Fused)')
axes[1, 1].axis('off')

# Difference
diff = (fine_pred != coarse_pred).float()
axes[1, 2].imshow(diff, cmap='Reds')
axes[1, 2].set_title('Difference (Fine vs Coarse)')
axes[1, 2].axis('off')

plt.suptitle('EGM-Net Dual-Branch Architecture', fontsize=14)
plt.tight_layout()
plt.show()

## 9. Quick Training Demo

In [None]:
from train_egm import EGMNetTrainer, create_dummy_dataset
from torch.utils.data import DataLoader

# Create small dummy dataset
print("Creating dummy dataset...")
dataset = create_dummy_dataset(num_samples=16, img_size=256, num_classes=3)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Training config
config = {
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'num_epochs': 2,
    'num_points': 1024,
    'boundary_ratio': 0.5,
    'checkpoint_dir': './checkpoints_demo'
}

# Use lite model for faster training
model = EGMNetLite(in_channels=1, num_classes=3, img_size=256)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
# Train for a few epochs
print("\nStarting training demo...")
trainer = EGMNetTrainer(model, config, device=device)
trainer.train(train_loader, num_epochs=2)

print("\n‚úÖ Training demo completed!")

## 10. Inference Speed Benchmark

In [None]:
import time

def benchmark_model(model, input_tensor, num_runs=50, warmup=10):
    """Benchmark inference speed."""
    model.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
    
    if device == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            start = time.time()
            _ = model(input_tensor)
            if device == 'cuda':
                torch.cuda.synchronize()
            times.append(time.time() - start)
    
    return np.mean(times) * 1000, np.std(times) * 1000  # ms

# Benchmark
print("Benchmarking inference speed...")
print("-" * 60)

test_input = torch.randn(1, 1, 256, 256).to(device)

for name, model in [('EGM-Net Full', egm_net), ('EGM-Net Lite', egm_lite)]:
    mean_time, std_time = benchmark_model(model, test_input)
    fps = 1000 / mean_time
    print(f"{name:20s}: {mean_time:.2f} ¬± {std_time:.2f} ms ({fps:.1f} FPS)")

print("\n‚úÖ Benchmark completed!")

## 11. Summary

In [None]:
print("""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                    EGM-NET ARCHITECTURE SUMMARY                       ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                       ‚ïë
‚ïë  üî¨ KEY INNOVATIONS:                                                  ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  1. MONOGENIC ENERGY GATING                                          ‚ïë
‚ïë     ‚Ä¢ Physics-based edge detection (Riesz Transform)                 ‚ïë
‚ïë     ‚Ä¢ Automatically focuses on boundary regions                      ‚ïë
‚ïë     ‚Ä¢ Suppresses artifacts in flat regions                           ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  2. GABOR BASIS (vs Fourier)                                         ‚ïë
‚ïë     ‚Ä¢ Localized oscillations (Gaussian √ó sin)                        ‚ïë
‚ïë     ‚Ä¢ NO Gibbs ringing artifacts                                     ‚ïë
‚ïë     ‚Ä¢ Sharp edges remain clean                                       ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  3. DUAL-PATH ARCHITECTURE                                           ‚ïë
‚ïë     ‚Ä¢ Coarse Branch: Smooth body regions (Conv decoder)              ‚ïë
‚ïë     ‚Ä¢ Fine Branch: Sharp boundaries (Gabor Implicit)                 ‚ïë
‚ïë     ‚Ä¢ Energy-gated fusion: Best of both worlds                       ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  4. RESOLUTION-FREE INFERENCE                                        ‚ïë
‚ïë     ‚Ä¢ Query at ANY coordinate ‚Üí Infinite zoom                        ‚ïë
‚ïë     ‚Ä¢ No retraining needed for different resolutions                 ‚ïë
‚ïë     ‚Ä¢ Perfect for high-resolution medical imaging                    ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  5. MAMBA ENCODER                                                    ‚ïë
‚ïë     ‚Ä¢ O(N) complexity (vs O(N¬≤) for Transformers)                    ‚ïë
‚ïë     ‚Ä¢ Global context awareness                                       ‚ïë
‚ïë     ‚Ä¢ Efficient for large images                                     ‚ïë
‚ïë                                                                       ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                       ‚ïë
‚ïë  üìä MODEL SIZES:                                                      ‚ïë
‚ïë     ‚Ä¢ EGM-Net Full:  ~9.13M parameters                               ‚ïë
‚ïë     ‚Ä¢ EGM-Net Lite:  ~635K parameters                                ‚ïë
‚ïë     ‚Ä¢ SpectralVMUNet: ~10.31M parameters                             ‚ïë
‚ïë                                                                       ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
""")

---

## üìö Next Steps

1. **Train on real data**: Replace dummy dataset with medical imaging dataset (e.g., Synapse, ACDC)
2. **Tune hyperparameters**: Adjust `num_frequencies`, `boundary_ratio`, learning rate
3. **Evaluate metrics**: Dice score, IoU, Hausdorff distance
4. **Ablation study**: Compare Gabor vs Fourier, with/without energy gating

---

**Repository**: https://github.com/QuocKhanhLuong/FourierNetwork