In [2]:
import os
import torch
from model import LandsatLSTPredictor
from dataset import LandsatDataModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
import wandb
from typing import List, Optional

def train_landsat_model(wandb_project: str, dataset_root: str, config: dict):
    if config["debug_monthly_split"]:
        wandb_tags = [
            "landsat", "lst-prediction", "earthformer", 
            "debug-monthly-split", f"model-{config.get('model_size', 'small')}"
        ]
    else:
        wandb_tags = [
            "landsat", "lst-prediction", "earthformer", 
            "year-based-split", f"model-{config.get('model_size', 'small')}"
        ]
    
    # Create directories
    os.makedirs("./checkpoints", exist_ok=True)
    os.makedirs("./logs", exist_ok=True)
    checkpoint_dir = "./checkpoints"
    log_dir = "./logs"
    
    print(f"✅ Found tiled dataset at {dataset_root}")
    
    # Initialize data module
    data_module = LandsatDataModule(
        dataset_root=dataset_root,
        cluster=config["cluster"],
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        input_sequence_length=config["input_sequence_length"],
        output_sequence_length=config["output_sequence_length"],
        train_years=config["train_years"],
        val_years=config["val_years"],
        test_years=config["test_years"],
        debug_monthly_split=config["debug_monthly_split"],
        debug_year=config["debug_year"],
        interpolated_scenes_file="./Data/ML/interpolated.txt",
        max_input_nodata_pct=config["max_input_nodata_pct"]
    )
    
    # Initialize Weights & Biases logger
    logger = WandbLogger(
        project=wandb_project,
        tags=wandb_tags,
        config=config,
        save_dir=log_dir,
        log_model=True,
    )
    
    # Initialize model with configurable size
    model = LandsatLSTPredictor(
        learning_rate=config["learning_rate"],
        weight_decay=1e-5,
        warmup_steps=1000,
        max_epochs=config["max_epochs"],
        input_sequence_length=config["input_sequence_length"],
        output_sequence_length=config["output_sequence_length"],
        model_size=config.get("model_size", "small")  # NEW: configurable model size
    )
    
    # Rest of the function remains the same...
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename=f'{logger.experiment.name}-{{epoch:02d}}-{{val_loss:.3f}}',
        save_top_k=3,
        monitor='val_loss',
        mode='min',
        save_last=True,
        verbose=True
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15 if not config["debug_monthly_split"] else 10,
        mode='min',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    trainer = pl.Trainer(
        max_epochs=config["max_epochs"],
        accelerator='gpu' if config["gpus"] > 0 else 'cpu',
        devices=config["gpus"] if config["gpus"] > 0 else None,
        precision=config["precision"],
        accumulate_grad_batches=1,
        val_check_interval=0.5,
        limit_train_batches=config["limit_train_batches"],
        limit_val_batches=config["limit_val_batches"],
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        log_every_n_steps=50,
        enable_progress_bar=True,
        enable_model_summary=True,
        deterministic=False,
        benchmark=True,
    )
    
    try:
        trainer.fit(model, data_module)
        
        print("\n🧪 Running final test...")
        try:
            test_results = trainer.test(model, data_module, ckpt_path='best')
            print(f"✅ Test completed: {test_results}")
        except Exception as e:
            print(f"⚠️ Test failed (this is okay if no test data): {e}")
        
        print(f"\n🎉 Training completed successfully!")
        print(f"📁 Best model saved to: {checkpoint_callback.best_model_path}")
        print(f"🔗 View experiment at: {logger.experiment.url}")
        
        if checkpoint_callback.best_model_path:
            wandb.save(checkpoint_callback.best_model_path)
        
    except KeyboardInterrupt:
        print("\n⚠️ Training interrupted by user")
        print(f"📁 Last checkpoint saved to: {checkpoint_callback.last_model_path}")
        
    except Exception as e:
        print(f"\n❌ Training failed with error: {e}")
        import traceback
        traceback.print_exc()
        
        if 'logger' in locals():
            wandb.log({"error": str(e)})
        
        raise
    
    finally:
        if 'logger' in locals():
            wandb.finish()
    
    return trainer, model, data_module

"""    
    Hyperparameters:
        dataset_root: Path to preprocessed dataset with Cities_Tiles and DEM_2014_Tiles
        batch_size: Training batch size
        max_epochs: Maximum training epochs
        learning_rate: Initial learning rate
        num_workers: Number of data loading workers
        gpus: Number of GPUs to use
        precision: Training precision ('32', '16', or 'mixed')
        limit_train_batches: Fraction of training data to use (for debugging)
        limit_val_batches: Fraction of validation data to use (for debugging)
        experiment_name: Name for logging
        checkpoint_dir: Directory to save checkpoints
        train_years: Years to use for training (if None, uses default 70/15/15 split)
        val_years: Years to use for validation
        test_years: Years to use for testing
        use_custom_years: Whether to use custom year splits in experiment name
        debug_monthly_split: If True, use monthly splits within debug_year for fast debugging
        debug_year: Year to use for debug monthly splits (default: 2014)
    """

if __name__ == "__main__":    
    wandb_project = "AAAI-Project-final-tests"
    dataset_root = "./Data/ML"  
    hyperparameters = {
        "learning_rate": 0.001,
        "gpus": 1,
        "precision": 16,
        "debug_monthly_split": True,
        "debug_year": 2014,
        "batch_size": 4, # Get batch size
        "max_epochs": 3,
        "num_workers": 8,
        "input_sequence_length": 3,
        "output_sequence_length": 1,
        "model_size": "medium",  # "tiny", "small", "medium", "large"
        "train_years": [2013,2014,2015,2016,2017,2018,2019,2020,2021],
        "val_years": [2022,2023],
        "test_years": [2024,2025],
        "use_custom_years": True,
        "limit_train_batches": 0.01,
        "limit_val_batches": 0.01,
        "limit_test_batches": 0.01,
        "max_input_nodata_pct": 0.95,
        "cluster": "all" #1,2,3,4, all
    }
    
    train_landsat_model(wandb_project, dataset_root, hyperparameters)
# Run large 5 epochs full
# Run tiny 5 epochs full
# run medium 5 epochs full
# Run sweep for a month
# run sweep for years

#You updated model.py after with the gradient explosion stuff

✅ Found tiled dataset at ./Data/ML


Model 'medium' initialized with 14,978,705 parameters


Using 16bit native Automatic Mixed Precision (AMP)
  scaler = torch.cuda.amp.GradScaler()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loaded 174 interpolated scenes to exclude from ground truth
Examples of interpolated scenes:
  Arlington_TX/2013-12-15T12:00:00Z
  Arlington_TX/2015-02-15T12:00:00Z
  Arlington_TX/2015-03-15T12:00:00Z
  Arlington_TX/2015-11-15T12:00:00Z
  Arlington_TX/2016-04-15T12:00:00Z
  ... and 169 more
Debug monthly splits for year 2014:
  Train months: [1, 2, 3, 4, 5, 6, 7, 8] (Jan-Aug)
  Val months: [6, 7, 8, 9, 10] (Jun-Oct)
  Test months: [8, 9, 10, 11, 12] (Aug-Dec)
  Current split (train): [1, 2, 3, 4, 5, 6, 7, 8]

🔄 Building tile sequences for train split using 124 cores...
   Excluding 174 interpolated scenes from ground truth...


Processing cities (train): 100%|██████████| 124/124 [00:43<00:00,  2.86city/s]


Sequences by month for train split (year 2014):
  01 (Jan): 3062 sequences
  02 (Feb): 3862 sequences
  03 (Mar): 2790 sequences
  04 (Apr): 2813 sequences
  05 (May): 3124 sequences
  06 (Jun): 0 sequences
  07 (Jul): 0 sequences
  08 (Aug): 0 sequences

=== INTERPOLATED SCENE FILTERING STATS ===
Interpolated scenes loaded: 174
Valid sequences after filtering: 15651
Interpolated scenes affect years in this split: [2014]
DEBUG train split: 124 cities, year 2014, months [1, 2, 3, 4, 5, 6, 7, 8], 15651 tile sequences
Loaded 174 interpolated scenes to exclude from ground truth
Examples of interpolated scenes:
  Arlington_TX/2013-12-15T12:00:00Z
  Arlington_TX/2015-02-15T12:00:00Z
  Arlington_TX/2015-03-15T12:00:00Z
  Arlington_TX/2015-11-15T12:00:00Z
  Arlington_TX/2016-04-15T12:00:00Z
  ... and 169 more
Debug monthly splits for year 2014:
  Train months: [1, 2, 3, 4, 5, 6, 7, 8] (Jan-Aug)
  Val months: [6, 7, 8, 9, 10] (Jun-Oct)
  Test months: [8, 9, 10, 11, 12] (Aug-Dec)
  Current split

Processing cities (val): 100%|██████████| 124/124 [00:26<00:00,  4.71city/s]


Sequences by month for val split (year 2014):
  06 (Jun): 3106 sequences
  07 (Jul): 3114 sequences
  08 (Aug): 0 sequences
  09 (Sep): 0 sequences
  10 (Oct): 0 sequences

=== INTERPOLATED SCENE FILTERING STATS ===
Interpolated scenes loaded: 174
Valid sequences after filtering: 6220
Interpolated scenes affect years in this split: [2014]
DEBUG val split: 124 cities, year 2014, months [6, 7, 8, 9, 10], 6220 tile sequences


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(

  | Name      | Type                   | Params
-----------------------------------------------------
0 | model     | CuboidTransformerModel | 15.0 M
1 | criterion | MSELoss                | 0     
-----------------------------------------------------
15.0 M    Trainable params
0         Non-trainable params
15.0 M    Total params
29.957    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

🖼️ Attempting to log images at epoch 0
✅ Successfully logged validation image at epoch 0


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  return fn(*args, **kwargs)


Validation: 0it [00:00, ?it/s]

🖼️ Attempting to log images at epoch 0
✅ Successfully logged validation image at epoch 0


Metric val_loss improved. New best score: 0.165
Epoch 0, global step 19: 'val_loss' reached 0.16480 (best 0.16480), saving model to './checkpoints/spring-dust-4-epoch=00-val_loss=0.165.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

🖼️ Attempting to log images at epoch 0
✅ Successfully logged validation image at epoch 0


Metric val_loss improved by 0.062 >= min_delta = 0.0. New best score: 0.102
Epoch 0, global step 38: 'val_loss' reached 0.10231 (best 0.10231), saving model to './checkpoints/spring-dust-4-epoch=00-val_loss=0.102.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

🖼️ Attempting to log images at epoch 1
✅ Successfully logged validation image at epoch 1


Metric val_loss improved by 0.036 >= min_delta = 0.0. New best score: 0.067
Epoch 1, global step 58: 'val_loss' reached 0.06681 (best 0.06681), saving model to './checkpoints/spring-dust-4-epoch=01-val_loss=0.067.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

🖼️ Attempting to log images at epoch 1
✅ Successfully logged validation image at epoch 1


Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.062
Epoch 1, global step 77: 'val_loss' reached 0.06196 (best 0.06196), saving model to './checkpoints/spring-dust-4-epoch=01-val_loss=0.062.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

🖼️ Attempting to log images at epoch 2
✅ Successfully logged validation image at epoch 2


Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 0.045
Epoch 2, global step 97: 'val_loss' reached 0.04524 (best 0.04524), saving model to './checkpoints/spring-dust-4-epoch=02-val_loss=0.045.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

🖼️ Attempting to log images at epoch 2
✅ Successfully logged validation image at epoch 2


Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.039
Epoch 2, global step 116: 'val_loss' reached 0.03937 (best 0.03937), saving model to './checkpoints/spring-dust-4-epoch=02-val_loss=0.039.ckpt' as top 3
`Trainer.fit` stopped: `max_epochs=3` reached.



🧪 Running final test...
Loaded 174 interpolated scenes to exclude from ground truth
Examples of interpolated scenes:
  Arlington_TX/2013-12-15T12:00:00Z
  Arlington_TX/2015-02-15T12:00:00Z
  Arlington_TX/2015-03-15T12:00:00Z
  Arlington_TX/2015-11-15T12:00:00Z
  Arlington_TX/2016-04-15T12:00:00Z
  ... and 169 more
Debug monthly splits for year 2014:
  Train months: [1, 2, 3, 4, 5, 6, 7, 8] (Jan-Aug)
  Val months: [6, 7, 8, 9, 10] (Jun-Oct)
  Test months: [8, 9, 10, 11, 12] (Aug-Dec)
  Current split (test): [8, 9, 10, 11, 12]

🔄 Building tile sequences for test split using 124 cores...
   Excluding 174 interpolated scenes from ground truth...


Processing cities (test): 100%|██████████| 124/124 [00:20<00:00,  5.94city/s]
Restoring states from the checkpoint path at ./checkpoints/spring-dust-4-epoch=02-val_loss=0.039.ckpt


Sequences by month for test split (year 2014):
  08 (Aug): 2334 sequences
  09 (Sep): 2119 sequences
  10 (Oct): 0 sequences
  11 (Nov): 0 sequences
  12 (Dec): 0 sequences

=== INTERPOLATED SCENE FILTERING STATS ===
Interpolated scenes loaded: 174
Valid sequences after filtering: 4453
Interpolated scenes affect years in this split: [2014]
DEBUG test split: 124 cities, year 2014, months [8, 9, 10, 11, 12], 4453 tile sequences


  return torch.load(f, map_location=map_location)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./checkpoints/spring-dust-4-epoch=02-val_loss=0.039.ckpt


Testing: 0it [00:00, ?it/s]

⚠️ Test failed (this is okay if no test data): 'LandsatLSTPredictor' object has no attribute 'log_images_to_wandb'

🎉 Training completed successfully!
📁 Best model saved to: ./checkpoints/spring-dust-4-epoch=02-val_loss=0.039.ckpt
🔗 View experiment at: https://wandb.ai/jesus-guerrero-ml/AAAI-Project-final-tests/runs/k49syb2z


0,1
epoch,▁▁▁▅▅▅███
lr-AdamW,▁▁▁
train_loss,█▁▁
train_mae,█▁▁
train_mae_F,█▁▁
train_rmse_F,█▁▁
trainer/global_step,▁▂▃▃▃▄▆▆▆▇██
val_loss,█▅▃▂▁▁
val_mae,█▄▂▂▁▁
val_mae_F,█▄▂▂▁▁

0,1
epoch,2.0
lr-AdamW,0.001
train_loss,0.05925
train_mae,0.04769
train_mae_F,19.07764
train_rmse_F,23.69869
trainer/global_step,116.0
val_loss,0.03937
val_mae,0.0346
val_mae_F,13.8379


In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"AMP available: {hasattr(torch.cuda, 'amp')}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
import pytorch_lightning as pl
print(f"PyTorch Lightning version: {pl.__version__}")