In [None]:
from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel
import torch

# Optimized config for Landsat 3-timestep forecasting
landsat_config = {
    'input_shape': (3, 128, 128, 9),    # 3 input timesteps, 128x128, 9 Landsat bands
    'target_shape': (3, 128, 128, 1),   # 3 output timesteps
    
    # Small model for prototyping
    'base_units': 96,                    # Small but efficient
    'num_heads': 6,                      # Divisible by base_units
    'enc_depth': [2, 2],                 # 2-level hierarchy (sufficient for short sequences)
    'dec_depth': [1, 1],                 # Matching decoder depth
    
    # Dropout for better generalization during prototyping
    'attn_drop': 0.1,
    'proj_drop': 0.1,
    'ffn_drop': 0.1,
    
    # Global vectors for capturing Landsat scene patterns
    'num_global_vectors': 8,
    'use_dec_self_global': True,
    'use_dec_cross_global': True,
    
    # Optimized for satellite imagery
    'pos_embed_type': 't+hw',            # Separate temporal and spatial embeddings
    'use_relative_pos': True,            # Good for satellite spatial patterns
    'ffn_activation': 'gelu',            # Works well for vision tasks
    
    # Cuboid settings optimized for short temporal sequences
    'enc_cuboid_size': [(2, 4, 4), (2, 4, 4)],     # Small temporal cuboids for 3 timesteps
    'enc_cuboid_strategy': [('l', 'l', 'l'), ('d', 'd', 'd')],
    
    # Cross-attention settings for decoder
    'dec_cross_cuboid_hw': [(4, 4), (4, 4)],
    'dec_cross_n_temporal': [1, 2],      # Use 1-2 temporal frames for cross-attention
}

# Create model
model = CuboidTransformerModel(**landsat_config)
print(f"✓ Landsat model created! Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test with dummy Landsat data
batch_size = 4  # You can use larger batches with 40GB VRAM
dummy_landsat = torch.randn(batch_size, 3, 128, 128, 9)
print(f"Input shape: {dummy_landsat.shape}")

# Forward pass test
with torch.no_grad():
    output = model(dummy_landsat)
    print(f"Output shape: {output.shape}")
    print("✓ Forward pass successful!")

# Memory usage estimate
def estimate_memory_usage(model, input_shape, batch_size=1):
    model.eval()
    dummy_input = torch.randn(batch_size, *input_shape)
    
    # Rough memory estimate
    param_memory = sum(p.numel() * 4 for p in model.parameters()) / 1e9  # GB
    input_memory = dummy_input.numel() * 4 / 1e9  # GB
    
    print(f"Estimated memory usage:")
    print(f"  Parameters: {param_memory:.2f} GB")
    print(f"  Input (batch={batch_size}): {input_memory:.2f} GB")
    print(f"  Activation estimate: ~{param_memory * 2:.2f} GB")
    print(f"  Total estimate: ~{param_memory * 3 + input_memory:.2f} GB")

estimate_memory_usage(model, (3, 128, 128, 9), batch_size=8)

In [None]:
import os
import torch
from model import LandsatLSTPredictor
from dataset import LandsatDataModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
import wandb
from typing import List, Optional

def train_landsat_model(
    dataset_root: str = "./Data/Dataset",
    batch_size: int = 4,
    max_epochs: int = 100,
    learning_rate: float = 1e-4,
    num_workers: int = 4,
    gpus: int = 1,
    precision: int = 16,  
    accumulate_grad_batches: int = 1,
    val_check_interval: float = 1.0,
    limit_train_batches: float = 1.0,
    limit_val_batches: float = 1.0,
    experiment_name: str = "landsat_lst_prediction",
    checkpoint_dir: str = "./checkpoints",
    log_dir: str = "./logs",
    wandb_project: str = "landsat-lst-forecasting",
    wandb_tags: list = None,
    # Year-based split parameters
    train_years: Optional[List[int]] = None,
    val_years: Optional[List[int]] = None,
    test_years: Optional[List[int]] = None,
    use_custom_years: bool = False,
    # New debug monthly split parameters
    debug_monthly_split: bool = False,
    debug_year: int = 2014,
    input_sequence_length: int = 3,
    output_sequence_length: int = 1
):
    """
    Complete training pipeline for Landsat LST prediction with year-based or debug monthly splits
    
    Args:
        dataset_root: Path to preprocessed dataset with Cities_Tiles and DEM_2014_Tiles
        batch_size: Training batch size
        max_epochs: Maximum training epochs
        learning_rate: Initial learning rate
        num_workers: Number of data loading workers
        gpus: Number of GPUs to use
        precision: Training precision ('32', '16', or 'mixed')
        accumulate_grad_batches: Gradient accumulation steps
        val_check_interval: Validation frequency
        limit_train_batches: Fraction of training data to use (for debugging)
        limit_val_batches: Fraction of validation data to use (for debugging)
        experiment_name: Name for logging
        checkpoint_dir: Directory to save checkpoints
        log_dir: Directory for logs
        wandb_project: Weights & Biases project name
        wandb_tags: List of tags for the experiment
        train_years: Years to use for training (if None, uses default 70/15/15 split)
        val_years: Years to use for validation
        test_years: Years to use for testing
        use_custom_years: Whether to use custom year splits in experiment name
        debug_monthly_split: If True, use monthly splits within debug_year for fast debugging
        debug_year: Year to use for debug monthly splits (default: 2014)
    """
    
    # Set up default tags
    if wandb_tags is None:
        if debug_monthly_split:
            wandb_tags = ["landsat", "lst-prediction", "earthformer", "debug-monthly-split"]
        else:
            wandb_tags = ["landsat", "lst-prediction", "earthformer", "year-based-split"]
    
    # Create directories
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    
    # Verify tiled dataset exists
    dataset_path = os.path.join(dataset_root)
    cities_tiles = os.path.join(dataset_path, "Cities_Tiles")
    dem_tiles = os.path.join(dataset_path, "DEM_2014_Tiles")
    
    if not os.path.exists(cities_tiles):
        raise FileNotFoundError(f"Cities_Tiles directory not found at {cities_tiles}. Please run convert_to_tiles() first.")
    if not os.path.exists(dem_tiles):
        raise FileNotFoundError(f"DEM_2014_Tiles directory not found at {dem_tiles}. Please run convert_to_tiles() first.")
    
    print(f"✅ Found tiled dataset at {dataset_root}")
    
    # Initialize data module with year-based or debug monthly splits
    data_module = LandsatDataModule(
        dataset_root=dataset_root,
        batch_size=batch_size,
        num_workers=num_workers,
        input_sequence_length=input_sequence_length,    # Changed from sequence_length
        output_sequence_length=output_sequence_length,  # New parameter
        train_years=train_years,
        val_years=val_years,
        test_years=test_years,
        debug_monthly_split=debug_monthly_split,
        debug_year=debug_year
    )
    
    # Test data module setup to catch issues early
    if debug_monthly_split:
        print(f"Testing data module setup with debug monthly splits (year {debug_year})...")
    else:
        print("Testing data module setup with year-based splits...")
        
    try:
        data_module.setup("fit")
        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        
        print(f"✅ Training batches: {len(train_loader)}")
        print(f"✅ Validation batches: {len(val_loader)}")
        
        # Print split information
        if debug_monthly_split:
            print(f"✅ Debug year: {debug_year}")
            print(f"✅ Training months: {sorted(data_module.train_dataset.allowed_months)}")
            print(f"✅ Validation months: {sorted(data_module.val_dataset.allowed_months)}")
            print(f"✅ Test months: {sorted(data_module.test_dataset.allowed_months) if hasattr(data_module, 'test_dataset') else 'Not loaded'}")
        else:
            print(f"✅ Training years: {sorted(data_module.train_dataset.train_years)}")
            print(f"✅ Validation years: {sorted(data_module.train_dataset.val_years)}")
            print(f"✅ Test years: {sorted(data_module.train_dataset.test_years)}")
        
        # Test one batch
        if len(train_loader) > 0:
            sample_batch = next(iter(train_loader))
            inputs, targets = sample_batch
            print(f"✅ Sample batch - Inputs: {inputs.shape}, Targets: {targets.shape}")
            
            # Show sample sequence information
            sample_seq = data_module.train_dataset.tile_sequences[0]
            city, tile_row, tile_col, input_months, output_months = sample_seq
            print(f"✅ Sample sequence: {city} tile({tile_row:03d},{tile_col:03d})")
            print(f"   Input months: {input_months}")
            print(f"   Output months: {output_months}")
        else:
            print("⚠️ No training batches found!")
        
    except Exception as e:
        print(f"❌ Data module test failed: {e}")
        raise
    
    # Create comprehensive config for wandb including split information
    config = {
        "batch_size": batch_size,
        "max_epochs": max_epochs,
        "learning_rate": learning_rate,
        "num_workers": num_workers,
        "precision": precision,
        "accumulate_grad_batches": accumulate_grad_batches,
        "val_check_interval": val_check_interval,
        "limit_train_batches": limit_train_batches,
        "limit_val_batches": limit_val_batches,
        "dataset_root": dataset_root,
        "model_type": "CuboidTransformer",
        "input_shape": [input_sequence_length, 128, 128, 9],      # Updated
        "target_shape": [output_sequence_length, 128, 128, 1],   # Updated
        "input_sequence_length": input_sequence_length,           # New
        "output_sequence_length": output_sequence_length,         # New
        "train_batches": len(train_loader),
        "val_batches": len(val_loader),
        "total_train_samples": len(train_loader) * batch_size,
        "total_val_samples": len(val_loader) * batch_size,
        "debug_monthly_split": debug_monthly_split,
    }
    
    # Add split-specific configuration
    if debug_monthly_split:
        config.update({
            "split_type": "debug_monthly",
            "debug_year": debug_year,
            "train_months": sorted(data_module.train_dataset.allowed_months),
            "val_months": sorted(data_module.val_dataset.allowed_months),
            "test_months": sorted(data_module.test_dataset.allowed_months) if hasattr(data_module, 'test_dataset') else [],
            "temporal_coverage": f"{debug_year} (monthly splits)",
        })
        
        # Add debug info to experiment name
        experiment_name = f"{experiment_name}_debug_monthly_{debug_year}"
        
    else:
        config.update({
            "split_type": "year_based",
            "train_years": sorted(data_module.train_dataset.train_years),
            "val_years": sorted(data_module.train_dataset.val_years),
            "test_years": sorted(data_module.train_dataset.test_years),
            "train_year_range": f"{min(data_module.train_dataset.train_years)}-{max(data_module.train_dataset.train_years)}",
            "val_year_range": f"{min(data_module.train_dataset.val_years)}-{max(data_module.train_dataset.val_years)}",
            "test_year_range": f"{min(data_module.train_dataset.test_years)}-{max(data_module.train_dataset.test_years)}",
            "temporal_coverage": f"{min(data_module.train_dataset.train_years)}-{max(data_module.train_dataset.test_years)}",
        })
        
        # Add year split info to experiment name if using custom years
        if use_custom_years and train_years is not None:
            train_range = f"{min(train_years)}-{max(train_years)}"
            val_range = f"{min(val_years)}-{max(val_years)}" if val_years else "auto"
            experiment_name = f"{experiment_name}_train{train_range}_val{val_range}"
    
    # Initialize Weights & Biases logger
    logger = WandbLogger(
        project=wandb_project,
        name=experiment_name,
        tags=wandb_tags,
        config=config,
        save_dir=log_dir,
        log_model=True,  # Log model checkpoints to wandb
    )
    
    # Initialize model
    model = LandsatLSTPredictor(
        learning_rate=learning_rate,
        weight_decay=1e-5,
        warmup_steps=1000,
        max_epochs=max_epochs,
        input_sequence_length=input_sequence_length,    # New parameter
        output_sequence_length=output_sequence_length   # New parameter
    )
    
    # Test model with sample data
    print("Testing model with sample data...")
    try:
        model.eval()
        with torch.no_grad():
            test_output = model(inputs)
            print(f"✅ Model test - Output shape: {test_output.shape}")
    except Exception as e:
        print(f"❌ Model test failed: {e}")
        raise
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename=f'{experiment_name}-{{epoch:02d}}-{{val_loss:.3f}}',
        save_top_k=3,
        monitor='val_loss',
        mode='min',
        save_last=True,
        verbose=True
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15 if not debug_monthly_split else 10,  # Shorter patience for debug
        mode='min',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Trainer configuration
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if gpus > 0 else 'cpu',
        devices=gpus if gpus > 0 else None,
        precision=precision,
        accumulate_grad_batches=accumulate_grad_batches,
        val_check_interval=val_check_interval,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        log_every_n_steps=50,
        enable_progress_bar=True,
        enable_model_summary=True,
        deterministic=False,
        benchmark=True,
    )
    
    # Print comprehensive training info
    print(f"\n{'='*80}")
    if debug_monthly_split:
        print(f"LANDSAT LST PREDICTION TRAINING - DEBUG MONTHLY SPLITS")
    else:
        print(f"LANDSAT LST PREDICTION TRAINING - YEAR-BASED SPLITS")
    print(f"{'='*80}")
    print(f"Dataset: {dataset_root}")
    print(f"  - Cities Tiles: {cities_tiles}")
    print(f"  - DEM Tiles: {dem_tiles}")
    
    if debug_monthly_split:
        print(f"Debug Monthly Split Configuration (Year {debug_year}):")
        print(f"  - Training months: {sorted(data_module.train_dataset.allowed_months)} (Jan-Aug)")
        print(f"  - Validation months: {sorted(data_module.val_dataset.allowed_months)} (Jun-Oct)")  
        print(f"  - Test months: {sorted(data_module.test_dataset.allowed_months) if hasattr(data_module, 'test_dataset') else 'Not loaded'} (Aug-Dec)")
        print(f"  - Overlap explanation: Months overlap to ensure sequence continuity")
    else:
        print(f"Temporal Split Configuration:")
        print(f"  - Training years: {sorted(data_module.train_dataset.train_years)} ({len(data_module.train_dataset.train_years)} years)")
        print(f"  - Validation years: {sorted(data_module.train_dataset.val_years)} ({len(data_module.train_dataset.val_years)} years)")
        print(f"  - Test years: {sorted(data_module.train_dataset.test_years)} ({len(data_module.train_dataset.test_years)} years)")
    
    print(f"Training Configuration:")
    print(f"  - Batch size: {batch_size}")
    print(f"  - Max epochs: {max_epochs}")
    print(f"  - Learning rate: {learning_rate}")
    print(f"  - Precision: {precision}")
    print(f"  - Devices: {gpus} GPU(s)" if gpus > 0 else "  - Device: CPU")
    print(f"  - Num workers: {num_workers}")
    print(f"Dataset Statistics:")
    print(f"  - Training batches: {len(train_loader)} ({len(train_loader) * batch_size} samples)")
    print(f"  - Validation batches: {len(val_loader)} ({len(val_loader) * batch_size} samples)")
    print(f"  - Data limits: {limit_train_batches*100:.0f}% train, {limit_val_batches*100:.0f}% val")
    print(f"Logging:")
    print(f"  - Experiment: {experiment_name}")
    print(f"  - Checkpoints: {checkpoint_dir}")
    print(f"  - Logs: {log_dir}")
    print(f"  - Wandb project: {wandb_project}")
    print(f"  - Wandb tags: {wandb_tags}")
    print(f"{'='*80}\n")
    
    # Train the model
    try:
        if debug_monthly_split:
            print(f"🚀 Starting debug training with monthly splits (year {debug_year})...")
        else:
            print("🚀 Starting training with year-based temporal splits...")
            
        trainer.fit(model, data_module)
        
        # Test the model if we have test data
        print("\n🧪 Running final test...")
        try:
            test_results = trainer.test(model, data_module, ckpt_path='best')
            print(f"✅ Test completed: {test_results}")
        except Exception as e:
            print(f"⚠️ Test failed (this is okay if no test data): {e}")
        
        print(f"\n🎉 Training completed successfully!")
        print(f"📁 Best model saved to: {checkpoint_callback.best_model_path}")
        print(f"🔗 View experiment at: {logger.experiment.url}")
        
        # Log final artifacts to wandb
        if checkpoint_callback.best_model_path:
            wandb.save(checkpoint_callback.best_model_path)
        
    except KeyboardInterrupt:
        print("\n⚠️ Training interrupted by user")
        print(f"📁 Last checkpoint saved to: {checkpoint_callback.last_model_path}")
        
    except Exception as e:
        print(f"\n❌ Training failed with error: {e}")
        import traceback
        traceback.print_exc()
        
        # Log the error to wandb
        if 'logger' in locals():
            wandb.log({"error": str(e)})
        
        raise
    
    finally:
        # Ensure wandb run is finished
        if 'logger' in locals():
            wandb.finish()
    
    return trainer, model, data_module


# ================================================================================
# DEBUG TRAINING FUNCTIONS (using monthly splits)
# ================================================================================

def debug_monthly_training(dataset_root: str = "./Data/Dataset", debug_year: int = 2014):
    """Quick debug run with monthly splits for rapid prototyping"""
    print(f"🔧 Running debug training with monthly splits (year {debug_year})...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=16,
        max_epochs=5,  # Few epochs for quick testing
        learning_rate=1e-3,
        num_workers=0,  # Disable multiprocessing for debugging
        gpus=1,
        precision=16,
        limit_train_batches = 0.15,
        limit_val_batches = 0.15,
        experiment_name="debug_monthly_split",
        val_check_interval=0.5,
        wandb_project="landsat-debug",
        wandb_tags=["debug", "monthly-split", f"year-{debug_year}"],
        debug_monthly_split=True,
        debug_year=debug_year
    )
    
    print("✅ Debug monthly training completed!")
    return trainer, model, data_module


def debug_monthly_minimal(dataset_root: str = "./Data/Dataset"):
    """Minimal debug run for testing basic functionality"""
    print("🔧 Running minimal debug with monthly splits...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=1,
        max_epochs=2,
        learning_rate=1e-3,
        num_workers=0,
        gpus=0,  # Use CPU for maximum compatibility
        precision="32",
        limit_train_batches = 0.15,
        limit_val_batches = 0.15,
        experiment_name="minimal_monthly_debug",
        wandb_project="landsat-debug",
        wandb_tags=["minimal", "debug", "monthly", "cpu"],
        debug_monthly_split=True,
        debug_year=2014
    )
    
    print("✅ Minimal monthly debug completed!")
    return trainer, model, data_module


def debug_monthly_different_years(dataset_root: str = "./Data/Dataset"):
    """Test monthly splits with different years"""
    results = []
    
    for year in [2014, 2015, 2016]:
        print(f"\n🔧 Testing monthly splits for year {year}...")
        
        try:
            trainer, model, data_module = train_landsat_model(
                dataset_root=dataset_root,
                batch_size=2,
                max_epochs=3,
                learning_rate=1e-3,
                num_workers=0,
                gpus=1,
                precision=16,
                experiment_name=f"debug_monthly_{year}",
                wandb_project="landsat-debug",
                wandb_tags=["debug", "monthly", f"year-{year}", "comparison"],
                debug_monthly_split=True,
                debug_year=year
            )
            
            results.append((year, trainer, model, data_module))
            print(f"✅ Year {year} completed successfully!")
            
        except Exception as e:
            print(f"❌ Year {year} failed: {e}")
            results.append((year, None, None, None))
    
    print("\n📊 Multi-year debug results:")
    for year, trainer, model, data_module in results:
        if trainer is not None:
            print(f"  {year}: SUCCESS")
        else:
            print(f"  {year}: FAILED")
    
    return results


# ================================================================================
# REGULAR TRAINING FUNCTIONS (using year-based splits)
# ================================================================================

def debug_training(dataset_root: str = "./Data/Dataset"):
    """Quick debug run with year-based splits and conservative settings"""
    print("🔧 Running debug training with year-based splits...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=2,
        max_epochs=3,
        learning_rate=1e-3,
        num_workers=0,
        gpus=1,
        precision=16,
        limit_train_batches=0.1,
        limit_val_batches=0.1,
        experiment_name="debug_year_based_split",
        val_check_interval=0.5,
        wandb_project="landsat-debug",
        wandb_tags=["debug", "year-based", "quick-test"]
    )
    
    print("✅ Debug training completed!")
    return trainer, model, data_module


def debug_with_enhanced_logging(dataset_root: str = "./Data/Dataset"):
    """Debug run using a quarter of the dataset with enhanced logging"""
    print("🔧 Running enhanced debug training with year-based splits...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=4,
        max_epochs=5,
        learning_rate=1e-3,
        num_workers=2,
        gpus=1,
        precision=16,
        limit_train_batches=0.25,
        limit_val_batches=0.25,
        experiment_name="enhanced_debug_year_split",
        val_check_interval=0.5,
        wandb_project="landsat-debug",
        wandb_tags=["enhanced-debug", "year-based", "realistic-test"]
    )
    
    print("✅ Enhanced debug training completed!")
    return trainer, model, data_module


def full_training_gpu(dataset_root: str = "./Data/Dataset"):
    """Full training with optimized GPU settings and default year splits"""
    print("🚀 Starting full GPU training with year-based splits...")
    
    return train_landsat_model(
        dataset_root=dataset_root,
        batch_size=8,
        max_epochs=50,
        learning_rate=2e-4,
        num_workers=4,
        gpus=1,
        precision=16,
        experiment_name="landsat_full_training_year_split",
        val_check_interval=1.0,
        wandb_project="landsat-lst-forecasting",
        wandb_tags=["full-training", "production", "earthformer", "gpu", "year-based-split"]
    )


def custom_year_training(
    dataset_root: str = "./Data/Dataset",
    train_years: List[int] = None,
    val_years: List[int] = None,
    test_years: List[int] = None
):
    """Training with custom year splits"""
    
    # Default to research timeline if not specified
    if train_years is None:
        train_years = [2013, 2014, 2015, 2016, 2017]
    if val_years is None:
        val_years = [2022, 2023]
    if test_years is None:
        test_years = [2024, 2025]
    
    print(f"🚀 Starting training with custom year splits...")
    print(f"   Training: {train_years}")
    print(f"   Validation: {val_years}")
    print(f"   Test: {test_years}")
    
    return train_landsat_model(
        dataset_root=dataset_root,
        batch_size=6,
        max_epochs=40,
        learning_rate=1e-4,
        num_workers=4,
        gpus=1,
        precision=16,
        experiment_name="landsat_custom_year_split",
        val_check_interval=1.0,
        wandb_project="landsat-lst-forecasting",
        wandb_tags=["custom-years", "research-timeline", "earthformer"],
        train_years=train_years,
        val_years=val_years,
        test_years=test_years,
        use_custom_years=True
    )

def debug_monthly_training(dataset_root: str = "./Data/Dataset", debug_year: int = 2014):
    """Quick debug run with monthly splits for rapid prototyping"""
    print(f"🔧 Running debug training with monthly splits (year {debug_year})...")
    
    trainer, model, data_module = train_landsat_model(
        dataset_root=dataset_root,
        batch_size=16,
        max_epochs=5,  # Few epochs for quick testing
        learning_rate=1e-3,
        num_workers=0,  # Disable multiprocessing for debugging
        gpus=1,
        precision=16,
        experiment_name="debug_monthly_split",
        val_check_interval=0.5,
        wandb_project="landsat-debug",
        wandb_tags=["debug", "monthly-split", f"year-{debug_year}"],
        debug_monthly_split=True,
        debug_year=debug_year
    )
    
    print("✅ Debug monthly training completed!")
    return trainer, model, data_module

if __name__ == "__main__":    
    # debug_monthly_minimal()
    debug_monthly_training()
    # research_timeline_training()
    # full_training_gpu()

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"AMP available: {hasattr(torch.cuda, 'amp')}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
import pytorch_lightning as pl
print(f"PyTorch Lightning version: {pl.__version__}")