In [None]:
from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel
import torch

# Optimized config for Landsat 3-timestep forecasting
landsat_config = {
    'input_shape': (3, 128, 128, 9),    # 3 input timesteps, 128x128, 9 Landsat bands
    'target_shape': (3, 128, 128, 1),   # 3 output timesteps
    
    # Small model for prototyping
    'base_units': 96,                    # Small but efficient
    'num_heads': 6,                      # Divisible by base_units
    'enc_depth': [2, 2],                 # 2-level hierarchy (sufficient for short sequences)
    'dec_depth': [1, 1],                 # Matching decoder depth
    
    # Dropout for better generalization during prototyping
    'attn_drop': 0.1,
    'proj_drop': 0.1,
    'ffn_drop': 0.1,
    
    # Global vectors for capturing Landsat scene patterns
    'num_global_vectors': 8,
    'use_dec_self_global': True,
    'use_dec_cross_global': True,
    
    # Optimized for satellite imagery
    'pos_embed_type': 't+hw',            # Separate temporal and spatial embeddings
    'use_relative_pos': True,            # Good for satellite spatial patterns
    'ffn_activation': 'gelu',            # Works well for vision tasks
    
    # Cuboid settings optimized for short temporal sequences
    'enc_cuboid_size': [(2, 4, 4), (2, 4, 4)],     # Small temporal cuboids for 3 timesteps
    'enc_cuboid_strategy': [('l', 'l', 'l'), ('d', 'd', 'd')],
    
    # Cross-attention settings for decoder
    'dec_cross_cuboid_hw': [(4, 4), (4, 4)],
    'dec_cross_n_temporal': [1, 2],      # Use 1-2 temporal frames for cross-attention
}

# Create model
model = CuboidTransformerModel(**landsat_config)
print(f"✓ Landsat model created! Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test with dummy Landsat data
batch_size = 4  # You can use larger batches with 40GB VRAM
dummy_landsat = torch.randn(batch_size, 3, 128, 128, 9)
print(f"Input shape: {dummy_landsat.shape}")

# Forward pass test
with torch.no_grad():
    output = model(dummy_landsat)
    print(f"Output shape: {output.shape}")
    print("✓ Forward pass successful!")

# Memory usage estimate
def estimate_memory_usage(model, input_shape, batch_size=1):
    model.eval()
    dummy_input = torch.randn(batch_size, *input_shape)
    
    # Rough memory estimate
    param_memory = sum(p.numel() * 4 for p in model.parameters()) / 1e9  # GB
    input_memory = dummy_input.numel() * 4 / 1e9  # GB
    
    print(f"Estimated memory usage:")
    print(f"  Parameters: {param_memory:.2f} GB")
    print(f"  Input (batch={batch_size}): {input_memory:.2f} GB")
    print(f"  Activation estimate: ~{param_memory * 2:.2f} GB")
    print(f"  Total estimate: ~{param_memory * 3 + input_memory:.2f} GB")

estimate_memory_usage(model, (3, 128, 128, 9), batch_size=8)

In [None]:
import os
import torch
from model import LandsatLSTPredictor
from dataset import LandsatDataModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

def train_landsat_model(
    dataset_root: str = "./Data/Dataset",
    batch_size: int = 4,
    max_epochs: int = 100,
    learning_rate: float = 1e-4,
    num_workers: int = 4,
    gpus: int = 1,
    precision: str = "32",  # Start with 32-bit precision for stability
    accumulate_grad_batches: int = 1,
    val_check_interval: float = 1.0,
    limit_train_batches: float = 1.0,
    limit_val_batches: float = 1.0,
    experiment_name: str = "landsat_lst_prediction",
    checkpoint_dir: str = "./checkpoints",
    log_dir: str = "./logs"
):
    """
    Complete training pipeline for Landsat LST prediction
    
    Args:
        dataset_root: Path to preprocessed dataset
        batch_size: Training batch size
        max_epochs: Maximum training epochs
        learning_rate: Initial learning rate
        num_workers: Number of data loading workers
        gpus: Number of GPUs to use
        precision: Training precision ('32', '16', or '16-mixed')
        accumulate_grad_batches: Gradient accumulation steps
        val_check_interval: Validation frequency
        limit_train_batches: Fraction of training data to use (for debugging)
        limit_val_batches: Fraction of validation data to use (for debugging)
        experiment_name: Name for logging
        checkpoint_dir: Directory to save checkpoints
        log_dir: Directory for logs
    """
    
    # Create directories
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    
    # Initialize data module
    data_module = LandsatDataModule(
        dataset_root=dataset_root,
        batch_size=batch_size,
        num_workers=num_workers,
        sequence_length=3
    )
    
    # Initialize model
    model = LandsatLSTPredictor(
        learning_rate=learning_rate,
        weight_decay=1e-5,
        warmup_steps=1000,
        max_epochs=max_epochs
    )
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename=f'{experiment_name}-{{epoch:02d}}-{{val_loss:.3f}}',
        save_top_k=3,
        monitor='val_loss',
        mode='min',
        save_last=True
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15,
        mode='min',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')  # Changed to epoch for simplicity
    
    # Logger
    logger = TensorBoardLogger(
        save_dir=log_dir,
        name=experiment_name,
        version=None
    )
    
    # Trainer - simplified configuration
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if gpus > 0 else 'cpu',
        devices=gpus if gpus > 0 else None,
        # precision=precision,
        accumulate_grad_batches=accumulate_grad_batches,
        val_check_interval=val_check_interval,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        log_every_n_steps=50,
        enable_progress_bar=True,
        enable_model_summary=True,
        # Additional stability settings
        deterministic=False,  # Set to True for reproducibility, False for speed
        benchmark=True,  # Optimize for consistent input sizes
    )
    
    # Print model summary
    print(f"\n{'='*60}")
    print(f"LANDSAT LST PREDICTION TRAINING")
    print(f"{'='*60}")
    print(f"Dataset: {dataset_root}")
    print(f"Batch size: {batch_size}")
    print(f"Max epochs: {max_epochs}")
    print(f"Learning rate: {learning_rate}")
    print(f"Precision: {precision}")
    print(f"Devices: {gpus} GPU(s)" if gpus > 0 else "CPU")
    print(f"Experiment: {experiment_name}")
    print(f"{'='*60}\n")
    
    # Train the model
    try:
        trainer.fit(model, data_module)
        
        # Test the model
        print("\nRunning final test...")
        trainer.test(model, data_module, ckpt_path='best')
        
        print(f"\nTraining completed! Best model saved to: {checkpoint_callback.best_model_path}")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
        
    except Exception as e:
        print(f"\nTraining failed with error: {e}")
        raise
    
    return trainer, model, data_module


# Quick test/debug function with conservative settings
def debug_training(dataset_root: str = "./Data/Dataset"):
    """Quick debug run with small dataset fraction and conservative settings"""
    print("Running debug training...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=2,
        max_epochs=3,
        learning_rate=1e-3,
        num_workers=0,  # Disable multiprocessing for debugging
        gpus=1,
        precision="32",  # Use 32-bit for stability
        limit_train_batches=0.5,  # Use only 10% of data
        limit_val_batches=0.5,
        experiment_name="debug_landsat",
        val_check_interval=0.5,  # Check validation twice per epoch
    )
    
    print("Debug training completed!")


# Even more minimal debug function
def minimal_debug_training(dataset_root: str = "./Data/Dataset"):
    """Minimal debug run with absolute minimum settings"""
    print("Running minimal debug training...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=1,  # Smallest possible batch
        max_epochs=1,  # Just one epoch
        learning_rate=1e-3,
        num_workers=0,  # No multiprocessing
        gpus=0,  # Use CPU to avoid GPU issues
        precision="32",  # Standard precision
        limit_train_batches=0.5,  # Use only 5% of data
        limit_val_batches=0.5,
        experiment_name="minimal_debug",
        val_check_interval=1.0,
    )
    
    print("Minimal debug training completed!")


if __name__ == "__main__":
    # For debugging, try in this order:
    
    # 1. First try the standard debug
    try:
        debug_training()
    except Exception as e:
        print(f"Debug training failed: {e}")
        print("Trying minimal debug on CPU...")
        
        # 2. If that fails, try minimal debug
        try:
            minimal_debug_training()
        except Exception as e:
            print(f"Minimal debug also failed: {e}")
            print("Please check your dataset path and dependencies.")
    
    # For full training, uncomment this:
    # train_landsat_model(
    #     dataset_root="./Data/Dataset",
    #     batch_size=8,
    #     max_epochs=50,
    #     learning_rate=2e-4,
    #     gpus=1,
    #     precision="16-mixed",  # Use mixed precision for full training
    #     experiment_name="landsat_experiment_1"
    # )

In [None]:
import torch
import os
from dataset import LandsatDataModule
import numpy as np
from pathlib import Path

# Force CPU mode before any operations
os.environ['CUDA_VISIBLE_DEVICES'] = ''
torch.set_default_device('cpu')

def force_cpu_mode():
    """Force everything to CPU and clear GPU memory"""
    print("🔧 FORCING CPU MODE")
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("  ✅ GPU cache cleared")
    
    # Set environment variables to force CPU
    os.environ['CUDA_VISIBLE_DEVICES'] = ''
    print("  ✅ CUDA_VISIBLE_DEVICES set to empty")
    
    # Set default device to CPU
    torch.set_default_device('cpu')
    print("  ✅ Default device set to CPU")

def check_environment():
    """Check the environment for potential issues"""
    print("🔍 ENVIRONMENT CHECK")
    print("=" * 60)
    
    import pytorch_lightning as pl
    
    print(f"PyTorch version: {torch.__version__}")
    print(f"PyTorch Lightning version: {pl.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA device count: {torch.cuda.device_count() if torch.cuda.is_available() else 0}")
    print(f"Default device: {torch.get_default_device()}")
    print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
    
    # Check for any CUDA tensors
    if torch.cuda.is_available():
        print(f"Current CUDA memory usage: {torch.cuda.memory_allocated()} bytes")
    
    print("=" * 60)

def check_dataset_files(dataset_root: str = "./Data/Dataset"):
    """Check if dataset files exist and are accessible"""
    print(f"\n🔍 CHECKING DATASET FILES")
    print("=" * 60)
    
    dataset_path = Path(dataset_root)
    if not dataset_path.exists():
        print(f"❌ Dataset root does not exist: {dataset_path}")
        return False
    
    print(f"✅ Dataset root exists: {dataset_path}")
    
    # Check for expected subdirectories
    cities_dir = dataset_path / "Cities_Preprocessed"
    dem_dir = dataset_path / "DEM_2014_Preprocessed"
    
    if cities_dir.exists():
        cities = list(cities_dir.iterdir())
        print(f"✅ Cities directory found with {len(cities)} cities")
        
        # Check a few city directories
        for i, city in enumerate(cities[:3]):  # Check first 3 cities
            if city.is_dir():
                scenes = list(city.iterdir())
                print(f"  City {city.name}: {len(scenes)} scenes")
                
                # Check first scene
                if scenes:
                    scene = scenes[0]
                    if scene.is_dir():
                        files = list(scene.glob("*.tif"))
                        print(f"    Scene {scene.name}: {len(files)} .tif files")
                        for file in files[:3]:  # Show first 3 files
                            print(f"      {file.name}")
    else:
        print(f"❌ Cities directory not found: {cities_dir}")
        return False
    
    if dem_dir.exists():
        dem_cities = list(dem_dir.iterdir())
        print(f"✅ DEM directory found with {len(dem_cities)} cities")
    else:
        print(f"❌ DEM directory not found: {dem_dir}")
    
    return True

def minimal_tensor_test():
    """Test basic tensor operations to isolate the issue"""
    print("\n🔍 MINIMAL TENSOR TESTS")
    print("=" * 60)
    
    try:
        # Test 1: Basic tensor creation
        x = torch.randn(1, 3, 128, 128, 9)
        print(f"✅ Test 1 passed: Basic tensor creation {x.shape}")
        
        # Test 2: Indexing operations that might fail
        try:
            selected = x[:, :, :64, :64, :]  # Spatial cropping
            print(f"✅ Test 2 passed: Spatial indexing {selected.shape}")
        except Exception as e:
            print(f"❌ Test 2 failed: Spatial indexing - {e}")
        
        # Test 3: Channel selection (common source of indexing errors)
        try:
            channels = x[:, :, :, :, [0, 1, 2]]  # Select first 3 channels
            print(f"✅ Test 3 passed: Channel selection {channels.shape}")
        except Exception as e:
            print(f"❌ Test 3 failed: Channel selection - {e}")
        
        # Test 4: Temporal operations
        try:
            temporal = x[:, 1:, :, :, :]  # Skip first timestep
            print(f"✅ Test 4 passed: Temporal indexing {temporal.shape}")
        except Exception as e:
            print(f"❌ Test 4 failed: Temporal indexing - {e}")
            
        # Test 5: Reshaping operations
        try:
            reshaped = x.reshape(1, -1)
            print(f"✅ Test 5 passed: Reshaping {reshaped.shape}")
        except Exception as e:
            print(f"❌ Test 5 failed: Reshaping - {e}")
            
        # Test 6: Permutation operations (common in transformers)
        try:
            permuted = x.permute(0, 4, 1, 2, 3)  # (B, C, T, H, W)
            print(f"✅ Test 6 passed: Permutation {permuted.shape}")
        except Exception as e:
            print(f"❌ Test 6 failed: Permutation - {e}")
            
        # Test 7: Complex indexing (advanced selection)
        try:
            indices = torch.tensor([0, 2, 4])
            advanced = x[:, :, :, :, indices]
            print(f"✅ Test 7 passed: Advanced indexing {advanced.shape}")
        except Exception as e:
            print(f"❌ Test 7 failed: Advanced indexing - {e}")
            
    except Exception as e:
        print(f"❌ Minimal tensor tests failed: {e}")

def debug_data_shapes(dataset_root: str = "./Data/Dataset"):
    """Debug script to check data shapes and identify the indexing issue"""
    
    print("🔍 DEBUGGING DATA SHAPES AND TENSOR DIMENSIONS")
    print("=" * 60)
    
    try:
        # Force CPU operation
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        device = torch.device('cpu')
        
        # Initialize data module with minimal settings
        data_module = LandsatDataModule(
            dataset_root=dataset_root,
            batch_size=1,
            num_workers=0,  # Single threaded
            sequence_length=3
        )
        
        print("✅ Data module created successfully")
        
        # Setup data module
        data_module.setup("fit")
        print("✅ Data module setup completed")
        
        # Get train dataloader
        train_loader = data_module.train_dataloader()
        print(f"✅ Train dataloader created with {len(train_loader)} batches")
        
        # Try to load one batch
        print("\n🔍 EXAMINING FIRST BATCH:")
        inputs = None
        targets = None
        
        for batch_idx, (batch_inputs, batch_targets) in enumerate(train_loader):
            inputs = batch_inputs
            targets = batch_targets
            
            print(f"\nBatch {batch_idx}:")
            print(f"  Input tensor shape: {inputs.shape}")
            print(f"  Input tensor dtype: {inputs.dtype}")
            print(f"  Input tensor device: {inputs.device}")
            print(f"  Input tensor min/max: {inputs.min():.3f} / {inputs.max():.3f}")
            print(f"  Input contains NaN: {torch.isnan(inputs).any()}")
            print(f"  Input contains Inf: {torch.isinf(inputs).any()}")
            
            print(f"\n  Target tensor shape: {targets.shape}")
            print(f"  Target tensor dtype: {targets.dtype}")
            print(f"  Target tensor device: {targets.device}")
            print(f"  Target tensor min/max: {targets.min():.3f} / {targets.max():.3f}")
            print(f"  Target contains NaN: {torch.isnan(targets).any()}")
            print(f"  Target contains Inf: {torch.isinf(targets).any()}")
            
            # Check individual dimensions
            print(f"\n  Detailed input shape analysis:")
            print(f"    Batch size: {inputs.shape[0]}")
            print(f"    Time steps: {inputs.shape[1]}")
            print(f"    Height: {inputs.shape[2]}")
            print(f"    Width: {inputs.shape[3]}")
            print(f"    Channels: {inputs.shape[4]}")
            
            # Check if any dimension is 0 or unexpected
            for i, dim in enumerate(inputs.shape):
                if dim == 0:
                    print(f"    ⚠️  WARNING: Dimension {i} is 0!")
                if dim > 1000:
                    print(f"    ⚠️  WARNING: Dimension {i} is very large: {dim}")
            
            # Check data range for each channel
            print(f"\n  Channel-wise statistics (input):")
            for c in range(inputs.shape[4]):
                channel_data = inputs[0, :, :, :, c]
                print(f"    Channel {c}: min={channel_data.min():.3f}, max={channel_data.max():.3f}, mean={channel_data.mean():.3f}")
            
            # Only examine first batch to avoid overwhelming output
            if batch_idx == 0:
                break
        
        print(f"✅ Successfully examined {batch_idx + 1} batch(es)")
        
        # Test validation data too
        print(f"\n🔍 EXAMINING VALIDATION DATA:")
        try:
            val_loader = data_module.val_dataloader()
            print(f"✅ Validation dataloader created with {len(val_loader)} batches")
            
            for batch_idx, (val_inputs, val_targets) in enumerate(val_loader):
                print(f"\nValidation Batch {batch_idx}:")
                print(f"  Input shape: {val_inputs.shape}")
                print(f"  Target shape: {val_targets.shape}")
                
                if batch_idx == 0:  # Just check first validation batch
                    break
                    
        except Exception as e:
            print(f"❌ Validation data loading failed: {e}")
        
        # Now test model components
        print("\n🔍 TESTING MODEL COMPONENTS:")
        try:
            from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel
            
            # Get the actual data shape from our loaded batch
            if inputs is not None:
                sample_input = inputs[:1].cpu()  # Take first sample, ensure CPU
                
                print(f"  Using sample input shape: {sample_input.shape}")
                
                # Test with minimal config that matches your data exactly
                minimal_config = {
                    'input_shape': tuple(sample_input.shape[1:]),  # Remove batch dimension
                    'target_shape': (3, 128, 128, 1),
                    'base_units': 32,  # Much smaller
                    'num_heads': 4,
                    'enc_depth': [1],  # Single level
                    'dec_depth': [1],
                    'attn_drop': 0.0,
                    'proj_drop': 0.0,
                    'ffn_drop': 0.0,
                    'num_global_vectors': 4,
                    'use_dec_self_global': False,  # Disable to simplify
                    'use_dec_cross_global': False,
                    'pos_embed_type': 't+hw',
                    'use_relative_pos': False,  # Disable to simplify
                    'ffn_activation': 'gelu',
                    'enc_cuboid_size': [(1, 4, 4)],  # Very conservative
                    'enc_cuboid_strategy': [('l', 'l', 'l')],
                    'dec_cross_cuboid_hw': [(4, 4)],
                    'dec_cross_n_temporal': [1],
                }
                
                print(f"  Creating minimal model with config:")
                for key, value in minimal_config.items():
                    print(f"    {key}: {value}")
                
                model = CuboidTransformerModel(**minimal_config)
                model.eval()
                model = model.cpu()  # Force CPU
                
                print(f"  ✅ Minimal model created with {sum(p.numel() for p in model.parameters()):,} parameters")
                
                # Test forward pass with actual data
                print(f"  Testing forward pass with real data...")
                
                with torch.no_grad():
                    try:
                        print(f"    Input tensor device: {sample_input.device}")
                        print(f"    Input tensor shape: {sample_input.shape}")
                        print(f"    Model device: {next(model.parameters()).device}")
                        
                        # Ensure everything is on CPU
                        sample_input = sample_input.cpu()
                        model = model.cpu()
                        
                        outputs = model(sample_input)
                        print(f"  ✅ Forward pass successful! Output shape: {outputs.shape}")
                        
                        # Check output validity
                        print(f"    Output device: {outputs.device}")
                        print(f"    Output min/max: {outputs.min():.3f} / {outputs.max():.3f}")
                        print(f"    Output contains NaN: {torch.isnan(outputs).any()}")
                        print(f"    Output contains Inf: {torch.isinf(outputs).any()}")
                        
                    except Exception as e:
                        print(f"  ❌ Forward pass failed: {e}")
                        print(f"     Error type: {type(e).__name__}")
                        
                        # More detailed debugging
                        print(f"  🔍 Debugging model internals...")
                        
                        import traceback
                        print("  Full traceback:")
                        traceback.print_exc()
            else:
                print("  ❌ No input data available for model testing")
            
        except ImportError as e:
            print(f"  ❌ Could not import CuboidTransformerModel: {e}")
            print(f"     Check that earthformer is properly installed")
        except Exception as e:
            print(f"  ❌ Model creation failed: {e}")
            import traceback
            traceback.print_exc()
                
    except Exception as e:
        print(f"❌ Data loading failed: {e}")
        print(f"Error type: {type(e).__name__}")
        import traceback
        traceback.print_exc()

def main():
    """Main debugging function"""
    print("🚀 STARTING COMPREHENSIVE DEBUG")
    print("=" * 80)
    
    # Force CPU mode first
    force_cpu_mode()
    
    # Check environment
    check_environment()
    
    # Check dataset files
    if not check_dataset_files():
        print("❌ Dataset check failed - cannot proceed with data debugging")
        return
    
    # Run minimal tensor tests
    minimal_tensor_test()
    
    # Debug data shapes
    debug_data_shapes()
    
    print("\n" + "=" * 80)
    print("🏁 DEBUGGING COMPLETE")
    print("=" * 80)
    print("If all tests pass, the issue might be in the model configuration.")
    print("If tests fail, the issue is likely in data preprocessing or tensor operations.")
    print("Next steps:")
    print("1. If data loading works but model fails -> check model config")
    print("2. If data loading fails -> check dataset preprocessing")
    print("3. If tensor tests fail -> check PyTorch installation")

if __name__ == "__main__":
    main()