## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install -q torch torchvision numpy matplotlib tqdm

In [None]:
# Clone repository
!git clone https://github.com/QuocKhanhLuong/FourierNetwork.git
%cd FourierNetwork

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"üñ•Ô∏è Using device: {device}")
if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Import Models

In [None]:
# Import our models
from monogenic import EnergyMap, MonogenicSignal, BoundaryDetector
from gabor_implicit import GaborBasis, GaborNet, ImplicitSegmentationHead
from egm_net import EGMNet, EGMNetLite
from spectral_mamba import SpectralVMUNet

print("‚úÖ All modules imported successfully!")

## üè• REAL DATA TRAINING

### ACDC Dataset (Automated Cardiac Diagnosis Challenge)
- **Modality**: Cine-MRI (Cardiac)
- **150 patients**: 100 training + 50 testing
- **Classes**: 4 (Background, RV, Myocardium, LV)
- **Frames**: ED (End Diastole) + ES (End Systole) per patient
- **Total slices**: ~1900 (varies 6-18 slices per volume)

### BraTS21 Dataset (Brain Tumor Segmentation)
- **Modality**: Multi-modal MRI (T1, T1ce, T2, FLAIR)
- **Classes**: 4 (Background, Necrotic/Non-enhancing, Edema, Enhancing)

### M&Ms Dataset (Multi-Centre, Multi-Vendor)
- **Modality**: Cardiac MRI from 4 different vendors
- **Classes**: 4 (Background, RV, Myocardium, LV)

In [None]:
# Step 1: Download dataset from Google Drive
import os

# Install dependencies
!pip install -q gdown nibabel scikit-image

# ============================================================
# üîß SELECT YOUR DATASET HERE
# ============================================================
DATASET = 'ACDC'  # Options: 'ACDC', 'BraTS21', 'MnM'

# ============================================================
# Google Drive folder IDs
# ============================================================
# N·∫øu b·∫°n c√≥ 2 links ri√™ng cho training v√† testing, ƒëi·ªÅn v√†o ƒë√¢y:
DRIVE_FOLDERS = {
    'ACDC': {
        'all': '1EelzBVjIoDQ4uzt0_2JzmF_PuUHsD93e',  # Folder ch·ª©a c·∫£ training v√† testing
        # N·∫øu training/testing t√°ch ri√™ng, uncomment d∆∞·ªõi ƒë√¢y:
        # 'training': 'YOUR_TRAINING_FOLDER_ID_HERE',
        # 'testing': 'YOUR_TESTING_FOLDER_ID_HERE'
    },
    'BraTS21': {
        'all': '1m7b5u_6cEj9PbqzZQDNoDonnXIowBCY2'
    },
    'MnM': {
        'all': '1DpW8ucYE17Tj8iMlsAev_yM6TaZZmWYj'
    }
}

# Dataset configurations
DATASET_CONFIG = {
    'ACDC': {
        'num_classes': 4,
        'class_names': ['Background', 'RV', 'Myocardium', 'LV'],
        'in_channels': 1,
        'img_size': 224,
        # ACDC: 100 training patients (patient001-100), 50 testing patients (patient101-150)
        # Each patient has ED + ES frames, each frame has 6-18 slices (avg ~9)
        # Total expected: ~1900-2000 slices
    },
    'BraTS21': {
        'num_classes': 4,
        'class_names': ['Background', 'NCR/NET', 'Edema', 'Enhancing'],
        'in_channels': 4,
        'img_size': 224,
    },
    'MnM': {
        'num_classes': 4,
        'class_names': ['Background', 'RV', 'Myocardium', 'LV'],
        'in_channels': 1,
        'img_size': 224,
    }
}

# Paths
RAW_DATA_DIR = f'./data/{DATASET}'
PREPROCESSED_DIR = f'./preprocessed_data/{DATASET}'

os.makedirs(RAW_DATA_DIR, exist_ok=True)

# Download from Google Drive
folder_config = DRIVE_FOLDERS[DATASET]
print(f"üì• Step 1: Downloading {DATASET} from Google Drive...")

# Check if training and testing are separate folders
if 'training' in folder_config and 'testing' in folder_config:
    # Separate links for training and testing
    print("   Mode: Separate training and testing folders")
    
    training_id = folder_config['training']
    testing_id = folder_config['testing']
    
    print(f"\n   üì¶ Downloading TRAINING data...")
    !gdown --folder "https://drive.google.com/drive/folders/{training_id}" -O {RAW_DATA_DIR}/training --remaining-ok
    
    print(f"\n   üì¶ Downloading TESTING data...")
    !gdown --folder "https://drive.google.com/drive/folders/{testing_id}" -O {RAW_DATA_DIR}/testing --remaining-ok
else:
    # Single folder containing all data
    folder_id = folder_config['all']
    print(f"   Mode: Single folder (should contain training + testing subfolders)")
    print(f"   Folder ID: {folder_id}")
    
    !gdown --folder "https://drive.google.com/drive/folders/{folder_id}" -O {RAW_DATA_DIR} --remaining-ok

print(f"\n‚úÖ Download complete!")

# Get config
config_data = DATASET_CONFIG[DATASET]
print(f"\nüìä Dataset: {DATASET}")
print(f"   Expected slices: ~1900 for ACDC (100 training + 50 testing patients)")
print(f"   Classes: {config_data['class_names']}")

In [None]:
# Step 2: Check downloaded data structure
import glob

print(f"üìÇ Step 2: Checking downloaded data structure...")
print(f"   Root: {RAW_DATA_DIR}")
print()

# Show folder structure
def show_tree(path, prefix="", max_depth=3, current_depth=0):
    if current_depth >= max_depth:
        return
    
    try:
        entries = sorted(os.listdir(path))
    except:
        return
    
    dirs = [e for e in entries if os.path.isdir(os.path.join(path, e))]
    files = [e for e in entries if os.path.isfile(os.path.join(path, e))]
    
    # Show directories
    for d in dirs[:10]:  # Max 10 folders
        print(f"{prefix}üìÅ {d}/")
        show_tree(os.path.join(path, d), prefix + "   ", max_depth, current_depth + 1)
    if len(dirs) > 10:
        print(f"{prefix}   ... and {len(dirs)-10} more folders")
    
    # Show files (first 5 only)
    for f in files[:5]:
        print(f"{prefix}üìÑ {f}")
    if len(files) > 5:
        print(f"{prefix}   ... and {len(files)-5} more files")

show_tree(RAW_DATA_DIR)

# Count patients
print(f"\n" + "="*60)
print("üìä ACDC Patient Count Analysis")
print("="*60)

# Find all patient folders recursively
all_patients = []
for root, dirs, files in os.walk(RAW_DATA_DIR):
    for d in dirs:
        if d.startswith('patient'):
            all_patients.append((d, root))

# Separate training vs testing
training_patients = [p for p, path in all_patients if 'training' in path.lower()]
testing_patients = [p for p, path in all_patients if 'testing' in path.lower()]
unknown_patients = [p for p, path in all_patients if 'training' not in path.lower() and 'testing' not in path.lower()]

# Also check by patient number
patients_by_number = {}
for p, path in all_patients:
    try:
        num = int(p.replace('patient', ''))
        patients_by_number[num] = path
    except:
        pass

training_by_num = [n for n in patients_by_number if n <= 100]
testing_by_num = [n for n in patients_by_number if n > 100]

print(f"\nüìç By folder structure:")
print(f"   Training folder patients: {len(training_patients)}")
print(f"   Testing folder patients: {len(testing_patients)}")
print(f"   Unknown location: {len(unknown_patients)}")

print(f"\nüìç By patient number (ACDC convention):")
print(f"   Training (001-100): {len(training_by_num)} patients")
print(f"   Testing (101-150): {len(testing_by_num)} patients")
print(f"   Total unique patients: {len(patients_by_number)}")

# Expected vs actual
print(f"\n" + "="*60)
print("‚ö†Ô∏è DIAGNOSTIC")
print("="*60)

if len(patients_by_number) >= 140:
    print(f"‚úÖ Looks good! Found {len(patients_by_number)} patients (expected ~150)")
elif len(patients_by_number) >= 90:
    print(f"‚ö†Ô∏è Partial data: Found {len(patients_by_number)} patients")
    print(f"   Missing: {'training' if len(training_by_num) < 90 else 'testing'} data")
elif len(patients_by_number) < 60:
    print(f"‚ùå Incomplete: Only {len(patients_by_number)} patients found!")
    print(f"   Expected: 150 (100 training + 50 testing)")
    print(f"\nüîß Possible fixes:")
    print(f"   1. Check if your Drive folder has training/ and testing/ subfolders")
    print(f"   2. If they're separate links, update DRIVE_FOLDERS above with both IDs")
    print(f"   3. Make sure sharing is enabled on the Drive folder")

In [None]:
# Step 3: Preprocess ACDC data (NIfTI ‚Üí .npy)
# CH·ªà x·ª≠ l√Ω folder TRAINING (100 patients) - chia 80/20 cho train/val
import configparser
import json
import numpy as np
import nibabel as nib
from skimage.transform import resize
from tqdm import tqdm

def preprocess_single_patient_acdc(patient_path, target_size=(224, 224)):
    """
    Process one ACDC patient.
    
    ACDC structure per patient:
    - Info.cfg: contains ED and ES frame numbers
    - patient001_frame01.nii(.gz): 3D volume at frame 01
    - patient001_frame01_gt.nii(.gz): Ground truth segmentation
    
    Returns:
        List of (volume_3d, mask_3d, volume_id) tuples
    """
    patient_folder = os.path.basename(patient_path)
    info_cfg_path = os.path.join(patient_path, 'Info.cfg')
    
    if not os.path.exists(info_cfg_path):
        # Try to find frames manually
        gt_files = glob.glob(os.path.join(patient_path, '*_gt.nii*'))
        if not gt_files:
            return []
        
        results = []
        for gt_file in gt_files:
            img_file = gt_file.replace('_gt.nii', '.nii')
            if os.path.exists(img_file):
                frame_name = os.path.basename(gt_file).split('_gt')[0].split('_')[-1]
                try:
                    img_data = nib.load(img_file).get_fdata()
                    mask_data = nib.load(gt_file).get_fdata()
                    volume, mask = process_volume(img_data, mask_data, target_size)
                    volume_id = f"{patient_folder}_{frame_name}"
                    results.append((volume, mask, volume_id))
                except Exception as e:
                    print(f"  Error: {e}")
        return results
    
    # Read Info.cfg to get ED and ES frame numbers
    try:
        parser = configparser.ConfigParser()
        with open(info_cfg_path, 'r') as f:
            config_string = '[DEFAULT]\n' + f.read()
        parser.read_string(config_string)
        ed_frame = int(parser['DEFAULT']['ED'])
        es_frame = int(parser['DEFAULT']['ES'])
    except Exception as e:
        print(f"  Error reading Info.cfg for {patient_folder}: {e}")
        return []
    
    results = []
    
    for frame_num, frame_name in [(ed_frame, 'ED'), (es_frame, 'ES')]:
        img_path = None
        mask_path = None
        
        for ext in ['.nii.gz', '.nii']:
            test_img = os.path.join(patient_path, f'{patient_folder}_frame{frame_num:02d}{ext}')
            test_mask = os.path.join(patient_path, f'{patient_folder}_frame{frame_num:02d}_gt{ext}')
            
            if os.path.exists(test_img) and os.path.exists(test_mask):
                img_path = test_img
                mask_path = test_mask
                break
        
        if img_path is None:
            continue
        
        try:
            img_data = nib.load(img_path).get_fdata()
            mask_data = nib.load(mask_path).get_fdata()
            volume, mask = process_volume(img_data, mask_data, target_size)
            volume_id = f"{patient_folder}_{frame_name}"
            results.append((volume, mask, volume_id))
        except Exception as e:
            print(f"  Error processing {patient_folder} {frame_name}: {e}")
            continue
    
    return results


def process_volume(img_data, mask_data, target_size):
    """Process a 3D volume: resize each slice, normalize."""
    num_slices = img_data.shape[2]
    
    resized_img = np.zeros((target_size[0], target_size[1], num_slices), dtype=np.float32)
    resized_mask = np.zeros((target_size[0], target_size[1], num_slices), dtype=np.uint8)
    
    for i in range(num_slices):
        resized_img[:, :, i] = resize(
            img_data[:, :, i], target_size, 
            order=1, preserve_range=True, anti_aliasing=True, mode='reflect'
        )
        resized_mask[:, :, i] = resize(
            mask_data[:, :, i], target_size, 
            order=0, preserve_range=True, anti_aliasing=False, mode='reflect'
        )
    
    max_val = resized_img.max()
    if max_val > 0:
        resized_img = resized_img / max_val
    
    return resized_img, resized_mask


def preprocess_acdc_training_only(input_dir, output_dir, target_size=(224, 224)):
    """
    Preprocess ONLY training folder from ACDC.
    Training: 100 patients ‚Üí split 80/20 for train/val
    Testing folder s·∫Ω d√πng ri√™ng ƒë·ªÉ final evaluation
    """
    os.makedirs(output_dir, exist_ok=True)
    
    volumes_dir = os.path.join(output_dir, 'volumes')
    masks_dir = os.path.join(output_dir, 'masks')
    os.makedirs(volumes_dir, exist_ok=True)
    os.makedirs(masks_dir, exist_ok=True)
    
    # =========================================================================
    # T√åM TRAINING FOLDER
    # =========================================================================
    training_dir = None
    
    # Check common locations
    possible_paths = [
        os.path.join(input_dir, 'training'),
        os.path.join(input_dir, 'Training'),
        os.path.join(input_dir, 'TRAINING'),
        input_dir  # Fallback: maybe patients are directly in root
    ]
    
    for path in possible_paths:
        if os.path.exists(path):
            # Check if this folder contains patient subfolders
            patient_dirs = [d for d in os.listdir(path) if d.startswith('patient') and os.path.isdir(os.path.join(path, d))]
            if len(patient_dirs) > 0:
                training_dir = path
                print(f"‚úÖ Found training data at: {path}")
                print(f"   Contains {len(patient_dirs)} patient folders")
                break
    
    if training_dir is None:
        # Deep search
        for root, dirs, files in os.walk(input_dir):
            if 'training' in root.lower():
                patient_dirs = [d for d in dirs if d.startswith('patient')]
                if len(patient_dirs) > 0:
                    training_dir = root
                    print(f"‚úÖ Found training data at: {root}")
                    print(f"   Contains {len(patient_dirs)} patient folders")
                    break
    
    if training_dir is None:
        print(f"‚ùå Cannot find training folder in {input_dir}")
        print(f"   Looking for folder containing patient* subfolders...")
        return 0, 0, []
    
    # =========================================================================
    # GET ALL TRAINING PATIENTS
    # =========================================================================
    patient_folders = sorted([
        os.path.join(training_dir, d) 
        for d in os.listdir(training_dir) 
        if d.startswith('patient') and os.path.isdir(os.path.join(training_dir, d))
    ])
    
    print(f"\nüìä Training patients found: {len(patient_folders)}")
    if len(patient_folders) > 0:
        print(f"   First: {os.path.basename(patient_folders[0])}")
        print(f"   Last: {os.path.basename(patient_folders[-1])}")
    
    if len(patient_folders) == 0:
        print("‚ùå No patient folders found!")
        return 0, 0, []
    
    # =========================================================================
    # PROCESS ALL PATIENTS
    # =========================================================================
    volume_info = {}
    total_slices = 0
    all_volume_ids = []
    
    for patient_path in tqdm(patient_folders, desc="Preprocessing ACDC Training"):
        patient_results = preprocess_single_patient_acdc(patient_path, target_size)
        
        for volume, mask, volume_id in patient_results:
            volume_save_path = os.path.join(volumes_dir, f'{volume_id}.npy')
            mask_save_path = os.path.join(masks_dir, f'{volume_id}.npy')
            
            np.save(volume_save_path, volume)
            np.save(mask_save_path, mask)
            
            num_slices = mask.shape[2]
            volume_info[volume_id] = {'num_slices': num_slices}
            total_slices += num_slices
            all_volume_ids.append(volume_id)
    
    # =========================================================================
    # SPLIT 80/20 FOR TRAIN/VAL
    # =========================================================================
    np.random.seed(42)  # Reproducible split
    shuffled_ids = np.random.permutation(all_volume_ids).tolist()
    
    split_idx = int(0.8 * len(shuffled_ids))
    train_volumes = shuffled_ids[:split_idx]
    val_volumes = shuffled_ids[split_idx:]
    
    # Save metadata with split info
    metadata = {
        'dataset': 'ACDC',
        'split_source': 'training_folder_only',
        'target_size': list(target_size),
        'total_volumes': len(volume_info),
        'total_slices': total_slices,
        'train_volumes': train_volumes,
        'val_volumes': val_volumes,
        'train_ratio': 0.8,
        'volume_info': volume_info,
        'num_classes': 4,
        'class_names': ['Background', 'RV', 'MYO', 'LV']
    }
    
    with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"\n‚úÖ Preprocessing Complete!")
    print(f"   Total volumes: {len(volume_info)} (from {len(patient_folders)} patients)")
    print(f"   Total slices: {total_slices}")
    print(f"   Train volumes: {len(train_volumes)} (80%)")
    print(f"   Val volumes: {len(val_volumes)} (20%)")
    print(f"   Output: {output_dir}")
    
    return len(volume_info), total_slices, all_volume_ids


# Run preprocessing
print(f"\nüîÑ Step 3: Preprocessing {DATASET} TRAINING data...")
print(f"   Input: {RAW_DATA_DIR}")
print(f"   Output: {PREPROCESSED_DIR}")
print(f"   Target size: {config_data['img_size']}√ó{config_data['img_size']}")
print(f"   Split: 80% train, 20% val")
print()

target_size = (config_data['img_size'], config_data['img_size'])
num_volumes, total_slices, all_ids = preprocess_acdc_training_only(RAW_DATA_DIR, PREPROCESSED_DIR, target_size)

print(f"\nüìä Expected vs Actual:")
print(f"   Expected: ~1800 slices (100 patients √ó 2 frames √ó ~9 slices)")
print(f"   Actual: {total_slices} slices")

if total_slices < 500:
    print(f"\n‚ö†Ô∏è Slice count is low! Check if training folder was downloaded.")

In [None]:
# Step 4: Load preprocessed data v·ªõi split 80/20
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json


class ACDCPreprocessedDataset(Dataset):
    """
    ACDC Dataset from preprocessed .npy files.
    Uses predefined train/val split from metadata (80/20).
    Classes: Background (0), RV (1), Myocardium (2), LV (3)
    """
    def __init__(self, data_dir, split='train', min_foreground=50):
        self.data_dir = data_dir
        self.split = split
        self.min_foreground = min_foreground
        
        volumes_dir = os.path.join(data_dir, 'volumes')
        masks_dir = os.path.join(data_dir, 'masks')
        
        # Load metadata with train/val split
        metadata_path = os.path.join(data_dir, 'metadata.json')
        if not os.path.exists(metadata_path):
            print(f"‚ùå metadata.json not found in {data_dir}")
            self.slices = []
            return
        
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        
        # Get volumes for this split (80/20 ƒë√£ chia s·∫µn)
        if split == 'train':
            volume_ids = metadata.get('train_volumes', [])
        else:
            volume_ids = metadata.get('val_volumes', [])
        
        if len(volume_ids) == 0:
            # Fallback: split by ratio if metadata doesn't have split info
            all_volumes = sorted(glob.glob(os.path.join(volumes_dir, '*.npy')))
            split_idx = int(0.8 * len(all_volumes))
            if split == 'train':
                volume_files = all_volumes[:split_idx]
            else:
                volume_files = all_volumes[split_idx:]
            volume_ids = [os.path.basename(f).replace('.npy', '') for f in volume_files]
        
        # Build slice index
        self.slices = []
        for vol_id in volume_ids:
            vol_path = os.path.join(volumes_dir, f'{vol_id}.npy')
            mask_path = os.path.join(masks_dir, f'{vol_id}.npy')
            
            if not os.path.exists(vol_path) or not os.path.exists(mask_path):
                continue
            
            # Load mask to check which slices have foreground
            mask = np.load(mask_path)
            num_slices = mask.shape[2]
            
            for slice_idx in range(num_slices):
                mask_slice = mask[:, :, slice_idx]
                if np.sum(mask_slice > 0) >= self.min_foreground:
                    self.slices.append((vol_path, mask_path, slice_idx))
        
        print(f"   {split.upper()}: {len(self.slices)} slices from {len(volume_ids)} volumes")
    
    def __len__(self):
        return len(self.slices)
    
    def __getitem__(self, idx):
        vol_path, mask_path, slice_idx = self.slices[idx]
        
        # Load using memmap for speed
        volume = np.load(vol_path, mmap_mode='r')
        mask = np.load(mask_path, mmap_mode='r')
        
        img = volume[:, :, slice_idx].copy()
        seg = mask[:, :, slice_idx].copy()
        
        # To tensor [C, H, W]
        img = torch.from_numpy(img).unsqueeze(0).float()
        seg = torch.from_numpy(seg).long()
        
        return img, seg


# Create datasets
print(f"\nüìä Step 4: Loading preprocessed {DATASET} dataset...")
print(f"   Data source: TRAINING folder only (100 patients)")
print(f"   Split ratio: 80% train, 20% val")
print()

train_dataset = ACDCPreprocessedDataset(PREPROCESSED_DIR, split='train')
val_dataset = ACDCPreprocessedDataset(PREPROCESSED_DIR, split='val')

# Check if data loaded
if len(train_dataset) == 0:
    print("\n‚ö†Ô∏è No data loaded! Please check preprocessing step.")
else:
    # Create dataloaders
    BATCH_SIZE = 8
    NUM_CLASSES = config_data['num_classes']
    IN_CHANNELS = config_data['in_channels']
    IMG_SIZE = config_data['img_size']
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=2, 
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=2, 
        pin_memory=True
    )
    
    print(f"\n‚úÖ Dataset loaded successfully!")
    print(f"   Training slices: {len(train_dataset)} (80%)")
    print(f"   Validation slices: {len(val_dataset)} (20%)")
    print(f"   Total: {len(train_dataset) + len(val_dataset)} slices")
    print(f"   Batch size: {BATCH_SIZE}")
    print(f"   Number of classes: {NUM_CLASSES}")
    print(f"   Image size: {IMG_SIZE}√ó{IMG_SIZE}")

## üéØ Training Configuration

In [None]:
# Training configuration
config = {
    # Dataset (auto-configured from above)
    'dataset': DATASET,
    'in_channels': IN_CHANNELS,
    'num_classes': NUM_CLASSES,
    'img_size': IMG_SIZE,
    'class_names': config_data['class_names'],
    
    # Model
    'model': 'egm_net',           # 'egm_net', 'egm_net_lite', 'spectral_vmamba'
    'base_channels': 64,
    
    # Training
    'num_epochs': 100,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'batch_size': BATCH_SIZE,
    
    # Loss weights
    'dice_weight': 1.0,
    'ce_weight': 1.0,
    'boundary_weight': 0.5,
    
    # Implicit representation
    'num_points': 2048,           # Points sampled for implicit loss
    'boundary_ratio': 0.5,        # Ratio of points near boundaries
    
    # Checkpointing
    'save_every': 10,
    'checkpoint_dir': f'./checkpoints_{DATASET}',
    
    # Early stopping
    'patience': 20,
}

# Create checkpoint directory
os.makedirs(config['checkpoint_dir'], exist_ok=True)

print("üìã Training Configuration:")
print(f"   Dataset: {config['dataset']}")
print(f"   Classes: {config['class_names']}")
print(f"   Model: {config['model']}")
print(f"   Epochs: {config['num_epochs']}")
print(f"   LR: {config['learning_rate']}")

In [None]:
# Create model based on config
print(f"Creating {config['model']} model...")
print(f"   Input channels: {config['in_channels']}")
print(f"   Output classes: {config['num_classes']}")

if config['model'] == 'egm_net':
    model = EGMNet(
        in_channels=config['in_channels'],
        num_classes=config['num_classes'],
        img_size=config['img_size'],
        base_channels=config['base_channels'],
        num_stages=4,
        encoder_depth=2
    )
elif config['model'] == 'egm_net_lite':
    model = EGMNetLite(
        in_channels=config['in_channels'],
        num_classes=config['num_classes'],
        img_size=config['img_size']
    )
else:  # spectral_vmamba
    model = SpectralVMUNet(
        in_channels=config['in_channels'],
        out_channels=config['num_classes'],
        img_size=config['img_size'],
        base_channels=config['base_channels'],
        num_stages=4
    )

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úÖ Model created!")
print(f"   Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
print(f"   Trainable parameters: {trainable_params:,}")

## üìà Loss Functions & Metrics

In [None]:
# Loss functions
class DiceLoss(nn.Module):
    """Dice loss for segmentation."""
    def __init__(self, smooth=1e-5):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        # pred: (B, C, H, W) logits
        # target: (B, H, W) class indices
        pred = F.softmax(pred, dim=1)
        num_classes = pred.shape[1]
        
        target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()


class BoundaryLoss(nn.Module):
    """Boundary-aware loss using Sobel edge detection."""
    def __init__(self):
        super().__init__()
        # Sobel kernels
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
        self.register_buffer('sobel_x', sobel_x.view(1, 1, 3, 3))
        self.register_buffer('sobel_y', sobel_y.view(1, 1, 3, 3))
        
    def get_boundaries(self, mask):
        # mask: (B, H, W)
        mask = mask.float().unsqueeze(1)
        edge_x = F.conv2d(mask, self.sobel_x, padding=1)
        edge_y = F.conv2d(mask, self.sobel_y, padding=1)
        edges = torch.sqrt(edge_x**2 + edge_y**2)
        return (edges > 0.5).float()
    
    def forward(self, pred, target):
        # Get boundary regions
        boundaries = self.get_boundaries(target)
        
        # Weight loss by boundary
        pred_probs = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, pred.shape[1]).permute(0, 3, 1, 2).float()
        
        # BCE at boundaries
        boundary_loss = F.binary_cross_entropy(
            pred_probs * boundaries, 
            target_one_hot * boundaries,
            reduction='sum'
        ) / (boundaries.sum() + 1e-6)
        
        return boundary_loss


class CombinedLoss(nn.Module):
    """Combined loss for EGM-Net training."""
    def __init__(self, dice_weight=1.0, ce_weight=1.0, boundary_weight=0.5, num_classes=2):
        super().__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.boundary_weight = boundary_weight
        
        self.dice_loss = DiceLoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.boundary_loss = BoundaryLoss()
        
    def forward(self, outputs, targets):
        # For EGM-Net, outputs is a dict
        if isinstance(outputs, dict):
            pred = outputs['output']
            coarse = outputs.get('coarse')
            fine = outputs.get('fine')
            
            # Main loss
            loss = self.dice_weight * self.dice_loss(pred, targets)
            loss += self.ce_weight * self.ce_loss(pred, targets)
            loss += self.boundary_weight * self.boundary_loss(pred, targets)
            
            # Auxiliary losses (coarse and fine branches)
            if coarse is not None:
                loss += 0.3 * self.ce_loss(coarse, targets)
            if fine is not None:
                loss += 0.3 * self.ce_loss(fine, targets)
                
            return loss
        else:
            # Standard output (SpectralVMUNet)
            loss = self.dice_weight * self.dice_loss(outputs, targets)
            loss += self.ce_weight * self.ce_loss(outputs, targets)
            loss += self.boundary_weight * self.boundary_loss(outputs, targets)
            return loss


# Metrics
def compute_dice(pred, target, num_classes):
    """Compute per-class Dice scores."""
    dice_scores = []
    pred_classes = torch.argmax(pred, dim=1)
    
    for c in range(num_classes):
        pred_c = (pred_classes == c).float()
        target_c = (target == c).float()
        
        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum()
        
        if union > 0:
            dice = (2.0 * intersection) / union
        else:
            dice = torch.tensor(1.0)  # Both empty = perfect
            
        dice_scores.append(dice.item())
    
    return dice_scores


def compute_iou(pred, target, num_classes):
    """Compute per-class IoU scores."""
    iou_scores = []
    pred_classes = torch.argmax(pred, dim=1)
    
    for c in range(num_classes):
        pred_c = (pred_classes == c).float()
        target_c = (target == c).float()
        
        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum() - intersection
        
        if union > 0:
            iou = intersection / union
        else:
            iou = torch.tensor(1.0)
            
        iou_scores.append(iou.item())
    
    return iou_scores


print("‚úÖ Loss functions and metrics defined!")

## üöÄ Training Loop

In [None]:
# Full training function
def train_model(model, train_loader, val_loader, config, device):
    """Complete training loop with validation and checkpointing."""
    
    # Loss and optimizer
    criterion = CombinedLoss(
        dice_weight=config['dice_weight'],
        ce_weight=config['ce_weight'],
        boundary_weight=config['boundary_weight'],
        num_classes=config['num_classes']
    ).to(device)
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # Learning rate scheduler (cosine annealing)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['num_epochs'],
        eta_min=1e-6
    )
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_dice': [],
        'val_iou': [],
        'learning_rate': []
    }
    
    best_val_dice = 0.0
    patience_counter = 0
    
    print("\n" + "="*60)
    print("üöÄ Starting Training")
    print("="*60)
    
    for epoch in range(config['num_epochs']):
        # =============== Training ===============
        model.train()
        train_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]")
        
        for batch_idx, (images, masks) in enumerate(train_pbar):
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            
            # Compute loss
            loss = criterion(outputs, masks)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)
        
        # =============== Validation ===============
        model.eval()
        val_loss = 0.0
        all_dice = []
        all_iou = []
        
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Val]"):
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                
                # Get prediction tensor
                pred = outputs['output'] if isinstance(outputs, dict) else outputs
                
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                # Compute metrics
                dice_scores = compute_dice(pred, masks, config['num_classes'])
                iou_scores = compute_iou(pred, masks, config['num_classes'])
                
                all_dice.append(np.mean(dice_scores[1:]))  # Exclude background
                all_iou.append(np.mean(iou_scores[1:]))
        
        avg_val_loss = val_loss / len(val_loader)
        avg_val_dice = np.mean(all_dice)
        avg_val_iou = np.mean(all_iou)
        
        history['val_loss'].append(avg_val_loss)
        history['val_dice'].append(avg_val_dice)
        history['val_iou'].append(avg_val_iou)
        history['learning_rate'].append(optimizer.param_groups[0]['lr'])
        
        # Update learning rate
        scheduler.step()
        
        # Print epoch summary
        print(f"\nüìä Epoch {epoch+1}/{config['num_epochs']}")
        print(f"   Train Loss: {avg_train_loss:.4f}")
        print(f"   Val Loss:   {avg_val_loss:.4f}")
        print(f"   Val Dice:   {avg_val_dice:.4f}")
        print(f"   Val IoU:    {avg_val_iou:.4f}")
        print(f"   LR:         {optimizer.param_groups[0]['lr']:.6f}")
        
        # =============== Checkpointing ===============
        # Save best model
        if avg_val_dice > best_val_dice:
            best_val_dice = avg_val_dice
            patience_counter = 0
            
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_dice': best_val_dice,
                'config': config,
                'history': history
            }
            torch.save(checkpoint, os.path.join(config['checkpoint_dir'], 'best_model.pth'))
            print(f"   ‚úÖ New best model saved! (Dice: {best_val_dice:.4f})")
        else:
            patience_counter += 1
        
        # Save periodic checkpoints
        if (epoch + 1) % config['save_every'] == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history
            }
            torch.save(checkpoint, os.path.join(config['checkpoint_dir'], f'checkpoint_epoch_{epoch+1}.pth'))
        
        # Early stopping
        if patience_counter >= config['patience']:
            print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch+1} epochs")
            break
    
    print("\n" + "="*60)
    print(f"üéâ Training completed!")
    print(f"   Best Val Dice: {best_val_dice:.4f}")
    print("="*60)
    
    return history


print("‚úÖ Training function defined!")

In [None]:
# üöÄ START TRAINING
history = train_model(model, train_loader, val_loader, config, device)

## üìä Training Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Dice score
axes[0, 1].plot(history['val_dice'], label='Val Dice', linewidth=2, color='green')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Dice Score')
axes[0, 1].set_title('Validation Dice Score')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].axhline(y=max(history['val_dice']), color='r', linestyle='--', alpha=0.5, label=f"Best: {max(history['val_dice']):.4f}")

# IoU score
axes[1, 0].plot(history['val_iou'], label='Val IoU', linewidth=2, color='orange')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('IoU Score')
axes[1, 0].set_title('Validation IoU Score')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].plot(history['learning_rate'], label='Learning Rate', linewidth=2, color='purple')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')

plt.suptitle('Training Progress', fontsize=14)
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüìà Training Summary:")
print(f"   Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"   Final Val Loss:   {history['val_loss'][-1]:.4f}")
print(f"   Best Val Dice:    {max(history['val_dice']):.4f}")
print(f"   Best Val IoU:     {max(history['val_iou']):.4f}")

## üîç Inference & Visualization

In [None]:
# Load best model
checkpoint = torch.load(os.path.join(config['checkpoint_dir'], 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
print(f"‚úÖ Loaded best model from epoch {checkpoint['epoch']} (Dice: {checkpoint['best_val_dice']:.4f})")

In [None]:
# Visualize predictions on validation set
model.eval()

fig, axes = plt.subplots(4, 4, figsize=(16, 16))

with torch.no_grad():
    for i in range(4):
        idx = i * (len(val_dataset) // 4)
        img, mask = val_dataset[idx]
        img = img.unsqueeze(0).to(device)
        
        # Get prediction
        outputs = model(img)
        pred = outputs['output'] if isinstance(outputs, dict) else outputs
        pred_mask = torch.argmax(pred, dim=1)[0].cpu()
        
        # Get energy map if available
        energy = outputs.get('energy', None)
        
        # Plot
        axes[i, 0].imshow(img[0, 0].cpu(), cmap='gray')
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask, cmap='viridis')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_mask, cmap='viridis')
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
        
        # Overlay
        overlay = img[0, 0].cpu().numpy()
        overlay = np.stack([overlay, overlay, overlay], axis=-1)
        pred_np = pred_mask.numpy()
        mask_np = mask.numpy()
        
        # Red for prediction, blue for ground truth
        overlay[..., 0] = np.where(pred_np > 0, 1.0, overlay[..., 0])
        overlay[..., 2] = np.where(mask_np > 0, 1.0, overlay[..., 2])
        overlay = np.clip(overlay, 0, 1)
        
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Overlay (Red=Pred, Blue=GT)')
        axes[i, 3].axis('off')

plt.suptitle('Validation Predictions', fontsize=14)
plt.tight_layout()
plt.savefig('predictions.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Visualize EGM-Net branches (Coarse vs Fine)
if config['model'] in ['egm_net', 'egm_net_lite']:
    fig, axes = plt.subplots(3, 5, figsize=(20, 12))
    
    with torch.no_grad():
        for i in range(3):
            idx = i * (len(val_dataset) // 3)
            img, mask = val_dataset[idx]
            img = img.unsqueeze(0).to(device)
            
            outputs = model(img)
            
            # Extract all outputs
            final_pred = torch.argmax(outputs['output'], dim=1)[0].cpu()
            coarse_pred = torch.argmax(outputs['coarse'], dim=1)[0].cpu()
            fine_pred = torch.argmax(outputs['fine'], dim=1)[0].cpu()
            energy = outputs['energy'][0, 0].cpu()
            
            # Plot
            axes[i, 0].imshow(img[0, 0].cpu(), cmap='gray')
            axes[i, 0].set_title('Input')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(energy, cmap='hot')
            axes[i, 1].set_title('Energy Map')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(coarse_pred, cmap='viridis')
            axes[i, 2].set_title('Coarse Branch')
            axes[i, 2].axis('off')
            
            axes[i, 3].imshow(fine_pred, cmap='viridis')
            axes[i, 3].set_title('Fine Branch')
            axes[i, 3].axis('off')
            
            axes[i, 4].imshow(final_pred, cmap='viridis')
            axes[i, 4].set_title('Final (Fused)')
            axes[i, 4].axis('off')
    
    plt.suptitle('EGM-Net Branch Analysis\\n(Energy-Gated Fusion of Coarse + Fine)', fontsize=14)
    plt.tight_layout()
    plt.savefig('branch_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()

## üéØ Resolution-Free Inference Demo

In [None]:
# Demonstrate resolution-free inference (unique to EGM-Net)
if config['model'] in ['egm_net', 'egm_net_lite']:
    print("üî¨ Resolution-Free Inference Demo")
    print("   EGM-Net can render at ANY resolution without retraining!")
    
    resolutions = [64, 128, 256, 512]
    
    # Get a sample image
    img, mask = val_dataset[0]
    img = img.unsqueeze(0).to(device)
    
    fig, axes = plt.subplots(2, len(resolutions), figsize=(16, 8))
    
    with torch.no_grad():
        for i, res in enumerate(resolutions):
            # Render at different resolutions
            outputs = model(img, output_size=(res, res))
            pred = torch.argmax(outputs['output'], dim=1)[0].cpu()
            
            # Also show input at same res for comparison
            input_resized = F.interpolate(img, size=(res, res), mode='bilinear', align_corners=False)
            
            axes[0, i].imshow(input_resized[0, 0].cpu(), cmap='gray')
            axes[0, i].set_title(f'Input {res}√ó{res}')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(pred, cmap='viridis')
            axes[1, i].set_title(f'Prediction {res}√ó{res}')
            axes[1, i].axis('off')
    
    plt.suptitle('Resolution-Free Rendering\\n(Same model weights, different output resolutions)', fontsize=14)
    plt.tight_layout()
    plt.savefig('resolution_free.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\\n‚úÖ Same model can output at 64√ó64 to 512√ó512 (or higher)!")
    print("   This is impossible with standard CNN decoders.")

## üíæ Save & Export Model

In [None]:
# Save final model to Google Drive
from google.colab import drive

try:
    drive.mount('/content/drive')
    
    # Save to Drive
    save_path = '/content/drive/MyDrive/EGM_Net_Models'
    os.makedirs(save_path, exist_ok=True)
    
    # Save checkpoint
    final_checkpoint = {
        'model_state_dict': model.state_dict(),
        'config': config,
        'history': history,
        'best_val_dice': max(history['val_dice'])
    }
    torch.save(final_checkpoint, os.path.join(save_path, 'egm_net_trained.pth'))
    
    # Also copy training curves
    import shutil
    shutil.copy('training_curves.png', save_path)
    shutil.copy('predictions.png', save_path)
    
    print(f"‚úÖ Model saved to Google Drive: {save_path}")
    print(f"   - egm_net_trained.pth")
    print(f"   - training_curves.png")
    print(f"   - predictions.png")
    
except Exception as e:
    print(f"‚ö†Ô∏è Could not save to Google Drive: {e}")
    print("   Model is saved locally in ./checkpoints/")

## üìä Final Evaluation Metrics

In [None]:
# Final evaluation on full validation set
model.eval()

all_dice_scores = []
all_iou_scores = []
all_hd95_scores = []  # Hausdorff Distance 95

# Helper function for Hausdorff distance
def compute_hausdorff_95(pred, target):
    """Compute 95th percentile Hausdorff distance."""
    from scipy.ndimage import distance_transform_edt
    
    pred_np = pred.numpy().astype(bool)
    target_np = target.numpy().astype(bool)
    
    if pred_np.sum() == 0 or target_np.sum() == 0:
        return 0.0
    
    # Distance transforms
    pred_dist = distance_transform_edt(~pred_np)
    target_dist = distance_transform_edt(~target_np)
    
    # Get surface points
    pred_surface = pred_np & (distance_transform_edt(pred_np) <= 1)
    target_surface = target_np & (distance_transform_edt(target_np) <= 1)
    
    # Distances from pred surface to target, and vice versa
    d_pred_to_target = target_dist[pred_surface]
    d_target_to_pred = pred_dist[target_surface]
    
    if len(d_pred_to_target) == 0 or len(d_target_to_pred) == 0:
        return 0.0
    
    # 95th percentile
    hd95 = max(np.percentile(d_pred_to_target, 95), np.percentile(d_target_to_pred, 95))
    return hd95

print("üîç Running final evaluation...")

with torch.no_grad():
    for images, masks in tqdm(val_loader, desc="Evaluating"):
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model(images)
        pred = outputs['output'] if isinstance(outputs, dict) else outputs
        pred_masks = torch.argmax(pred, dim=1)
        
        for b in range(images.shape[0]):
            # Per-sample metrics
            dice = compute_dice(pred[b:b+1], masks[b:b+1], config['num_classes'])
            iou = compute_iou(pred[b:b+1], masks[b:b+1], config['num_classes'])
            
            all_dice_scores.append(np.mean(dice[1:]))  # Exclude background
            all_iou_scores.append(np.mean(iou[1:]))
            
            # Hausdorff distance (for foreground)
            try:
                hd95 = compute_hausdorff_95(
                    (pred_masks[b] > 0).cpu(),
                    (masks[b] > 0).cpu()
                )
                all_hd95_scores.append(hd95)
            except:
                pass

# Print results
print("\n" + "="*60)
print("üìä FINAL EVALUATION RESULTS")
print("="*60)
print(f"\n{'Metric':<20} {'Mean':<12} {'Std':<12}")
print("-"*44)
print(f"{'Dice Score':<20} {np.mean(all_dice_scores):.4f}       {np.std(all_dice_scores):.4f}")
print(f"{'IoU Score':<20} {np.mean(all_iou_scores):.4f}       {np.std(all_iou_scores):.4f}")
if all_hd95_scores:
    print(f"{'HD95 (mm)':<20} {np.mean(all_hd95_scores):.2f}         {np.std(all_hd95_scores):.2f}")
print("-"*44)
print(f"\nTotal validation samples: {len(all_dice_scores)}")
print("="*60)

## 3. Test Monogenic Signal Processing

In [None]:
# Create a test image with edges
def create_test_image(size=256):
    """Create synthetic medical-like image with organs."""
    img = torch.zeros(1, 1, size, size)
    
    # Add circular "organ"
    y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
    center1 = (size // 2, size // 2)
    radius1 = size // 4
    mask1 = ((x - center1[0])**2 + (y - center1[1])**2) < radius1**2
    img[0, 0, mask1] = 0.7
    
    # Add smaller "tumor"
    center2 = (size // 2 + 30, size // 2 - 20)
    radius2 = size // 10
    mask2 = ((x - center2[0])**2 + (y - center2[1])**2) < radius2**2
    img[0, 0, mask2] = 1.0
    
    # Add noise
    img = img + 0.05 * torch.randn_like(img)
    
    return img, mask1.float(), mask2.float()

# Create test image
test_img, organ_mask, tumor_mask = create_test_image(256)
print(f"Test image shape: {test_img.shape}")

In [None]:
# Test Monogenic Energy Extraction
energy_extractor = EnergyMap(normalize=True, smoothing_sigma=1.0)
energy, mono_out = energy_extractor(test_img)

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

axes[0, 0].imshow(test_img[0, 0], cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

axes[0, 1].imshow(energy[0, 0].detach(), cmap='hot')
axes[0, 1].set_title('Energy Map (Edges)')
axes[0, 1].axis('off')

axes[0, 2].imshow(mono_out['phase'][0, 0].detach(), cmap='twilight')
axes[0, 2].set_title('Phase')
axes[0, 2].axis('off')

axes[1, 0].imshow(mono_out['orientation'][0, 0].detach(), cmap='hsv')
axes[1, 0].set_title('Orientation')
axes[1, 0].axis('off')

axes[1, 1].imshow(mono_out['riesz_x'][0, 0].detach(), cmap='RdBu')
axes[1, 1].set_title('Riesz X Component')
axes[1, 1].axis('off')

axes[1, 2].imshow(mono_out['riesz_y'][0, 0].detach(), cmap='RdBu')
axes[1, 2].set_title('Riesz Y Component')
axes[1, 2].axis('off')

plt.suptitle('Monogenic Signal Decomposition', fontsize=14)
plt.tight_layout()
plt.show()

print("\n‚úÖ Monogenic processing works correctly!")

## 4. Test Gabor Basis vs Fourier Features

In [None]:
from gabor_implicit import GaborBasis, FourierFeatures

# Create coordinate grid
size = 128
y = torch.linspace(-1, 1, size)
x = torch.linspace(-1, 1, size)
yy, xx = torch.meshgrid(y, x, indexing='ij')
coords = torch.stack([xx, yy], dim=-1).view(1, -1, 2)  # (1, size*size, 2)

# Compare Gabor vs Fourier
gabor = GaborBasis(input_dim=2, num_frequencies=32)
fourier = FourierFeatures(input_dim=2, num_frequencies=32, scale=10.0)

gabor_features = gabor(coords)
fourier_features = fourier(coords)

print(f"Gabor features shape: {gabor_features.shape}")
print(f"Fourier features shape: {fourier_features.shape}")

In [None]:
# Visualize first few basis functions
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    # Gabor
    gabor_vis = gabor_features[0, :, i].view(size, size).detach().numpy()
    axes[0, i].imshow(gabor_vis, cmap='RdBu', vmin=-1, vmax=1)
    axes[0, i].set_title(f'Gabor Basis {i+1}')
    axes[0, i].axis('off')
    
    # Fourier
    fourier_vis = fourier_features[0, :, i].view(size, size).detach().numpy()
    axes[1, i].imshow(fourier_vis, cmap='RdBu', vmin=-1, vmax=1)
    axes[1, i].set_title(f'Fourier Basis {i+1}')
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Gabor\n(Localized)', fontsize=12)
axes[1, 0].set_ylabel('Fourier\n(Global)', fontsize=12)

plt.suptitle('Gabor vs Fourier Basis Functions\n(Gabor is localized ‚Üí No Gibbs ringing)', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Create and Analyze Models

In [None]:
# Create EGM-Net models
print("Creating models...")

# Full model
egm_net = EGMNet(
    in_channels=1,
    num_classes=3,
    img_size=256,
    base_channels=64,
    num_stages=4,
    encoder_depth=2
).to(device)

# Lite model
egm_lite = EGMNetLite(
    in_channels=1,
    num_classes=3,
    img_size=256
).to(device)

# Spectral Mamba (comparison)
spec_mamba = SpectralVMUNet(
    in_channels=1,
    out_channels=3,
    img_size=256,
    base_channels=64,
    num_stages=4
).to(device)

print("\nüìä Model Comparison:")
print("-" * 50)
models = {
    'EGM-Net Full': egm_net,
    'EGM-Net Lite': egm_lite,
    'SpectralVMUNet': spec_mamba
}

for name, model in models.items():
    params = sum(p.numel() for p in model.parameters())
    print(f"{name:20s}: {params:,} parameters ({params/1e6:.2f}M)")

## 6. Test Forward Pass

In [None]:
# Test forward pass
test_input = torch.randn(2, 1, 256, 256).to(device)

print("Testing forward pass...")
print(f"Input shape: {test_input.shape}")

with torch.no_grad():
    # EGM-Net
    egm_out = egm_net(test_input)
    print(f"\nüîπ EGM-Net Output:")
    for k, v in egm_out.items():
        print(f"   {k}: {v.shape}")
    
    # SpectralVMUNet
    spec_out = spec_mamba(test_input)
    print(f"\nüîπ SpectralVMUNet Output: {spec_out.shape}")

print("\n‚úÖ Forward pass successful!")

## 7. Test Resolution-Free Inference (Unique to EGM-Net)

In [None]:
# EGM-Net can query at arbitrary coordinates!
print("Testing Resolution-Free Inference...")

# Create query points (random locations)
num_points = 10000
random_coords = torch.rand(1, num_points, 2).to(device) * 2 - 1  # [-1, 1]

with torch.no_grad():
    # Query at random points
    point_output = egm_net.query_points(test_input[:1], random_coords)
    
print(f"Query coordinates: {random_coords.shape}")
print(f"Point outputs: {point_output.shape}")
print("\n‚úÖ Resolution-free inference works!")
print("   ‚Üí You can zoom into boundaries at ANY resolution!")

In [None]:
# Demonstrate resolution-free: render at different resolutions
resolutions = [64, 128, 256, 512]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

with torch.no_grad():
    for idx, res in enumerate(resolutions):
        # Render at this resolution
        output = egm_net(test_input[:1], output_size=(res, res))
        pred = torch.argmax(output['output'], dim=1)[0].cpu().numpy()
        
        axes[idx].imshow(pred, cmap='viridis')
        axes[idx].set_title(f'{res}√ó{res}')
        axes[idx].axis('off')

plt.suptitle('Resolution-Free Rendering (Same model, different output sizes)', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Visualize Energy-Gated Fusion

In [None]:
# Visualize the dual-branch architecture
with torch.no_grad():
    outputs = egm_net(test_input[:1])

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Input
axes[0, 0].imshow(test_input[0, 0].cpu(), cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

# Energy Map
axes[0, 1].imshow(outputs['energy'][0, 0].cpu(), cmap='hot')
axes[0, 1].set_title('Energy Map (Edge Detection)')
axes[0, 1].axis('off')

# Coarse Branch
coarse_pred = torch.argmax(outputs['coarse'], dim=1)[0].cpu()
axes[0, 2].imshow(coarse_pred, cmap='viridis')
axes[0, 2].set_title('Coarse Branch (Smooth)')
axes[0, 2].axis('off')

# Fine Branch
fine_pred = torch.argmax(outputs['fine'], dim=1)[0].cpu()
axes[1, 0].imshow(fine_pred, cmap='viridis')
axes[1, 0].set_title('Fine Branch (Sharp)')
axes[1, 0].axis('off')

# Final Output
final_pred = torch.argmax(outputs['output'], dim=1)[0].cpu()
axes[1, 1].imshow(final_pred, cmap='viridis')
axes[1, 1].set_title('Final Output (Fused)')
axes[1, 1].axis('off')

# Difference
diff = (fine_pred != coarse_pred).float()
axes[1, 2].imshow(diff, cmap='Reds')
axes[1, 2].set_title('Difference (Fine vs Coarse)')
axes[1, 2].axis('off')

plt.suptitle('EGM-Net Dual-Branch Architecture', fontsize=14)
plt.tight_layout()
plt.show()

## 9. Quick Training Demo

In [None]:
from train_egm import EGMNetTrainer, create_dummy_dataset
from torch.utils.data import DataLoader

# Create small dummy dataset
print("Creating dummy dataset...")
dataset = create_dummy_dataset(num_samples=16, img_size=256, num_classes=3)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Training config
config = {
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'num_epochs': 2,
    'num_points': 1024,
    'boundary_ratio': 0.5,
    'checkpoint_dir': './checkpoints_demo'
}

# Use lite model for faster training
model = EGMNetLite(in_channels=1, num_classes=3, img_size=256)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
# Train for a few epochs
print("\nStarting training demo...")
trainer = EGMNetTrainer(model, config, device=device)
trainer.train(train_loader, num_epochs=2)

print("\n‚úÖ Training demo completed!")

## 10. Inference Speed Benchmark

In [None]:
import time

def benchmark_model(model, input_tensor, num_runs=50, warmup=10):
    """Benchmark inference speed."""
    model.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
    
    if device == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            start = time.time()
            _ = model(input_tensor)
            if device == 'cuda':
                torch.cuda.synchronize()
            times.append(time.time() - start)
    
    return np.mean(times) * 1000, np.std(times) * 1000  # ms

# Benchmark
print("Benchmarking inference speed...")
print("-" * 60)

test_input = torch.randn(1, 1, 256, 256).to(device)

for name, model in [('EGM-Net Full', egm_net), ('EGM-Net Lite', egm_lite)]:
    mean_time, std_time = benchmark_model(model, test_input)
    fps = 1000 / mean_time
    print(f"{name:20s}: {mean_time:.2f} ¬± {std_time:.2f} ms ({fps:.1f} FPS)")

print("\n‚úÖ Benchmark completed!")

## 11. Summary

In [None]:
print("""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                    EGM-NET ARCHITECTURE SUMMARY                       ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                       ‚ïë
‚ïë  üî¨ KEY INNOVATIONS:                                                  ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  1. MONOGENIC ENERGY GATING                                          ‚ïë
‚ïë     ‚Ä¢ Physics-based edge detection (Riesz Transform)                 ‚ïë
‚ïë     ‚Ä¢ Automatically focuses on boundary regions                      ‚ïë
‚ïë     ‚Ä¢ Suppresses artifacts in flat regions                           ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  2. GABOR BASIS (vs Fourier)                                         ‚ïë
‚ïë     ‚Ä¢ Localized oscillations (Gaussian √ó sin)                        ‚ïë
‚ïë     ‚Ä¢ NO Gibbs ringing artifacts                                     ‚ïë
‚ïë     ‚Ä¢ Sharp edges remain clean                                       ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  3. DUAL-PATH ARCHITECTURE                                           ‚ïë
‚ïë     ‚Ä¢ Coarse Branch: Smooth body regions (Conv decoder)              ‚ïë
‚ïë     ‚Ä¢ Fine Branch: Sharp boundaries (Gabor Implicit)                 ‚ïë
‚ïë     ‚Ä¢ Energy-gated fusion: Best of both worlds                       ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  4. RESOLUTION-FREE INFERENCE                                        ‚ïë
‚ïë     ‚Ä¢ Query at ANY coordinate ‚Üí Infinite zoom                        ‚ïë
‚ïë     ‚Ä¢ No retraining needed for different resolutions                 ‚ïë
‚ïë     ‚Ä¢ Perfect for high-resolution medical imaging                    ‚ïë
‚ïë                                                                       ‚ïë
‚ïë  5. MAMBA ENCODER                                                    ‚ïë
‚ïë     ‚Ä¢ O(N) complexity (vs O(N¬≤) for Transformers)                    ‚ïë
‚ïë     ‚Ä¢ Global context awareness                                       ‚ïë
‚ïë     ‚Ä¢ Efficient for large images                                     ‚ïë
‚ïë                                                                       ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                       ‚ïë
‚ïë  üìä MODEL SIZES:                                                      ‚ïë
‚ïë     ‚Ä¢ EGM-Net Full:  ~9.13M parameters                               ‚ïë
‚ïë     ‚Ä¢ EGM-Net Lite:  ~635K parameters                                ‚ïë
‚ïë     ‚Ä¢ SpectralVMUNet: ~10.31M parameters                             ‚ïë
‚ïë                                                                       ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
""")

---

## üìö Next Steps

1. **Train on real data**: Replace dummy dataset with medical imaging dataset (e.g., Synapse, ACDC)
2. **Tune hyperparameters**: Adjust `num_frequencies`, `boundary_ratio`, learning rate
3. **Evaluate metrics**: Dice score, IoU, Hausdorff distance
4. **Ablation study**: Compare Gabor vs Fourier, with/without energy gating

---

**Repository**: https://github.com/QuocKhanhLuong/FourierNetwork