In [None]:
from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel
import torch

# Optimized config for Landsat 3-timestep forecasting
landsat_config = {
    'input_shape': (3, 128, 128, 9),    # 3 input timesteps, 128x128, 9 Landsat bands
    'target_shape': (3, 128, 128, 1),   # 3 output timesteps
    
    # Small model for prototyping
    'base_units': 96,                    # Small but efficient
    'num_heads': 6,                      # Divisible by base_units
    'enc_depth': [2, 2],                 # 2-level hierarchy (sufficient for short sequences)
    'dec_depth': [1, 1],                 # Matching decoder depth
    
    # Dropout for better generalization during prototyping
    'attn_drop': 0.1,
    'proj_drop': 0.1,
    'ffn_drop': 0.1,
    
    # Global vectors for capturing Landsat scene patterns
    'num_global_vectors': 8,
    'use_dec_self_global': True,
    'use_dec_cross_global': True,
    
    # Optimized for satellite imagery
    'pos_embed_type': 't+hw',            # Separate temporal and spatial embeddings
    'use_relative_pos': True,            # Good for satellite spatial patterns
    'ffn_activation': 'gelu',            # Works well for vision tasks
    
    # Cuboid settings optimized for short temporal sequences
    'enc_cuboid_size': [(2, 4, 4), (2, 4, 4)],     # Small temporal cuboids for 3 timesteps
    'enc_cuboid_strategy': [('l', 'l', 'l'), ('d', 'd', 'd')],
    
    # Cross-attention settings for decoder
    'dec_cross_cuboid_hw': [(4, 4), (4, 4)],
    'dec_cross_n_temporal': [1, 2],      # Use 1-2 temporal frames for cross-attention
}

# Create model
model = CuboidTransformerModel(**landsat_config)
print(f"✓ Landsat model created! Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test with dummy Landsat data
batch_size = 4  # You can use larger batches with 40GB VRAM
dummy_landsat = torch.randn(batch_size, 3, 128, 128, 9)
print(f"Input shape: {dummy_landsat.shape}")

# Forward pass test
with torch.no_grad():
    output = model(dummy_landsat)
    print(f"Output shape: {output.shape}")
    print("✓ Forward pass successful!")

# Memory usage estimate
def estimate_memory_usage(model, input_shape, batch_size=1):
    model.eval()
    dummy_input = torch.randn(batch_size, *input_shape)
    
    # Rough memory estimate
    param_memory = sum(p.numel() * 4 for p in model.parameters()) / 1e9  # GB
    input_memory = dummy_input.numel() * 4 / 1e9  # GB
    
    print(f"Estimated memory usage:")
    print(f"  Parameters: {param_memory:.2f} GB")
    print(f"  Input (batch={batch_size}): {input_memory:.2f} GB")
    print(f"  Activation estimate: ~{param_memory * 2:.2f} GB")
    print(f"  Total estimate: ~{param_memory * 3 + input_memory:.2f} GB")

estimate_memory_usage(model, (3, 128, 128, 9), batch_size=8)

In [1]:
import os
import torch
from model import LandsatLSTPredictor
from dataset import LandsatDataModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
import wandb

def train_landsat_model(
    dataset_root: str = "./Data/Dataset",
    batch_size: int = 4,
    max_epochs: int = 100,
    learning_rate: float = 1e-4,
    num_workers: int = 4,
    gpus: int = 1,
    precision: str = "32",  # Start with 32-bit precision for stability
    accumulate_grad_batches: int = 1,
    val_check_interval: float = 1.0,
    limit_train_batches: float = 1.0,
    limit_val_batches: float = 1.0,
    experiment_name: str = "landsat_lst_prediction",
    checkpoint_dir: str = "./checkpoints",
    log_dir: str = "./logs",
    wandb_project: str = "landsat-lst-forecasting",
    wandb_tags: list = None
):
    """
    Complete training pipeline for Landsat LST prediction
    
    Args:
        dataset_root: Path to preprocessed dataset with Cities_Tiles and DEM_2014_Tiles
        batch_size: Training batch size
        max_epochs: Maximum training epochs
        learning_rate: Initial learning rate
        num_workers: Number of data loading workers
        gpus: Number of GPUs to use
        precision: Training precision ('32', '16', or '16-mixed')
        accumulate_grad_batches: Gradient accumulation steps
        val_check_interval: Validation frequency
        limit_train_batches: Fraction of training data to use (for debugging)
        limit_val_batches: Fraction of validation data to use (for debugging)
        experiment_name: Name for logging
        checkpoint_dir: Directory to save checkpoints
        log_dir: Directory for logs
        wandb_project: Weights & Biases project name
        wandb_tags: List of tags for the experiment
    """
    
    # Set up default tags
    if wandb_tags is None:
        wandb_tags = ["landsat", "lstm-prediction", "earthformer"]
    
    # Create directories
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    
    # Verify tiled dataset exists
    dataset_path = os.path.join(dataset_root)
    cities_tiles = os.path.join(dataset_path, "Cities_Tiles")
    dem_tiles = os.path.join(dataset_path, "DEM_2014_Tiles")
    
    if not os.path.exists(cities_tiles):
        raise FileNotFoundError(f"Cities_Tiles directory not found at {cities_tiles}. Please run convert_to_tiles() first.")
    if not os.path.exists(dem_tiles):
        raise FileNotFoundError(f"DEM_2014_Tiles directory not found at {dem_tiles}. Please run convert_to_tiles() first.")
    
    print(f"✅ Found tiled dataset at {dataset_root}")
    
    # Initialize data module
    data_module = LandsatDataModule(
        dataset_root=dataset_root,
        batch_size=batch_size,
        num_workers=num_workers,
        sequence_length=3
    )
    
    # Test data module setup to catch issues early
    print("Testing data module setup...")
    try:
        data_module.setup("fit")
        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        
        print(f"✅ Training batches: {len(train_loader)}")
        print(f"✅ Validation batches: {len(val_loader)}")
        
        # Test one batch
        sample_batch = next(iter(train_loader))
        inputs, targets = sample_batch
        print(f"✅ Sample batch - Inputs: {inputs.shape}, Targets: {targets.shape}")
        
    except Exception as e:
        print(f"❌ Data module test failed: {e}")
        raise
    
    # Create comprehensive config for wandb
    config = {
        "batch_size": batch_size,
        "max_epochs": max_epochs,
        "learning_rate": learning_rate,
        "num_workers": num_workers,
        "precision": precision,
        "accumulate_grad_batches": accumulate_grad_batches,
        "val_check_interval": val_check_interval,
        "limit_train_batches": limit_train_batches,
        "limit_val_batches": limit_val_batches,
        "dataset_root": dataset_root,
        "model_type": "CuboidTransformer",
        "input_shape": [3, 128, 128, 9],
        "target_shape": [3, 128, 128, 1],
        "sequence_length": 3,
        "train_batches": len(train_loader),
        "val_batches": len(val_loader),
        "total_train_samples": len(train_loader) * batch_size,
        "total_val_samples": len(val_loader) * batch_size,
    }
    
    # Initialize Weights & Biases logger
    logger = WandbLogger(
        project=wandb_project,
        name=experiment_name,
        tags=wandb_tags,
        config=config,
        save_dir=log_dir,
        log_model=True,  # Log model checkpoints to wandb
    )
    
    # Initialize model
    model = LandsatLSTPredictor(
        learning_rate=learning_rate,
        weight_decay=1e-5,
        warmup_steps=1000,
        max_epochs=max_epochs
    )
    
    # Test model with sample data
    print("Testing model with sample data...")
    try:
        model.eval()
        with torch.no_grad():
            test_output = model(inputs)
            print(f"✅ Model test - Output shape: {test_output.shape}")
    except Exception as e:
        print(f"❌ Model test failed: {e}")
        raise
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename=f'{experiment_name}-{{epoch:02d}}-{{val_loss:.3f}}',
        save_top_k=3,
        monitor='val_loss',
        mode='min',
        save_last=True,
        verbose=True  # Added verbose for better feedback
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15,
        mode='min',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Trainer configuration
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if gpus > 0 else 'cpu',
        devices=gpus if gpus > 0 else None,
        # precision=precision,  # Re-enabled precision setting
        accumulate_grad_batches=accumulate_grad_batches,
        val_check_interval=val_check_interval,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        log_every_n_steps=50,
        enable_progress_bar=True,
        enable_model_summary=True,
        deterministic=False,  # Set to True for reproducibility, False for speed
        benchmark=True,  # Optimize for consistent input sizes
    )
    
    # Print comprehensive training info
    print(f"\n{'='*70}")
    print(f"LANDSAT LST PREDICTION TRAINING - TILED DATASET")
    print(f"{'='*70}")
    print(f"Dataset: {dataset_root}")
    print(f"  - Cities Tiles: {cities_tiles}")
    print(f"  - DEM Tiles: {dem_tiles}")
    print(f"Batch size: {batch_size}")
    print(f"Max epochs: {max_epochs}")
    print(f"Learning rate: {learning_rate}")
    print(f"Precision: {precision}")
    print(f"Devices: {gpus} GPU(s)" if gpus > 0 else "CPU")
    print(f"Num workers: {num_workers}")
    print(f"Experiment: {experiment_name}")
    print(f"Data limits: {limit_train_batches*100:.0f}% train, {limit_val_batches*100:.0f}% val")
    print(f"Checkpoints: {checkpoint_dir}")
    print(f"Logs: {log_dir}")
    print(f"Wandb project: {wandb_project}")
    print(f"Wandb tags: {wandb_tags}")
    print(f"{'='*70}\n")
    
    # Train the model
    try:
        print("🚀 Starting training...")
        trainer.fit(model, data_module)
        
        # Test the model if we have test data
        print("\n🧪 Running final test...")
        try:
            test_results = trainer.test(model, data_module, ckpt_path='best')
            print(f"✅ Test completed: {test_results}")
        except Exception as e:
            print(f"⚠️ Test failed (this is okay if no test data): {e}")
        
        print(f"\n🎉 Training completed successfully!")
        print(f"📁 Best model saved to: {checkpoint_callback.best_model_path}")
        print(f"🔗 View experiment at: {logger.experiment.url}")
        
        # Log final artifacts to wandb
        if checkpoint_callback.best_model_path:
            wandb.save(checkpoint_callback.best_model_path)
        
    except KeyboardInterrupt:
        print("\n⚠️ Training interrupted by user")
        print(f"📁 Last checkpoint saved to: {checkpoint_callback.last_model_path}")
        
    except Exception as e:
        print(f"\n❌ Training failed with error: {e}")
        import traceback
        traceback.print_exc()
        
        # Log the error to wandb
        if 'logger' in locals():
            wandb.log({"error": str(e)})
        
        raise
    
    finally:
        # Ensure wandb run is finished
        if 'logger' in locals():
            wandb.finish()
    
    return trainer, model, data_module


# Quick test/debug function with conservative settings
def debug_training(dataset_root: str = "./Data/Dataset"):
    """Quick debug run with small dataset fraction and conservative settings"""
    print("🔧 Running debug training with tiled dataset...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=2,
        max_epochs=3,
        learning_rate=1e-3,
        num_workers=0,  # Disable multiprocessing for debugging
        gpus=1,
        precision="32",  # Use 32-bit for stability
        limit_train_batches=0.1,  # Use only 10% of data
        limit_val_batches=0.1,
        experiment_name="debug_tiled_landsat",
        val_check_interval=0.5,  # Check validation twice per epoch
        wandb_project="landsat-debug",
        wandb_tags=["debug", "tiled", "quick-test"]
    )
    
    print("✅ Debug training completed!")


# Enhanced debug function with quarter dataset
def debug_with_enhanced_logging(dataset_root: str = "./Data/Dataset"):
    """Debug run using a quarter of the dataset with enhanced logging"""
    print("🔧 Running enhanced debug training with quarter dataset...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=4,
        max_epochs=5,
        learning_rate=1e-3,
        num_workers=2,  # Some multiprocessing for realistic testing
        gpus=1,
        precision="32",  # Use 32-bit for stability
        limit_train_batches=0.25,  # Use quarter of training data
        limit_val_batches=0.25,   # Use quarter of validation data
        experiment_name="enhanced_debug_quarter_dataset",
        val_check_interval=0.5,  # Check validation twice per epoch
        wandb_project="landsat-debug",
        wandb_tags=["enhanced-debug", "quarter-dataset", "realistic-test"]
    )
    
    print("✅ Enhanced debug training completed!")


# Even more minimal debug function
def minimal_debug_training(dataset_root: str = "./Data/Dataset"):
    """Minimal debug run with absolute minimum settings"""
    print("🔧 Running minimal debug training...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=1,  # Smallest possible batch
        max_epochs=1,  # Just one epoch
        learning_rate=1e-3,
        num_workers=0,  # No multiprocessing
        gpus=0,  # Use CPU to avoid GPU issues
        precision="32",  # Standard precision
        limit_train_batches=0.05,  # Use only 5% of data
        limit_val_batches=0.05,
        experiment_name="minimal_debug_tiled",
        val_check_interval=1.0,
        wandb_project="landsat-debug",
        wandb_tags=["minimal", "debug", "cpu"]
    )
    
    print("✅ Minimal debug training completed!")


# Full training configurations
def full_training_gpu(dataset_root: str = "./Data/Dataset"):
    """Full training with optimized GPU settings"""
    print("🚀 Starting full GPU training...")
    
    return train_landsat_model(
        dataset_root=dataset_root,
        batch_size=8,  # Larger batch for better GPU utilization
        max_epochs=50,
        learning_rate=2e-4,
        num_workers=4,
        gpus=1,
        precision="16-mixed",  # Mixed precision for speed
        experiment_name="landsat_full_training",
        val_check_interval=1.0,
        wandb_project="landsat-lst-forecasting",
        wandb_tags=["full-training", "production", "earthformer", "gpu"]
    )



if __name__ == "__main__":
    # For debugging with enhanced logging:
    debug_with_enhanced_logging()
    
    # For full training with enhanced logging:
    # full_training_with_enhanced_logging()
    
    # For dataset analysis only:
    # analyze_dataset_structure_only()
    
    # For exporting metadata to JSON:
    # export_dataset_metadata_to_json()

  from pandas.core.computation.check import NUMEXPR_INSTALLED


🔧 Running enhanced debug training with quarter dataset...
✅ Found tiled dataset at ./Data/Dataset
Testing data module setup...
train split: 86 cities, 1351 tile sequences
val split: 19 cities, 212 tile sequences
✅ Training batches: 338
✅ Validation batches: 53
✅ Sample batch - Inputs: torch.Size([4, 3, 128, 128, 9]), Targets: torch.Size([4, 3, 128, 128, 1])


[34m[1mwandb[0m: Currently logged in as: [33mjesus-guerrero[0m ([33mjesus-guerrero-ml[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Model initialized with 9,036,109 parameters
Testing model with sample data...
✅ Model test - Output shape: torch.Size([4, 3, 128, 128, 1])


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



LANDSAT LST PREDICTION TRAINING - TILED DATASET
Dataset: ./Data/Dataset
  - Cities Tiles: ./Data/Dataset/Cities_Tiles
  - DEM Tiles: ./Data/Dataset/DEM_2014_Tiles
Batch size: 4
Max epochs: 5
Learning rate: 0.001
Precision: 32
Devices: 1 GPU(s)
Num workers: 2
Experiment: enhanced_debug_quarter_dataset
Data limits: 25% train, 25% val
Checkpoints: ./checkpoints
Logs: ./logs
Wandb project: landsat-debug
Wandb tags: ['enhanced-debug', 'quarter-dataset', 'realistic-test']

🚀 Starting training...
train split: 86 cities, 1351 tile sequences
val split: 19 cities, 212 tile sequences


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | model     | CuboidTransformerModel | 9.0 M 
1 | criterion | MSELoss                | 0     
-----------------------------------------------------
9.0 M     Trainable params
0         Non-trainable params
9.0 M     Total params
36.144    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sample 0: LST range 66.0°F to 100.0°F
Sample 1: LST range 85.0°F to 119.0°F
Sample 2: LST range 75.0°F to 102.0°F
Sample 3: LST range 80.0°F to 98.0°F
=== LST DATA ANALYSIS ===
Timestep 0:
  Input LST  - Min: 9.0, Max: 19.0, Mean: 14.5
  Target LST - Min: 87.0, Max: 118.0, Mean: 100.6
  Pred LST   - Min: -0.9, Max: 1.2, Mean: -0.6
Timestep 1:
  Input LST  - Min: 57.0, Max: 77.0, Mean: 65.7
  Target LST - Min: 66.0, Max: 100.0, Mean: 80.2
  Pred LST   - Min: -1.0, Max: 1.1, Mean: -0.4
Timestep 2:
  Input LST  - Min: 70.0, Max: 100.0, Mean: 83.2
  Target LST - Min: 85.0, Max: 119.0, Mean: 98.9
  Pred LST   - Min: -1.1, Max: 0.9, Mean: -0.5
Global temp range: -0.8°F to 111.0°F


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  return fn(*args, **kwargs)


Sample 0: LST range 79.0°F to 128.0°F
Sample 1: LST range 104.0°F to 134.0°F
Sample 2: LST range 108.0°F to 125.0°F
Sample 3: LST range 104.0°F to 126.0°F
=== LST DATA ANALYSIS ===
Timestep 0:
  Input LST  - Min: 56.0, Max: 105.0, Mean: 83.7
  Target LST - Min: 84.0, Max: 129.0, Mean: 109.8
  Pred LST   - Min: -1.4, Max: 0.9, Mean: -0.3
Timestep 1:
  Input LST  - Min: 76.0, Max: 124.0, Mean: 103.3
  Target LST - Min: 79.0, Max: 128.0, Mean: 107.3
  Pred LST   - Min: -1.3, Max: 1.2, Mean: -0.2
Timestep 2:
  Input LST  - Min: 86.0, Max: 128.0, Mean: 110.1
  Target LST - Min: 75.0, Max: 125.0, Mean: 103.8
  Pred LST   - Min: -1.5, Max: 1.0, Mean: -0.4
Global temp range: -0.9°F to 125.0°F




Validation: 0it [00:00, ?it/s]

Sample 0: LST range 66.0°F to 100.0°F
Sample 1: LST range 85.0°F to 119.0°F
Sample 2: LST range 75.0°F to 102.0°F
Sample 3: LST range 80.0°F to 98.0°F
=== LST DATA ANALYSIS ===
Timestep 0:
  Input LST  - Min: 9.0, Max: 19.0, Mean: 14.5
  Target LST - Min: 87.0, Max: 118.0, Mean: 100.6
  Pred LST   - Min: 2.7, Max: 7.5, Mean: 7.4
Timestep 1:
  Input LST  - Min: 57.0, Max: 77.0, Mean: 65.7
  Target LST - Min: 66.0, Max: 100.0, Mean: 80.2
  Pred LST   - Min: 2.7, Max: 7.5, Mean: 7.4
Timestep 2:
  Input LST  - Min: 70.0, Max: 100.0, Mean: 83.2
  Target LST - Min: 85.0, Max: 119.0, Mean: 98.9
  Pred LST   - Min: 2.7, Max: 7.6, Mean: 7.4
Global temp range: 4.4°F to 111.0°F


Metric val_loss improved. New best score: 6438.942
Epoch 0, global step 42: 'val_loss' reached 6438.94238 (best 6438.94238), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=00-val_loss=6438.942.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Sample 0: LST range 66.0°F to 100.0°F
Sample 1: LST range 85.0°F to 119.0°F
Sample 2: LST range 75.0°F to 102.0°F
Sample 3: LST range 80.0°F to 98.0°F
=== LST DATA ANALYSIS ===
Timestep 0:
  Input LST  - Min: 9.0, Max: 19.0, Mean: 14.5
  Target LST - Min: 87.0, Max: 118.0, Mean: 100.6
  Pred LST   - Min: 6.1, Max: 11.8, Mean: 10.3
Timestep 1:
  Input LST  - Min: 57.0, Max: 77.0, Mean: 65.7
  Target LST - Min: 66.0, Max: 100.0, Mean: 80.2
  Pred LST   - Min: 6.1, Max: 11.8, Mean: 10.3
Timestep 2:
  Input LST  - Min: 70.0, Max: 100.0, Mean: 83.2
  Target LST - Min: 85.0, Max: 119.0, Mean: 98.9
  Pred LST   - Min: 6.1, Max: 11.8, Mean: 10.3
Global temp range: 8.6°F to 111.0°F


Metric val_loss improved by 444.007 >= min_delta = 0.0. New best score: 5994.935
Epoch 0, global step 84: 'val_loss' reached 5994.93506 (best 5994.93506), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=00-val_loss=5994.935.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 491.391 >= min_delta = 0.0. New best score: 5503.544
Epoch 1, global step 126: 'val_loss' reached 5503.54395 (best 5503.54395), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=01-val_loss=5503.544.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 541.646 >= min_delta = 0.0. New best score: 4961.897
Epoch 1, global step 168: 'val_loss' reached 4961.89746 (best 4961.89746), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=01-val_loss=4961.897.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 572.389 >= min_delta = 0.0. New best score: 4389.508
Epoch 2, global step 210: 'val_loss' reached 4389.50830 (best 4389.50830), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=02-val_loss=4389.508.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 586.537 >= min_delta = 0.0. New best score: 3802.971
Epoch 2, global step 252: 'val_loss' reached 3802.97119 (best 3802.97119), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=02-val_loss=3802.971.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 583.048 >= min_delta = 0.0. New best score: 3219.924
Epoch 3, global step 294: 'val_loss' reached 3219.92358 (best 3219.92358), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=03-val_loss=3219.924.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 555.848 >= min_delta = 0.0. New best score: 2664.076
Epoch 3, global step 336: 'val_loss' reached 2664.07593 (best 2664.07593), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=03-val_loss=2664.076.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 521.288 >= min_delta = 0.0. New best score: 2142.788
Epoch 4, global step 378: 'val_loss' reached 2142.78760 (best 2142.78760), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=04-val_loss=2142.788.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 463.485 >= min_delta = 0.0. New best score: 1679.302
Epoch 4, global step 420: 'val_loss' reached 1679.30225 (best 1679.30225), saving model to './checkpoints/enhanced_debug_quarter_dataset-epoch=04-val_loss=1679.302.ckpt' as top 3
`Trainer.fit` stopped: `max_epochs=5` reached.
Restoring states from the checkpoint path at ./checkpoints/enhanced_debug_quarter_dataset-epoch=04-val_loss=1679.302.ckpt



🧪 Running final test...
test split: 19 cities, 0 tile sequences


  return torch.load(f, map_location=map_location)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./checkpoints/enhanced_debug_quarter_dataset-epoch=04-val_loss=1679.302.ckpt


✅ Test completed: []

🎉 Training completed successfully!
📁 Best model saved to: ./checkpoints/enhanced_debug_quarter_dataset-epoch=04-val_loss=1679.302.ckpt
🔗 View experiment at: https://wandb.ai/jesus-guerrero-ml/landsat-debug/runs/7rg8pmlv


  rank_zero_warn(


0,1
epoch,▁▁▁▁▃▃▃▃▃▅▅▅▅▅▆▆▆▆█████
lr-AdamW,▁▁▁▁▁
train_loss_epoch,█▆▅▃▁
train_loss_step,█▆▄▇▇▄▂▁
train_mae_epoch,█▇▅▃▁
train_mae_step,█▇▄█▇▄▂▁
train_temp_mae_scaled,█▇▅▃▁
train_temp_rmse_scaled,█▇▅▃▁
trainer/global_step,▁▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▆▆▇▇▇▇▇███
val_correlation,▁▂▄▇▇█████

0,1
epoch,4.0
lr-AdamW,0.001
train_loss_epoch,4936.82715
train_loss_step,3678.20776
train_mae_epoch,67.90926
train_mae_step,57.67937
train_temp_mae_scaled,67.90926
train_temp_rmse_scaled,69.93489
trainer/global_step,419.0
val_correlation,0.2399


✅ Enhanced debug training completed!
