In [None]:
# -*- coding: utf-8 -*-
"""
ConvLSTM (Convolutional Long Short-Term Memory) Training and Evaluation Script

This script implements a ConvLSTM-based model for flood inundation forecasting,
serving as a baseline comparison to architectures like FNO. It includes data
loading (multi-event, chunked), model definition (ConvLSTMCell, ConvLSTMFloodModel),
custom loss functions, and a full training/evaluation loop with hyperparameter
grid search.

This code is refactored from a Jupyter Notebook, standardizing it for
command-line execution and public release.

Key Features:
- ConvLSTM model architecture with dynamic 3D positional encoding.
- Multiple Dataset classes for different loading strategies:
    - Flood3DMultiEventDataset: Loads all training events from a directory.
    - Flood3DTestDataset: Loads a full test event into RAM (for 100m data).
    - Flood3DChunkedTestDataset: Loads a test event in chunks (for 30m data).
- Hyperparameter grid search via command-line arguments.
- Robust memory management and logging.
"""

import os
import h5py
import time
import random
import numpy as np
import glob
import math
import gc
import psutil
import logging
import argparse  # Import argparse for command-line arguments
import csv
from typing import List, Tuple, Optional, Dict, Any

# PyTorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Scikit-learn Imports
from sklearn.metrics import r2_score, mean_squared_error

# =========================
# 1. Memory Management Utilities
# =========================

def print_memory_usage(logger: logging.Logger):
    """
    Logs the current CPU and GPU memory usage.
    """
    # CPU Memory
    process = psutil.Process(os.getpid())
    cpu_memory = process.memory_info().rss / (1024 * 1024)  # MB
    logger.info(f"CPU Memory Usage: {cpu_memory:.2f} MB")
    
    # GPU Memory (if available)
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            reserved = torch.cuda.memory_reserved(i) / (1024 * 1024)  # MB
            allocated = torch.cuda.memory_allocated(i) / (1024 * 1024)  # MB
            logger.info(f"GPU:{i} Reserved Memory: {reserved:.2f} MB, Allocated Memory: {allocated:.2f} MB")

def clean_memory():
    """
    Performs garbage collection and empties the CUDA cache.
    """
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# =========================
# 2. Logging and Seeding Utilities
# =========================

def setup_logging() -> logging.Logger:
    """
    Configures a global logger.
    """
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if logger.hasHandlers():
        logger.handlers.clear()
    
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    return logger

# Initialize global logger
logger = setup_logging()

def set_seed(seed: int = 42, deterministic: bool = True):
    """
    Sets random seeds for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def seed_worker(worker_id: int):
    """
    Sets the random seed for a DataLoader worker.
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# =========================
# 3. Custom Loss Functions
# =========================

class DynamicWeightedMSELoss(nn.Module):
    """
    Dynamic Weighted MSE Loss.
    
    Applies different weights based on target value percentiles.
    Uses a caching mechanism to update thresholds periodically.
    """
    def __init__(self, 
                 percentile_thresholds: List[float] = [50, 75, 90], 
                 weights_multipliers: List[float] = [1.0, 2.0, 5.0, 10.0], 
                 eps: float = 1e-8,
                 update_freq: int = 10):
        """
        Args:
            percentile_thresholds: List of percentiles (0-100).
            weights_multipliers: List of weights, length must be len(thresholds) + 1.
            eps: Epsilon for safe division.
            update_freq: How often (in batches) to update thresholds.
        """
        super(DynamicWeightedMSELoss, self).__init__()
        self.percentile_thresholds = percentile_thresholds
        self.weights_multipliers = weights_multipliers
        self.eps = eps
        self.update_freq = update_freq
        
        self.thresholds_cache = None
        self.call_count = 0
        
        if len(weights_multipliers) != len(percentile_thresholds) + 1:
            raise ValueError("Length of weights_multipliers must be len(percentile_thresholds) + 1")
    
    def update_thresholds(self, targets: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Optional[List[torch.Tensor]]:
        if mask is not None:
            valid_targets = targets[mask > 0.5]
        else:
            valid_targets = targets.reshape(-1)
        
        if valid_targets.numel() == 0:
            return None
        
        thresholds = [torch.quantile(valid_targets, p / 100.0) for p in self.percentile_thresholds]
        return thresholds
    
    def forward(self, 
                outputs: torch.Tensor, 
                targets: torch.Tensor, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        mse = (outputs - targets) ** 2
        
        self.call_count += 1
        if self.thresholds_cache is None or self.call_count % self.update_freq == 0:
            self.thresholds_cache = self.update_thresholds(targets, mask)
            
        if self.thresholds_cache is None:
            return mse
        
        thresholds = self.thresholds_cache
        weights = torch.ones_like(targets) * self.weights_multipliers[0]
        
        for i in range(len(thresholds)):
            if i == 0:
                condition = (targets > 0) & (targets <= thresholds[i])
            else:
                condition = (targets > thresholds[i-1]) & (targets <= thresholds[i])
            weights = torch.where(condition, self.weights_multipliers[i], weights)
        
        weights = torch.where(targets > thresholds[-1], self.weights_multipliers[-1], weights)
        
        return mse * weights

class StandardMSELoss(nn.Module):
    """
    Wrapper for standard nn.MSELoss to match the interface of
    DynamicWeightedMSELoss, accepting a mask argument.
    """
    def __init__(self):
        super(StandardMSELoss, self).__init__()
        self.mse = nn.MSELoss(reduction='none')
    
    def forward(self, 
                outputs: torch.Tensor, 
                targets: torch.Tensor, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self.mse(outputs, targets)

# =========================
# 4. Dataset Classes
# =========================

class Flood3DMultiEventDataset(Dataset):
    """
    3D Spatio-temporal Dataset for multiple flood events.
    Loads all samples from all H5 files in a directory into RAM.
    """
    def __init__(self, 
                 data_dir: str, 
                 time_steps: int = 11, 
                 transform: Optional[Any] = None, 
                 driver_key: str = 'data',  # Key was 'data' in convLSTM notebook
                 label_key: str = 'data'):
        
        self.time_steps = time_steps
        self.transform = transform
        self.driver_key = driver_key
        self.label_key = label_key
        
        driver_files = sorted(glob.glob(os.path.join(data_dir, "*_X*.h5")))
        label_files = sorted(glob.glob(os.path.join(data_dir, "*_Y*.h5")))
        
        if not driver_files or len(driver_files) != len(label_files):
            raise ValueError(f"Driver/Label file mismatch or not found in {data_dir}")
        
        self.input_tensors = []
        self.label_tensors = []
        self.mask_tensors = []
        
        for driver_file, label_file in zip(driver_files, label_files):
            event_name = os.path.basename(driver_file).split('_X')[0]
            logger.info(f"Loading training event: {event_name}")
            
            with h5py.File(driver_file, 'r') as hf:
                driver_data = hf[self.driver_key][:]
            
            with h5py.File(label_file, 'r') as hf:
                label_data = hf[self.label_key][:]
            
            if driver_data.shape[1] != time_steps:
                logger.warning(f"Event {event_name} has {driver_data.shape[1]} time steps, expected {time_steps}. Skipping.")
                continue
            
            for i in range(driver_data.shape[0]):
                driver_sequence = driver_data[i].astype(np.float32)
                mask = np.expand_dims(driver_sequence[-1, ..., 4], axis=0)  # (1, H, W)
                input_data = np.transpose(driver_sequence, (0, 3, 1, 2))  # (T, C, H, W)
                label = np.expand_dims(label_data[i].squeeze(-1), axis=0) # (1, H, W)
                
                if self.transform:
                    input_data, label, mask = self.transform(input_data, label, mask)
                
                self.input_tensors.append(torch.from_numpy(input_data).float())
                self.label_tensors.append(torch.from_numpy(label).float())
                self.mask_tensors.append(torch.from_numpy(mask).float())
            
            del driver_data, label_data
            gc.collect()
        
        self.num_samples = len(self.input_tensors)
        logger.info(f"Loaded {self.num_samples} training samples in total.")
    
    def __len__(self) -> int:
        return self.num_samples
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return (
            self.input_tensors[idx],
            self.label_tensors[idx],
            self.mask_tensors[idx]
        )

class Flood3DTestDataset(Dataset):
    """
    3D Spatio-temporal Test Dataset. Loads entire event into RAM.
    Suitable for 100m (LR) data.
    """
    def __init__(self, 
                 driver_path: str, 
                 label_path: str, 
                 driver_key: str = 'data', 
                 label_key: str = 'data',
                 time_steps: int = 11):
        
        with h5py.File(driver_path, 'r') as hf:
            self.driver_data = hf[driver_key][:]
        with h5py.File(label_path, 'r') as hf:
            self.label_data = hf[label_key][:]

        # Handle single-sample files
        if self.driver_data.ndim == 5 and self.driver_data.shape[0] == 1:
            self.driver_data = np.expand_dims(self.driver_data[0], axis=0)
        if self.label_data.ndim == 4 and self.label_data.shape[0] == 1:
            self.label_data = np.expand_dims(self.label_data[0], axis=0)

        assert self.driver_data.shape[0] == self.label_data.shape[0], "Sample count mismatch"
        assert self.driver_data.shape[1] == time_steps, f"Time steps must be {time_steps}"

        self.num_samples = self.driver_data.shape[0]
        self.masks = np.expand_dims(self.driver_data[:, -1, ..., 4], axis=1).astype(np.float32) # (N, 1, H, W)
        self.inputs = np.transpose(self.driver_data, (0, 1, 4, 2, 3)).astype(np.float32) # (N, T, C, H, W)
        self.labels = self.label_data.astype(np.float32) # (N, H, W, 1)

        del self.driver_data, self.label_data
        gc.collect()

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.inputs[idx]
        y = np.expand_dims(self.labels[idx].squeeze(-1), axis=0) # (1, H, W)
        m = self.masks[idx]
        return torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(m)

class Flood3DChunkedTestDataset(Dataset):
    """
    3D Spatio-temporal Test Dataset with Chunked Loading.
    Suitable for 30m (HR) data.
    """
    def __init__(self, 
                 driver_path: str, 
                 label_path: str, 
                 driver_key: str = 'data', 
                 label_key: str = 'data',
                 time_steps: int = 11, 
                 chunk_size: int = 4):
        
        if not os.path.isfile(driver_path):
            raise FileNotFoundError(f"Driver file not found: {driver_path}")
        if not os.path.isfile(label_path):
            raise FileNotFoundError(f"Label file not found: {label_path}")

        self.driver_path = driver_path
        self.label_path = label_path
        self.driver_key = driver_key
        self.label_key = label_key
        self.time_steps = time_steps
        self.chunk_size = chunk_size
        
        with h5py.File(driver_path, 'r') as hf:
            data_shape = hf[driver_key].shape
            self.num_samples = 1 if (len(data_shape) == 5 and data_shape[0] == 1) else data_shape[0]
        
        self.current_chunk_idx = -1
        self.current_chunk_data = None

    def __len__(self) -> int:
        return self.num_samples
    
    def load_chunk(self, chunk_idx: int):
        start_idx = chunk_idx * self.chunk_size
        end_idx = min(start_idx + self.chunk_size, self.num_samples)
        
        logger.info(f"Loading data chunk {chunk_idx}, samples: {start_idx} to {end_idx-1}")
        
        if self.current_chunk_data is not None:
            del self.current_chunk_data
            clean_memory()
        
        with h5py.File(self.driver_path, 'r') as hf:
            driver_data = np.expand_dims(hf[self.driver_key][0], axis=0) if self.num_samples == 1 else hf[self.driver_key][start_idx:end_idx]
        
        with h5py.File(self.label_path, 'r') as hf:
            label_data = np.expand_dims(hf[self.label_key][0], axis=0) if self.num_samples == 1 else hf[self.label_key][start_idx:end_idx]
        
        driver_data = driver_data.astype(np.float32)
        label_data = label_data.astype(np.float32)
        
        input_tensors, label_tensors, mask_tensors = [], [], []
        
        for i in range(driver_data.shape[0]):
            mask = np.expand_dims(driver_data[i, -1, ..., 4], axis=0)  # (1, H, W)
            input_data = np.transpose(driver_data[i], (0, 3, 1, 2))  # (T, C, H, W)
            label_sample = np.expand_dims(label_data[i].squeeze(-1), 0)  # (1, H, W)
            
            input_tensors.append(torch.from_numpy(input_data).float())
            label_tensors.append(torch.from_numpy(label_sample).float())
            mask_tensors.append(torch.from_numpy(mask).float())
        
        del driver_data, label_data
        gc.collect()
        
        self.current_chunk_idx = chunk_idx
        self.current_chunk_data = {
            'inputs': input_tensors,
            'labels': label_tensors,
            'masks': mask_tensors,
            'start_idx': start_idx
        }
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        chunk_idx = idx // self.chunk_size
        local_idx = idx % self.chunk_size
        
        if chunk_idx != self.current_chunk_idx:
            self.load_chunk(chunk_idx)
        
        if local_idx >= len(self.current_chunk_data['inputs']):
            raise IndexError(f"Index {idx} out of range for chunk {chunk_idx}")
        
        return (
            self.current_chunk_data['inputs'][local_idx],
            self.current_chunk_data['labels'][local_idx],
            self.current_chunk_data['masks'][local_idx]
        )
    
    def __del__(self):
        self.current_chunk_data = None
        gc.collect()

# =========================
# 5. ConvLSTM Model Architecture
# =========================

class PositionalEncoding3D(nn.Module):
    """
    3D Positional Encoding with caching (same as in FNO script).
    
    Generates dynamic positional encodings based on normalized
    coordinates (T, H, W) and caches them for efficiency.
    """
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        self.div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        self.pe_cache: Dict[str, torch.Tensor] = {}
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input shape: (batch, T, C, H, W)
        Output shape: (batch, T, H, W, d_model)
        """
        device = x.device
        batch, T, C, H, W = x.shape
        cache_key = f"{T}_{H}_{W}_{device}"
        
        if cache_key in self.pe_cache:
            return self.pe_cache[cache_key][None, ...].expand(batch, -1, -1, -1, -1)
        
        pos_t = (torch.arange(T, device=device).float() / T)
        pos_h = (torch.arange(H, device=device).float() / H)
        pos_w = (torch.arange(W, device=device).float() / W)
        
        pe = torch.zeros(T, H, W, self.d_model, device=device)
        div_term = self.div_term.to(device)
        
        for i in range(self.d_model // 2):
            sin_t, cos_t = torch.sin(pos_t * div_term[i]), torch.cos(pos_t * div_term[i])
            sin_h, cos_h = torch.sin(pos_h * div_term[i]), torch.cos(pos_h * div_term[i])
            sin_w, cos_w = torch.sin(pos_w * div_term[i]), torch.cos(pos_w * div_term[i])

            pe[..., 2*i]   = sin_t[:, None, None] * sin_h[None, :, None] * sin_w[None, None, :]
            pe[..., 2*i+1] = cos_t[:, None, None] * cos_h[None, :, None] * cos_w[None, None, :]
        
        self.pe_cache[cache_key] = pe
        return pe[None, ...].expand(batch, -1, -1, -1, -1)

class ConvLSTMCell(nn.Module):
    """
    A single ConvLSTM cell.
    """
    def __init__(self, input_channels: int, hidden_channels: int, 
                 kernel_size: int, padding: int):
        super(ConvLSTMCell, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = padding
        
        # Convolutions for input-to-hidden and hidden-to-hidden
        self.conv = nn.Conv2d(
            in_channels=self.input_channels + self.hidden_channels,
            out_channels=4 * self.hidden_channels,  # i, f, g, o gates
            kernel_size=self.kernel_size,
            padding=self.padding
        )

        # Initialize forget gate bias to 1 for better initial memory
        self.conv.bias.data[hidden_channels:2*hidden_channels].fill_(1.0)
        
    def forward(self, x: torch.Tensor, 
              h_prev: torch.Tensor, 
              c_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Single forward pass of the ConvLSTM cell.
        
        Args:
            x: Input tensor [B, C_in, H, W]
            h_prev: Previous hidden state [B, C_hidden, H, W]
            c_prev: Previous cell state [B, C_hidden, H, W]
            
        Returns:
            h_next: Next hidden state
            c_next: Next cell state
        """
        combined = torch.cat([x, h_prev], dim=1)  # Concatenate along channel dim
        combined_conv = self.conv(combined)
        
        # Split into i, f, g, o gates
        cc_i, cc_f, cc_g, cc_o = torch.split(combined_conv, self.hidden_channels, dim=1)
        
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        g = torch.tanh(cc_g)
        o = torch.sigmoid(cc_o)
        
        c_next = f * c_prev + i * g
        h_next = o * torch.tanh(c_next)
        
        return h_next, c_next

class ConvLSTM(nn.Module):
    """
    Multi-layer ConvLSTM network.
    """
    def __init__(self, input_channels: int, hidden_channels: int, 
                 kernel_size: int, num_layers: int, dropout: float = 0.0):
        super(ConvLSTM, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.padding = kernel_size // 2
        
        self.cell_list = nn.ModuleList()
        
        # First layer
        self.cell_list.append(
            ConvLSTMCell(input_channels, hidden_channels, kernel_size, self.padding)
        )
        
        # Subsequent layers
        for _ in range(1, num_layers):
            self.cell_list.append(
                ConvLSTMCell(hidden_channels, hidden_channels, kernel_size, self.padding)
            )
        
        self.dropout_layer = nn.Dropout(dropout)
        
    def forward(self, input_tensor: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass for the multi-layer ConvLSTM.
        
        Args:
            input_tensor: [B, T, C_in, H, W]
            
        Returns:
            last_state_list: List of (h, c) tuples for each layer at the last time step.
        """
        batch_size, seq_len, _, height, width = input_tensor.size()
        
        # Initialize hidden and cell states
        h_list, c_list = [], []
        for _ in range(self.num_layers):
            h, c = self._init_hidden(batch_size, height, width, input_tensor.device)
            h_list.append(h)
            c_list.append(c)
        
        # Iterate over time steps
        for t in range(seq_len):
            current_input = input_tensor[:, t, :, :, :]
            
            for layer_idx in range(self.num_layers):
                h, c = h_list[layer_idx], c_list[layer_idx]
                
                # Input for the current layer
                if layer_idx == 0:
                    layer_input = current_input
                else:
                    # Apply dropout between layers
                    layer_input = self.dropout_layer(h_list[layer_idx-1])
                
                h_list[layer_idx], c_list[layer_idx] = self.cell_list[layer_idx](
                    layer_input, h, c
                )
        
        # Collect last states
        last_state_list = [(h_list[i], c_list[i]) for i in range(self.num_layers)]
            
        return last_state_list
    
    def _init_hidden(self, batch_size: int, height: int, width: int, 
                     device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initializes hidden and cell states to zeros."""
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=device))

class ConvLSTMFloodModel(nn.Module):
    """
    The complete ConvLSTM model for flood prediction, including
    positional encoding and input/output layers.
    """
    def __init__(self, in_channels: int, out_channels: int, 
                 hidden_channels: int, num_layers: int, 
                 kernel_size: int = 3, dropout: float = 0.2):
        super(ConvLSTMFloodModel, self).__init__()
        
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.kernel_size = kernel_size
        
        self.pos_encoder = PositionalEncoding3D(hidden_channels)
        
        # 1. Lifting Layer (P)
        self.fc0 = nn.Linear(in_channels + hidden_channels, hidden_channels)
        
        # 2. ConvLSTM core
        self.convlstm = ConvLSTM(
            input_channels=hidden_channels,
            hidden_channels=hidden_channels,
            kernel_size=kernel_size,
            num_layers=num_layers,
            dropout=dropout
        )
        
        # 3. Projection Layer (Q)
        self.conv_out = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input shape: (batch, T, C, H, W)
        Output shape: (batch, out_channels, H, W)
        """
        batch_size, T, C, H, W = x.shape
        
        # 1. Add positional encoding
        pos_enc = self.pos_encoder(x)  # (batch, T, H, W, hidden)
        
        # 2. Concatenate features and positional encoding
        x = x.permute(0, 1, 3, 4, 2)  # (batch, T, H, W, C)
        x = torch.cat([x, pos_enc], dim=-1) # (batch, T, H, W, C + hidden)
        
        # 3. Lift to hidden dimension
        x = self.fc0(x)  # (batch, T, H, W, hidden)
        
        # 4. Permute for ConvLSTM
        x = x.permute(0, 1, 4, 2, 3)  # (batch, T, hidden, H, W)
        
        # 5. Apply ConvLSTM
        last_states = self.convlstm(x)
        
        # 6. Get last hidden state from the last layer
        h_last = last_states[-1][0]  # (batch, hidden, H, W)
        
        # 7. Project to output
        output = self.conv_out(h_last)  # (batch, out_channels, H, W)
        
        return output

# =========================
# 6. Training & Evaluation Functions
# =========================

def train_one_epoch(model: nn.Module, 
                    dataloader: DataLoader, 
                    optimizer: optim.Optimizer, 
                    criterion: nn.Module, 
                    device: torch.device, 
                    epoch: int, 
                    total_epochs: int) -> Tuple[float, float, float, float]:
    """
    Runs a single training epoch.
    """
    model.train()
    epoch_loss = 0.0
    all_preds, all_labels = [], []
    start_time = time.time()
    
    progress_bar = tqdm(dataloader, 
                       desc=f"Epoch {epoch}/{total_epochs} [ConvLSTM]", 
                       bar_format="{l_bar}{bar:20}{r_bar}",
                       leave=False)
    
    for batch_idx, (inputs, targets, masks) in enumerate(progress_bar):
        inputs = inputs.to(device)
        targets = targets.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad(set_to_none=True)
        
        outputs = model(inputs)
        
        # Calculate per-pixel loss
        if isinstance(criterion, DynamicWeightedMSELoss):
            loss_tensor = criterion(outputs, targets, masks)
        else:
            # StandardMSELoss or nn.MSELoss(reduction='none')
            loss_tensor = criterion(outputs, targets, mask=masks if isinstance(criterion, StandardMSELoss) else None)
        
        # Apply mask for standard MSE
        if not isinstance(criterion, DynamicWeightedMSELoss):
             loss_tensor = loss_tensor * masks

        # Aggregate loss over valid pixels
        masked_loss = loss_tensor.sum() / (masks.sum() + 1e-8)
        
        masked_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        epoch_loss += masked_loss.item()
        
        with torch.no_grad():
            valid_mask = masks.squeeze(1).bool()
            all_preds.append(outputs.detach().squeeze(1)[valid_mask].cpu().numpy())
            all_labels.append(targets.squeeze(1)[valid_mask].cpu().numpy())
        
        progress_bar.set_postfix({
            'Loss': f"{masked_loss.item():.4f}",
            'LR': f"{optimizer.param_groups[0]['lr']:.2e}"
        }, refresh=False)
        
        del inputs, targets, masks, outputs, loss_tensor, masked_loss
        if batch_idx % 10 == 0:
            clean_memory()
    
    epoch_duration = time.time() - start_time
    avg_loss = epoch_loss / (batch_idx + 1)
    
    train_r2, train_rmse = 0.0, 0.0
    try:
        if all_preds and all_labels:
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            train_r2 = r2_score(all_labels, all_preds)
            train_rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
            del all_preds, all_labels
    except Exception as e:
        logger.error(f"Error calculating training metrics: {e}")
    
    clean_memory()
    return avg_loss, train_r2, train_rmse, epoch_duration

@torch.no_grad()
def predict_and_evaluate(model: nn.Module, 
                         dataloader: DataLoader, 
                         device: torch.device) -> Tuple[float, float, float]:
    """
    Optimized prediction and evaluation loop for test sets.
    """
    model.eval()
    all_preds, all_labels = [], []
    start_time = time.time()
    
    pbar = tqdm(dataloader, 
               desc="Testing [ConvLSTM]", 
               bar_format="{l_bar}{bar:20}{r_bar}",
               leave=False)
    
    with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
        for inputs, labels, masks in pbar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            outputs = model(inputs)
            
            valid_mask = masks.squeeze(1).bool()
            all_preds.append(outputs.squeeze(1)[valid_mask].cpu().numpy())
            all_labels.append(labels.squeeze(1)[valid_mask].cpu().numpy())
            
            del inputs, labels, masks, outputs
            clean_memory()
    
    r2, rmse = 0.0, float('inf')
    try:
        if all_preds and all_labels:
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            r2 = r2_score(all_labels, all_preds)
            rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
            del all_preds, all_labels
    except Exception as e:
        logger.error(f"Error calculating test metrics: {e}")
    
    test_time = time.time() - start_time
    clean_memory()
    return r2, rmse, test_time

@torch.no_grad()
def evaluate_on_training_set(model: nn.Module, 
                             train_dataset: Dataset, 
                             device: torch.device, 
                             batch_size: int = 100) -> Tuple[float, float, float]:
    """
    Evaluates the model on the full training set (post-training).
    """
    model.eval()
    all_preds, all_labels = [], []
    start_time = time.time()
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=0,
    )
    
    pbar = tqdm(train_loader, 
               desc="Evaluating on Training Set [ConvLSTM]", 
               bar_format="{l_bar}{bar:20}{r_bar}",
               leave=False)
    
    with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
        for inputs, labels, masks in pbar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            outputs = model(inputs)
            
            valid_mask = masks.squeeze(1).bool()
            all_preds.append(outputs.squeeze(1)[valid_mask].cpu().numpy())
            all_labels.append(labels.squeeze(1)[valid_mask].cpu().numpy())
            
            del inputs, labels, masks, outputs
            if len(all_preds) % 10 == 0:
                clean_memory()

    r2, rmse = 0.0, float('inf')
    try:
        if all_preds and all_labels:
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            r2 = r2_score(all_labels, all_preds)
            rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
            del all_preds, all_labels
    except Exception as e:
        logger.error(f"Error calculating full training set metrics: {e}")
    
    eval_time = time.time() - start_time
    clean_memory()
    return r2, rmse, eval_time

# =========================
# 7. Main Execution
# =========================

def get_args() -> argparse.Namespace:
    """
    Parses command-line arguments.
    """
    parser = argparse.ArgumentParser(description="ConvLSTM Training and Evaluation Script")
    
    # --- Data Paths ---
    parser.add_argument('--train_data_dir', type=str, 
                        default="/home/ubuntu/Documents/xjq/data_Q_timeseries_0329/train_data_LF",
                        help="Directory for multi-event training data.")
    parser.add_argument('--test_100m_dir', type=str, 
                        default="/home/ubuntu/Documents/xjq/data_Q_timeseries_0329/test_data_LF",
                        help="Directory for 100m (LR) test data.")
    parser.add_argument('--test_30m_dir', type=str, 
                        default="/home/ubuntu/Documents/xjq/data_Q_timeseries_0329/test_data_HF",
                        help="Directory for 30m (HR) test data.")
    
    # --- Training Parameters ---
    parser.add_argument('--num_epochs', type=int, default=20, help="Number of training epochs.")
    parser.add_argument('--batch_size', type=int, default=2, help="Batch size for training.")
    parser.add_argument('--lr', type=float, default=1e-3, help="Learning rate.")
    parser.add_argument('--dropout', type=float, default=0.1, help="Dropout rate in ConvLSTM.")
    
    # --- Model Hyperparameters (Grid Search) ---
    parser.add_argument('--kernel_size', type=int, nargs='+', default=[3, 5, 7], 
                        help="List of kernel sizes to try.")
    parser.add_argument('--hidden_channels', type=int, nargs='+', default=[8, 16, 32, 64], 
                        help="List of hidden channel sizes to try.")
    parser.add_argument('--num_layers', type=int, nargs='+', default=[1, 2, 3, 4, 5], 
                        help="List of layer counts to try.")
    
    # --- Loss and Memory Config ---
    parser.add_argument('--use_dynamic_loss', action='store_true', 
                        help="Use DynamicWeightedMSELoss instead of standard MSE.")
    parser.add_argument('--hf_chunk_size', type=int, default=10, 
                        help="Chunk size for loading 30m data (must use chunked loader for 30m).")
    parser.add_argument('--eval_batch_size', type=int, default=100, 
                        help="Batch size for post-training evaluation on the full train set.")
    
    # --- System Config ---
    parser.add_argument('--num_workers', type=int, default=0, 
                        help="Number of DataLoader workers (recommend 0).")
    
    # --- Output Files ---
    parser.add_argument('--results_file', type=str, default="ConvLSTM_test_results.csv",
                        help="File to save test results.")
    parser.add_argument('--train_eval_file', type=str, default="ConvLSTM_train_eval_results.csv",
                        help="File to save full training set evaluation results.")

    return parser.parse_args()

def main(args: argparse.Namespace):
    """
    Main training and evaluation loop for ConvLSTM.
    """
    # 1. Initialization
    set_seed(42, deterministic=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    logger.info("Initial memory state:")
    print_memory_usage()

    # 2. Load Training Dataset
    logger.info(f"Loading training dataset from: {args.train_data_dir}")
    full_dataset = Flood3DMultiEventDataset(args.train_data_dir, driver_key='data', label_key='data')
    logger.info(f"Training dataset size: {len(full_dataset)}")
    
    train_loader = DataLoader(
        full_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers,
        worker_init_fn=seed_worker
    )

    # 3. Prepare Test Files
    def prepare_test_files(test_dir: str) -> List[Tuple[str, str]]:
        driver_files = sorted([f for f in os.listdir(test_dir) if f.endswith('_X.h5')])
        label_files = sorted([f for f in os.listdir(test_dir) if f.endswith('_Y.h5')])
        if len(driver_files) != len(label_files):
            raise ValueError(f"Test file mismatch in {test_dir}")
        return list(zip(driver_files, label_files))

    test_100m_files = prepare_test_files(args.test_100m_dir)
    test_30m_files = prepare_test_files(args.test_30m_dir)

    # 4. Initialize Log Files
    loss_name = "dynamic" if args.use_dynamic_loss else "mse"
    test_log_file = args.results_file.replace(".csv", f"_{loss_name}.csv")
    train_eval_log_file = args.train_eval_file.replace(".csv", f"_{loss_name}.csv")
    
    test_log_header = ["kernel_size", "hidden", "layers", "resolution", "event", "test_r2", "test_rmse", "pred_time"]
    train_eval_header = ["kernel_size", "hidden", "layers", "train_r2", "train_rmse", "eval_time"]
    
    with open(test_log_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(test_log_header)
    
    with open(train_eval_log_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(train_eval_header)

    # 5. Experiment Grid Search Loop
    for kernel_size in args.kernel_size:
        for hidden in args.hidden_channels:
            for n_layers in args.num_layers:
                
                exp_id = f"ks{kernel_size}_h{hidden}_l{n_layers}_{loss_name}"
                final_model_path = f"final_ConvLSTM_model_{exp_id}.pth"
                
                if os.path.exists(final_model_path):
                    logger.info(f"Model for config ({exp_id}) already exists, skipping.")
                    continue

                train_log_file = f"train_log_ConvLSTM_{exp_id}.csv"
                train_log_header = ["kernel_size", "hidden", "layers", "epoch", "train_loss", "train_r2", "train_rmse", "epoch_time"]
                with open(train_log_file, 'w', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(train_log_header)

                # Initialize Model
                model = ConvLSTMFloodModel(
                    in_channels=7,
                    out_channels=1,
                    hidden_channels=hidden,
                    num_layers=n_layers,
                    kernel_size=kernel_size,
                    dropout=args.dropout
                ).to(device)
                
                # Initialize Loss
                if args.use_dynamic_loss:
                    criterion = DynamicWeightedMSELoss().to(device)
                    logger.info("Using Dynamic Weighted MSE Loss")
                else:
                    criterion = StandardMSELoss().to(device)
                    logger.info("Using Standard MSE Loss")
                
                optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

                logger.info(f"\n=== Starting Training: {exp_id} ===")
                logger.info("Memory state after model initialization:")
                print_memory_usage()
                
                best_train_loss = float('inf')
                
                # --- Training Loop ---
                for epoch in range(1, args.num_epochs + 1):
                    train_loss, train_r2, train_rmse, epoch_time = train_one_epoch(
                        model, train_loader, optimizer, criterion, device, epoch, args.num_epochs
                    )
                    
                    with open(train_log_file, 'a', newline='') as f:
                        writer = csv.writer(f)
                        writer.writerow([
                            kernel_size, hidden, n_layers, epoch, 
                            f"{train_loss:.6f}", f"{train_r2:.4f}", f"{train_rmse:.4f}", f"{epoch_time:.2f}"
                        ])
                    
                    logger.info(f"Epoch {epoch}/{args.num_epochs}: "
                               f"Loss={train_loss:.4f}, R2={train_r2:.4f}, RMSE={train_rmse:.4f}, "
                               f"Time={epoch_time:.2f}s")

                    scheduler.step(train_loss)
                    
                    if train_loss < best_train_loss:
                        best_train_loss = train_loss
                        torch.save(model.state_dict(), final_model_path)
                        logger.info(f"Saved best model (Train Loss: {train_loss:.4f})")
                
                # --- Post-Training Evaluation ---
                logger.info(f"Loading best model from {final_model_path} for evaluation.")
                model.load_state_dict(torch.load(final_model_path))
                
                logger.info("Evaluating model on the full training set...")
                train_eval_r2, train_eval_rmse, train_eval_time = evaluate_on_training_set(
                    model, full_dataset, device, batch_size=args.eval_batch_size
                )
                
                with open(train_eval_log_file, 'a', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow([
                        kernel_size, hidden, n_layers, 
                        f"{train_eval_r2:.6f}", f"{train_eval_rmse:.6f}", f"{train_eval_time:.2f}"
                    ])
                
                logger.info(f"Full Train Set Eval: R2={train_eval_r2:.6f}, RMSE={train_eval_rmse:.6f}, Time={train_eval_time:.2f}s")

                # --- Test Set Evaluation ---
                def run_testing(test_files: List[Tuple[str, str]], test_dir: str, resolution: str):
                    for drv_file, lbl_file in test_files:
                        event_name = drv_file.replace('_X.h5', '')
                        drv_path = os.path.join(test_dir, drv_file)
                        lbl_path = os.path.join(test_dir, lbl_file)
                        
                        logger.info(f"Memory state before testing event {event_name}:")
                        print_memory_usage()

                        if resolution == "30m":
                            logger.info(f"Using Chunked Loader for 30m data (Chunk size: {args.hf_chunk_size})")
                            test_ds = Flood3DChunkedTestDataset(
                                driver_path=drv_path, label_path=lbl_path,
                                driver_key='data', label_key='data',
                                chunk_size=args.hf_chunk_size
                            )
                            test_batch_size = 2
                        else:
                            logger.info("Using Full-Load Loader for 100m data")
                            test_ds = Flood3DTestDataset(
                                driver_path=drv_path, label_path=lbl_path,
                                driver_key='data', label_key='data'
                            )
                            test_batch_size = 4
                            
                        test_loader = DataLoader(
                            test_ds, 
                            batch_size=test_batch_size, 
                            shuffle=False,
                            pin_memory=True,
                            num_workers=0
                        )

                        r2, rmse, test_time = predict_and_evaluate(model, test_loader, device)

                        with open(test_log_file, 'a', newline='') as f:
                            writer = csv.writer(f)
                            writer.writerow([
                                kernel_size, hidden, n_layers, 
                                resolution, event_name, f"{r2:.4f}", f"{rmse:.4f}", f"{test_time:.2f}"
                            ])
                        
                        logger.info(f"Event: {event_name}, Res: {resolution}, "
                                   f"R2={r2:.4f}, RMSE={rmse:.4f}, Time={test_time:.2f}s")
                        
                        del test_ds, test_loader
                        clean_memory()

                logger.info("Testing on 100m dataset...")
                run_testing(test_100m_files, args.test_100m_dir, "100m")

                logger.info("Testing on 30m (ZS-SR) dataset...")
                run_testing(test_30m_files, args.test_30m_dir, "30m")

                del model
                clean_memory()
                logger.info(f"Memory state after experiment {exp_id}:")
                print_memory_usage()

    logger.info(f"\nAll ConvLSTM experiments complete!")
    logger.info(f"Test results saved to: {test_log_file}")
    logger.info(f"Training set evaluation results saved to: {train_eval_log_file}")

if __name__ == "__main__":
    # 1. Parse command-line arguments
    args = get_args()
    
    # 2. Log the configuration
    logger.info("Starting ConvLSTM run with configuration:")
    logger.info("=" * 30)
    for k, v in vars(args).items():
        logger.info(f"{k}: {v}")
    logger.info("=" * 30)
    
    # 3. Run the main function
    main(args)
