In [1]:
"""
================================================================================
AEGNN BASELINE TRAINING - N-Caltech101
================================================================================
Official AEGNN reproduction using graph_res network
"""

import os
import sys
from pathlib import Path

# Add AEGNN to path
aegnn_root = Path("./aegnn").resolve()
if aegnn_root.exists():
    sys.path.insert(0, str(aegnn_root))

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from aegnn.models.recognition import RecognitionModel
from dataset_caltech_aegnn import get_aegnn_dataloaders

# ============================================================================
# CONFIGURATION
# ============================================================================

CONFIG = {
    'model': {
        'network': 'graph_res',        # AEGNN architecture
        'dataset': 'ncaltech101',
        'num_classes': 101,
        'img_shape': (180, 240),
        'dim': 3,                      # 3D positions (x, y, t)
        'model_kwargs': {}
    },
    
    'dataset': {
        'root': './datasets/ncaltech',
        'r': 3.0,                      # Spatial radius (paper setting)
        'd_max': 32,                   # Max neighbors (paper setting)
        'n_samples': 15000,            # Events per sample (paper setting)
        'beta': 3e-3,                # Time scaling factor (paper setting)
        'train_ratio': 0.7,
        'random_seed': 42
    },
    
    'training': {
        'batch_size': 8,              # Adjust based on GPU memory
        'eval_batch_size': 8,
        'num_workers': 0,              # Reduce if CPU bottleneck
        'lr': 1e-3,
        'weight_decay': 0.0,
        'epochs': 100,
        'scheduler_step': 30,          # LR decay every 30 epochs
        'scheduler_gamma': 0.5,        # LR decay factor
    },
    
    'hardware': {
        'gpus': 1,
        'accelerator': 'gpu',
        'precision': 32,               # Use fp32 for stability
    },
    
    'logging': {
        'save_dir': './checkpoints_aegnn',
        'log_dir': './runs_aegnn',
        'save_top_k': 3,
        'log_every_n_steps': 20
    },
    
    'seed': 42
}


# ============================================================================
# DATA MODULE
# ============================================================================

class NCaltech101DataModule(pl.LightningDataModule):
    """PyTorch Lightning DataModule for AEGNN baseline."""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.train_loader = None
        self.test_loader = None
    
    def setup(self, stage=None):
        """Create train/test dataloaders."""
        print("\nüì¶ Loading AEGNN dataset...")
        
        self.train_loader, self.test_loader = get_aegnn_dataloaders(
            root_dir=self.config['dataset']['root'],
            batch_size=self.config['training']['batch_size'],
            num_workers=self.config['training']['num_workers'],
            r=self.config['dataset']['r'],
            d_max=self.config['dataset']['d_max'],
            n_samples=self.config['dataset']['n_samples'],
            beta=self.config['dataset']['beta'],
            train_ratio=self.config['dataset']['train_ratio'],
            random_seed=self.config['dataset']['random_seed']
        )
        
        print(f"‚úÖ Setup complete!")
    
    def train_dataloader(self):
        return self.train_loader
    
    def val_dataloader(self):
        return self.test_loader
    
    def test_dataloader(self):
        return self.test_loader


# ============================================================================
# LIGHTNING MODULE
# ============================================================================

class AEGNNLightningWrapper(pl.LightningModule):
    """PyTorch Lightning wrapper for AEGNN model."""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters(config)
        
        # Initialize AEGNN model
        self.model = RecognitionModel(
            network=config['model']['network'],
            dataset=config['model']['dataset'],
            num_classes=config['model']['num_classes'],
            img_shape=config['model']['img_shape'],
            dim=config['model']['dim'],
            **config['model']['model_kwargs']
        )
    
    def forward(self, data):
        # ‚úÖ CRITICAL: Validate AFTER data is on GPU
        if self.training or not hasattr(self, '_validated_first_batch'):
            print(f"\nüîç GPU Validation:")
            print(f"  Device: {data.x.device}")
            print(f"  num_nodes: {data.num_nodes}")
            print(f"  num_edges: {data.num_edges}")
            
            if data.num_edges > 0:
                max_idx = data.edge_index.max().item()
                print(f"  edge_index.max(): {max_idx}")
                print(f"  Expected max: {data.num_nodes - 1}")
                
                if max_idx >= data.num_nodes:
                    print(f"\n‚ùå INVALID EDGE_INDEX ON GPU!")
                    print(f"   Data was valid on CPU but corrupted on GPU")
                    print(f"   This indicates a CUDA memory corruption bug")
                    
                    # Emergency fix
                    mask = (data.edge_index[0] < data.num_nodes) & (data.edge_index[1] < data.num_nodes)
                    data.edge_index = data.edge_index[:, mask]
                    print(f"   Fixed: new num_edges={data.edge_index.size(1)}")
            
            self._validated_first_batch = True
        
        return self.model(data)

    
    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = torch.nn.functional.cross_entropy(logits, batch.y)
        acc = (logits.argmax(dim=1) == batch.y).float().mean() * 100
        
        self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=batch.num_graphs)
        self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=batch.num_graphs)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        if batch_idx == 0:
            print(f"\n{'='*70}")
            print(f"RAW BATCH INSPECTION (before forward):")
            print(f"  batch.num_nodes: {batch.num_nodes}")
            print(f"  batch.num_edges: {batch.num_edges}")
            print(f"  batch.edge_index device: {batch.edge_index.device}")
            print(f"  batch.edge_index.shape: {batch.edge_index.shape}")
            
            if batch.num_edges > 0:
                print(f"  batch.edge_index.max(): {batch.edge_index.max().item()}")
                print(f"  batch.edge_index.min(): {batch.edge_index.min().item()}")
                
                # Check if valid
                if batch.edge_index.max() >= batch.num_nodes:
                    print(f"  ‚ùå ALREADY INVALID BEFORE FORWARD!")
                else:
                    print(f"  ‚úì Valid edge_index")
            print(f"{'='*70}\n")
        logits = self(batch)
        loss = torch.nn.functional.cross_entropy(logits, batch.y)
        acc = (logits.argmax(dim=1) == batch.y).float().mean() * 100
        
        self.log('val_loss', loss, prog_bar=True, batch_size=batch.num_graphs)
        self.log('val_acc', acc, prog_bar=True, batch_size=batch.num_graphs)
        
        return {'val_loss': loss, 'val_acc': acc}
    
    def test_step(self, batch, batch_idx):
        logits = self(batch)
        acc = (logits.argmax(dim=1) == batch.y).float().mean() * 100
        
        self.log('test_acc', acc, batch_size=batch.num_graphs)
        
        return {'test_acc': acc}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.config['training']['lr'],
            weight_decay=self.config['training']['weight_decay']
        )
        
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.config['training']['scheduler_step'],
            gamma=self.config['training']['scheduler_gamma']
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'
            }
        }


# ============================================================================
# MAIN TRAINING LOOP
# ============================================================================

def main():
    # Set seed for reproducibility
    pl.seed_everything(CONFIG['seed'])
    
    # Print header
    print("\n" + "="*80)
    print("üöÄ AEGNN BASELINE TRAINING")
    print("="*80)
    print(f"üì¶ Model: {CONFIG['model']['network']}")
    print(f"üìä Dataset: N-Caltech101")
    print(f"üî¢ Batch Size: {CONFIG['training']['batch_size']}")
    print(f"üìà Epochs: {CONFIG['training']['epochs']}")
    print(f"‚öôÔ∏è  Learning Rate: {CONFIG['training']['lr']}")
    print(f"üéØ Events/Sample: {CONFIG['dataset']['n_samples']}")
    print(f"üìê Radius: {CONFIG['dataset']['r']}")
    print(f"üîó Max Neighbors: {CONFIG['dataset']['d_max']}")
    print("="*80 + "\n")
    
    # Create datamodule and model
    datamodule = NCaltech101DataModule(CONFIG)
    model = AEGNNLightningWrapper(CONFIG)
    
    # Print model size
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"‚úÖ Model parameters: {params:,}\n")
    
    # Create directories
    os.makedirs(CONFIG['logging']['save_dir'], exist_ok=True)
    os.makedirs(CONFIG['logging']['log_dir'], exist_ok=True)
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=CONFIG['logging']['save_dir'],
        filename='aegnn-{epoch:02d}-{val_acc:.2f}',
        monitor='val_acc',
        mode='max',
        save_top_k=CONFIG['logging']['save_top_k'],
        save_last=True,
        verbose=True
    )
    
    early_stop_callback = EarlyStopping(
        monitor='val_acc',
        patience=30,
        mode='max',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Logger
    logger = TensorBoardLogger(
        CONFIG['logging']['log_dir'],
        name='aegnn'
    )
    
    # Trainer
    trainer = pl.Trainer(
        max_epochs=CONFIG['training']['epochs'],
        accelerator=CONFIG['hardware']['accelerator'],
        devices=CONFIG['hardware']['gpus'],
        precision=CONFIG['hardware']['precision'],
        callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
        logger=logger,
        log_every_n_steps=CONFIG['logging']['log_every_n_steps'],
#        gradient_clip_val=1.0,
        enable_progress_bar=True,
        enable_model_summary=True
    )
    
    # Train
    print("üèÉ Starting training...\n")
    trainer.fit(model, datamodule=datamodule)
    
    # Test
    print("\nüìä Running final evaluation...\n")
    trainer.test(model, datamodule=datamodule, ckpt_path='best')
    
    # Print results
    print("\n" + "="*80)
    print("üèÅ TRAINING COMPLETE!")
    print("="*80)
    print(f"‚úÖ Best model: {checkpoint_callback.best_model_path}")
    print(f"‚úÖ Best val_acc: {checkpoint_callback.best_model_score:.2f}%")
    print("="*80 + "\n")


if __name__ == '__main__':
    main()


Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
  return _C._get_float32_matmul_precision()
You are using a CUDA device ('NVIDIA GeForce RTX 5080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision



üöÄ AEGNN BASELINE TRAINING
üì¶ Model: graph_res
üìä Dataset: N-Caltech101
üî¢ Batch Size: 8
üìà Epochs: 100
‚öôÔ∏è  Learning Rate: 0.001
üéØ Events/Sample: 15000
üìê Radius: 3.0
üîó Max Neighbors: 32

‚úÖ Model parameters: 20,401,120

üèÉ Starting training...


üì¶ Loading AEGNN dataset...

üöÄ Creating AEGNN DataLoaders
Batch size: 8
Num workers: 0
Debug mode: True



FileNotFoundError: Directory not found: datasets/ncaltech/img