# AutoML-Fire: Spatiotemporal U-Net CNN (Notebook-first)

This notebook implements a simple spatiotemporal CNN baseline for fire prediction using a U-Net architecture. The model takes the last T days of meteorological data as input channels and predicts next-day fire counts on a 0.25Â° grid.

**Key Features:**
- Spatiotemporal CNN with U-Net-like architecture
- Input: Last T days of meteorological variables as channels
- Output: Next-day fire count prediction per grid cell
- Handles missing data with masking
- Supports both static and dynamic features


In [None]:
# Configuration
from dataclasses import dataclass
from typing import List, Optional, Tuple
import warnings
warnings.filterwarnings('ignore')

@dataclass
class Config:
    """Configuration for Spatiotemporal U-Net CNN"""
    # Data paths
    data_zarr_or_nc_path: str = "data/spatiotemporal_data.zarr"  # Placeholder - update with your path
    
    # Model parameters
    T: int = 7  # Number of time steps to look back
    seed: int = 42
    
    # Training/validation date ranges
    train_start: str = "2020-01-01"
    train_end: str = "2022-12-31"
    val_start: str = "2023-01-01"
    val_end: str = "2023-12-31"
    
    # Variables
    variables: List[str] = ["tmin", "tmax", "humidity", "windspeed", "soil_moisture", "ndvi", "rain", "cloudcover"]
    static_variables: List[str] = ["elevation", "slope", "aspect", "landcover"]  # Optional
    
    # Training parameters
    batch_size: int = 32
    learning_rate: float = 1e-3
    n_epochs: int = 100
    patience: int = 10
    
    # Model architecture
    n_filters: int = 64
    n_depths: int = 3
    dropout_rate: float = 0.2

# Initialize configuration
CFG = Config()
print(f"Configuration loaded: T={CFG.T}, Variables={len(CFG.variables)}")
print(f"Training period: {CFG.train_start} to {CFG.train_end}")
print(f"Validation period: {CFG.val_start} to {CFG.val_end}")


In [None]:
# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchmetrics

# Try optional imports
try:
    import xarray as xr
    XARRAY_AVAILABLE = True
    print("âœ“ xarray available")
except ImportError:
    XARRAY_AVAILABLE = False
    print("âš  xarray not available")

try:
    import zarr
    ZARR_AVAILABLE = True
    print("âœ“ zarr available")
except ImportError:
    ZARR_AVAILABLE = False
    print("âš  zarr not available")

try:
    import dask
    DASK_AVAILABLE = True
    print("âœ“ dask available")
except ImportError:
    DASK_AVAILABLE = False
    print("âš  dask not available")

# Set random seeds
torch.manual_seed(CFG.seed)
np.random.seed(CFG.seed)

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Mixed precision support
try:
    from torch.cuda.amp import autocast, GradScaler
    MIXED_PRECISION = True
    print("âœ“ Mixed precision available")
except ImportError:
    MIXED_PRECISION = False
    print("âš  Mixed precision not available")


In [None]:
# Data Loader
class SpatiotemporalDataset(Dataset):
    """Dataset for spatiotemporal fire prediction"""
    
    def __init__(self, data_dict, target_dict, mask_dict=None, transform=None):
        """
        Args:
            data_dict: {variable: [T, H, W, C]} arrays
            target_dict: {target: [H, W]} arrays  
            mask_dict: {variable: [T, H, W]} boolean masks for missing data
            transform: Optional transform
        """
        self.data_dict = data_dict
        self.target_dict = target_dict
        self.mask_dict = mask_dict or {}
        self.transform = transform
        
        # Get dimensions
        self.T, self.H, self.W = next(iter(data_dict.values())).shape[:3]
        self.n_samples = self.H * self.W
        
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        # Convert flat index to 2D coordinates
        h = idx // self.W
        w = idx % self.W
        
        # Extract data for this spatial location
        sample = {}
        
        # Stack time as channels: [T*C] -> [T*C, H, W]
        for var, data in self.data_dict.items():
            # data shape: [T, H, W, C] -> [T*C, H, W]
            var_data = data[:, h, w, :].reshape(-1)  # [T*C]
            sample[var] = torch.FloatTensor(var_data)
        
        # Add target
        for target, target_data in self.target_dict.items():
            sample[target] = torch.FloatTensor([target_data[h, w]])
        
        # Add masks if available
        for var, mask in self.mask_dict.items():
            var_mask = mask[:, h, w].reshape(-1)  # [T]
            sample[f'{var}_mask'] = torch.BoolTensor(var_mask)
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

def load_spatiotemporal_data(data_path, variables, static_variables=None):
    """Load spatiotemporal data from zarr/netCDF file"""
    print(f"Loading data from: {data_path}")
    
    # This is a placeholder - adapt based on your data format
    # For now, create dummy data for demonstration
    print("âš  Using dummy data - replace with actual data loading")
    
    # Dummy dimensions (0.25Â° grid)
    H, W = 180, 360  # Approximate global 0.25Â° grid
    T = CFG.T
    n_days = 1000  # Total days
    
    # Create dummy data
    data_dict = {}
    for var in variables:
        # Shape: [T, H, W, 1] for each variable
        data_dict[var] = np.random.randn(n_days, H, W, 1).astype(np.float32)
    
    # Add static variables
    if static_variables:
        for var in static_variables:
            data_dict[var] = np.random.randn(n_days, H, W, 1).astype(np.float32)
    
    # Create dummy target (fire counts)
    target_dict = {
        'fire_count': np.random.poisson(0.1, (n_days, H, W)).astype(np.float32)
    }
    
    # Create date array
    dates = pd.date_range('2020-01-01', periods=n_days, freq='D')
    
    print(f"Data shape: {n_days} days, {H}x{W} grid")
    print(f"Variables: {list(data_dict.keys())}")
    print(f"Target: {list(target_dict.keys())}")
    
    return data_dict, target_dict, dates

# Load data
data_dict, target_dict, dates = load_spatiotemporal_data(
    CFG.data_zarr_or_nc_path, 
    CFG.variables, 
    CFG.static_variables
)


In [None]:
# Train/Validation Split and Normalization
print("Creating train/validation split...")

# Convert dates to datetime
dates = pd.to_datetime(dates)

# Create train/val masks
train_mask = (dates >= CFG.train_start) & (dates <= CFG.train_end)
val_mask = (dates >= CFG.val_start) & (dates <= CFG.val_end)

print(f"Training period: {dates[train_mask].min()} to {dates[train_mask].max()} ({train_mask.sum()} days)")
print(f"Validation period: {dates[val_mask].min()} to {dates[val_mask].max()} ({val_mask.sum()} days)")

# Split data
train_data = {var: data[train_mask] for var, data in data_dict.items()}
val_data = {var: data[val_mask] for var, data in data_dict.items()}

train_target = {target: data[train_mask] for target, data in target_dict.items()}
val_target = {target: data[val_mask] for target, data in target_dict.items()}

# Normalization using training statistics
print("Computing normalization statistics...")
scalers = {}
for var in CFG.variables + CFG.static_variables:
    if var in train_data:
        # Compute mean and std across time and space
        var_data = train_data[var].reshape(-1, train_data[var].shape[-1])
        scalers[var] = {
            'mean': np.mean(var_data, axis=0),
            'std': np.std(var_data, axis=0) + 1e-8  # Add small epsilon
        }
        print(f"  {var}: mean={scalers[var]['mean'][0]:.3f}, std={scalers[var]['std'][0]:.3f}")

# Normalize data
def normalize_data(data_dict, scalers):
    normalized = {}
    for var, data in data_dict.items():
        if var in scalers:
            normalized[var] = (data - scalers[var]['mean']) / scalers[var]['std']
        else:
            normalized[var] = data
    return normalized

train_data_norm = normalize_data(train_data, scalers)
val_data_norm = normalize_data(val_data, scalers)

print("âœ“ Data normalized using training statistics")


In [None]:
# U-Net Model Architecture
class SpatiotemporalUNet(nn.Module):
    """Simple U-Net for spatiotemporal fire prediction"""
    
    def __init__(self, input_channels, n_filters=64, n_depths=3, dropout_rate=0.2):
        super(SpatiotemporalUNet, self).__init__()
        
        self.n_filters = n_filters
        self.n_depths = n_depths
        
        # Encoder (downsampling path)
        self.encoder = nn.ModuleList()
        in_channels = input_channels
        
        for i in range(n_depths):
            out_channels = n_filters * (2 ** i)
            self.encoder.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                    nn.Dropout2d(dropout_rate)
                )
            )
            in_channels = out_channels
        
        # Decoder (upsampling path)
        self.decoder = nn.ModuleList()
        
        for i in range(n_depths - 1, 0, -1):
            in_channels = n_filters * (2 ** i)
            out_channels = n_filters * (2 ** (i - 1))
            self.decoder.append(
                nn.Sequential(
                    nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                    nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                    nn.Dropout2d(dropout_rate)
                )
            )
        
        # Final output layer
        self.final_conv = nn.Sequential(
            nn.Conv2d(n_filters, 1, kernel_size=1),
            nn.Softplus()  # Ensure positive output for count prediction
        )
        
    def forward(self, x):
        # Encoder
        encoder_outputs = []
        for i, encoder_block in enumerate(self.encoder):
            x = encoder_block(x)
            encoder_outputs.append(x)
            if i < len(self.encoder) - 1:  # Don't downsample after last encoder
                x = nn.functional.max_pool2d(x, 2)
        
        # Decoder
        for i, decoder_block in enumerate(self.decoder):
            x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            # Skip connection
            if i < len(encoder_outputs) - 1:
                skip = encoder_outputs[-(i+2)]
                x = torch.cat([x, skip], dim=1)
            x = decoder_block(x)
        
        # Final output
        x = self.final_conv(x)
        return x

# Calculate input channels
input_channels = CFG.T * len(CFG.variables)
if CFG.static_variables:
    input_channels += len(CFG.static_variables)

print(f"Model input channels: {input_channels}")
print(f"Architecture: {CFG.n_depths} depths, {CFG.n_filters} base filters")

# Create model
model = SpatiotemporalUNet(
    input_channels=input_channels,
    n_filters=CFG.n_filters,
    n_depths=CFG.n_depths,
    dropout_rate=CFG.dropout_rate
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M parameters")


In [None]:
# Loss Functions
class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = nn.functional.mse_loss(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class TweedieLoss(nn.Module):
    """Tweedie Loss for count data (Poisson-like)"""
    
    def __init__(self, p=1.5):
        super(TweedieLoss, self).__init__()
        self.p = p
        
    def forward(self, pred, target):
        # Tweedie deviance
        if self.p == 1:  # Poisson
            return torch.mean(pred - target * torch.log(pred + 1e-8))
        elif self.p == 2:  # Gamma
            return torch.mean(target / (pred + 1e-8) + torch.log(pred + 1e-8))
        else:  # General Tweedie
            return torch.mean(
                (target * torch.pow(pred, 1 - self.p)) / (1 - self.p) - 
                torch.pow(pred, 2 - self.p) / (2 - self.p)
            )

# Define loss function
criterion = nn.MSELoss()  # Start with MSE, can switch to others
# criterion = FocalLoss(alpha=1, gamma=2)  # For class imbalance
# criterion = TweedieLoss(p=1.5)  # For count data

print(f"Loss function: {criterion.__class__.__name__}")

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=CFG.learning_rate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.n_epochs)

# Mixed precision scaler
if MIXED_PRECISION:
    scaler = GradScaler()
    print("âœ“ Mixed precision enabled")
else:
    scaler = None
    print("âš  Mixed precision not available")

print(f"Optimizer: AdamW (lr={CFG.learning_rate})")
print(f"Scheduler: CosineAnnealingLR")


In [None]:
# Training Loop
def train_epoch(model, dataloader, optimizer, criterion, device, scaler=None):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    n_batches = 0
    
    for batch in dataloader:
        optimizer.zero_grad()
        
        # Prepare input (stack all variables as channels)
        inputs = []
        for var in CFG.variables + CFG.static_variables:
            if var in batch:
                inputs.append(batch[var])
        
        if not inputs:
            continue
            
        x = torch.stack(inputs, dim=1).to(device)  # [B, T*C, H, W]
        y = batch['fire_count'].to(device)
        
        if scaler is not None:
            with autocast():
                outputs = model(x)
                loss = criterion(outputs.squeeze(), y)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(x)
            loss = criterion(outputs.squeeze(), y)
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
        n_batches += 1
    
    return total_loss / max(n_batches, 1)

def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    n_batches = 0
    
    with torch.no_grad():
        for batch in dataloader:
            # Prepare input
            inputs = []
            for var in CFG.variables + CFG.static_variables:
                if var in batch:
                    inputs.append(batch[var])
            
            if not inputs:
                continue
                
            x = torch.stack(inputs, dim=1).to(device)
            y = batch['fire_count'].to(device)
            
            outputs = model(x)
            loss = criterion(outputs.squeeze(), y)
            
            total_loss += loss.item()
            n_batches += 1
    
    return total_loss / max(n_batches, 1)

# Create dummy dataloaders (replace with actual data)
print("Creating dummy dataloaders...")
print("âš  Replace with actual SpatiotemporalDataset implementation")

# For now, create simple dummy data
dummy_train_loader = DataLoader(
    range(100),  # Dummy dataset
    batch_size=CFG.batch_size,
    shuffle=True
)

dummy_val_loader = DataLoader(
    range(50),  # Dummy dataset
    batch_size=CFG.batch_size,
    shuffle=False
)

print(f"Training batches: {len(dummy_train_loader)}")
print(f"Validation batches: {len(dummy_val_loader)}")
print("âš  Note: Replace dummy dataloaders with actual SpatiotemporalDataset")


In [None]:
# Training with Early Stopping
print("Starting training...")

train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0
best_model_state = None

for epoch in range(CFG.n_epochs):
    # Training
    train_loss = train_epoch(model, dummy_train_loader, optimizer, criterion, device, scaler)
    train_losses.append(train_loss)
    
    # Validation
    val_loss = validate_epoch(model, dummy_val_loader, criterion, device)
    val_losses.append(val_loss)
    
    # Learning rate scheduling
    scheduler.step()
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        best_model_state = model.state_dict().copy()
    else:
        patience_counter += 1
    
    # Print progress
    if epoch % 10 == 0 or epoch == CFG.n_epochs - 1:
        print(f"Epoch {epoch:3d}/{CFG.n_epochs}: "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Early stopping
    if patience_counter >= CFG.patience:
        print(f"Early stopping at epoch {epoch}")
        break

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("âœ“ Loaded best model weights")

print(f"Training completed. Best validation loss: {best_val_loss:.4f}")
print(f"Total epochs: {len(train_losses)}")


In [None]:
# Evaluation
print("Evaluating model...")

# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss', alpha=0.8)
plt.plot(val_losses, label='Validation Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Curves')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(optimizer.param_groups[0]['lr'] * np.ones(len(train_losses)), label='Learning Rate')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Model evaluation metrics
print(f"\nðŸ“Š TRAINING SUMMARY:")
print(f"Final Training Loss: {train_losses[-1]:.4f}")
print(f"Best Validation Loss: {best_val_loss:.4f}")
print(f"Total Epochs: {len(train_losses)}")
print(f"Early Stopping: {'Yes' if patience_counter >= CFG.patience else 'No'}")

# Model size and complexity
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nðŸ”§ MODEL COMPLEXITY:")
print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")
print(f"Model Size: {total_params / 1e6:.2f}M parameters")


In [None]:
# Inference Function
def predict_next_day(model, recent_data, device, scalers=None):
    """
    Predict fire count for next day given recent T days of data.
    
    Args:
        model: Trained SpatiotemporalUNet model
        recent_data: Dict with recent T days of data {var: [T, H, W, C]}
        device: torch device
        scalers: Optional normalization scalers
    
    Returns:
        np.array: Predicted fire counts [H, W]
    """
    model.eval()
    
    with torch.no_grad():
        # Normalize data if scalers provided
        if scalers:
            normalized_data = {}
            for var, data in recent_data.items():
                if var in scalers:
                    normalized_data[var] = (data - scalers[var]['mean']) / scalers[var]['std']
                else:
                    normalized_data[var] = data
        else:
            normalized_data = recent_data
        
        # Stack variables as channels
        inputs = []
        for var in CFG.variables + CFG.static_variables:
            if var in normalized_data:
                inputs.append(torch.FloatTensor(normalized_data[var]))
        
        if not inputs:
            raise ValueError("No input variables found")
        
        # Reshape to [1, T*C, H, W] for batch processing
        x = torch.stack(inputs, dim=1).unsqueeze(0).to(device)
        
        # Predict
        outputs = model(x)
        
        # Convert back to numpy
        predictions = outputs.squeeze().cpu().numpy()
        
        return predictions

# Example usage (commented out)
# recent_data = {
#     'tmin': np.random.randn(CFG.T, 180, 360, 1),
#     'tmax': np.random.randn(CFG.T, 180, 360, 1),
#     # ... other variables
# }
# predictions = predict_next_day(model, recent_data, device, scalers)
# print(f"Prediction shape: {predictions.shape}")
# print(f"Prediction range: {predictions.min():.3f} to {predictions.max():.3f}")

print("âœ“ Inference function created: predict_next_day()")
print("Usage: predictions = predict_next_day(model, recent_data_dict, device, scalers)")

# Save model (optional)
# torch.save(model.state_dict(), 'spatiotemporal_unet_model.pth')
# print("âœ“ Model saved to spatiotemporal_unet_model.pth")
