In [None]:
from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel
import torch

# Optimized config for Landsat 3-timestep forecasting
landsat_config = {
    'input_shape': (3, 128, 128, 9),    # 3 input timesteps, 128x128, 9 Landsat bands
    'target_shape': (3, 128, 128, 1),   # 3 output timesteps
    
    # Small model for prototyping
    'base_units': 96,                    # Small but efficient
    'num_heads': 6,                      # Divisible by base_units
    'enc_depth': [2, 2],                 # 2-level hierarchy (sufficient for short sequences)
    'dec_depth': [1, 1],                 # Matching decoder depth
    
    # Dropout for better generalization during prototyping
    'attn_drop': 0.1,
    'proj_drop': 0.1,
    'ffn_drop': 0.1,
    
    # Global vectors for capturing Landsat scene patterns
    'num_global_vectors': 8,
    'use_dec_self_global': True,
    'use_dec_cross_global': True,
    
    # Optimized for satellite imagery
    'pos_embed_type': 't+hw',            # Separate temporal and spatial embeddings
    'use_relative_pos': True,            # Good for satellite spatial patterns
    'ffn_activation': 'gelu',            # Works well for vision tasks
    
    # Cuboid settings optimized for short temporal sequences
    'enc_cuboid_size': [(2, 4, 4), (2, 4, 4)],     # Small temporal cuboids for 3 timesteps
    'enc_cuboid_strategy': [('l', 'l', 'l'), ('d', 'd', 'd')],
    
    # Cross-attention settings for decoder
    'dec_cross_cuboid_hw': [(4, 4), (4, 4)],
    'dec_cross_n_temporal': [1, 2],      # Use 1-2 temporal frames for cross-attention
}

# Create model
model = CuboidTransformerModel(**landsat_config)
print(f"✓ Landsat model created! Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test with dummy Landsat data
batch_size = 4  # You can use larger batches with 40GB VRAM
dummy_landsat = torch.randn(batch_size, 3, 128, 128, 9)
print(f"Input shape: {dummy_landsat.shape}")

# Forward pass test
with torch.no_grad():
    output = model(dummy_landsat)
    print(f"Output shape: {output.shape}")
    print("✓ Forward pass successful!")

# Memory usage estimate
def estimate_memory_usage(model, input_shape, batch_size=1):
    model.eval()
    dummy_input = torch.randn(batch_size, *input_shape)
    
    # Rough memory estimate
    param_memory = sum(p.numel() * 4 for p in model.parameters()) / 1e9  # GB
    input_memory = dummy_input.numel() * 4 / 1e9  # GB
    
    print(f"Estimated memory usage:")
    print(f"  Parameters: {param_memory:.2f} GB")
    print(f"  Input (batch={batch_size}): {input_memory:.2f} GB")
    print(f"  Activation estimate: ~{param_memory * 2:.2f} GB")
    print(f"  Total estimate: ~{param_memory * 3 + input_memory:.2f} GB")

estimate_memory_usage(model, (3, 128, 128, 9), batch_size=8)

In [3]:
import os
import torch
from model import LandsatLSTPredictor
from dataset import LandsatDataModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

def train_landsat_model(
    dataset_root: str = "./Data/Dataset",
    batch_size: int = 4,
    max_epochs: int = 100,
    learning_rate: float = 1e-4,
    num_workers: int = 4,
    gpus: int = 1,
    precision: str = "16-mixed",  # Use mixed precision for memory efficiency
    accumulate_grad_batches: int = 1,
    val_check_interval: float = 1.0,
    limit_train_batches: float = 1.0,
    limit_val_batches: float = 1.0,
    experiment_name: str = "landsat_lst_prediction",
    checkpoint_dir: str = "./checkpoints",
    log_dir: str = "./logs"
):
    """
    Complete training pipeline for Landsat LST prediction
    
    Args:
        dataset_root: Path to preprocessed dataset
        batch_size: Training batch size
        max_epochs: Maximum training epochs
        learning_rate: Initial learning rate
        num_workers: Number of data loading workers
        gpus: Number of GPUs to use
        precision: Training precision ('32' or '16-mixed')
        accumulate_grad_batches: Gradient accumulation steps
        val_check_interval: Validation frequency
        limit_train_batches: Fraction of training data to use (for debugging)
        limit_val_batches: Fraction of validation data to use (for debugging)
        experiment_name: Name for logging
        checkpoint_dir: Directory to save checkpoints
        log_dir: Directory for logs
    """
    
    # Create directories
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    
    # Initialize data module
    data_module = LandsatDataModule(
        dataset_root=dataset_root,
        batch_size=batch_size,
        num_workers=num_workers,
        sequence_length=3
    )
    
    # Initialize model
    model = LandsatLSTPredictor(
        learning_rate=learning_rate,
        weight_decay=1e-5,
        warmup_steps=1000,
        max_epochs=max_epochs
    )
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename=f'{experiment_name}-{{epoch:02d}}-{{val_loss:.3f}}',
        save_top_k=3,
        monitor='val_loss',
        mode='min',
        save_last=True
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15,
        mode='min',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='step')
    
    # Logger
    logger = TensorBoardLogger(
        save_dir=log_dir,
        name=experiment_name,
        version=None
    )
    
    # Trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if gpus > 0 else 'cpu',
        devices=gpus if gpus > 0 else None,
        precision=precision,
        accumulate_grad_batches=accumulate_grad_batches,
        val_check_interval=val_check_interval,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        log_every_n_steps=50,
        enable_progress_bar=True,
        enable_model_summary=True
    )
    
    # Print model summary
    print(f"\n{'='*60}")
    print(f"LANDSAT LST PREDICTION TRAINING")
    print(f"{'='*60}")
    print(f"Dataset: {dataset_root}")
    print(f"Batch size: {batch_size}")
    print(f"Max epochs: {max_epochs}")
    print(f"Learning rate: {learning_rate}")
    print(f"Precision: {precision}")
    print(f"Devices: {gpus} GPU(s)" if gpus > 0 else "CPU")
    print(f"Experiment: {experiment_name}")
    print(f"{'='*60}\n")
    
    # Train the model
    try:
        trainer.fit(model, data_module)
        
        # Test the model
        print("\nRunning final test...")
        trainer.test(model, data_module, ckpt_path='best')
        
        print(f"\nTraining completed! Best model saved to: {checkpoint_callback.best_model_path}")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
        
    except Exception as e:
        print(f"\nTraining failed with error: {e}")
        raise
    
    return trainer, model, data_module


# Quick test/debug function
def debug_training(dataset_root: str = "./Data/Dataset"):
    """Quick debug run with small dataset fraction"""
    print("Running debug training...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=2,
        max_epochs=3,
        learning_rate=1e-3,
        num_workers=2,
        gpus=1,
        limit_train_batches=0.1,  # Use only 10% of data
        limit_val_batches=0.1,
        experiment_name="debug_landsat",
        val_check_interval=0.5,  # Check validation twice per epoch
    )
    
    print("Debug training completed!")


if __name__ == "__main__":
    # For full training, uncomment this:
    # train_landsat_model()
    
    # For debugging, use this:
    debug_training()
    
    # For custom training with specific parameters:
    # train_landsat_model(
    #     dataset_root="./Data/Dataset",
    #     batch_size=8,
    #     max_epochs=50,
    #     learning_rate=2e-4,
    #     gpus=1,
    #     experiment_name="landsat_experiment_1"
    # )

Running debug training...
Model initialized with 9,036,109 parameters


MisconfigurationException: Precision '16-mixed' is invalid. Allowed precision values: ('16', '32', '64', 'bf16', 'mixed')