In [None]:
from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel
import torch

# Optimized config for Landsat 3-timestep forecasting
landsat_config = {
    'input_shape': (3, 128, 128, 9),    # 3 input timesteps, 128x128, 9 Landsat bands
    'target_shape': (3, 128, 128, 1),   # 3 output timesteps
    
    # Small model for prototyping
    'base_units': 96,                    # Small but efficient
    'num_heads': 6,                      # Divisible by base_units
    'enc_depth': [2, 2],                 # 2-level hierarchy (sufficient for short sequences)
    'dec_depth': [1, 1],                 # Matching decoder depth
    
    # Dropout for better generalization during prototyping
    'attn_drop': 0.1,
    'proj_drop': 0.1,
    'ffn_drop': 0.1,
    
    # Global vectors for capturing Landsat scene patterns
    'num_global_vectors': 8,
    'use_dec_self_global': True,
    'use_dec_cross_global': True,
    
    # Optimized for satellite imagery
    'pos_embed_type': 't+hw',            # Separate temporal and spatial embeddings
    'use_relative_pos': True,            # Good for satellite spatial patterns
    'ffn_activation': 'gelu',            # Works well for vision tasks
    
    # Cuboid settings optimized for short temporal sequences
    'enc_cuboid_size': [(2, 4, 4), (2, 4, 4)],     # Small temporal cuboids for 3 timesteps
    'enc_cuboid_strategy': [('l', 'l', 'l'), ('d', 'd', 'd')],
    
    # Cross-attention settings for decoder
    'dec_cross_cuboid_hw': [(4, 4), (4, 4)],
    'dec_cross_n_temporal': [1, 2],      # Use 1-2 temporal frames for cross-attention
}

# Create model
model = CuboidTransformerModel(**landsat_config)
print(f"✓ Landsat model created! Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test with dummy Landsat data
batch_size = 4  # You can use larger batches with 40GB VRAM
dummy_landsat = torch.randn(batch_size, 3, 128, 128, 9)
print(f"Input shape: {dummy_landsat.shape}")

# Forward pass test
with torch.no_grad():
    output = model(dummy_landsat)
    print(f"Output shape: {output.shape}")
    print("✓ Forward pass successful!")

# Memory usage estimate
def estimate_memory_usage(model, input_shape, batch_size=1):
    model.eval()
    dummy_input = torch.randn(batch_size, *input_shape)
    
    # Rough memory estimate
    param_memory = sum(p.numel() * 4 for p in model.parameters()) / 1e9  # GB
    input_memory = dummy_input.numel() * 4 / 1e9  # GB
    
    print(f"Estimated memory usage:")
    print(f"  Parameters: {param_memory:.2f} GB")
    print(f"  Input (batch={batch_size}): {input_memory:.2f} GB")
    print(f"  Activation estimate: ~{param_memory * 2:.2f} GB")
    print(f"  Total estimate: ~{param_memory * 3 + input_memory:.2f} GB")

estimate_memory_usage(model, (3, 128, 128, 9), batch_size=8)

In [4]:
import os
import torch
from model import LandsatLSTPredictor  # Your enhanced model
from dataset import LandsatDataModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
import wandb

def train_landsat_model_with_images(
    dataset_root: str = "./Data/Dataset",
    batch_size: int = 4,
    max_epochs: int = 100,
    learning_rate: float = 1e-4,
    num_workers: int = 4,
    gpus: int = 1,
    precision: str = "32",
    accumulate_grad_batches: int = 1,
    val_check_interval: float = 1.0,
    limit_train_batches: float = 1.0,
    limit_val_batches: float = 1.0,
    experiment_name: str = "landsat_lst_prediction_with_viz",
    checkpoint_dir: str = "./checkpoints",
    log_dir: str = "./logs",
    wandb_project: str = "landsat-lst-forecasting",
    wandb_tags: list = None,
    log_images_every_n_epochs: int = 5,
    max_images_to_log: int = 4
):
    """
    Enhanced training pipeline with WandB image logging
    
    Args:
        log_images_every_n_epochs: How often to log images (every N epochs)
        max_images_to_log: Maximum number of samples to visualize per batch
    """
    
    if wandb_tags is None:
        wandb_tags = ["landsat", "lst-prediction", "earthformer", "with-viz"]
    
    # Create directories
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    
    # Verify dataset
    dataset_path = os.path.join(dataset_root)
    cities_tiles = os.path.join(dataset_path, "Cities_Tiles")
    dem_tiles = os.path.join(dataset_path, "DEM_2014_Tiles")
    
    if not os.path.exists(cities_tiles):
        raise FileNotFoundError(f"Cities_Tiles directory not found at {cities_tiles}")
    if not os.path.exists(dem_tiles):
        raise FileNotFoundError(f"DEM_2014_Tiles directory not found at {dem_tiles}")
    
    print(f"✅ Found tiled dataset at {dataset_root}")
    
    # Initialize data module
    data_module = LandsatDataModule(
        dataset_root=dataset_root,
        batch_size=batch_size,
        num_workers=num_workers,
        sequence_length=3
    )
    
    # Test data module
    print("Testing data module setup...")
    try:
        data_module.setup("fit")
        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        
        print(f"✅ Training batches: {len(train_loader)}")
        print(f"✅ Validation batches: {len(val_loader)}")
        
        # Test one batch
        sample_batch = next(iter(train_loader))
        inputs, targets = sample_batch
        print(f"✅ Sample batch - Inputs: {inputs.shape}, Targets: {targets.shape}")
        
    except Exception as e:
        print(f"❌ Data module test failed: {e}")
        raise
    
    # Enhanced config for wandb
    config = {
        "batch_size": batch_size,
        "max_epochs": max_epochs,
        "learning_rate": learning_rate,
        "num_workers": num_workers,
        "precision": precision,
        "accumulate_grad_batches": accumulate_grad_batches,
        "val_check_interval": val_check_interval,
        "limit_train_batches": limit_train_batches,
        "limit_val_batches": limit_val_batches,
        "dataset_root": dataset_root,
        "model_type": "CuboidTransformer",
        "input_shape": [3, 128, 128, 9],
        "target_shape": [3, 128, 128, 1],
        "sequence_length": 3,
        "train_batches": len(train_loader),
        "val_batches": len(val_loader),
        "total_train_samples": len(train_loader) * batch_size,
        "total_val_samples": len(val_loader) * batch_size,
        "log_images_every_n_epochs": log_images_every_n_epochs,
        "max_images_to_log": max_images_to_log,
        "band_names": ['DEM', 'LST', 'Red', 'Green', 'Blue', 'NDVI', 'NDWI', 'NDBI', 'Albedo']
    }
    
    # Initialize Weights & Biases logger
    logger = WandbLogger(
        project=wandb_project,
        name=experiment_name,
        tags=wandb_tags,
        config=config,
        save_dir=log_dir,
        log_model=True,
    )
    
    # Initialize enhanced model with image logging
    model = LandsatLSTPredictor(
        learning_rate=learning_rate,
        weight_decay=1e-5,
        warmup_steps=1000,
        max_epochs=max_epochs,
        log_images_every_n_epochs=log_images_every_n_epochs,
        max_images_to_log=max_images_to_log
    )
    
    # Test model
    print("Testing model with sample data...")
    try:
        model.eval()
        with torch.no_grad():
            test_output = model(inputs)
            print(f"✅ Model test - Output shape: {test_output.shape}")
    except Exception as e:
        print(f"❌ Model test failed: {e}")
        raise
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename=f'{experiment_name}-{{epoch:02d}}-{{val_loss:.3f}}',
        save_top_k=3,
        monitor='val_loss',
        mode='min',
        save_last=True,
        verbose=True
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15,
        mode='min',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Trainer configuration
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if gpus > 0 else 'cpu',
        devices=gpus if gpus > 0 else None,
        accumulate_grad_batches=accumulate_grad_batches,
        val_check_interval=val_check_interval,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        log_every_n_steps=50,
        enable_progress_bar=True,
        enable_model_summary=True,
        deterministic=False,
        benchmark=True,
    )
    
    # Print training info
    print(f"\n{'='*70}")
    print(f"LANDSAT LST PREDICTION TRAINING - WITH IMAGE VISUALIZATION")
    print(f"{'='*70}")
    print(f"Dataset: {dataset_root}")
    print(f"Batch size: {batch_size}")
    print(f"Max epochs: {max_epochs}")
    print(f"Learning rate: {learning_rate}")
    print(f"Precision: {precision}")
    print(f"Devices: {gpus} GPU(s)" if gpus > 0 else "CPU")
    print(f"Experiment: {experiment_name}")
    print(f"Image logging: Every {log_images_every_n_epochs} epochs")
    print(f"Max images per log: {max_images_to_log}")
    print(f"Wandb project: {wandb_project}")
    print(f"{'='*70}\n")
    
    # Train the model
    try:
        print("🚀 Starting training with image logging...")
        trainer.fit(model, data_module)
        
        # Test the model
        print("\n🧪 Running final test...")
        try:
            test_results = trainer.test(model, data_module, ckpt_path='best')
            print(f"✅ Test completed: {test_results}")
        except Exception as e:
            print(f"⚠️ Test failed: {e}")
        
        print(f"\n🎉 Training completed successfully!")
        print(f"📁 Best model saved to: {checkpoint_callback.best_model_path}")
        print(f"🔗 View experiment with images at: {logger.experiment.url}")
        
        # Log final artifacts
        if checkpoint_callback.best_model_path:
            wandb.save(checkpoint_callback.best_model_path)
        
    except KeyboardInterrupt:
        print("\n⚠️ Training interrupted by user")
        print(f"📁 Last checkpoint: {checkpoint_callback.last_model_path}")
        
    except Exception as e:
        print(f"\n❌ Training failed: {e}")
        import traceback
        traceback.print_exc()
        
        if 'logger' in locals():
            wandb.log({"error": str(e)})
        raise
    
    finally:
        if 'logger' in locals():
            wandb.finish()
    
    return trainer, model, data_module


def debug_with_images(dataset_root: str = "./Data/Dataset"):
    """Debug run with image logging"""
    print("🔧 Running debug training with image visualization...")
    
    return train_landsat_model_with_images(
        dataset_root=dataset_root,
        batch_size=2,
        max_epochs=3,
        learning_rate=1e-3,
        num_workers=0,
        gpus=1,
        precision="32",
        limit_train_batches=0.1,
        limit_val_batches=0.1,
        experiment_name="debug_with_images",
        val_check_interval=0.5,
        wandb_project="landsat-debug-viz",
        wandb_tags=["debug", "visualization", "landsat"],
        log_images_every_n_epochs=1,  # Log every epoch for debugging
        max_images_to_log=2
    )


def full_training_with_viz(dataset_root: str = "./Data/Dataset"):
    """Full training with image visualization"""
    print("🚀 Starting full training with visualization...")
    
    return train_landsat_model_with_images(
        dataset_root=dataset_root,
        batch_size=8,
        max_epochs=50,
        learning_rate=2e-4,
        num_workers=4,
        gpus=1,
        precision="16-mixed",
        experiment_name="landsat_full_training_with_viz",
        val_check_interval=1.0,
        wandb_project="landsat-lst-forecasting",
        wandb_tags=["full-training", "visualization", "earthformer", "gpu"],
        log_images_every_n_epochs=5,  # Log images every 5 epochs
        max_images_to_log=4
    )


# Additional utility function for creating sample visualizations
def create_sample_visualization(dataset_root: str = "./Data/Dataset"):
    """Create and log a sample visualization without training"""
    from dataset import LandsatDataModule
    import matplotlib.pyplot as plt
    import wandb
    
    print("📊 Creating sample visualization...")
    
    # Initialize wandb for standalone visualization
    wandb.init(
        project="landsat-sample-viz",
        name="sample_tiles_preview",
        tags=["sample", "preview", "no-training"]
    )
    
    # Load sample data
    data_module = LandsatDataModule(
        dataset_root=dataset_root,
        batch_size=4,
        num_workers=0,
        sequence_length=3
    )
    
    data_module.setup("fit")
    train_loader = data_module.train_dataloader()
    
    # Get one batch
    sample_batch = next(iter(train_loader))
    inputs, targets = sample_batch
    
    print(f"Sample data shapes - Inputs: {inputs.shape}, Targets: {targets.shape}")
    
    # Create temporary model for visualization methods
    temp_model = LandsatLSTPredictor()
    
    # Create visualizations
    fig1 = temp_model.create_landsat_visualization(inputs, targets, None, 0, 4)
    fig2 = temp_model.create_temporal_sequence_viz(inputs, targets, None, 0)
    
    # Log to wandb
    wandb.log({
        "sample_landsat_tiles": wandb.Image(fig1, caption="Sample Landsat Tiles"),
        "sample_temporal_sequence": wandb.Image(fig2, caption="Sample Temporal Sequence")
    })
    
    plt.close(fig1)
    plt.close(fig2)
    
    print("✅ Sample visualization created and logged to WandB!")
    print(f"🔗 View at: {wandb.run.url}")
    
    wandb.finish()

import matplotlib.pyplot as plt

# Then add this function to your training script:
def train_with_visualization(
    dataset_root: str = "./Data/Dataset",
    batch_size: int = 4,
    max_epochs: int = 50,
    learning_rate: float = 2e-4,
    num_workers: int = 4,
    gpus: int = 1,
    experiment_name: str = "landsat_with_viz",
    log_images_every_n_epochs: int = 5,
    max_images_to_log: int = 4
):
    """Enhanced training with visualization"""
    
    # Your existing data module setup
    data_module = LandsatDataModule(
        dataset_root=dataset_root,
        batch_size=batch_size,
        num_workers=num_workers,
        sequence_length=3
    )
    
    # Enhanced WandB logger
    logger = WandbLogger(
        project="landsat-lst-forecasting",
        name=experiment_name,
        tags=["earthformer", "visualization", "landsat"],
        save_dir="./logs",
        log_model=True,
    )
    
    # Enhanced model with visualization
    model = LandsatLSTPredictor(
        learning_rate=learning_rate,
        weight_decay=1e-5,
        warmup_steps=1000,
        max_epochs=max_epochs,
        log_images_every_n_epochs=log_images_every_n_epochs,
        max_images_to_log=max_images_to_log
    )
    
    # Your existing callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath="./checkpoints",
        filename=f'{experiment_name}-{{epoch:02d}}-{{val_loss:.3f}}',
        save_top_k=3,
        monitor='val_loss',
        mode='min',
        save_last=True,
        verbose=True
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15,
        mode='min',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Your existing trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if gpus > 0 else 'cpu',
        devices=gpus if gpus > 0 else None,
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        log_every_n_steps=50,
        enable_progress_bar=True,
        enable_model_summary=True,
    )
    
    # Train as usual
    trainer.fit(model, data_module)
    
    return trainer, model, data_module

# Quick debug function
def debug_visualization():
    """Quick test with image logging"""
    return train_with_visualization(
        batch_size=2,
        max_epochs=3,
        experiment_name="debug_viz",
        log_images_every_n_epochs=1,  # Log every epoch for debugging
        max_images_to_log=2
    )

def test_visualization_only():
    """Just create sample visualizations to test"""
    from dataset import LandsatDataModule
    
    # Initialize wandb for testing
    wandb.init(project="landsat-viz-test", name="test_viz")
    
    # Load sample data
    data_module = LandsatDataModule(
        dataset_root="./Data/Dataset",
        batch_size=2,
        num_workers=0
    )
    data_module.setup("fit")
    
    # Get one batch
    train_loader = data_module.train_dataloader()
    inputs, targets = next(iter(train_loader))
    
    # Create model just for visualization
    model = LandsatLSTPredictor()
    
    # Create dummy predictions
    with torch.no_grad():
        predictions = model(inputs)
    
    # Test visualization
    fig = model.create_landsat_visualization(inputs, targets, predictions)
    wandb.log({"test_visualization": wandb.Image(fig)})
    
    plt.close(fig)
    wandb.finish()
    print("✅ Test visualization created and logged to WandB!")


if __name__ == "__main__":
    # test_visualization_only()
    # debug_visualization()

    # For debugging with image logging:
    debug_with_images()
    
    # For full training with visualization:
    # full_training_with_viz()
    
    # For just creating sample visualizations:
    # create_sample_visualization()

🔧 Running debug training with image visualization...
✅ Found tiled dataset at ./Data/Dataset
Testing data module setup...
train split: 86 cities, 1351 tile sequences
val split: 19 cities, 212 tile sequences
✅ Training batches: 676
✅ Validation batches: 106
✅ Sample batch - Inputs: torch.Size([2, 3, 128, 128, 9]), Targets: torch.Size([2, 3, 128, 128, 1])


  rank_zero_warn(


Model initialized with 9,036,109 parameters
Testing model with sample data...
✅ Model test - Output shape: torch.Size([2, 3, 128, 128, 1])


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



LANDSAT LST PREDICTION TRAINING - WITH IMAGE VISUALIZATION
Dataset: ./Data/Dataset
Batch size: 2
Max epochs: 3
Learning rate: 0.001
Precision: 32
Devices: 1 GPU(s)
Experiment: debug_with_images
Image logging: Every 1 epochs
Max images per log: 2
Wandb project: landsat-debug-viz

🚀 Starting training with image logging...
train split: 86 cities, 1351 tile sequences
val split: 19 cities, 212 tile sequences


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | model     | CuboidTransformerModel | 9.0 M 
1 | criterion | MSELoss                | 0     
-----------------------------------------------------
9.0 M     Trainable params
0         Non-trainable params
9.0 M     Total params
36.144    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 6454.605
Epoch 0, global step 33: 'val_loss' reached 6454.60498 (best 6454.60498), saving model to './checkpoints/debug_with_images-epoch=00-val_loss=6454.605.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 329.479 >= min_delta = 0.0. New best score: 6125.126
Epoch 0, global step 66: 'val_loss' reached 6125.12598 (best 6125.12598), saving model to './checkpoints/debug_with_images-epoch=00-val_loss=6125.126.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 380.510 >= min_delta = 0.0. New best score: 5744.616
Epoch 1, global step 100: 'val_loss' reached 5744.61572 (best 5744.61572), saving model to './checkpoints/debug_with_images-epoch=01-val_loss=5744.616.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 408.777 >= min_delta = 0.0. New best score: 5335.839
Epoch 1, global step 133: 'val_loss' reached 5335.83887 (best 5335.83887), saving model to './checkpoints/debug_with_images-epoch=01-val_loss=5335.839.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 444.283 >= min_delta = 0.0. New best score: 4891.556
Epoch 2, global step 167: 'val_loss' reached 4891.55566 (best 4891.55566), saving model to './checkpoints/debug_with_images-epoch=02-val_loss=4891.556.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 463.529 >= min_delta = 0.0. New best score: 4428.026
Epoch 2, global step 200: 'val_loss' reached 4428.02637 (best 4428.02637), saving model to './checkpoints/debug_with_images-epoch=02-val_loss=4428.026.ckpt' as top 3
`Trainer.fit` stopped: `max_epochs=3` reached.
Restoring states from the checkpoint path at ./checkpoints/debug_with_images-epoch=02-val_loss=4428.026.ckpt



🧪 Running final test...
test split: 19 cities, 0 tile sequences


  return torch.load(f, map_location=map_location)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./checkpoints/debug_with_images-epoch=02-val_loss=4428.026.ckpt


✅ Test completed: []

🎉 Training completed successfully!
📁 Best model saved to: ./checkpoints/debug_with_images-epoch=02-val_loss=4428.026.ckpt
🔗 View experiment with images at: https://wandb.ai/jesus-guerrero-ml/landsat-lst-forecasting/runs/p7gmy923


  rank_zero_warn(


0,1
epoch,▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅██████████▁▁▁▁▅▅████
lr-AdamW,▁▁▁███
train_loss_epoch,▇▄▁█▆▅
train_loss_step,▅▆█▆▅▆▄▅▅▇▆▆▄▅▄▅▅▅▃▄▅▄▅▂▅▂▄▄▃▂▃▃▁▂▁▂▁██▆
train_mae_epoch,▇▄▁█▆▅
train_mae_step,▆▆█▆▅▇▄▆▅▇▇▆▅▅▅▅▆▅▄▅▆▅▆▃▅▃▅▅▃▂▄▄▂▃▁▃▂██▇
train_temp_mae_scaled,▇▄▁█▆▅
train_temp_rmse_scaled,▇▄▁█▆▅
trainer/global_step,▁▁▁▂▂▃▃▃▃▃▄▄▅▅▅▆▆▆▆▆▇██████▁▁▁▁▁▁▁▁▂▂▂▂▂
val_correlation,███▁▆▃▅▆█

0,1
epoch,2.0
lr-AdamW,0.001
train_loss_epoch,9112.23047
train_loss_step,11278.71875
train_mae_epoch,93.66899
train_mae_step,106.0422
train_temp_mae_scaled,93.66899
train_temp_rmse_scaled,94.97568
trainer/global_step,200.0
val_correlation,0.27445
