In [None]:
import sys
import os
sys.path.append('../src')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import pytorch_lightning as L
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from data_utils import (
    EOBSDataLoader, 
    EOBSTemporalPredictionDataset, 
    EOBSMaskedModelingDataset,
    get_device
)
from models import (
    TemporalPredictionModel, 
    MaskedModelingModel, 
    TemporalPredictionLightningModule,
    MaskedModelingLightningModule
)

print("🌍 E-OBS Self-Supervised Learning with PyTorch Lightning")
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {L.__version__}")

# Detect best available device
device = get_device()
print(f"Selected device: {device}")


In [None]:
# Load E-OBS data
eobs_loader = EOBSDataLoader(data_dir="../src/data")
eobs_data = eobs_loader.load_all_data()

# Get precipitation data
if 'precipitation_mean' in eobs_data:
    precip_data = eobs_data['precipitation_mean']
    print("\n📊 Precipitation Data Info:")
    eobs_loader.get_data_info(precip_data)
else:
    print("❌ Precipitation data not found. Please check data files.")


In [None]:
# Create temporal prediction dataset
if 'precipitation_mean' in eobs_data:
    temporal_dataset = EOBSTemporalPredictionDataset(
        precipitation_data=precip_data,
        sequence_length=7,  # Use past 7 days
        prediction_horizon=1,  # Predict next 1 day
        variable_name='rr',
        spatial_crop_size=(64, 64),  # Crop to 64x64 patches
        normalize=True,
        log_transform=True  # Apply log(1+x) to precipitation
    )
    
    # Create data loaders
    train_size = int(0.8 * len(temporal_dataset))
    val_size = len(temporal_dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        temporal_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
    
    print(f"📈 Temporal Prediction Dataset created:")
    print(f"   - Training samples: {len(train_dataset)}")
    print(f"   - Validation samples: {len(val_dataset)}")
    print(f"   - Batch size: 16")
else:
    print("❌ Cannot create temporal dataset without precipitation data")


In [None]:
# Create temporal prediction model and Lightning module
if 'precipitation_mean' in eobs_data:
    # Create the base model
    temporal_model = TemporalPredictionModel(
        input_channels=1,
        hidden_channels=64,
        num_layers=3,
        sequence_length=7,
        prediction_horizon=1,
        spatial_size=(64, 64)
    )
    
    # Wrap in Lightning module
    temporal_lightning_model = TemporalPredictionLightningModule(
        model=temporal_model,
        learning_rate=1e-3,
        weight_decay=1e-4
    )
    
    print("\n✅ Temporal prediction Lightning module created!")
else:
    print("❌ Cannot create temporal model without precipitation data")


In [None]:
# Setup Lightning trainer for temporal prediction
if 'precipitation_mean' in eobs_data:
    # Create callbacks
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            verbose=True,
            mode='min'
        ),
        ModelCheckpoint(
            monitor='val_loss',
            dirpath='../src/runs/temporal_prediction_lightning/',
            filename='best-checkpoint',
            save_top_k=1,
            mode='min'
        ),
        LearningRateMonitor(logging_interval='epoch')
    ]
    
    # Create logger
    logger = TensorBoardLogger(
        save_dir='../src/runs/',
        name='temporal_prediction_lightning'
    )
    
    # Create trainer
    temporal_trainer = L.Trainer(
        max_epochs=20,  # Reduced for demo
        accelerator='auto',
        devices='auto',
        logger=logger,
        callbacks=callbacks,
        log_every_n_steps=10,
        val_check_interval=0.5,  # Validate twice per epoch
        precision='16-mixed' if device != 'cpu' else '32-true',  # Mixed precision for faster training
        gradient_clip_val=1.0  # Gradient clipping
    )
    
    print("\n🚀 Lightning trainer configured:")
    print(f"   - Max epochs: 20")
    print(f"   - Early stopping: patience=10")
    print(f"   - Precision: {'16-mixed' if device != 'cpu' else '32-true'}")
    print(f"   - Gradient clipping: 1.0")
else:
    print("❌ Cannot setup trainer without precipitation data")
