In [None]:
# -*- coding: utf-8 -*-
"""
FNO3d (Spatio-temporal Fourier Neural Operator) Training and Evaluation Script

This script implements the FNO3d model for flood inundation forecasting,
including data loading (with multiple resolution and chunking strategies),
model definition (FNO3d, SpectralConv3d), custom loss functions, and a
full training/evaluation loop with hyperparameter grid search.

This code is refactored from a Jupyter Notebook, standardizing it for
command-line execution and public release.

Key Features:
- FNO3d model architecture with 3D spectral convolutions.
- Dynamic 3D positional encoding with caching.
- Multiple Dataset classes for different loading strategies:
    - Flood3DMultiEventDataset: Loads all training events from a directory.
    - Flood3DTestDataset: Loads a full test event into RAM (for 100m data).
    - Flood3DChunkedTestDataset: Loads a test event in chunks (for 30m data).
- Hyperparameter grid search via command-line arguments.
- Robust memory management with explicit gc.collect() and torch.cuda.empty_cache().
- Logging for training progress and evaluation results.
"""

import os
import h5py
import time
import random
import numpy as np
import glob
import math
import gc
import psutil
import logging
import argparse  # Import argparse for command-line arguments
from typing import List, Tuple, Optional, Dict, Any

# PyTorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Scikit-learn Imports
from sklearn.metrics import r2_score, mean_squared_error

# =========================
# 1. Memory Management Utilities
# =========================

def print_memory_usage(logger: logging.Logger):
    """
    Logs the current CPU and GPU memory usage.
    """
    # CPU Memory
    process = psutil.Process(os.getpid())
    cpu_memory = process.memory_info().rss / (1024 * 1024)  # MB
    logger.info(f"CPU Memory Usage: {cpu_memory:.2f} MB")
    
    # GPU Memory (if available)
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            reserved = torch.cuda.memory_reserved(i) / (1024 * 1024)  # MB
            allocated = torch.cuda.memory_allocated(i) / (1024 * 1024)  # MB
            logger.info(f"GPU:{i} Reserved Memory: {reserved:.2f} MB, Allocated Memory: {allocated:.2f} MB")

def clean_memory():
    """
    Performs garbage collection and empties the CUDA cache.
    """
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# =========================
# 2. Logging and Seeding Utilities
# =========================

def setup_logging() -> logging.Logger:
    """
    Configures a global logger.
    """
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if logger.hasHandlers():
        logger.handlers.clear()
    
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    return logger

# Initialize global logger
logger = setup_logging()

def set_seed(seed: int = 42, deterministic: bool = True):
    """
    Sets random seeds for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def seed_worker(worker_id: int):
    """
    Sets the random seed for a DataLoader worker.
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# =========================
# 3. Custom Loss Functions
# =========================

class DynamicWeightedMSELoss(nn.Module):
    """
    Dynamic Weighted MSE Loss.
    
    This loss function applies different weights to the MSE loss based on
    the magnitude (percentiles) of the target (ground truth) values.
    It uses a caching mechanism to update the thresholds periodically.
    """
    def __init__(self, 
                 percentile_thresholds: List[float] = [50, 75, 90], 
                 weights_multipliers: List[float] = [1.0, 2.0, 5.0, 10.0], 
                 eps: float = 1e-8,
                 update_freq: int = 10):
        """
        Args:
            percentile_thresholds: List of percentiles to define water depth bins.
            weights_multipliers: List of weights for each bin. Length must be
                                 len(percentile_thresholds) + 1.
            eps: Small value to prevent division by zero in masked loss.
            update_freq: How often (in batches) to update the thresholds.
        """
        super(DynamicWeightedMSELoss, self).__init__()
        self.percentile_thresholds = percentile_thresholds
        self.weights_multipliers = weights_multipliers
        self.eps = eps
        self.update_freq = update_freq
        
        self.thresholds_cache = None
        self.call_count = 0
        
        if len(weights_multipliers) != len(percentile_thresholds) + 1:
            raise ValueError("Length of weights_multipliers must be len(percentile_thresholds) + 1")
    
    def update_thresholds(self, targets: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Optional[List[torch.Tensor]]:
        """
        Updates the percentile thresholds based on the current batch.
        """
        if mask is not None:
            valid_targets = targets[mask > 0.5]
        else:
            valid_targets = targets.reshape(-1)
        
        if valid_targets.numel() == 0:
            return None
        
        thresholds = []
        for p in self.percentile_thresholds:
            threshold = torch.quantile(valid_targets, p / 100.0)
            thresholds.append(threshold)
        
        return thresholds
    
    def forward(self, 
                outputs: torch.Tensor, 
                targets: torch.Tensor, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Calculates the dynamic weighted MSE loss.
        
        Args:
            outputs: Model predictions (B, 1, H, W)
            targets: Ground truth labels (B, 1, H, W)
            mask: Optional mask for valid areas (B, 1, H, W)
            
        Returns:
            Weighted MSE loss tensor (B, 1, H, W)
        """
        mse = (outputs - targets) ** 2
        
        self.call_count += 1
        if self.thresholds_cache is None or self.call_count % self.update_freq == 0:
            self.thresholds_cache = self.update_thresholds(targets, mask)
            
        if self.thresholds_cache is None:
            return mse
        
        thresholds = self.thresholds_cache
        
        # Start with the base weight
        weights = torch.ones_like(targets) * self.weights_multipliers[0]
        
        # Apply weights for different percentile ranges
        for i in range(len(thresholds)):
            if i == 0:
                # First interval: (0, threshold_0]
                condition = (targets > 0) & (targets <= thresholds[i])
            else:
                # Mid intervals: (threshold_i-1, threshold_i]
                condition = (targets > thresholds[i-1]) & (targets <= thresholds[i])
                
            weights = torch.where(condition, 
                                 self.weights_multipliers[i] * torch.ones_like(targets),
                                 weights)
        
        # Last interval: (threshold_last, infinity)
        weights = torch.where(targets > thresholds[-1],
                             self.weights_multipliers[-1] * torch.ones_like(targets),
                             weights)
        
        weighted_mse = mse * weights
        
        return weighted_mse  # Return per-pixel loss (B, 1, H, W)

class StandardMSELoss(nn.Module):
    """
    Wrapper for standard nn.MSELoss to match the interface of
    DynamicWeightedMSELoss, accepting a mask argument.
    """
    def __init__(self):
        super(StandardMSELoss, self).__init__()
        # reduction='none' keeps the per-pixel loss
        self.mse = nn.MSELoss(reduction='none')
    
    def forward(self, 
                outputs: torch.Tensor, 
                targets: torch.Tensor, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Calculates the standard MSE loss.
        
        Args:
            outputs: Model predictions (B, 1, H, W)
            targets: Ground truth labels (B, 1, H, W)
            mask: Optional mask (ignored by this loss, but accepted
                  for compatibility).
            
        Returns:
            MSE loss tensor (B, 1, H, W)
        """
        return self.mse(outputs, targets)

# =========================
# 4. Dataset Classes
# =========================

class Flood3DMultiEventDataset(Dataset):
    """
    3D Spatio-temporal Dataset for multiple flood events.
    Loads all samples from all H5 files in a directory into RAM.
    
    Expected file format in data_dir:
    - *_X.h5 or *_X_norm.h5 (Driver data)
    - *_Y.h5 or *_Y_norm.h5 (Label data)
    
    H5 structure:
    - driver_key: (num_samples, T, H, W, C_in)
    - label_key: (num_samples, H, W, C_out)
    """
    def __init__(self, 
                 data_dir: str, 
                 time_steps: int = 11, 
                 transform: Optional[Any] = None, 
                 driver_key: str = 'X_data', 
                 label_key: str = 'Y_data'):
        
        self.time_steps = time_steps
        self.transform = transform
        self.driver_key = driver_key
        self.label_key = label_key
        
        driver_files = sorted(glob.glob(os.path.join(data_dir, "*_X*.h5")))
        label_files = sorted(glob.glob(os.path.join(data_dir, "*_Y*.h5")))
        
        if len(driver_files) != len(label_files):
            raise ValueError(f"Driver file count ({len(driver_files)}) != Label file count ({len(label_files)}) in {data_dir}")
        
        if len(driver_files) == 0:
            raise FileNotFoundError(f"No data files found in {data_dir}")
        
        self.input_tensors = []
        self.label_tensors = []
        self.mask_tensors = []
        
        for driver_file, label_file in zip(driver_files, label_files):
            event_name = os.path.basename(driver_file).replace("_X.h5", "").replace("_X_norm.h5", "")
            logger.info(f"Loading training event: {event_name}")
            
            with h5py.File(driver_file, 'r') as hf:
                driver_data = hf[self.driver_key][:]
            
            with h5py.File(label_file, 'r') as hf:
                label_data = hf[self.label_key][:]
            
            if driver_data.shape[1] != time_steps:
                logger.warning(f"Event {event_name} has {driver_data.shape[1]} time steps, expected {time_steps}. Skipping.")
                continue
            
            num_samples = driver_data.shape[0]
            
            for i in range(num_samples):
                driver_sequence = driver_data[i].astype(np.float32)  # (T, H, W, C_in)
                
                # Mask is 5th channel (index 4) of the last time step
                mask = driver_sequence[-1, ..., 4]  # (H, W)
                mask = np.expand_dims(mask, axis=0)  # (1, H, W)
                
                # Transpose input to (T, C_in, H, W)
                input_data = np.transpose(driver_sequence, (0, 3, 1, 2))
                
                # Prepare label (1, H, W)
                label = label_data[i].squeeze(-1) if label_data[i].shape[-1] == 1 else label_data[i]
                label = np.expand_dims(label, axis=0)
                
                if self.transform:
                    input_data, label, mask = self.transform(input_data, label, mask)
                
                self.input_tensors.append(torch.from_numpy(input_data).float())
                self.label_tensors.append(torch.from_numpy(label).float())
                self.mask_tensors.append(torch.from_numpy(mask).float())
            
            del driver_data, label_data
            gc.collect()
        
        self.num_samples = len(self.input_tensors)
        logger.info(f"Loaded {self.num_samples} training samples in total.")
    
    def __len__(self) -> int:
        return self.num_samples
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return (
            self.input_tensors[idx],    # (T, C_in, H, W)
            self.label_tensors[idx],    # (1, H, W)
            self.mask_tensors[idx]      # (1, H, W)
        )

class Flood3DTestDataset(Dataset):
    """
    3D Spatio-temporal Test Dataset.
    Loads the *entire* test event into RAM upon initialization.
    Suitable for smaller datasets (e.g., 100m resolution).
    
    H5 structure:
    - driver_key: (num_samples, T, H, W, C_in)
    - label_key: (num_samples, H, W, C_out)
    """
    def __init__(self, 
                 driver_path: str, 
                 label_path: str, 
                 driver_key: str = 'data', 
                 label_key: str = 'data',
                 time_steps: int = 11):
        
        with h5py.File(driver_path, 'r') as hf:
            self.driver_data = hf[driver_key][:]
            
        with h5py.File(label_path, 'r') as hf:
            self.label_data = hf[label_key][:]

        # Handle potential single-sample files
        if self.driver_data.ndim == 5 and self.driver_data.shape[0] == 1:
            self.driver_data = np.expand_dims(self.driver_data[0], axis=0)
        if self.label_data.ndim == 4 and self.label_data.shape[0] == 1:
            self.label_data = np.expand_dims(self.label_data[0], axis=0)

        assert self.driver_data.shape[0] == self.label_data.shape[0], "Sample count mismatch"
        assert self.driver_data.shape[1] == time_steps, f"Time steps must be {time_steps}"

        self.num_samples = self.driver_data.shape[0]
        self.time_steps = time_steps

        # Pre-process masks (N, 1, H, W)
        self.masks = self.driver_data[:, -1, ..., 4]  # (N, H, W)
        self.masks = np.expand_dims(self.masks, axis=1).astype(np.float32)

        # Pre-process inputs (N, T, C_in, H, W)
        self.inputs = np.transpose(self.driver_data, (0, 1, 4, 2, 3)).astype(np.float32)
        
        self.label_data = self.label_data.astype(np.float32)

        del self.driver_data
        gc.collect()

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.inputs[idx]  # (T, C_in, H, W)
        
        y = self.label_data[idx].squeeze(-1) if self.label_data[idx].ndim == 3 else self.label_data[idx]
        y = np.expand_dims(y, axis=0)  # (1, H, W)
        
        m = self.masks[idx]  # (1, H, W)

        return (
            torch.from_numpy(x).float(),
            torch.from_numpy(y).float(),
            torch.from_numpy(m).float()
        )

class Flood3DChunkedTestDataset(Dataset):
    """
    3D Spatio-temporal Test Dataset with Chunked Loading.
    Loads data in chunks, suitable for very large high-resolution (30m) data
    that does not fit into RAM.
    """
    def __init__(self, 
                 driver_path: str, 
                 label_path: str, 
                 driver_key: str = 'data', 
                 label_key: str = 'data',
                 time_steps: int = 11, 
                 chunk_size: int = 4):
        
        if not os.path.isfile(driver_path):
            raise FileNotFoundError(f"Driver file not found: {driver_path}")
        if not os.path.isfile(label_path):
            raise FileNotFoundError(f"Label file not found: {label_path}")

        self.driver_path = driver_path
        self.label_path = label_path
        self.driver_key = driver_key
        self.label_key = label_key
        self.time_steps = time_steps
        self.chunk_size = chunk_size
        
        with h5py.File(driver_path, 'r') as hf:
            data_shape = hf[driver_key].shape
            if len(data_shape) == 5 and data_shape[0] == 1:
                self.num_samples = 1
            else:
                self.num_samples = data_shape[0]
        
        self.current_chunk_idx = -1
        self.current_chunk_data = None

    def __len__(self) -> int:
        return self.num_samples
    
    def load_chunk(self, chunk_idx: int):
        """Loads a specific chunk of data into memory."""
        start_idx = chunk_idx * self.chunk_size
        end_idx = min(start_idx + self.chunk_size, self.num_samples)
        
        logger.info(f"Loading data chunk {chunk_idx}, samples: {start_idx} to {end_idx-1}")
        
        if self.current_chunk_data is not None:
            del self.current_chunk_data
            self.current_chunk_data = None
            clean_memory()
        
        with h5py.File(self.driver_path, 'r') as hf:
            if self.num_samples == 1 and len(hf[self.driver_key].shape) == 5 and hf[self.driver_key].shape[0] == 1:
                driver_data = np.expand_dims(hf[self.driver_key][0], axis=0)
            else:
                driver_data = hf[self.driver_key][start_idx:end_idx]
        
        with h5py.File(self.label_path, 'r') as hf:
            if self.num_samples == 1 and len(hf[self.label_key].shape) == 4 and hf[self.label_key].shape[0] == 1:
                label_data = np.expand_dims(hf[self.label_key][0], axis=0)
            else:
                label_data = hf[self.label_key][start_idx:end_idx]
        
        driver_data = driver_data.astype(np.float32)
        label_data = label_data.astype(np.float32)
        
        input_tensors, label_tensors, mask_tensors = [], [], []
        
        for i in range(driver_data.shape[0]):
            mask = driver_data[i, -1, ..., 4]  # (H, W)
            mask = np.expand_dims(mask, axis=0)  # (1, H, W)
            
            input_data = np.transpose(driver_data[i], (0, 3, 1, 2))  # (T, C_in, H, W)
            
            label_sample = label_data[i].squeeze(-1) if label_data[i].ndim == 3 else label_data[i]
            label_sample = np.expand_dims(label_sample, 0)  # (1, H, W)
            
            input_tensors.append(torch.from_numpy(input_data).float())
            label_tensors.append(torch.from_numpy(label_sample).float())
            mask_tensors.append(torch.from_numpy(mask).float())
        
        del driver_data, label_data
        gc.collect()
        
        self.current_chunk_idx = chunk_idx
        self.current_chunk_data = {
            'inputs': input_tensors,
            'labels': label_tensors,
            'masks': mask_tensors,
            'start_idx': start_idx
        }
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        chunk_idx = idx // self.chunk_size
        local_idx = idx % self.chunk_size
        
        if chunk_idx != self.current_chunk_idx:
            self.load_chunk(chunk_idx)
        
        if local_idx >= len(self.current_chunk_data['inputs']):
            raise IndexError(f"Index {idx} out of range for chunk {chunk_idx}")
        
        return (
            self.current_chunk_data['inputs'][local_idx],
            self.current_chunk_data['labels'][local_idx],
            self.current_chunk_data['masks'][local_idx]
        )
    
    def __del__(self):
        self.current_chunk_data = None
        gc.collect()

# =========================
# 5. FNO3d Model Architecture
# =========================

class SpectralConv3d(nn.Module):
    """
    3D Spectral Convolution Layer.
    Performs global convolution in the frequency domain.
    """
    def __init__(self, in_channels: int, out_channels: int, 
                 modes_t: int, modes_h: int, modes_w: int):
        super(SpectralConv3d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes_t = modes_t  # Max modes in time dimension
        self.modes_h = modes_h  # Max modes in height dimension
        self.modes_w = modes_w  # Max modes in width dimension
        
        self.scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            self.scale * torch.randn(in_channels, out_channels, 
                                    self.modes_t, self.modes_h, self.modes_w, 
                                    dtype=torch.cfloat)
        )
        self.weights2 = nn.Parameter(
            self.scale * torch.randn(in_channels, out_channels,
                                    self.modes_t, self.modes_h, self.modes_w,
                                    dtype=torch.cfloat)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch, in_channels, T, H, W)
        batch_size = x.shape[0]
        
        # 3D Fourier Transform
        x_ft = torch.fft.rfftn(x.float(), dim=(-3, -2, -1))
        
        # Initialize output tensor in frequency domain
        out_ft = torch.zeros(batch_size, self.out_channels, 
                            x.size(-3), x.size(-2), x.size(-1)//2 + 1,
                            dtype=torch.cfloat, device=x.device)
        
        # Truncate modes
        mt = min(self.modes_t, x.size(-3))
        mh = min(self.modes_h, x.size(-2))
        mw = min(self.modes_w, x.size(-1)//2 + 1)
        
        # Multiply by weights in frequency domain (low-frequency components)
        # Handle positive frequencies
        out_ft[:, :, :mt, :mh, :mw] += torch.einsum(
            "bixyz, ioxyz->boxyz", 
            x_ft[:, :, :mt, :mh, :mw], 
            self.weights1[:, :, :mt, :mh, :mw]
        )
        
        # Handle negative frequencies (for time dimension)
        out_ft[:, :, -mt:, :mh, :mw] += torch.einsum(
            "bixyz, ioxyz->boxyz",
            x_ft[:, :, -mt:, :mh, :mw],
            self.weights2[:, :, :mt, :mh, :mw]
        )
        
        # 3D Inverse Fourier Transform
        x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x

class PositionalEncoding3D(nn.Module):
    """
    3D Positional Encoding with caching.
    
    Generates dynamic positional encodings based on normalized
    coordinates (T, H, W) and caches them for efficiency.
    """
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        # Term for sin/cos frequencies
        self.div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # Cache for different resolutions
        self.pe_cache: Dict[str, torch.Tensor] = {}
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input shape: (batch, T, C, H, W)
        Output shape: (batch, T, H, W, d_model)
        """
        device = x.device
        batch, T, C, H, W = x.shape
        
        cache_key = f"{T}_{H}_{W}_{device}"
        
        if cache_key in self.pe_cache:
            # Return cached encoding, expanded to batch size
            return self.pe_cache[cache_key][None, ...].expand(batch, -1, -1, -1, -1)
        
        # Create normalized coordinate grids
        pos_t = torch.arange(T, device=device).float() / T
        pos_h = torch.arange(H, device=device).float() / H
        pos_w = torch.arange(W, device=device).float() / W
        
        pe = torch.zeros(T, H, W, self.d_model, device=device)
        div_term = self.div_term.to(device)
        
        # Interleave sin/cos, multiplying across dimensions
        for i in range(self.d_model // 2):
            sin_t = torch.sin(pos_t * div_term[i])
            cos_t = torch.cos(pos_t * div_term[i])
            sin_h = torch.sin(pos_h * div_term[i])
            cos_h = torch.cos(pos_h * div_term[i])
            sin_w = torch.sin(pos_w * div_term[i])
            cos_w = torch.cos(pos_w * div_term[i])

            pe[..., 2*i]   = sin_t[:, None, None] * sin_h[None, :, None] * sin_w[None, None, :]
            pe[..., 2*i+1] = cos_t[:, None, None] * cos_h[None, :, None] * cos_w[None, None, :]
        
        self.pe_cache[cache_key] = pe
        
        return pe[None, ...].expand(batch, -1, -1, -1, -1)

class FNO3d(nn.Module):
    """
    Spatio-temporal Fourier Neural Operator (FNO3d).
    
    This model processes a sequence of 3D spatio-temporal data (T, H, W)
    and predicts the state at the final time step.
    """
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int, 
                 modes_t: int, 
                 modes_h: int, 
                 modes_w: int,
                 hidden_channels: int, 
                 num_layers: int):
        super(FNO3d, self).__init__()
        
        self.pos_encoder = PositionalEncoding3D(hidden_channels)
        
        # 1. Lifting Layer (P)
        # Input channels = in_channels + positional_encoding_dim
        self.fc0 = nn.Linear(in_channels + hidden_channels, hidden_channels)
        
        self.spectral_convs = nn.ModuleList()
        self.pointwise_convs = nn.ModuleList()
        
        # 2. Fourier Layers (Fl)
        for _ in range(num_layers):
            self.spectral_convs.append(
                SpectralConv3d(hidden_channels, hidden_channels, 
                              modes_t, modes_h, modes_w)
            )
            # Local path (1x1x1 conv)
            self.pointwise_convs.append(
                nn.Conv3d(hidden_channels, hidden_channels, 1)
            )
        
        # 3. Projection Layer (Q)
        # Predicts from the last time step's features
        self.fc1 = nn.Linear(hidden_channels, out_channels)
        self.activation = nn.GELU()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input shape: (batch, T, C, H, W)
        Output shape: (batch, out_channels, H, W)
        """
        batch_size, T, C, H, W = x.shape
        
        # 1. Add positional encoding
        pos_enc = self.pos_encoder(x)  # (batch, T, H, W, hidden)
        
        # 2. Concatenate features and positional encoding
        x = x.permute(0, 1, 3, 4, 2)  # (batch, T, H, W, C)
        x = torch.cat([x, pos_enc], dim=-1) # (batch, T, H, W, C + hidden)
        
        # 3. Lift to hidden dimension
        x = self.fc0(x)  # (batch, T, H, W, hidden)
        x = x.permute(0, 4, 1, 2, 3)  # (batch, hidden, T, H, W)
        
        # 4. Apply Fourier Layers
        for spec_conv, pw_conv in zip(self.spectral_convs, self.pointwise_convs):
            x1 = spec_conv(x)  # Global (frequency) path
            x2 = pw_conv(x)    # Local (physical) path
            x = self.activation(x1 + x2) # Add & activate
        
        # 5. Aggregate time (select last time step)
        x = x[:, :, -1]  # (batch, hidden, H, W)
        
        # 6. Project to output
        x = x.permute(0, 2, 3, 1)  # (batch, H, W, hidden)
        x = self.fc1(x)            # (batch, H, W, out_channels)
        x = x.permute(0, 3, 1, 2)  # (batch, out_channels, H, W)
        
        return x

# =========================
# 6. Training & Evaluation Functions
# =========================

def train_one_epoch(model: nn.Module, 
                    dataloader: DataLoader, 
                    optimizer: optim.Optimizer, 
                    criterion: nn.Module, 
                    device: torch.device, 
                    epoch: int, 
                    total_epochs: int, 
                    memory_cleanup_freq: int = 5) -> Tuple[float, float, float, float]:
    """
    Runs a single training epoch with optimized memory management.
    """
    model.train()
    epoch_loss = 0.0
    batch_count = 0
    all_preds, all_labels = [], []
    start_time = time.time()
    
    progress_bar = tqdm(dataloader, 
                       desc=f"Epoch {epoch}/{total_epochs} [Train]", 
                       bar_format="{l_bar}{bar:20}{r_bar}",
                       leave=False)
    
    for batch_idx, (inputs, targets, masks) in enumerate(progress_bar):
        inputs = inputs.to(device)  # (B, T, C_in, H, W)
        targets = targets.to(device)  # (B, 1, H, W)
        masks = masks.to(device)  # (B, 1, H, W)
        
        optimizer.zero_grad(set_to_none=True)
        
        # Forward pass
        outputs = model(inputs)  # (B, 1, H, W)
        
        # Calculate loss (per-pixel)
        loss_tensor = criterion(outputs, targets, masks)
        # Aggregate loss only on valid (masked) pixels
        masked_loss = loss_tensor.sum() / (masks.sum() + 1e-8)
        
        # Backward pass
        masked_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        epoch_loss += masked_loss.item()
        batch_count += 1
        
        with torch.no_grad():
            valid_mask = masks.squeeze(1).bool()
            all_preds.append(outputs.detach().squeeze(1)[valid_mask].cpu().numpy())
            all_labels.append(targets.squeeze(1)[valid_mask].cpu().numpy())
        
        del inputs, targets, masks, outputs, loss_tensor, masked_loss
        
        if (batch_idx + 1) % memory_cleanup_freq == 0:
            clean_memory()
        
        progress_bar.set_postfix({
            'Loss': f"{epoch_loss / (batch_idx + 1):.4f}",
            'LR': f"{optimizer.param_groups[0]['lr']:.2e}"
        }, refresh=False)
    
    epoch_duration = time.time() - start_time
    avg_loss = epoch_loss / batch_count
    
    train_r2, train_rmse = 0.0, 0.0
    try:
        if all_preds and all_labels:
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            
            train_r2 = r2_score(all_labels, all_preds)
            train_rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
            
            del all_preds, all_labels
    except Exception as e:
        logger.error(f"Error calculating training metrics: {e}")
    
    clean_memory()
    
    return avg_loss, train_r2, train_rmse, epoch_duration

@torch.no_grad()
def predict_and_evaluate(model: nn.Module, 
                         dataloader: DataLoader, 
                         device: torch.device, 
                         batch_size: int = 4, 
                         memory_cleanup_freq: int = 5) -> Tuple[float, float, float]:
    """
    Optimized prediction and evaluation loop for test sets.
    Uses AMP for faster inference and robust memory handling.
    """
    model.eval()
    all_preds, all_labels = [], []
    start_time = time.time()
    batch_count = 0
    
    pbar = tqdm(dataloader, 
               desc="Testing [3D]", 
               bar_format="{l_bar}{bar:20}{r_bar}",
               leave=False)
    
    # Use Automatic Mixed Precision (AMP) for faster inference
    with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
        for inputs, labels, masks in pbar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            outputs = model(inputs)
            
            valid_mask = masks.squeeze(1).bool()
            all_preds.append(outputs.squeeze(1)[valid_mask].cpu().numpy())
            all_labels.append(labels.squeeze(1)[valid_mask].cpu().numpy())
            
            batch_count += 1
            del inputs, labels, masks, outputs, valid_mask
            
            if batch_count % memory_cleanup_freq == 0:
                clean_memory()
                # Consolidate memory if lists get too large
                if len(all_preds) > 100:
                    all_preds = [np.concatenate(all_preds)]
                    all_labels = [np.concatenate(all_labels)]
    
    r2, rmse = 0.0, float('inf')
    try:
        if all_preds and all_labels:
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            
            r2 = r2_score(all_labels, all_preds)
            rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
            
            del all_preds, all_labels
    except Exception as e:
        logger.error(f"Error calculating test metrics: {e}")
    
    test_time = time.time() - start_time
    clean_memory()
    
    return r2, rmse, test_time

@torch.no_grad()
def evaluate_on_training_set(model: nn.Module, 
                             train_dataset: Dataset, 
                             device: torch.device, 
                             batch_size: int = 600, 
                             memory_cleanup_freq: int = 5) -> Tuple[float, float, float]:
    """
    Evaluates the model on the full training set (post-training).
    """
    model.eval()
    all_preds, all_labels = [], []
    start_time = time.time()
    batch_count = 0
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,  # No need to shuffle for evaluation
        pin_memory=True,
        num_workers=0,  # Use main thread for simplicity
    )
    
    pbar = tqdm(train_loader, 
               desc="Evaluating on Training Set", 
               bar_format="{l_bar}{bar:20}{r_bar}",
               leave=False)
    
    with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
        for inputs, labels, masks in pbar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            outputs = model(inputs)
            
            valid_mask = masks.squeeze(1).bool()
            all_preds.append(outputs.squeeze(1)[valid_mask].cpu().numpy())
            all_labels.append(labels.squeeze(1)[valid_mask].cpu().numpy())
            
            batch_count += 1
            del inputs, labels, masks, outputs, valid_mask
            
            if batch_count % memory_cleanup_freq == 0:
                clean_memory()
                if len(all_preds) > 100:
                    all_preds = [np.concatenate(all_preds)]
                    all_labels = [np.concatenate(all_labels)]

    r2, rmse = 0.0, float('inf')
    try:
        if all_preds and all_labels:
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            
            r2 = r2_score(all_labels, all_preds)
            rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
            
            del all_preds, all_labels
    except Exception as e:
        logger.error(f"Error calculating full training set metrics: {e}")
    
    eval_time = time.time() - start_time
    clean_memory()
    
    return r2, rmse, eval_time

# =========================
# 7. Main Execution
# =========================

def get_args() -> argparse.Namespace:
    """
    Parses command-line arguments.
    """
    parser = argparse.ArgumentParser(description="FNO3d Training and Evaluation Script")
    
    # --- Data Paths ---
    parser.add_argument('--train_data_dir', type=str, 
                        default="/home/ubuntu/Documents/xjq/data_Q_timeseries_0329/train_data_LF",
                        help="Directory for multi-event training data.")
    parser.add_argument('--test_100m_dir', type=str, 
                        default="/home/ubuntu/Documents/xjq/data_Q_timeseries_0329/test_data_LF",
                        help="Directory for 100m (LR) test data.")
    parser.add_argument('--test_30m_dir', type=str, 
                        default="/home/ubuntu/Documents/xjq/data_Q_timeseries_0329/test_data_HF",
                        help="Directory for 30m (HR) test data.")
    
    # --- Training Parameters ---
    parser.add_argument('--num_epochs', type=int, default=20, help="Number of training epochs.")
    parser.add_argument('--batch_size', type=int, default=2, help="Batch size for training.")
    parser.add_argument('--lr', type=float, default=1e-3, help="Learning rate.")
    
    # --- Model Hyperparameters (Grid Search) ---
    parser.add_argument('--modes_t', type=int, nargs='+', default=[2, 4, 8], 
                        help="List of temporal modes to try.")
    parser.add_argument('--modes_h', type=int, nargs='+', default=[5, 10, 20, 40], 
                        help="List of height modes to try.")
    parser.add_argument('--modes_w', type=int, nargs='+', default=[5, 10, 20, 40], 
                        help="List of width modes to try.")
    parser.add_argument('--hidden_channels', type=int, nargs='+', default=[10, 20, 40], 
                        help="List of hidden channel sizes to try.")
    parser.add_argument('--num_layers', type=int, nargs='+', default=[1, 3, 5], 
                        help="List of layer counts to try.")
    
    # --- Loss and Memory Config ---
    parser.add_argument('--use_dynamic_loss', action='store_true', 
                        help="Use DynamicWeightedMSELoss instead of standard MSE.")
    parser.add_argument('--use_chunked_loading_hf', action='store_true', 
                        help="Use chunked loading for 30m (HR) test data.")
    parser.add_argument('--hf_chunk_size', type=int, default=300, 
                        help="Chunk size for loading 30m data if chunked loading is enabled.")
    parser.add_argument('--eval_batch_size', type=int, default=60, 
                        help="Batch size for post-training evaluation on the full train set.")
    
    # --- System Config ---
    parser.add_argument('--num_workers', type=int, default=0, 
                        help="Number of DataLoader workers.")
    
    # --- Output Files ---
    parser.add_argument('--results_file', type=str, default="FNO3D_test_results.csv",
                        help="File to save test results.")
    parser.add_argument('--train_eval_file', type=str, default="FNO3D_train_eval_results.csv",
                        help="File to save full training set evaluation results.")

    return parser.parse_args()

def main(args: argparse.Namespace):
    """
    Main training and evaluation loop.
    """
    # 1. Initialization
    set_seed(42, deterministic=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    logger.info("Initial memory state:")
    print_memory_usage()

    # 2. Load Training Dataset
    logger.info("Loading training dataset...")
    
    if os.path.isdir(args.train_data_dir) and len(os.listdir(args.train_data_dir)) > 0:
        logger.info(f"Using multi-event dataset from: {args.train_data_dir}")
        train_dataset = Flood3DMultiEventDataset(args.train_data_dir)
    else:
        raise FileNotFoundError(f"Training data directory not found or is empty: {args.train_data_dir}")
        
    logger.info(f"Training dataset size: {len(train_dataset)}")
    
    logger.info("Memory state after loading training data:")
    print_memory_usage()

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers,
        worker_init_fn=seed_worker
    )

    # 3. Prepare Test Files
    def prepare_test_files(test_dir: str) -> List[Tuple[str, str]]:
        driver_files = sorted([f for f in os.listdir(test_dir) if f.endswith('_X.h5')])
        label_files = sorted([f for f in os.listdir(test_dir) if f.endswith('_Y.h5')])
        if len(driver_files) != len(label_files):
            raise ValueError(f"Test file mismatch in {test_dir}")
        return list(zip(driver_files, label_files))

    test_100m_files = prepare_test_files(args.test_100m_dir)
    test_30m_files = prepare_test_files(args.test_30m_dir)

    # 4. Initialize Log Files
    test_log_file = args.results_file
    with open(test_log_file, 'w') as f:
        f.write("modes_t,modes_h,modes_w,hidden,layers,resolution,event,r2,rmse,time\n")
    
    train_eval_log_file = args.train_eval_file
    with open(train_eval_log_file, 'w') as f:
        f.write("modes_t,modes_h,modes_w,hidden,layers,train_r2,train_rmse,eval_time\n")

    # 5. Experiment Grid Search Loop
    for modes_t in args.modes_t:
        for modes_h in args.modes_h:
            for modes_w in args.modes_w:
                for hidden in args.hidden_channels:
                    for n_layers in args.num_layers:
                        
                        exp_id = f"t{modes_t}_h{modes_h}_w{modes_w}_h{hidden}_l{n_layers}"
                        if args.use_dynamic_loss:
                            exp_id += "_dynamic"
                            
                        best_model_path = f"best_model_{exp_id}.pth"
                        
                        if os.path.exists(best_model_path):
                            logger.info(f"Model for config ({exp_id}) already exists, skipping.")
                            continue

                        train_log_file = f"train_log_{exp_id}.txt"
                        with open(train_log_file, 'w') as f:
                            f.write("epoch,train_loss,train_r2,train_rmse,epoch_time\n")

                        logger.info(f"Memory state before initializing model {exp_id}:")
                        print_memory_usage()

                        # Initialize Model
                        model = FNO3d(
                            in_channels=7,  # 7 input features
                            out_channels=1, # 1 output feature (water depth)
                            modes_t=modes_t,
                            modes_h=modes_h,
                            modes_w=modes_w,
                            hidden_channels=hidden,
                            num_layers=n_layers
                        ).to(device)
                        
                        # Initialize Loss
                        if args.use_dynamic_loss:
                            criterion = DynamicWeightedMSELoss(
                                update_freq=50 # Example update frequency
                            ).to(device)
                            logger.info("Using Dynamic Weighted MSE Loss")
                        else:
                            criterion = StandardMSELoss().to(device)
                            logger.info("Using Standard MSE Loss")
                        
                        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
                        
                        logger.info(f"\n=== Starting Training: {exp_id} ===")
                        logger.info("Memory state after model initialization:")
                        print_memory_usage()
                        
                        # --- Training Loop ---
                        for epoch in range(1, args.num_epochs + 1):
                            train_loss, train_r2, train_rmse, epoch_time = train_one_epoch(
                                model, train_loader, optimizer, criterion, device, 
                                epoch, args.num_epochs,
                                memory_cleanup_freq=10
                            )
                            
                            with open(train_log_file, 'a') as f:
                                f.write(f"{epoch},{train_loss:.6f},{train_r2:.4f},{train_rmse:.4f},{epoch_time:.2f}\n")
                            
                            logger.info(f"Epoch {epoch}/{args.num_epochs}: "
                                       f"Train Loss={train_loss:.4f}, R2={train_r2:.4f}, RMSE={train_rmse:.4f}, "
                                       f"Time={epoch_time:.2f}s")
                            
                            if epoch % 5 == 0:
                                logger.info(f"Memory state after Epoch {epoch}:")
                                print_memory_usage()
                        
                        torch.save(model.state_dict(), best_model_path)
                        logger.info(f"Training complete. Saved final model to {best_model_path}")
                        
                        logger.info("Memory state after training:")
                        print_memory_usage()
                        
                        # --- Full Training Set Evaluation ---
                        logger.info("Evaluating model on the full training set...")
                        
                        train_r2, train_rmse, train_eval_time = evaluate_on_training_set(
                            model, train_dataset, device, 
                            batch_size=args.eval_batch_size,
                            memory_cleanup_freq=10
                        )
                        
                        with open(train_eval_log_file, 'a') as f:
                            f.write(f"{modes_t},{modes_h},{modes_w},{hidden},{n_layers},"
                                   f"{train_r2:.6f},{train_rmse:.6f},{train_eval_time:.2f}\n")
                        
                        logger.info(f"Full Train Set Eval: R2={train_r2:.6f}, RMSE={train_rmse:.6f}, Time={train_eval_time:.2f}s")

                        # --- Test Set Evaluation ---
                        def run_testing(test_files: List[Tuple[str, str]], test_dir: str, resolution: str):
                            for drv_file, lbl_file in test_files:
                                event_name = drv_file.replace('_X.h5', '')
                                drv_path = os.path.join(test_dir, drv_file)
                                lbl_path = os.path.join(test_dir, lbl_file)
                                
                                logger.info(f"Memory state before testing event {event_name}:")
                                print_memory_usage()

                                # Select dataloader based on resolution and config
                                if resolution == "30m" and args.use_chunked_loading_hf:
                                    logger.info(f"Using Chunked Loader for {resolution} data")
                                    test_ds = Flood3DChunkedTestDataset(
                                        driver_path=drv_path, label_path=lbl_path,
                                        driver_key='data', label_key='data',
                                        chunk_size=args.hf_chunk_size
                                    )
                                    test_batch_size = 2 # Smaller batch size for HR data
                                else:
                                    logger.info(f"Using Full-Load Loader for {resolution} data")
                                    test_ds = Flood3DTestDataset(
                                        driver_path=drv_path, label_path=lbl_path,
                                        driver_key='data', label_key='data'
                                    )
                                    test_batch_size = 4 if resolution == "100m" else 2
                                    
                                test_loader = DataLoader(
                                    test_ds, 
                                    batch_size=test_batch_size, 
                                    shuffle=False,
                                    pin_memory=True,
                                    num_workers=0 # Use 0 for chunked/lazy loading
                                )

                                r2, rmse, test_time = predict_and_evaluate(
                                    model, test_loader, device, 
                                    batch_size=test_batch_size,
                                    memory_cleanup_freq=10
                                )

                                with open(test_log_file, 'a') as f:
                                    f.write(f"{modes_t},{modes_h},{modes_w},{hidden},{n_layers},"
                                           f"{resolution},{event_name},{r2:.4f},{rmse:.4f},{test_time:.2f}\n")
                                
                                logger.info(f"Event: {event_name}, Res: {resolution}, "
                                           f"R2={r2:.4f}, RMSE={rmse:.4f}, Time={test_time:.2f}s")
                                
                                del test_ds, test_loader
                                clean_memory()

                        # Run tests for 100m and 30m
                        logger.info("Testing on 100m dataset...")
                        run_testing(test_100m_files, args.test_100m_dir, "100m")

                        logger.info("Testing on 30m (ZS-SR) dataset...")
                        run_testing(test_30m_files, args.test_30m_dir, "30m")

                        # Clean up model from memory
                        del model, criterion, optimizer
                        clean_memory()
                        
                        logger.info(f"Memory state after experiment {exp_id}:")
                        print_memory_usage()

    logger.info(f"\nAll experiments complete!")
    logger.info(f"Test results saved to: {test_log_file}")
    logger.info(f"Training set evaluation results saved to: {train_eval_log_file}")

if __name__ == "__main__":
    # 1. Parse command-line arguments
    args = get_args()
    
    # 2. Log the configuration
    logger.info("Starting FNO3d run with configuration:")
    logger.info("=" * 30)
    for k, v in vars(args).items():
        logger.info(f"{k}: {v}")
    logger.info("=" * 30)
    
    # 3. Run the main function
    main(args)
