In [2]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

Looking in indexes: https://download.pytorch.org/whl/cu126
Collecting torch
  Using cached https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp311-cp311-win_amd64.whl.metadata (28 kB)
Collecting torchvision
  Using cached https://download.pytorch.org/whl/cu126/torchvision-0.21.0%2Bcu126-cp311-cp311-win_amd64.whl.metadata (6.3 kB)
Collecting torchaudio
  Using cached https://download.pytorch.org/whl/cu126/torchaudio-2.6.0%2Bcu126-cp311-cp311-win_amd64.whl.metadata (6.8 kB)
Collecting filelock (from torch)
  Using cached https://download.pytorch.org/whl/filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)
Collecting networkx (from torch)
  Using cached https://download.pytorch.org/whl/networkx-3.3-py3-none-any.whl.metadata (5.1 kB)
Collecting jinja2 (from torch)
  Using cached https://download.pytorch.org/whl/Jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)
Collecting fsspec (from torch)
  Using cached https://download.pytorch.org/whl/fsspec-2024.6.1-py3-none-any.whl.metadata (11 k



In [3]:
import torch

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("Number of GPUs:", torch.cuda.device_count())
    print("Current GPU:", torch.cuda.get_device_name(0))
else:
    print("CUDA is not available.")


PyTorch version: 2.6.0+cu126
CUDA available: True
Number of GPUs: 1
Current GPU: NVIDIA GeForce RTX 3050 Laptop GPU


In [6]:
import os
os.environ["DDE_BACKEND"] = "pytorch"
try:
  import torch
except ImportError:
  %pip install torch -q
try:
  import deepxde
except ImportError:
  %pip install deepxde -q
try:
  import pandas
except ImportError:
  %pip install pandas -q
try:
  import matplotlib
except ImportError:
  %pip install matplotlib -q
import sys
import time
import logging
import numpy as np
import torch
import deepxde as dde
import matplotlib.pyplot as plt
import pandas as pd
from scipy.interpolate import griddata
import re # <<<--- IMPORT REGEX MODULE



# <<<--- Attempt to explicitly set backend (ignore error if attribute doesn't exist) ---
try:
    dde.config.set_default_backend("pytorch")
    print("Attempted to explicitly set DeepXDE backend to PyTorch.")
except AttributeError:
    print(f"Warning: Could not explicitly set backend via dde.config (likely older DeepXDE version). Relying on environment variable.")
except Exception as e:
    print(f"Warning: Could not explicitly set backend via dde.config: {e}")
# --- END OF ADDED CODE ---

print(f"DeepXDE Backend requested: {os.environ.get('DDE_BACKEND', 'Not Set')}") # Use .get for safety
# Ensure dde.backend is accessed after potential import
if "deepxde" in sys.modules:
    print(f"DeepXDE Backend actual: {dde.backend.backend_name}")

    # Set up device (check AFTER backend is confirmed)
    if dde.backend.backend_name == "pytorch":
        if torch.cuda.is_available():
            print("CUDA available. Setting default device to CUDA.")
            try:
                # Let DeepXDE manage device placement where possible
                # torch.set_default_device("cuda") # This might conflict with DDE sometimes
                # Ensure default dtype is float32 for consistency
                torch.set_default_dtype(torch.float32)
                # Get the device DDE will use
                current_device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
                print(f"PyTorch CUDA device detected by DDE: {current_device}")
            except Exception as e:
                print(f"Warning: Error during PyTorch device setup: {e}")
        else:
            print("CUDA not available. Using CPU.")
            try:
                # torch.set_default_device("cpu")
                torch.set_default_dtype(torch.float32)
                print(f"PyTorch default device set to: cpu")
            except Exception as e:
                print(f"Warning: Failed to set default PyTorch device to CPU: {e}")

        print(f"PyTorch default dtype: {torch.get_default_dtype()}")
    else:
        print(f"Warning: Backend is '{dde.backend.backend_name}', not PyTorch. Skipping PyTorch device setup.")
else:
    print("Error: deepxde module not imported correctly.")


class PlotterConfig:
    """Stores configuration parameters specifically for plotting."""
    NX_PRED = 200
    NY_PRED = 100
    CMAP_VELOCITY = 'viridis'
    CMAP_PRESSURE = 'coolwarm'
    CMAP_TURBULENCE = 'plasma'


class Config:
    """Stores configuration parameters for the simulation."""
    DRIVE_MOUNT_POINT = '/content/drive'
    # Adjust path if necessary
    GDRIVE_BASE_FOLDER = 'PINN_RANS_ChannelFlow'
    OUTPUT_DIR = GDRIVE_BASE_FOLDER
    MODEL_DIR = os.path.join(OUTPUT_DIR, "model_checkpoints")
    LOG_DIR = os.path.join(OUTPUT_DIR, "logs")
    PLOT_DIR = os.path.join(OUTPUT_DIR, "plots")
    DATA_DIR = os.path.join(OUTPUT_DIR, "data")
    LOG_FILE = os.path.join(LOG_DIR, "training_log.log")
    REFERENCE_DATA_FILE = os.path.join(DATA_DIR, "reference_output_data.csv")

    # <<<--- CHECKPOINT_FILENAME_BASE: Just the base name ---
    CHECKPOINT_FILENAME_BASE  = "rans_channel_wf"

    NU = 0.0002
    RHO = 1.0
    MU = RHO * NU
    U_INLET = 1.0
    H = 2.0
    CHANNEL_HALF_HEIGHT = H / 2.0
    L = 10.0
    RE_H = U_INLET * H / NU
    EPS_SMALL = 1e-10

    CMU = 0.09
    CEPS1 = 1.44
    CEPS2 = 1.92
    SIGMA_K = 1.0
    SIGMA_EPS = 1.3

    KAPPA = 0.41
    E_WALL = 9.8
    Y_P = 0.04
    RE_TAU_TARGET = 350
    U_TAU_TARGET = RE_TAU_TARGET * NU / CHANNEL_HALF_HEIGHT
    YP_PLUS_TARGET = Y_P * U_TAU_TARGET / NU
    U_TARGET_WF = (U_TAU_TARGET / KAPPA) * np.log(max(E_WALL * YP_PLUS_TARGET, EPS_SMALL))
    K_TARGET_WF = U_TAU_TARGET**2 / np.sqrt(CMU)
    EPS_TARGET_WF = U_TAU_TARGET**3 / max(KAPPA * Y_P, EPS_SMALL)

    TURBULENCE_INTENSITY = 0.05
    MIXING_LENGTH_SCALE = 0.07 * CHANNEL_HALF_HEIGHT
    K_INLET = 1.5 * (U_INLET * TURBULENCE_INTENSITY)**2
    EPS_INLET = (CMU**0.75) * (K_INLET**1.5) / MIXING_LENGTH_SCALE
    K_INLET_TRANSFORMED = np.log(max(K_INLET, EPS_SMALL))
    EPS_INLET_TRANSFORMED = np.log(max(EPS_INLET, EPS_SMALL))
    K_TARGET_WF_TRANSFORMED = np.log(max(K_TARGET_WF, EPS_SMALL))
    EPS_TARGET_WF_TRANSFORMED = np.log(max(EPS_TARGET_WF, EPS_SMALL))

    GEOM = dde.geometry.Rectangle(xmin=[0, -CHANNEL_HALF_HEIGHT], xmax=[L, CHANNEL_HALF_HEIGHT])

    NUM_LAYERS = 8
    NUM_NEURONS = 64
    ACTIVATION = "tanh"
    INITIALIZER = "Glorot normal"
    NETWORK_INPUTS = 2
    NETWORK_OUTPUTS = 5

    NUM_DOMAIN_POINTS = 20000
    NUM_BOUNDARY_POINTS = 4000
    NUM_TEST_POINTS = 5000
    NUM_WF_POINTS_PER_WALL = 200
    LEARNING_RATE_ADAM = 1e-3
    ADAM_ITERATIONS = 50000
    LBFGS_ITERATIONS = 20000
    PDE_WEIGHTS = [1, 1, 1, 1, 1]
    BC_WEIGHTS = [10, 10, 10, 10, 10, 10, 10, 20, 20, 20]
    LOSS_WEIGHTS = PDE_WEIGHTS + BC_WEIGHTS
    SAVE_INTERVAL = 1000 # Checkpoint saving interval
    DISPLAY_EVERY = 1000 # Loss display interval


# Instantiate config
cfg = Config()
plotter_cfg = PlotterConfig()

DeepXDE Backend requested: pytorch
DeepXDE Backend actual: pytorch
CUDA available. Setting default device to CUDA.
PyTorch CUDA device detected by DDE: 0
PyTorch default dtype: torch.float32


In [1]:
import os
os.environ["DDE_BACKEND"] = "pytorch"
try:
  import torch
except ImportError:
  # If running locally and don't have torch, install it:
  # pip install torch
  print("PyTorch not found. Please install it.")
  exit()
try:
  import deepxde
except ImportError:
  # If running locally and don't have deepxde, install it:
  # pip install deepxde
  print("DeepXDE not found. Please install it.")
  exit()
try:
  import pandas
except ImportError:
  # If running locally and don't have pandas, install it:
  # pip install pandas
  print("Pandas not found. Please install it.")
  exit()
try:
  import matplotlib
except ImportError:
  # If running locally and don't have matplotlib, install it:
  # pip install matplotlib
  print("Matplotlib not found. Please install it.")
  exit()

import sys
import time
import logging
import numpy as np
import torch
import deepxde as dde
import matplotlib.pyplot as plt
import pandas as pd
from scipy.interpolate import griddata
import re # Regex module for checkpoint parsing


# --- Attempt to explicitly set backend ---
try:
    dde.config.set_default_backend("pytorch")
    print("Attempted to explicitly set DeepXDE backend to PyTorch.")
except AttributeError:
    print(f"Warning: Could not explicitly set backend via dde.config (likely older DeepXDE version). Relying on environment variable.")
except Exception as e:
    print(f"Warning: Could not explicitly set backend via dde.config: {e}")

print(f"DeepXDE Backend requested: {os.environ.get('DDE_BACKEND', 'Not Set')}")

# --- Check actual backend and setup device/dtype ---
if "deepxde" in sys.modules:
    print(f"DeepXDE Backend actual: {dde.backend.backend_name}")
    if dde.backend.backend_name == "pytorch":
        if torch.cuda.is_available():
            print("CUDA available.")
            try:
                torch.set_default_dtype(torch.float32)
                current_device = torch.cuda.current_device()
                print(f"PyTorch CUDA device detected by DDE: {current_device} ({torch.cuda.get_device_name(current_device)})")
                print(f"PyTorch version: {torch.__version__}")
                print(f"Number of GPUs: {torch.cuda.device_count()}")
            except Exception as e:
                print(f"Warning: Error during PyTorch device setup: {e}")
        else:
            print("CUDA not available. Using CPU.")
            try:
                torch.set_default_dtype(torch.float32)
                print(f"PyTorch default device set to: cpu")
            except Exception as e:
                print(f"Warning: Failed to set default PyTorch device to CPU: {e}")
        print(f"PyTorch default dtype: {torch.get_default_dtype()}")
    else:
        print(f"Warning: Backend is '{dde.backend.backend_name}', not PyTorch. PyTorch-specific device setup skipped.")
else:
    print("Error: deepxde module not imported correctly.")


# =============================
# ===== Configuration Classes =====
# =============================

class PlotterConfig:
    """Stores configuration parameters specifically for plotting."""
    NX_PRED = 200
    NY_PRED = 100
    CMAP_VELOCITY = 'viridis'
    CMAP_PRESSURE = 'coolwarm'
    CMAP_TURBULENCE = 'plasma'


class Config:
    """Stores configuration parameters for the simulation."""
    DRIVE_MOUNT_POINT = '/content/drive' # Specific to Google Colab
    GDRIVE_BASE_FOLDER = 'PINN_RANS_ChannelFlow' # Base folder name on Drive or local
    OUTPUT_DIR = GDRIVE_BASE_FOLDER # Default output dir (can be overwritten if not on Colab)
    MODEL_DIR = os.path.join(OUTPUT_DIR, "model_checkpoints")
    LOG_DIR = os.path.join(OUTPUT_DIR, "logs")
    PLOT_DIR = os.path.join(OUTPUT_DIR, "plots")
    DATA_DIR = os.path.join(OUTPUT_DIR, "data")
    LOG_FILE = os.path.join(LOG_DIR, "training_log.log")
    REFERENCE_DATA_FILE = os.path.join(DATA_DIR, "reference_output_data.csv")

    # Checkpoint filename base (without step number or extension)
    CHECKPOINT_FILENAME_BASE  = "rans_channel_wf"

    # --- Fluid and Geometry Parameters ---
    NU = 0.0002 # Kinematic viscosity
    RHO = 1.0 # Density (often set to 1 for incompressible flow)
    MU = RHO * NU # Dynamic viscosity
    U_INLET = 1.0 # Inlet velocity
    H = 2.0 # Full channel height
    CHANNEL_HALF_HEIGHT = H / 2.0
    L = 10.0 # Channel length
    RE_H = U_INLET * H / NU # Reynolds number based on full height
    EPS_SMALL = 1e-10 # Small epsilon for numerical stability (avoid log(0), division by zero)

    # --- k-epsilon Model Constants ---
    CMU = 0.09
    CEPS1 = 1.44
    CEPS2 = 1.92
    SIGMA_K = 1.0
    SIGMA_EPS = 1.3
    KAPPA = 0.41 # Von Karman constant

    # --- Wall Function Parameters ---
    E_WALL = 9.8 # Log-law constant for smooth walls
    Y_P = 0.04 # Distance from wall for applying wall functions (y_p)
    # Target values for deriving wall function BCs (can be based on desired Re_tau)
    RE_TAU_TARGET = 350 # Target friction Reynolds number
    U_TAU_TARGET = RE_TAU_TARGET * NU / CHANNEL_HALF_HEIGHT # Target friction velocity
    YP_PLUS_TARGET = Y_P * U_TAU_TARGET / NU # Target y+ at y_p
    # Target values at y_p based on log-law and turbulence equilibrium
    U_TARGET_WF = (U_TAU_TARGET / KAPPA) * np.log(max(E_WALL * YP_PLUS_TARGET, EPS_SMALL))
    K_TARGET_WF = U_TAU_TARGET**2 / np.sqrt(CMU)
    EPS_TARGET_WF = U_TAU_TARGET**3 / max(KAPPA * Y_P, EPS_SMALL)

    # --- Inlet Turbulence Parameters ---
    TURBULENCE_INTENSITY = 0.05 # Typical value for channel flow inlet
    MIXING_LENGTH_SCALE = 0.07 * CHANNEL_HALF_HEIGHT # Estimate based on boundary layer thickness
    # Inlet k and epsilon based on intensity and length scale
    K_INLET = 1.5 * (U_INLET * TURBULENCE_INTENSITY)**2
    EPS_INLET = (CMU**0.75) * (K_INLET**1.5) / MIXING_LENGTH_SCALE
    # Transformed values (log) for network prediction/BCs
    K_INLET_TRANSFORMED = np.log(max(K_INLET, EPS_SMALL))
    EPS_INLET_TRANSFORMED = np.log(max(EPS_INLET, EPS_SMALL))
    K_TARGET_WF_TRANSFORMED = np.log(max(K_TARGET_WF, EPS_SMALL))
    EPS_TARGET_WF_TRANSFORMED = np.log(max(EPS_TARGET_WF, EPS_SMALL))

    # --- Domain Geometry ---
    GEOM = dde.geometry.Rectangle(xmin=[0, -CHANNEL_HALF_HEIGHT], xmax=[L, CHANNEL_HALF_HEIGHT])

    # --- Network Architecture ---
    NUM_LAYERS = 8
    NUM_NEURONS = 64
    ACTIVATION = "tanh"
    INITIALIZER = "Glorot normal"
    NETWORK_INPUTS = 2 # x, y
    NETWORK_OUTPUTS = 5 # u, v, p', log(k), log(eps)

    # --- Training Parameters ---
    NUM_DOMAIN_POINTS = 20000
    NUM_BOUNDARY_POINTS = 4000 # For physical boundaries (inlet, outlet, walls)
    NUM_TEST_POINTS = 5000 # For evaluating PDE residuals during training/testing
    NUM_WF_POINTS_PER_WALL = 200 # Anchor points for wall function BCs (per wall)
    LEARNING_RATE_ADAM = 1e-3
    ADAM_ITERATIONS = 50000 # Number of iterations for Adam optimizer
    LBFGS_ITERATIONS = 20000 # Max iterations for L-BFGS optimizer
    # Loss weights: [PDE_cont, PDE_mom_x, PDE_mom_y, PDE_k, PDE_eps, BC_u_in, BC_v_in, BC_k_in, BC_eps_in, BC_p_out, BC_u_wall, BC_v_wall, BC_u_wf, BC_k_wf, BC_eps_wf]
    PDE_WEIGHTS = [1, 1, 1, 1, 1] # Weights for the 5 PDE residuals
    BC_WEIGHTS = [10, 10, 10, 10, 10, 10, 10, 20, 20, 20] # Weights for the 10 BCs (adjust as needed)
    LOSS_WEIGHTS = PDE_WEIGHTS + BC_WEIGHTS
    SAVE_INTERVAL = 1000 # Checkpoint saving interval (steps)
    DISPLAY_EVERY = 1000 # Loss display interval (steps)


# Instantiate config objects
cfg = Config()
plotter_cfg = PlotterConfig()


# ==============================================
# ===== Custom Checkpoint Callback Class ========
# ==============================================
class CustomModelCheckpoint(dde.callbacks.Callback):
    """Custom checkpoint callback that saves based on global step."""
    def __init__(self, filepath_base, period, verbose=1):
        super().__init__()
        self.filepath_base = filepath_base # e.g., /path/to/model_checkpoints/rans_channel_wf-
        self.period = period
        self.verbose = verbose
        self._saved_steps = set() # Tracks steps saved in the current trainer.train() call

    def on_epoch_end(self):
        """Check step at the end of each epoch."""
        self._save_checkpoint()

    def _save_checkpoint(self):
        """Internal method to check step and save."""
        if not hasattr(self, 'model') or not self.model or not hasattr(self.model, 'train_state') or not self.model.train_state:
            return # Model not ready

        if dde.backend.backend_name != "pytorch":
             logging.warning("CustomModelCheckpoint requires PyTorch backend for model.save behavior.")
             return

        step = self.model.train_state.step

        if step > 0 and step % self.period == 0 and step not in self._saved_steps:
            filepath = f"{self.filepath_base}{step}.pt" # Construct filename using global step
            if self.verbose > 0:
                logging.info(f"Step {step}: saving model to {filepath} ...")
            try:
                # Use DeepXDE's save method which handles backend specifics
                self.model.save(filepath, verbose=0)
                self._saved_steps.add(step) # Mark this step as saved for this run
            except Exception as e:
                logging.error(f"Error saving checkpoint at step {step} to {filepath}: {e}", exc_info=True)

# ==============================================
# ===== END: Custom Checkpoint Callback ========
# ==============================================


# ==========================
# ===== Utility Functions =====
# ==========================
def setup_logging(log_file):
    """Configures logging to file and console."""
    log_dir = os.path.dirname(log_file)
    ensure_dir(log_dir)
    root_logger = logging.getLogger()
    # Clear existing handlers to avoid duplicate messages
    if root_logger.hasHandlers():
        root_logger.handlers.clear()
    log_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt='%Y-%m-%d %H:%M:%S')
    root_logger.setLevel(logging.INFO) # Set root logger level

    # File handler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(log_formatter)
    root_logger.addHandler(file_handler)

    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(log_formatter)
    root_logger.addHandler(console_handler)
    logging.info("Logging configured.")

def ensure_dir(directory):
    """Creates a directory if it doesn't exist."""
    if not os.path.exists(directory):
        os.makedirs(directory)
        logging.info(f"Created directory: {directory}")

def mount_drive(mount_point):
    """Mounts Google Drive if running in Colab."""
    if 'google.colab' in sys.modules:
        if not os.path.exists(os.path.join(mount_point, 'MyDrive')):
            try:
                from google.colab import drive
                logging.info(f"Mounting Google Drive at {mount_point}...")
                drive.mount(mount_point, force_remount=True)
                logging.info("Google Drive mounted successfully.")
                # Verify base folder access after mount
                gdrive_output_path = os.path.join(mount_point, 'MyDrive', cfg.GDRIVE_BASE_FOLDER)
                cfg.OUTPUT_DIR = gdrive_output_path # IMPORTANT: Update config path
                cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
                cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
                cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
                cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
                cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "training_log.log")
                cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
                logging.info(f"Output paths updated to Google Drive: {cfg.OUTPUT_DIR}")
                ensure_dir(cfg.OUTPUT_DIR)
                if os.path.exists(cfg.OUTPUT_DIR):
                    logging.info(f"Base folder exists: {cfg.OUTPUT_DIR}")
                else:
                    logging.warning(f"Configured base folder NOT found after mount: {cfg.OUTPUT_DIR}")
            except Exception as e:
                logging.error(f"Error mounting Google Drive or accessing path: {e}")
                # Fallback to local directory if mount fails
                logging.warning("Falling back to local directory structure.")
                cfg.OUTPUT_DIR = cfg.GDRIVE_BASE_FOLDER # Use base folder name locally
                cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
                cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
                cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
                cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
                cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "training_log.log")
                cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
        else:
            logging.info("Google Drive already mounted.")
            # Still update paths if already mounted
            gdrive_output_path = os.path.join(mount_point, 'MyDrive', cfg.GDRIVE_BASE_FOLDER)
            cfg.OUTPUT_DIR = gdrive_output_path # IMPORTANT: Update config path
            cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
            cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
            cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
            cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
            cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "training_log.log")
            cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
            logging.info(f"Output paths point to Google Drive: {cfg.OUTPUT_DIR}")
    else:
        logging.info("Not running in Google Colab. Using local directory structure.")
        # Ensure local paths are based on the script location or CWD
        cfg.OUTPUT_DIR = cfg.GDRIVE_BASE_FOLDER # Use base folder name locally
        cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
        cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
        cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
        cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
        cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "training_log.log")
        cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")


def setup_output_directories(config):
    """Creates all necessary output directories."""
    logging.info("Setting up output directories...")
    ensure_dir(config.OUTPUT_DIR)
    ensure_dir(config.MODEL_DIR)
    ensure_dir(config.LOG_DIR)
    ensure_dir(config.PLOT_DIR)
    ensure_dir(config.DATA_DIR)
    logging.info("Output directories verified/created.")

def log_configuration(config, plotter_config):
    """Logs the simulation and plotter configuration."""
    logging.info("=" * 50)
    logging.info("Simulation Configuration:")
    logging.info(f"  Output Directory: {config.OUTPUT_DIR}")
    logging.info(f"  Re_H: {config.RE_H:.0f}")
    logging.info(f"  Wall Function y_p: {config.Y_P} (Target y+: {config.YP_PLUS_TARGET:.2f})")
    logging.info(f"  Network: {config.NUM_LAYERS} layers, {config.NUM_NEURONS} neurons")
    logging.info(f"  Inlet k (log): {config.K_INLET_TRANSFORMED:.4f}, Inlet eps (log): {config.EPS_INLET_TRANSFORMED:.4f}")
    logging.info(f"  Target WF U: {config.U_TARGET_WF:.4f}, k (log): {config.K_TARGET_WF_TRANSFORMED:.4f}, eps (log): {config.EPS_TARGET_WF_TRANSFORMED:.4f}")
    logging.info(f"  Adam Iterations: {config.ADAM_ITERATIONS}, LR: {config.LEARNING_RATE_ADAM}")
    logging.info(f"  L-BFGS Iterations: {config.LBFGS_ITERATIONS}")
    logging.info(f"  Checkpoint Interval: {config.SAVE_INTERVAL}")
    logging.info(f"  Reference Data File (CSV): {config.REFERENCE_DATA_FILE}")
    logging.info("Plotter Configuration:")
    logging.info(f"  Prediction Grid Nx: {plotter_config.NX_PRED}, Ny: {plotter_config.NY_PRED}")
    logging.info("=" * 50)
# --- End Utility Functions ---


# ===============================
# ===== PDE System Definition =====
# ===============================
def pde(x, y, config): # Pass config explicitly
    """Defines the RANS k-epsilon PDE system."""
    if dde.backend.backend_name != "pytorch":
        logging.warning("PDE function relies on PyTorch autograd for transformations. Backend mismatch.")

    nu = config.NU; Cmu = config.CMU; Ceps1 = config.CEPS1; Ceps2 = config.CEPS2
    sigma_k = config.SIGMA_K; sigma_eps = config.SIGMA_EPS; eps_small = config.EPS_SMALL

    # Network outputs: u, v, p', log(k), log(eps)
    u, v, p_prime, k_raw, eps_raw = y[:, 0:1], y[:, 1:2], y[:, 2:3], y[:, 3:4], y[:, 4:5]

    # --- Apply inverse transformation and enforce positivity ---
    # Add eps_small *before* exp for k_raw, eps_raw if they can be very negative
    # k = dde.backend.exp(k_raw + eps_small) # Alternative if log(k) can be very negative
    k = dde.backend.exp(k_raw) + eps_small
    eps = dde.backend.exp(eps_raw) + eps_small

    # --- Calculate Gradients using PyTorch Autograd ---
    # Use autograd for reliability with transformed variables k, eps
    try:
        # Gradients of primitive variables (u, v, p') directly from network output 'y'
        u_x = dde.grad.jacobian(y, x, i=0, j=0); u_y = dde.grad.jacobian(y, x, i=0, j=1)
        v_x = dde.grad.jacobian(y, x, i=1, j=0); v_y = dde.grad.jacobian(y, x, i=1, j=1)
        p_prime_x = dde.grad.jacobian(y, x, i=2, j=0); p_prime_y = dde.grad.jacobian(y, x, i=2, j=1)

        u_xx = dde.grad.hessian(y, x, component=0, i=0, j=0); u_yy = dde.grad.hessian(y, x, component=0, i=1, j=1)
        v_xx = dde.grad.hessian(y, x, component=1, i=0, j=0); v_yy = dde.grad.hessian(y, x, component=1, i=1, j=1)
        # Mixed derivatives (needed for momentum diffusion terms)
        u_xy = dde.grad.hessian(y, x, component=0, i=0, j=1); v_xy = dde.grad.hessian(y, x, component=1, i=0, j=1) # or j=0, i=1

        # Gradients of transformed k, eps using torch.autograd
        # Ensure x requires grad if DeepXDE doesn't handle it automatically in this context
        if isinstance(x, torch.Tensor) and not x.requires_grad:
            x.requires_grad_(True)

        grad_k = torch.autograd.grad(k, x, grad_outputs=torch.ones_like(k), create_graph=True)[0]
        k_x, k_y = grad_k[:, 0:1], grad_k[:, 1:2]
        grad_eps = torch.autograd.grad(eps, x, grad_outputs=torch.ones_like(eps), create_graph=True)[0]
        eps_x, eps_y = grad_eps[:, 0:1], grad_eps[:, 1:2]

        # Hessians of transformed k, eps using torch.autograd
        grad_kx = torch.autograd.grad(k_x, x, grad_outputs=torch.ones_like(k_x), create_graph=True)[0]
        k_xx = grad_kx[:, 0:1]
        grad_ky = torch.autograd.grad(k_y, x, grad_outputs=torch.ones_like(k_y), create_graph=True)[0]
        k_yy = grad_ky[:, 1:2]

        grad_epsx = torch.autograd.grad(eps_x, x, grad_outputs=torch.ones_like(eps_x), create_graph=True)[0]
        eps_xx = grad_epsx[:, 0:1]
        grad_epsy = torch.autograd.grad(eps_y, x, grad_outputs=torch.ones_like(eps_y), create_graph=True)[0]
        eps_yy = grad_epsy[:, 1:2]

    except Exception as grad_e:
        logging.error(f"Error calculating gradients in PDE function: {grad_e}", exc_info=True)
        # Return tensors of zeros with the correct shape and device
        zero_tensor = torch.zeros_like(y[:, 0:1])
        return [zero_tensor] * 5 # Match the number of expected PDE residual outputs

    # --- Turbulent Viscosity ---
    # Use the transformed k, eps which are guaranteed positive
    k_safe = k
    eps_safe = eps # eps already has eps_small added
    # Add eps_small to denominator for extra safety, although eps_safe should be positive
    nu_t = Cmu * dde.backend.square(k_safe) / (eps_safe + eps_small) # Eq: nu_t = Cmu * k^2 / eps
    nu_eff = nu + nu_t # Effective viscosity

    # --- Gradients of nu_eff (Needed for diffusion terms in momentum eqns) ---
    # Using chain rule: d(nu_eff)/dx = d(nu_t)/dk * dk/dx + d(nu_t)/deps * deps/dx
    dnut_dk = 2.0 * Cmu * k_safe / (eps_safe + eps_small)
    dnut_deps = -Cmu * dde.backend.square(k_safe) / dde.backend.square(eps_safe + eps_small)

    nu_eff_x = dnut_dk * k_x + dnut_deps * eps_x
    nu_eff_y = dnut_dk * k_y + dnut_deps * eps_y

    # --- RANS Equation Residuals ---

    # 1. Continuity Equation: d(u)/dx + d(v)/dy = 0
    eq_continuity = u_x + v_y

    # 2. X-Momentum Equation:
    # d(u)/dt + u*du/dx + v*du/dy = -dp'/dx + d/dx[nu_eff * (2*du/dx)] + d/dy[nu_eff * (du/dy + dv/dx)]
    # Steady state: u*du/dx + v*du/dy + dp'/dx - [d/dx(...) + d/dy(...)] = 0
    adv_u = u * u_x + v * u_y # Advection
    # Diffusion terms (expanded using product rule)
    # d/dx[nu_eff * (2*du/dx)] = d(nu_eff)/dx * (2*du/dx) + nu_eff * (2*d^2u/dx^2)
    diff_u_term1 = nu_eff_x * (2 * u_x) + nu_eff * (2 * u_xx)
    # d/dy[nu_eff * (du/dy + dv/dx)] = d(nu_eff)/dy * (du/dy + dv/dx) + nu_eff * (d^2u/dy^2 + d^2v/dxdy)
    diff_u_term2 = nu_eff_y * (u_y + v_x) + nu_eff * (u_yy + v_xy) # Assuming v_xy = d^2v/dxdy
    eq_mom_x = adv_u + p_prime_x - (diff_u_term1 + diff_u_term2)

    # 3. Y-Momentum Equation:
    # d(v)/dt + u*dv/dx + v*dv/dy = -dp'/dy + d/dx[nu_eff * (dv/dx + du/dy)] + d/dy[nu_eff * (2*dv/dy)]
    # Steady state: u*dv/dx + v*dv/dy + dp'/dy - [d/dx(...) + d/dy(...)] = 0
    adv_v = u * v_x + v * v_y # Advection
    # Diffusion terms (expanded)
    # d/dx[nu_eff * (dv/dx + du/dy)] = d(nu_eff)/dx * (dv/dx + du/dy) + nu_eff * (d^2v/dx^2 + d^2u/dxdy)
    diff_v_term1 = nu_eff_x * (v_x + u_y) + nu_eff * (v_xx + u_xy) # Assuming u_xy = d^2u/dxdy
    # d/dy[nu_eff * (2*dv/dy)] = d(nu_eff)/dy * (2*dv/dy) + nu_eff * (2*d^2v/dy^2)
    diff_v_term2 = nu_eff_y * (2 * v_y) + nu_eff * (2 * v_yy)
    eq_mom_y = adv_v + p_prime_y - (diff_v_term1 + diff_v_term2)

    # --- Turbulence Model Equations ---

    # Production term P_k = nu_t * S^2, where S is the modulus of the mean strain rate tensor
    # S^2 = 2*(du/dx)^2 + 2*(dv/dy)^2 + (du/dy + dv/dx)^2  (for 2D)
    S_squared = 2 * (dde.backend.square(u_x) + dde.backend.square(v_y)) + dde.backend.square(u_y + v_x)
    P_k = nu_t * S_squared

    # 4. k-Equation:
    # d(k)/dt + u*dk/dx + v*dk/dy = d/dx[(nu + nu_t/sigma_k)*dk/dx] + d/dy[(nu + nu_t/sigma_k)*dk/dy] + P_k - eps
    # Steady state: u*dk/dx + v*dk/dy - [Diffusion] - P_k + eps = 0
    adv_k = u * k_x + v * k_y # Advection
    # Diffusion term: div[ (nu + nu_t/sigma_k) * grad(k) ]
    diffusivity_k = nu + nu_t / sigma_k
    # Gradient of diffusivity: d(diff_k)/dx = (1/sigma_k) * d(nu_t)/dx
    d_diffk_dx = (1 / sigma_k) * nu_eff_x # nu_eff_x contains gradients of nu_t
    d_diffk_dy = (1 / sigma_k) * nu_eff_y # nu_eff_y contains gradients of nu_t
    laplacian_k = k_xx + k_yy
    # Expand divergence using product rule: div(D*grad(k)) = grad(D).grad(k) + D*laplacian(k)
    diffusion_k = d_diffk_dx * k_x + d_diffk_dy * k_y + diffusivity_k * laplacian_k
    # Use eps_safe which is guaranteed positive
    eq_k = adv_k - diffusion_k - P_k + eps_safe

    # 5. ε-Equation:
    # d(eps)/dt + u*deps/dx + v*deps/dy = d/dx[(nu + nu_t/sigma_eps)*deps/dx] + d/dy[(nu + nu_t/sigma_eps)*deps/dy] + Ceps1*(eps/k)*P_k - Ceps2*(eps^2/k)
    # Steady state: u*deps/dx + v*deps/dy - [Diffusion] - Source + Sink = 0
    adv_eps = u * eps_x + v * eps_y # Advection
    # Diffusion term: div[ (nu + nu_t/sigma_eps) * grad(eps) ]
    diffusivity_eps = nu + nu_t / sigma_eps
    d_diffeps_dx = (1 / sigma_eps) * nu_eff_x # Gradients of nu_t
    d_diffeps_dy = (1 / sigma_eps) * nu_eff_y
    laplacian_eps = eps_xx + eps_yy
    diffusion_eps = d_diffeps_dx * eps_x + d_diffeps_dy * eps_y + diffusivity_eps * laplacian_eps
    # Source/Sink terms (use safe, positive k and eps)
    # Add eps_small to denominators for robustness, even if k_safe/eps_safe have it
    source_eps = Ceps1 * (eps_safe / (k_safe + eps_small)) * P_k
    sink_eps = Ceps2 * (dde.backend.square(eps_safe) / (k_safe + eps_small))
    eq_eps = adv_eps - diffusion_eps - source_eps + sink_eps

    # Return the residuals of the 5 equations
    return [eq_continuity, eq_mom_x, eq_mom_y, eq_k, eq_eps]
# --- End PDE function ---


# =============================
# ===== Boundary Conditions =====
# =============================
def get_boundary_conditions(config):
    """Defines all boundary conditions for the channel flow problem."""
    geom = config.GEOM
    h = config.CHANNEL_HALF_HEIGHT
    L = config.L
    y_p = config.Y_P # Wall function distance
    n_wf_points = config.NUM_WF_POINTS_PER_WALL

    # --- Boundary Definition Functions ---
    def boundary_inlet(x, on_boundary):
        return on_boundary and np.isclose(x[0], 0)

    def boundary_outlet(x, on_boundary):
        return on_boundary and np.isclose(x[0], L)

    def boundary_bottom_wall_physical(x, on_boundary):
        # Physical wall at y = -h
        return on_boundary and np.isclose(x[1], -h)

    def boundary_top_wall_physical(x, on_boundary):
        # Physical wall at y = +h
        return on_boundary and np.isclose(x[1], h)

    def boundary_walls_physical(x, on_boundary):
        # Combined physical walls
        return boundary_bottom_wall_physical(x, on_boundary) or boundary_top_wall_physical(x, on_boundary)

    # --- Inlet BCs (Dirichlet) ---
    # u = U_INLET, v = 0, k = k_inlet_transformed, eps = eps_inlet_transformed
    bc_u_inlet = dde.DirichletBC(geom, lambda x: config.U_INLET, boundary_inlet, component=0) # component=0 -> u
    bc_v_inlet = dde.DirichletBC(geom, lambda x: 0, boundary_inlet, component=1) # component=1 -> v
    bc_k_inlet = dde.DirichletBC(geom, lambda x: config.K_INLET_TRANSFORMED, boundary_inlet, component=3) # component=3 -> log(k)
    bc_eps_inlet = dde.DirichletBC(geom, lambda x: config.EPS_INLET_TRANSFORMED, boundary_inlet, component=4) # component=4 -> log(eps)

    # --- Outlet BC (Dirichlet) ---
    # p' = 0 (gauge pressure relative to outlet)
    bc_p_outlet = dde.DirichletBC(geom, lambda x: 0, boundary_outlet, component=2) # component=2 -> p'

    # --- Physical Wall BCs (No-slip) ---
    # u = 0, v = 0
    bc_u_walls = dde.DirichletBC(geom, lambda x: 0, boundary_walls_physical, component=0) # u=0 on walls
    bc_v_walls = dde.DirichletBC(geom, lambda x: 0, boundary_walls_physical, component=1) # v=0 on walls

    # --- Wall Function BCs (PointSetBC near walls) ---
    # Define anchor points slightly away from the walls at y = +/- (h - y_p)
    # Avoid placing points exactly at inlet/outlet for stability
    x_wf_coords = np.linspace(0 + L * 0.01, L - L * 0.01, n_wf_points)[:, None]
    points_bottom_wf = np.hstack((x_wf_coords, np.full_like(x_wf_coords, -h + y_p)))
    points_top_wf = np.hstack((x_wf_coords, np.full_like(x_wf_coords, h - y_p)))
    anchor_points_wf = np.vstack((points_bottom_wf, points_top_wf))
    logging.info(f"Generated {anchor_points_wf.shape[0]} anchor points for wall functions.")

    # Target values at these anchor points (using pre-calculated config values)
    U_target_vals = np.full((anchor_points_wf.shape[0], 1), config.U_TARGET_WF)
    k_target_vals = np.full((anchor_points_wf.shape[0], 1), config.K_TARGET_WF_TRANSFORMED)
    eps_target_vals = np.full((anchor_points_wf.shape[0], 1), config.EPS_TARGET_WF_TRANSFORMED)

    # Define PointSetBCs for u, k (log), eps (log) at the anchor points
    bc_u_wf = dde.PointSetBC(anchor_points_wf, U_target_vals, component=0) # Target u at WF points
    bc_k_wf = dde.PointSetBC(anchor_points_wf, k_target_vals, component=3) # Target log(k) at WF points
    bc_eps_wf = dde.PointSetBC(anchor_points_wf, eps_target_vals, component=4) # Target log(eps) at WF points

    # --- Collect all BCs ---
    all_bcs = [
        bc_u_inlet, bc_v_inlet, bc_k_inlet, bc_eps_inlet, # Inlet (4 BCs)
        bc_p_outlet, # Outlet (1 BC)
        bc_u_walls, bc_v_walls, # Physical Walls (2 BCs)
        bc_u_wf, bc_k_wf, bc_eps_wf # Wall Functions (3 BCs)
    ]
    # Note: The anchor_points_wf array is also needed by the Data object later.
    return all_bcs, anchor_points_wf
# --- End Boundary Conditions ---


# =======================
# ===== Trainer Class =====
# =======================
class Trainer:
    """Handles the setup, training, and checkpointing of the PINN model."""
    def __init__(self, config):
        self.config = config
        self.model = None
        self.losshistory = None
        self.train_state = None
        self.pde = pde # Assign the PDE function

    def build_model(self, bcs, anchor_points):
        """Builds the DeepXDE model including network and data."""
        logging.info("Building the PINN model...")
        if dde.backend.backend_name != "pytorch":
             raise RuntimeError("This code relies on the PyTorch backend.")

        # Define the neural network
        net = dde.maps.FNN(
            layer_sizes=[self.config.NETWORK_INPUTS] + [self.config.NUM_NEURONS] * self.config.NUM_LAYERS + [self.config.NETWORK_OUTPUTS],
            activation=self.config.ACTIVATION,
            kernel_initializer=self.config.INITIALIZER
        )

        # Wrap PDE to include config
        pde_with_config = lambda x, y: self.pde(x, y, config=self.config)

        # Define the PDE data object
        data = dde.data.PDE(
            geometry=self.config.GEOM,
            pde=pde_with_config,
            bcs=bcs, # List of boundary conditions
            num_domain=self.config.NUM_DOMAIN_POINTS,
            num_boundary=self.config.NUM_BOUNDARY_POINTS, # Sample points on physical boundaries
            num_test=self.config.NUM_TEST_POINTS, # Points for testing PDE residual during training
            anchors=anchor_points # Provide wall function anchor points here
        )
        self.model = dde.Model(data, net)
        logging.info("Model built successfully.")

    # ==============================================================
    # ========    UPDATED Trainer.train Method            ==========
    # ==============================================================
    def train(self):
        """Compiles and trains the model, handles checkpointing and optimizer switching."""
        if self.model is None:
            logging.error("Model not built. Call build_model first.")
            return None, None, None

        # --- Define Checkpoint Paths and Callback ---
        filepath_base = os.path.join(self.config.MODEL_DIR, f"{self.config.CHECKPOINT_FILENAME_BASE}-")
        logging.info(f"Checkpoint filename base: {filepath_base}")
        custom_checkpointer = CustomModelCheckpoint(
            filepath_base=filepath_base,
            period=self.config.SAVE_INTERVAL,
            verbose=1
        )

        # --- Check for Latest Checkpoint ---
        latest_checkpoint = None
        restored_step = 0
        if os.path.exists(self.config.MODEL_DIR):
            try:
                filename_pattern = re.compile(rf"^{re.escape(self.config.CHECKPOINT_FILENAME_BASE)}-(\d+)\.pt$")
                checkpoint_files = []
                logging.info(f"Searching for checkpoints in: {self.config.MODEL_DIR}")
                logging.info(f"Using pattern: {filename_pattern.pattern}")
                for f in os.listdir(self.config.MODEL_DIR):
                    match = filename_pattern.match(f)
                    if match:
                        step_num = int(match.group(1))
                        full_path = os.path.join(self.config.MODEL_DIR, f)
                        checkpoint_files.append((step_num, full_path))
                if checkpoint_files:
                    checkpoint_files.sort(key=lambda item: item[0])
                    restored_step, latest_checkpoint = checkpoint_files[-1]
                    logging.info(f"Found latest valid checkpoint: {latest_checkpoint} at step {restored_step}")
                else:
                    logging.info("No valid checkpoints found matching the pattern.")
            except Exception as e:
                logging.error(f"Error finding/parsing checkpoint filenames: {e}", exc_info=True)
                latest_checkpoint = None; restored_step = 0

        restore_path = latest_checkpoint if (latest_checkpoint and os.path.isfile(latest_checkpoint)) else None

        # --- Determine Initial Optimizer Based on Restored Step ---
        initial_optimizer = "adam"
        initial_lr = self.config.LEARNING_RATE_ADAM
        if restore_path and restored_step >= self.config.ADAM_ITERATIONS:
            # If restoring from a step within or after the L-BFGS phase, compile L-BFGS first.
            initial_optimizer = "L-BFGS"
            initial_lr = None # L-BFGS doesn't use LR in compile signature
            logging.info(f"Restored step ({restored_step}) >= Adam iterations ({self.config.ADAM_ITERATIONS}). Will compile with L-BFGS initially.")
        else:
            # Start with Adam (from scratch or resuming within Adam phase).
            logging.info(f"Restored step ({restored_step}) < Adam iterations ({self.config.ADAM_ITERATIONS}) or no checkpoint. Will compile with Adam initially.")

        # --- Compile with the Determined Initial Optimizer ---
        logging.info(f"Compiling model with {initial_optimizer} optimizer initially.")
        if dde.backend.backend_name != "pytorch":
             raise RuntimeError("Cannot compile model, backend is not PyTorch.")

        compile_args = {"optimizer": initial_optimizer, "loss_weights": self.config.LOSS_WEIGHTS}
        if initial_lr is not None:
            compile_args["lr"] = initial_lr
        try:
            self.model.compile(**compile_args)
            logging.info(f"Model compiled with {initial_optimizer}.")
        except Exception as e:
            logging.error(f"Failed to compile model with {initial_optimizer}: {e}", exc_info=True)
            return None, None, None


        # --- Explicit Restore (if applicable) ---
        if restore_path:
            try:
                logging.info(f"Explicitly restoring model state from: {restore_path}")
                # Restore should work because the compiled optimizer matches the saved state
                self.model.restore(restore_path, verbose=1)

                current_step_after_restore = self.model.train_state.step if self.model.train_state else -1
                logging.info(f"Model state restored. Internal DDE step count *after* restore: {current_step_after_restore}")

                # Force set the step counter based on the filename step for consistency
                if self.model.train_state:
                    if current_step_after_restore != restored_step:
                         logging.warning(f"Mismatch between expected restored step ({restored_step}) and DDE internal step ({current_step_after_restore}) after restore! Forcing DDE step.")
                    # else:
                         # logging.info("DDE internal step matches restored step.") # Can be verbose
                    logging.info(f"Manually setting internal step count to the restored step: {restored_step}")
                    self.model.train_state.step = restored_step
                    current_step_after_manual_set = self.model.train_state.step
                    logging.info(f"Internal step count after manual setting: {current_step_after_manual_set}")
                else:
                    # This should not happen if restore was successful with a valid state
                    logging.error("Cannot manually set step count: model.train_state is None after restore. Restore might have failed silently.")
                    self.model = None # Indicate failure
                    return None, None, None # Stop if restore fails

            except Exception as e:
                # The specific KeyError: 'step' should be avoided by compiling correctly first.
                # Other errors (file corruption, etc.) might still occur.
                logging.error(f"Failed during explicit model restore: {e}", exc_info=True)
                self.model = None # Indicate failure
                return None, None, None # Stop if restore fails
        else:
            logging.info("No suitable checkpoint found. Starting training from scratch.")
            restored_step = 0 # Ensure this is 0 if starting fresh
            if self.model.train_state:
                self.model.train_state.step = 0 # Ensure internal step is 0 for new training
            else:
                 # This can happen if the initial compile failed
                 logging.error("model.train_state is None when starting from scratch. Compile likely failed earlier.")
                 return None, None, None


        # --- Adam Training Phase ---
        run_adam_phase = False
        adam_iters_to_run = 0
        # Use the step count that's been synchronized after potential restore
        current_step_after_restore_or_init = self.model.train_state.step if self.model.train_state else restored_step

        if initial_optimizer == "adam":
            # Only run Adam if we started with Adam and haven't finished its iterations
            if current_step_after_restore_or_init < self.config.ADAM_ITERATIONS:
                run_adam_phase = True
                adam_iters_to_run = self.config.ADAM_ITERATIONS - current_step_after_restore_or_init
                logging.info(f"Starting Adam training for remaining {adam_iters_to_run} iterations...")
            else:
                logging.info(f"Adam phase already completed (current step {current_step_after_restore_or_init} >= {self.config.ADAM_ITERATIONS}).")
        else:
            # Adam is skipped if we restored directly into the L-BFGS phase
            logging.info("Skipping Adam phase as initial compilation was L-BFGS.")

        adam_start_time = time.time()
        if run_adam_phase and adam_iters_to_run > 0:
            try:
                # Ensure model is compiled with Adam before training Adam
                if self.model.opt_name != "adam": # Safety check
                     logging.warning("Model optimizer is not Adam before Adam training. Recompiling.")
                     self.model.compile("adam", lr=self.config.LEARNING_RATE_ADAM, loss_weights=self.config.LOSS_WEIGHTS)

                self.losshistory, self.train_state = self.model.train(
                    iterations=adam_iters_to_run,
                    display_every=self.config.DISPLAY_EVERY,
                    callbacks=[custom_checkpointer] # Use the custom callback
                )
                adam_time = time.time() - adam_start_time
                # Check if losshistory is valid before accessing attributes
                if self.losshistory and hasattr(self.losshistory, 'loss_train') and self.losshistory.loss_train:
                    final_loss = self.losshistory.loss_train[-1]
                    current_step = self.model.train_state.step if self.model.train_state else -1
                    logging.info(f"Adam training ({adam_iters_to_run} iterations) finished in {adam_time:.2f}s. Final loss: {final_loss}. Current step: {current_step}")
                else:
                    logging.error("Adam training finished but loss history is empty/invalid.")
                    # Allow proceeding to L-BFGS maybe, but log error. Could also return here.
                    # return self.model, self.losshistory, self.train_state

            except Exception as e:
                logging.error(f"Error occurred during Adam training: {e}", exc_info=True)
                return self.model, self.losshistory, self.train_state # Exit on error

        # --- L-BFGS Training Phase ---
        run_lbfgs_phase = False
        lbfgs_iters_to_run = 0
        needs_lbfgs_compile = False

        # Get the step count *after* any Adam training that might have occurred
        current_step_after_adam = self.model.train_state.step if self.model.train_state else current_step_after_restore_or_init

        if self.config.LBFGS_ITERATIONS > 0:
            # Rough target for total steps if both phases run fully
            total_target_steps = self.config.ADAM_ITERATIONS + self.config.LBFGS_ITERATIONS

            if current_step_after_adam < self.config.ADAM_ITERATIONS:
                 # Still in Adam phase according to step count, L-BFGS should not run yet.
                 logging.info(f"Skipping L-BFGS: Current step {current_step_after_adam} is less than Adam target {self.config.ADAM_ITERATIONS}.")
            elif current_step_after_adam >= total_target_steps:
                 # Already past the combined target steps based on initial config.
                 # Note: If LBFGS runs for its full iteration count regardless of starting step,
                 # the final step might exceed this simplistic target.
                 logging.info(f"Skipping L-BFGS: Current step {current_step_after_adam} meets or exceeds nominal total target steps {total_target_steps}.")
            else:
                 # We are at or past the Adam iterations, L-BFGS is configured, and potentially haven't finished.
                 run_lbfgs_phase = True
                 # Simplification: Run the full configured L-BFGS iterations when entering this phase.
                 # More complex logic could try to track remaining L-BFGS iterations, but it's tricky.
                 lbfgs_iters_to_run = self.config.LBFGS_ITERATIONS
                 logging.info(f"Proceeding to L-BFGS phase (current step {current_step_after_adam}). Target iterations for this phase: {lbfgs_iters_to_run}")

                 # Check if we need to compile L-BFGS (i.e., if the *current* optimizer is Adam)
                 if self.model.opt_name == "adam": # Check the actual current optimizer
                     needs_lbfgs_compile = True
                     logging.info("Switching optimizer: Will compile for L-BFGS.")
                 elif self.model.opt_name == "L-BFGS":
                      logging.info("Optimizer is already L-BFGS (likely resumed). No recompilation needed.")
                 else:
                      logging.warning(f"Unexpected optimizer '{self.model.opt_name}' before L-BFGS phase.")
                      # Force compile just in case.
                      needs_lbfgs_compile = True


        else:
            logging.info("L-BFGS iterations set to 0 in config, skipping L-BFGS training.")


        if run_lbfgs_phase and lbfgs_iters_to_run > 0:
             if self.model is not None and self.model.net is not None:
                 lbfgs_start_time = time.time()
                 try:
                     # Re-compile for L-BFGS only if necessary
                     if needs_lbfgs_compile:
                         self.model.compile("L-BFGS", loss_weights=self.config.LOSS_WEIGHTS)
                         logging.info(f"Model re-compiled with L-BFGS. Step count after compile: {self.model.train_state.step if self.model.train_state else 'N/A'}")

                     # Use the same CUSTOM checkpointer instance for L-BFGS saves
                     # The iterations here are L-BFGS specific iterations, not added to the global step counter in the same way Adam does.
                     self.losshistory, self.train_state = self.model.train(
                         iterations=lbfgs_iters_to_run, # Let L-BFGS run its course
                         display_every=self.config.DISPLAY_EVERY, # Can use display_every for LBFGS too
                         callbacks=[custom_checkpointer] # Pass the custom callback instance
                     )
                     lbfgs_time = time.time() - lbfgs_start_time
                     if self.losshistory and hasattr(self.losshistory, 'loss_train') and self.losshistory.loss_train:
                         final_loss = self.losshistory.loss_train[-1]
                         # The global step counter might not increase by LBFGS_ITERATIONS here,
                         # it might just reflect the step where L-BFGS started or finished.
                         # Checkpoint saving relies on the step counter updated by the callback mechanism.
                         current_step = self.model.train_state.step if self.model.train_state else -1
                         logging.info(f"L-BFGS training (max {lbfgs_iters_to_run} internal iterations) finished in {lbfgs_time:.2f}s. Final loss: {final_loss}. Final global step recorded: {current_step}")
                     else:
                         logging.error("L-BFGS training finished but loss history is invalid.")
                         # Decide whether to continue based on Adam success

                 except Exception as e:
                     logging.error(f"Error during L-BFGS compilation or training: {e}", exc_info=True)
                     # Decide whether to continue based on Adam success
             else:
                 logging.warning("Skipping L-BFGS training because the model state is not valid (e.g., restore failed or Adam failed).")

        # --- Post-training Validation (Optional) ---
        if self.model and self.model.net:
            try:
                self._post_training_checks()
            except Exception as e:
                 logging.error(f"Error during post-training checks: {e}", exc_info=True)

        logging.info("Training sequence complete.")
        # Make sure to return the potentially updated model, losshistory, train_state
        return self.model, self.losshistory, self.train_state
    # ==============================================================
    #========== END OF UPDATED Trainer.train Method      ===========
    # ==============================================================


    # --- Optional Post Training Checks ---
    def _post_training_checks(self):
        """Perform physics validation checks compatible with PyTorch backend."""
        logging.info("Performing PyTorch-based post-training checks...")
        if dde.backend.backend_name != "pytorch":
             logging.warning("Post-training checks skipped, requires PyTorch backend.")
             return
        if self.model is None or self.model.net is None:
             logging.error("Model or network not available for post-training checks.")
             return
        try:
            # Example check:
            self._check_turbulence_production()
        except Exception as e:
            logging.error(f"Error during post-training checks: {e}", exc_info=True)

    def _check_turbulence_production(self):
        """Ensure turbulence production term P_k >= 0 using PyTorch autograd."""
        logging.info("Checking turbulence production term Pk...")
        if not hasattr(self.model.data, 'test_x') or self.model.data.test_x is None or len(self.model.data.test_x) == 0:
            logging.warning("Test points not available, using training points for P_k check.")
            if not hasattr(self.model.data, 'train_x') or self.model.data.train_x is None or len(self.model.data.train_x) == 0:
                 logging.error("No points available (train or test) for P_k check.")
                 return
            X = self.model.data.train_x # Fallback to training points
        else:
             X = self.model.data.test_x # Prefer test points

        if self.model.net is None or not list(self.model.net.parameters()):
             logging.error("Model network not available or has no parameters for P_k check.")
             return

        # Get device from model parameters
        try:
            device = next(self.model.net.parameters()).device
        except StopIteration:
             logging.error("Model network has no parameters.")
             return

        x_tensor = torch.tensor(X, dtype=torch.float32, device=device, requires_grad=True)

        # Forward pass with gradient tracking
        y_tensor = self.model.net(x_tensor)
        if y_tensor.shape[1] != self.config.NETWORK_OUTPUTS:
             logging.error(f"Network output has unexpected shape {y_tensor.shape} for P_k check.")
             return

        u = y_tensor[:, 0:1]; v = y_tensor[:, 1:2]
        k_raw = y_tensor[:, 3:4]; eps_raw = y_tensor[:, 4:5]

        # Use the same transformation as in PDE
        k = torch.exp(k_raw) + self.config.EPS_SMALL
        eps = torch.exp(eps_raw) + self.config.EPS_SMALL

        # Compute gradients using PyTorch autograd
        try:
            u_grad = torch.autograd.grad(u, x_tensor, grad_outputs=torch.ones_like(u), create_graph=False)[0] # No graph needed here
            u_x, u_y = u_grad[:, 0:1], u_grad[:, 1:2]
            v_grad = torch.autograd.grad(v, x_tensor, grad_outputs=torch.ones_like(v), create_graph=False)[0]
            v_x, v_y = v_grad[:, 0:1], v_grad[:, 1:2]
        except Exception as grad_e:
             logging.error(f"Error computing velocity gradients for P_k check: {grad_e}", exc_info=True)
             return

        # Compute nu_t using safe k, eps from transformation
        k_safe_check = k
        eps_safe_check = eps
        nu_t_check = self.config.CMU * torch.square(k_safe_check) / torch.maximum(eps_safe_check, torch.tensor(self.config.EPS_SMALL, device=device))

        # Strain rate tensor squared (S^2)
        S_squared = 2*(torch.square(u_x) + torch.square(v_y)) + torch.square(u_y + v_x)
        # Production term P_k = nu_t * S^2
        P_k = nu_t_check * S_squared

        # Check for negative production (detach before converting to numpy/item)
        try:
            P_k_detached = P_k.detach()
            min_Pk = torch.min(P_k_detached).item()
            max_Pk = torch.max(P_k_detached).item()
            num_negative = torch.sum(P_k_detached < 0).item()
            if min_Pk < -self.config.EPS_SMALL * 10: # Allow small numerical errors
                logging.warning(f"Negative turbulence production detected! Min P_k = {min_Pk:.3e}. ({num_negative}/{len(P_k)} points < 0)")
            else:
                logging.info(f"Turbulence production check passed (min P_k = {min_Pk:.3e}, max P_k = {max_Pk:.3e})")
        except Exception as check_e:
             logging.error(f"Error checking P_k value: {check_e}", exc_info=True)
# --- End Trainer Class ---


# ========================
# ===== Plotter Class =====
# ========================
class Plotter:
    """Handles post-processing and plotting of simulation results."""
    def __init__(self, config, plotter_config, model, losshistory, train_state):
        self.config = config
        self.plotter_config = plotter_config
        self.model = model
        self.losshistory = losshistory
        self.train_state = train_state
        self.ref_data_path = config.REFERENCE_DATA_FILE
        self.plots_dir = config.PLOT_DIR
        self.ref_data = None
        self.has_ref_data = False
        self.ref_data_utau = None # Estimated friction velocity from reference data
        self.pinn_data_utau = None # Estimated friction velocity from PINN data

        # Predicted fields (initialized to None)
        self.X_grid, self.Y_grid = None, None
        self.u_pred, self.v_pred, self.p_prime_pred = None, None, None
        self.k_pred, self.eps_pred, self.nu_t_pred = None, None, None
        self.p_pred = None # Kinematic pressure

        os.makedirs(self.plots_dir, exist_ok=True)
        logging.info(f"Plotter initialized. Plots will be saved in: {self.plots_dir}")
        if self.ref_data_path:
             if os.path.exists(self.ref_data_path):
                 logging.info(f"Reference CSV data path found: {self.ref_data_path}")
             else:
                 logging.warning(f"Reference CSV file not found: {self.ref_data_path}. Comparisons will be skipped.")
        else:
             logging.info("No reference data path provided in config. Comparisons will be skipped.")

    def plot_loss_history(self):
        """Plots and saves the training loss history."""
        if self.losshistory and self.train_state:
            logging.info("Saving loss history plot...")
            try:
                os.makedirs(self.plots_dir, exist_ok=True)
                # Use isplot=False to prevent showing plot in non-interactive envs
                dde.saveplot(self.losshistory, self.train_state, issave=True, isplot=False, output_dir=self.plots_dir)
                # Rename default 'loss.png' for clarity
                default_loss_file = os.path.join(self.plots_dir, "loss.png")
                target_loss_file = os.path.join(self.plots_dir, "training_loss_history.png")
                if os.path.exists(default_loss_file):
                    try:
                        # Use os.replace for atomic rename if possible, fallback to rename
                        os.replace(default_loss_file, target_loss_file)
                    except OSError: # Fallback for cross-device links etc.
                        os.rename(default_loss_file, target_loss_file)
                    logging.info(f"Loss history plot saved as '{target_loss_file}'.")
                else:
                     # Check if the loss file was created with a different name pattern potentially
                     # List files in the plots_dir and check for recently created png files
                     potential_files = [f for f in os.listdir(self.plots_dir) if f.endswith('.png')]
                     if potential_files:
                          logging.warning(f"dde.saveplot might not have produced 'loss.png'. Found: {potential_files}. Check DeepXDE version/behavior.")
                     else:
                          logging.warning("dde.saveplot did not produce 'loss.png'. Check file permissions or DeepXDE version.")
            except ImportError:
                logging.error("Matplotlib might be needed by dde.saveplot but is not installed.")
            except Exception as e:
                 logging.error(f"Could not save loss history plot: {e}", exc_info=True)
        else:
            logging.warning("Loss history or train state not available, skipping loss plot.")

    def load_reference_data(self):
        """Loads and preprocesses reference data from a CSV file."""
        if not self.ref_data_path:
            self.has_ref_data = False
            return
        if not os.path.exists(self.ref_data_path):
            logging.warning(f"Reference CSV file not found: '{self.ref_data_path}'. Skipping.")
            self.has_ref_data = False
            return

        logging.info(f"Loading reference data from: {self.ref_data_path}")
        try:
            df_ref = pd.read_csv(self.ref_data_path)
            logging.info(f"Loaded reference data: {df_ref.shape[0]} rows, {df_ref.shape[1]} cols. Columns: {df_ref.columns.tolist()}")

            # --- Data Filtering (Optional) ---
            # Filter for latest time step if applicable
            if 'Time' in df_ref.columns:
                df_ref = df_ref[df_ref['Time'] == df_ref['Time'].max()].copy()
                logging.info(f"Filtered for latest time: {df_ref.shape[0]} rows remaining.")
            elif 'TimeStep' in df_ref.columns:
                df_ref = df_ref[df_ref['TimeStep'] == df_ref['TimeStep'].max()].copy()
                logging.info(f"Filtered for latest timestep: {df_ref.shape[0]} rows remaining.")

            # Identify coordinate columns (handle variations in naming)
            x_col, y_col, z_col = None, None, None
            coord_map = {'Points:0':'x', 'x-coordinate':'x', 'x':'x',
                         'Points:1':'y', 'y-coordinate':'y', 'y':'y',
                         'Points:2':'z', 'z-coordinate':'z', 'z':'z'}
            for col in df_ref.columns:
                mapped_coord = coord_map.get(col.lower().strip()) # Use lower case and strip spaces
                if mapped_coord == 'x': x_col = col
                if mapped_coord == 'y': y_col = col
                if mapped_coord == 'z': z_col = col
            if not x_col or not y_col:
                 raise ValueError(f"Could not definitively identify x/y coordinates in reference columns: {df_ref.columns.tolist()}")
            logging.info(f"Identified coordinate columns: x='{x_col}', y='{y_col}'" + (f", z='{z_col}'" if z_col else ""))

            # Filter for specific Z-plane if data is 3D
            if z_col and len(df_ref[z_col].unique()) > 1:
                target_z = 0.0 # Target the center plane
                unique_z = df_ref[z_col].unique()
                nearest_z = unique_z[np.argmin(np.abs(unique_z - target_z))]
                df_ref = df_ref[np.isclose(df_ref[z_col], nearest_z)].copy()
                logging.info(f"Filtered for z-plane near {target_z} (actual: {nearest_z:.4f}): {df_ref.shape[0]} rows remaining.")

            # --- Variable Renaming (Handle variations) ---
            var_map = { # Map potential CSV column names to consistent internal names
                'U:0':'u_ref', 'U_x':'u_ref', 'Velocity:0':'u_ref', 'velocity_x':'u_ref', 'u':'u_ref',
                'U:1':'v_ref', 'U_y':'v_ref', 'Velocity:1':'v_ref', 'velocity_y':'v_ref', 'v':'v_ref',
                'p':'p_ref', 'pressure':'p_ref', 'kinematic_pressure':'p_ref', # Assuming kinematic pressure if 'p'
                'k':'k_ref', 'turbulentKineticEnergy':'k_ref', 'tke':'k_ref',
                'epsilon':'eps_ref', 'turbulenceDissipationRate':'eps_ref', 'epsilon_dissipation_rate':'eps_ref', 'dissipation':'eps_ref',
                'nut':'nut_ref', 'turbulentViscosity':'nut_ref', 'nuTilda':'nut_ref', 'eddy_viscosity':'nut_ref'
            }
            rename_dict = {}
            for col in df_ref.columns:
                 mapped_var = var_map.get(col.lower().strip()) # Case-insensitive mapping
                 if mapped_var:
                      rename_dict[col] = mapped_var

            # Add coordinate renaming
            rename_dict[x_col] = 'x'
            rename_dict[y_col] = 'y'
            if z_col: rename_dict[z_col] = 'z'

            df_ref.rename(columns=rename_dict, inplace=True)
            logging.info(f"Renamed reference columns based on mapping: {rename_dict}")

            # --- Check for Required Columns ---
            required_cols = ['x', 'y', 'u_ref', 'p_ref', 'k_ref', 'eps_ref']
            missing_cols = [col for col in required_cols if col not in df_ref.columns]
            if missing_cols:
                raise ValueError(f"Missing required columns after renaming in reference data: {missing_cols}. Available: {df_ref.columns.tolist()}")

            # Keep only necessary columns + optional ones if present
            cols_to_keep = ['x', 'y'] + [col for col in ['u_ref', 'v_ref', 'p_ref', 'k_ref', 'eps_ref', 'nut_ref'] if col in df_ref.columns]
            if z_col and 'z' in df_ref.columns: cols_to_keep.append('z')
            df_ref = df_ref[cols_to_keep]

            # Sort and reset index
            df_ref.sort_values(by=['x', 'y'], inplace=True)
            df_ref.reset_index(drop=True, inplace=True)
            self.ref_data = df_ref
            self.has_ref_data = True
            logging.info(f"Successfully loaded and preprocessed reference CSV data. Final columns: {df_ref.columns.tolist()}")

        except FileNotFoundError:
            # Already logged warning, just pass
            pass
        except Exception as e:
            logging.error(f"Error loading or processing reference CSV: {e}", exc_info=True)
            self.ref_data = None
            self.has_ref_data = False

    def predict_pinn_fields(self):
        """Predicts flow fields using the trained PINN model on a grid."""
        if self.model is None or self.model.net is None:
             logging.error("PINN Model or network not available for prediction.")
             return False

        logging.info("Predicting PINN flow fields on evaluation grid...")
        nx = self.plotter_config.NX_PRED
        ny = self.plotter_config.NY_PRED
        x_coords = np.linspace(0, self.config.L, nx)
        y_coords = np.linspace(-self.config.CHANNEL_HALF_HEIGHT, self.config.CHANNEL_HALF_HEIGHT, ny)
        self.X_grid, self.Y_grid = np.meshgrid(x_coords, y_coords)
        # Create prediction points (N, 2) array
        pred_points = np.vstack((np.ravel(self.X_grid), np.ravel(self.Y_grid))).T

        try:
            # Use model.predict for inference
            # Ensure input is float32, as model was likely trained with it
            pred_points_tensor = torch.tensor(pred_points, dtype=torch.float32)
            predictions_raw = self.model.predict(pred_points_tensor.cpu().numpy()) # Predict expects numpy

            if predictions_raw is None or predictions_raw.shape[1] != self.config.NETWORK_OUTPUTS:
                logging.error(f"Prediction shape mismatch. Expected {self.config.NETWORK_OUTPUTS} outputs, got shape {predictions_raw.shape if predictions_raw is not None else 'None'}.")
                return False
        except Exception as e:
            logging.error(f"Error during PINN prediction: {e}", exc_info=True)
            return False

        # --- Process Raw Predictions ---
        # Reshape predictions back to grid format (ny, nx)
        self.u_pred = predictions_raw[:, 0].reshape(ny, nx)
        self.v_pred = predictions_raw[:, 1].reshape(ny, nx)
        self.p_prime_pred = predictions_raw[:, 2].reshape(ny, nx)
        k_raw_pred = predictions_raw[:, 3].reshape(ny, nx)
        eps_raw_pred = predictions_raw[:, 4].reshape(ny, nx)

        # Apply inverse transform (exp) and add epsilon for positivity
        self.k_pred = np.exp(k_raw_pred) + self.config.EPS_SMALL
        self.eps_pred = np.exp(eps_raw_pred) + self.config.EPS_SMALL

        # Calculate kinematic pressure p = p' - (2/3)*k (assuming isotropic normal stress contribution)
        # Note: For plotting/comparison, often p' itself is used, or a relative pressure.
        # This definition assumes p' = p_kinematic + (2/3)k
        self.p_pred = self.p_prime_pred - (2.0 / 3.0) * self.k_pred # Definition check needed

        # Calculate turbulent viscosity nu_t = Cmu * k^2 / eps
        # Use maximum with small number in denominator for robustness
        self.nu_t_pred = self.config.CMU * np.square(self.k_pred) / np.maximum(self.eps_pred, self.config.EPS_SMALL**2)

        logging.info("PINN field prediction and processing complete.")
        return True

    def plot_contour_fields(self):
        """Plots contour fields of the predicted PINN variables."""
        if self.u_pred is None:
            logging.warning("PINN data unavailable for plotting. Run predict_pinn_fields first. Skipping contours.")
            return

        logging.info("Generating PINN contour plots...")
        plt.figure(figsize=(18, 12)) # Adjust figure size as needed

        cmap_vel = self.plotter_config.CMAP_VELOCITY
        cmap_p = self.plotter_config.CMAP_PRESSURE
        cmap_turb = self.plotter_config.CMAP_TURBULENCE

        # Helper function for plotting individual contours
        def plot_contour(subplot_idx, data, title, label, cmap='viridis', is_log=False):
            plt.subplot(2, 3, subplot_idx)
            plot_data = data
            cbar_label = label
            levels = 50 # Number of contour levels

            # Optional log scale for positive quantities like k, eps, nut
            if is_log:
                # Avoid log(0) or log(negative) issues
                min_positive = np.min(data[data > self.config.EPS_SMALL]) if np.any(data > self.config.EPS_SMALL) else self.config.EPS_SMALL
                # Floor values at a small fraction of min_positive before taking log10
                plot_data = np.log10(np.maximum(data, min_positive * 0.1))
                cbar_label = f'log10({label})'

            try:
                cf = plt.contourf(self.X_grid, self.Y_grid, plot_data, levels=levels, cmap=cmap, extend='both')
                plt.colorbar(cf, label=cbar_label)
                plt.title(title)
                plt.xlabel('x (m)')
                plt.ylabel('y (m)')
                plt.gca().set_aspect('equal', adjustable='box') # Make aspect ratio equal
            except Exception as e:
                logging.error(f"Error plotting contour for {title}: {e}")

        # Plot individual fields
        plot_contour(1, self.u_pred, 'PINN Streamwise Velocity (u)', 'u (m/s)', cmap=cmap_vel)
        plot_contour(2, self.v_pred, 'PINN Transverse Velocity (v)', 'v (m/s)', cmap=cmap_vel)
        plot_contour(3, self.p_pred, "PINN Kinematic Pressure (p)", 'p/rho (m^2/s^2)', cmap=cmap_p) # Using calculated p
        # plot_contour(3, self.p_prime_pred, "PINN Fluctuation Pressure (p')", 'p\'/rho (m^2/s^2)', cmap=cmap_p) # Alternative: plot p'

        plot_contour(4, self.k_pred, 'PINN Turbulent Kinetic Energy (k)', 'k (m^2/s^2)', cmap=cmap_turb) # Linear scale k
        # plot_contour(4, self.k_pred, 'PINN Turbulent Kinetic Energy (k)', 'k (m^2/s^2)', cmap=cmap_turb, is_log=True) # Log scale k

        plot_contour(5, self.eps_pred, 'PINN Dissipation Rate (epsilon)', 'eps (m^2/s^3)', cmap=cmap_turb) # Linear scale eps
        # plot_contour(5, self.eps_pred, 'PINN Dissipation Rate (epsilon)', 'eps (m^2/s^3)', cmap=cmap_turb, is_log=True) # Log scale eps

        # Plot eddy viscosity ratio nu_t / nu
        eddy_viscosity_ratio = self.nu_t_pred / self.config.NU
        plot_contour(6, eddy_viscosity_ratio, 'PINN Eddy Viscosity Ratio', 'nu_t / nu', cmap=cmap_turb) # Linear scale ratio
        # plot_contour(6, eddy_viscosity_ratio, 'PINN Eddy Viscosity Ratio', 'nu_t / nu', cmap=cmap_turb, is_log=True) # Log scale ratio


        plt.tight_layout()
        save_path = os.path.join(self.plots_dir, "pinn_field_contours.png")
        try:
            plt.savefig(save_path, dpi=200, bbox_inches='tight')
            plt.close() # Close figure to free memory
            logging.info(f"PINN contour field plots saved to {save_path}")
        except Exception as e:
            logging.error(f"Failed to save contour plot: {e}")

    def _estimate_utau(self, data_source='pinn', x_slice_loc=None):
        """Estimates friction velocity u_tau from data near the wall."""
        if x_slice_loc is None:
            # Choose a location away from inlet/outlet, e.g., 80% downstream
            x_slice_loc = self.config.L * 0.8

        h = self.config.CHANNEL_HALF_HEIGHT
        y_p = self.config.Y_P # Use the wall function distance y_p
        nu = self.config.NU

        # Define two points near the wall (e.g., slightly inside and outside y_p) for gradient calc
        # Ensure points are within the channel bounds [-h, h]
        y_eval_1 = h - y_p * 1.1 # Slightly further from wall than y_p
        y_eval_2 = h - y_p * 0.9 # Slightly closer to wall than y_p

        # Clamp points to be within the domain bounds
        y_eval_1 = max(y_eval_1, -h + self.config.EPS_SMALL * 1.1)
        y_eval_2 = min(y_eval_2, h - self.config.EPS_SMALL * 0.9)
        # Ensure y_eval_2 > y_eval_1
        if y_eval_2 <= y_eval_1:
             y_eval_1 = h - y_p - self.config.EPS_SMALL
             y_eval_2 = h - y_p + self.config.EPS_SMALL
             y_eval_1 = max(y_eval_1, -h + self.config.EPS_SMALL * 1.1)
             y_eval_2 = min(y_eval_2, h - self.config.EPS_SMALL * 0.9)
             if y_eval_2 <= y_eval_1: # Failsafe
                  logging.error("Cannot define distinct points near y_p for u_tau estimation.")
                  return None

        eval_points = np.array([[x_slice_loc, y_eval_1], [x_slice_loc, y_eval_2]])
        # Average distance from the *nearest* wall (top wall in this case)
        y_dist_wall_avg = h - (y_eval_1 + y_eval_2) / 2.0

        u1, k1, eps1, u2, k2, eps2 = None, None, None, None, None, None

        try:
            # Get u, k, eps at the two evaluation points
            if data_source == 'pinn':
                if self.model is None: return None
                pred_raw = self.model.predict(eval_points)
                if pred_raw is None or pred_raw.shape[0] < 2: return None
                # Extract and transform
                u1, u2 = pred_raw[:, 0]
                k1_raw, k2_raw = pred_raw[:, 3]
                eps1_raw, eps2_raw = pred_raw[:, 4]
                k1 = np.exp(k1_raw) + self.config.EPS_SMALL
                k2 = np.exp(k2_raw) + self.config.EPS_SMALL
                eps1 = np.exp(eps1_raw) + self.config.EPS_SMALL
                eps2 = np.exp(eps2_raw) + self.config.EPS_SMALL

            elif data_source == 'reference' and self.has_ref_data:
                if self.ref_data is None: return None
                points_ref = self.ref_data[['x', 'y']].values
                req_cols = ['u_ref', 'k_ref', 'eps_ref']
                if not all(col in self.ref_data.columns for col in req_cols):
                    logging.warning(f"Reference data missing required columns {req_cols} for u_tau estimation.")
                    return None
                # Interpolate required values
                u_ref = self.ref_data['u_ref'].values
                k_ref = self.ref_data['k_ref'].values
                eps_ref = self.ref_data['eps_ref'].values
                interp_u = griddata(points_ref, u_ref, eval_points, method='linear')
                interp_k = griddata(points_ref, k_ref, eval_points, method='linear')
                interp_eps = griddata(points_ref, eps_ref, eval_points, method='linear')

                # Handle potential NaN from linear interpolation (e.g., if eval_points are outside convex hull)
                nan_mask = np.isnan(interp_u) | np.isnan(interp_k) | np.isnan(interp_eps)
                if np.any(nan_mask):
                    logging.warning(f"Linear interpolation failed for u_tau estimation ({data_source}) at x={x_slice_loc:.2f}. Trying 'nearest'.")
                    interp_u[nan_mask] = griddata(points_ref, u_ref, eval_points[nan_mask], method='nearest')
                    interp_k[nan_mask] = griddata(points_ref, k_ref, eval_points[nan_mask], method='nearest')
                    interp_eps[nan_mask] = griddata(points_ref, eps_ref, eval_points[nan_mask], method='nearest')
                    # Check again if nearest neighbor failed
                    if np.isnan(interp_u).any():
                         logging.error(f"Nearest neighbor interpolation also failed for u_tau estimation ({data_source}) at x={x_slice_loc:.2f}.")
                         return None
                u1, u2 = interp_u
                k1, k2 = interp_k
                eps1, eps2 = interp_eps # Already physical values from CSV

            else:
                logging.warning(f"Invalid data_source '{data_source}' or missing data for u_tau estimation.")
                return None

            # Estimate gradient du/dy
            # Need absolute value since y_eval_2 > y_eval_1 but corresponds to smaller wall distance
            du_dy_eval = (u2 - u1) / (y_eval_2 - y_eval_1) # Should be positive near top wall

            # Estimate effective viscosity at the average location
            k_avg = (k1 + k2) / 2.0
            eps_avg = (eps1 + eps2) / 2.0
            nu_t_avg = self.config.CMU * k_avg**2 / max(eps_avg, self.config.EPS_SMALL**2)
            nu_eff_avg = nu + nu_t_avg

            # Estimate wall shear stress tau_w = rho * nu_eff * |du/dy| (near wall)
            # Using absolute value of gradient for robustness
            tau_w_grad = self.config.RHO * nu_eff_avg * abs(du_dy_eval)
            # Estimate u_tau = sqrt(tau_w / rho)
            u_tau_grad = np.sqrt(max(tau_w_grad / self.config.RHO, self.config.EPS_SMALL)) # Ensure positive arg

            # Optional: Refine using log-law if in log region
            u_avg = (u1 + u2) / 2.0
            y_plus_est = y_dist_wall_avg * u_tau_grad / nu
            u_tau_estimated = u_tau_grad # Default estimate

            # If y+ is sufficiently large, blend with log-law estimate
            if y_plus_est > 11: # Heuristic threshold for log-law applicability
                 try:
                    # Estimate u_tau from U = u_tau/kappa * ln(E*y+)
                    log_arg = max(self.config.E_WALL * y_plus_est, self.config.EPS_SMALL)
                    denominator = (1 / self.config.KAPPA) * np.log(log_arg)
                    if abs(denominator) > self.config.EPS_SMALL:
                         u_tau_log = u_avg / denominator
                         # Simple average blend (could use weighting)
                         u_tau_estimated = (u_tau_grad + u_tau_log) / 2.0
                 except Exception as log_e:
                     logging.warning(f"Could not apply log-law refinement for u_tau: {log_e}")
                     pass # Use gradient estimate if log-law fails

            logging.info(f"Estimated u_tau ({data_source}) at x={x_slice_loc:.2f} m: {u_tau_estimated:.4f} m/s (y+ ~ {y_plus_est:.1f})")
            return u_tau_estimated

        except Exception as e:
            logging.error(f"Error estimating u_tau for {data_source} at x={x_slice_loc:.2f}: {e}", exc_info=True)
            return None

    def plot_profile_comparison(self):
        """Plots profiles of PINN vs Reference data at a channel cross-section."""
        if self.u_pred is None:
            logging.warning("PINN data unavailable for plotting. Skipping profile comparison.")
            return
        if not self.has_ref_data:
            logging.warning("Reference CSV data not loaded or failed processing. Skipping profile comparison.")
            return

        logging.info("Generating profile comparison plots...")
        # Define slice location (e.g., channel midpoint)
        x_slice_loc = self.config.L / 2.0
        ny_pinn = self.plotter_config.NY_PRED # Number of points in y from PINN grid

        # Find the closest x-coordinate in the PINN grid
        y_coords_pinn = self.Y_grid[:, 0] # y-coordinates from PINN grid
        x_coords_pinn = self.X_grid[0, :] # x-coordinates from PINN grid
        x_slice_idx_pinn = np.argmin(np.abs(x_coords_pinn - x_slice_loc))
        actual_x_pinn = x_coords_pinn[x_slice_idx_pinn] # Actual x used for slicing

        # Extract PINN data slice at the chosen x-index
        pinn_slice = {
            'y': y_coords_pinn,
            'u': self.u_pred[:, x_slice_idx_pinn],
            'v': self.v_pred[:, x_slice_idx_pinn],
            'p': self.p_pred[:, x_slice_idx_pinn], # Use calculated kinematic p
            'k': self.k_pred[:, x_slice_idx_pinn],
            'eps': self.eps_pred[:, x_slice_idx_pinn],
            'nut': self.nu_t_pred[:, x_slice_idx_pinn]
        }

        # --- Interpolate Reference Data onto PINN y-coordinates at the same x ---
        ref_slice = {'y': y_coords_pinn} # Initialize dict for interpolated ref data
        try:
            if self.ref_data is None: raise ValueError("Reference data frame is None.")

            # Points from reference data (x, y)
            points_ref = self.ref_data[['x', 'y']].values
            # Target points for interpolation (same x, PINN y-coords)
            target_points = np.vstack((np.full(ny_pinn, actual_x_pinn), y_coords_pinn)).T

            logging.info(f"Interpolating reference data onto {ny_pinn} points at x={actual_x_pinn:.2f}...")

            # Interpolate each variable present in the reference data
            variables_to_interpolate = [
                ('u_ref', 'u'), ('v_ref', 'v'), ('p_ref', 'p'),
                ('k_ref', 'k'), ('eps_ref', 'eps'), ('nut_ref', 'nut')
            ]
            for var_ref, var_pinn in variables_to_interpolate:
                if var_ref in self.ref_data.columns:
                    values_ref = self.ref_data[var_ref].values
                    # Linear interpolation
                    interp_values = griddata(points_ref, values_ref, target_points, method='linear')
                    # Handle NaNs with nearest neighbor fallback
                    nan_mask = np.isnan(interp_values)
                    if np.any(nan_mask):
                        interp_nearest = griddata(points_ref, values_ref, target_points[nan_mask], method='nearest')
                        interp_values[nan_mask] = interp_nearest
                        if np.any(np.isnan(interp_values)): # Check if nearest also failed
                             logging.warning(f"Interpolation (linear & nearest) failed for '{var_ref}' at some points.")
                    ref_slice[var_pinn] = interp_values
                else:
                    logging.warning(f"Reference variable '{var_ref}' not found in CSV data. Skipping interpolation.")
                    ref_slice[var_pinn] = np.full(ny_pinn, np.nan) # Fill with NaN if missing

        except Exception as e:
            logging.error(f"Error interpolating reference data for profiles: {e}", exc_info=True)
            # If interpolation fails, mark ref data as unavailable for this plot
            self.has_ref_data = False # Prevent plotting potentially bad data

        # --- Create Plots ---
        plt.figure(figsize=(15, 12)) # Adjust size
        plot_idx = 1
        h = self.config.CHANNEL_HALF_HEIGHT
        plot_vars = [ # Variables to plot, their names, and units
            ('u', 'Velocity u', 'm/s'),
            ('v', 'Velocity v', 'm/s'),
            ('p', 'Kinematic Pressure p', 'm^2/s^2'),
            ('k', 'TKE k', 'm^2/s^2'),
            ('eps', 'Dissipation eps', 'm^2/s^3'),
            ('nut', 'Eddy Viscosity nu_t', 'm^2/s')
        ]

        for key, name, unit in plot_vars:
            plt.subplot(3, 2, plot_idx) # Arrange plots in 3 rows, 2 columns

            # Plot PINN data
            plt.plot(pinn_slice[key], pinn_slice['y'] / h, 'r-', linewidth=2, label='PINN')

            # Plot Reference data if available and valid
            if self.has_ref_data and key in ref_slice and not np.all(np.isnan(ref_slice[key])):
                plt.plot(ref_slice[key], ref_slice['y'] / h, 'b--', linewidth=1.5, label='Reference (CSV)')

            plt.xlabel(f'{name} ({unit})')
            plt.ylabel('y/h') # Normalize y by half-height
            plt.title(f'{name} Profile at x={actual_x_pinn:.2f}m')
            plt.legend()
            plt.grid(True, linestyle=':')

            # Use log scale for x-axis for turbulence quantities (k, eps, nut) if values are positive
            if key in ['k', 'eps', 'nut']:
                 try:
                     # Check if all plotted values are sufficiently positive
                     min_val_for_log = self.config.EPS_SMALL / 10.0
                     pinn_positive = np.all(pinn_slice[key] > min_val_for_log)
                     ref_positive = True # Assume true if no ref data plotted
                     if self.has_ref_data and key in ref_slice and not np.all(np.isnan(ref_slice[key])):
                          ref_positive = np.all(ref_slice[key][~np.isnan(ref_slice[key])] > min_val_for_log)

                     if pinn_positive and ref_positive:
                          plt.semilogx()
                          plt.grid(True, which='both', linestyle=':') # Add minor grid for log scale
                 except Exception as log_e:
                      logging.warning(f"Could not apply log scale for {key}: {log_e}")

            plot_idx += 1

        plt.suptitle(f'Profile Comparison at x ≈ {x_slice_loc:.2f} m', fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.94]) # Adjust layout to prevent title overlap

        save_path = os.path.join(self.plots_dir, "profile_comparison_pinn_vs_csv.png")
        try:
            plt.savefig(save_path, dpi=200, bbox_inches='tight')
            plt.close()
            logging.info(f"Profile comparison plot saved to {save_path}")
        except Exception as e:
            logging.error(f"Failed to save profile comparison plot: {e}")

    def plot_wall_unit_comparison(self):
        """Plots profiles in wall units (y+, U+, k+, eps+) vs reference/theory."""
        if self.u_pred is None:
            logging.warning("PINN data unavailable for plotting. Skipping wall unit plots.")
            return

        logging.info("Generating wall unit comparison plots...")

        # --- Estimate Friction Velocity (u_tau) ---
        # Use a location away from inlet/outlet for better u_tau estimate
        x_slice_loc_utau = self.config.L * 0.8
        self.pinn_data_utau = self._estimate_utau(data_source='pinn', x_slice_loc=x_slice_loc_utau)
        if self.has_ref_data:
            self.ref_data_utau = self._estimate_utau(data_source='reference', x_slice_loc=x_slice_loc_utau)
        else:
            self.ref_data_utau = None

        # Proceed only if PINN u_tau could be estimated
        if not self.pinn_data_utau:
            logging.error("Could not estimate PINN u_tau. Skipping wall unit plots.")
            return
        if self.has_ref_data and not self.ref_data_utau:
            logging.warning("Could not estimate reference u_tau. Plotting PINN wall units only vs theory.")

        # --- Prepare Data for Wall Units ---
        nu = self.config.NU
        h = self.config.CHANNEL_HALF_HEIGHT
        kappa = self.config.KAPPA
        E_wall_plot = self.config.E_WALL # Use config value for consistency
        Cmu = self.config.CMU

        # Use the same x-slice as the profile plots for consistency, or define a new one
        x_slice_loc_plot = self.config.L / 2.0
        y_coords_pinn = self.Y_grid[:, 0]
        x_coords_pinn = self.X_grid[0, :]
        x_slice_idx_pinn = np.argmin(np.abs(x_coords_pinn - x_slice_loc_plot))
        actual_x_pinn = x_coords_pinn[x_slice_idx_pinn]

        # Extract PINN data near the top wall (y >= 0)
        # Could average top/bottom walls or plot separately
        wall_indices_pinn = y_coords_pinn >= 0 # Indices for top half
        y_wall_pinn = y_coords_pinn[wall_indices_pinn]
        # Distance from the nearest wall (top wall: h - y)
        y_dist_wall_pinn = np.maximum(h - y_wall_pinn, self.config.EPS_SMALL * h) # Avoid zero distance

        # Get corresponding u, k, eps from the PINN slice
        u_wall_pinn = self.u_pred[wall_indices_pinn, x_slice_idx_pinn]
        k_wall_pinn = self.k_pred[wall_indices_pinn, x_slice_idx_pinn]
        eps_wall_pinn = self.eps_pred[wall_indices_pinn, x_slice_idx_pinn]

        # Calculate wall units for PINN data
        utau_pinn_safe = max(self.pinn_data_utau, self.config.EPS_SMALL) # Avoid division by zero
        y_plus_pinn = y_dist_wall_pinn * utau_pinn_safe / nu
        u_plus_pinn = u_wall_pinn / utau_pinn_safe
        # k+ = k / u_tau^2
        k_plus_pinn = k_wall_pinn / max(utau_pinn_safe**2, self.config.EPS_SMALL)
        # eps+ = eps * nu / u_tau^4
        eps_plus_pinn = eps_wall_pinn * nu / max(utau_pinn_safe**4, self.config.EPS_SMALL)

        # Sort PINN data by y+ for plotting lines correctly
        sort_idx_pinn = np.argsort(y_plus_pinn)
        y_plus_pinn = y_plus_pinn[sort_idx_pinn]
        u_plus_pinn = u_plus_pinn[sort_idx_pinn]
        k_plus_pinn = k_plus_pinn[sort_idx_pinn]
        eps_plus_pinn = eps_plus_pinn[sort_idx_pinn]

        # --- Prepare Reference Data (if available and u_tau estimated) ---
        y_plus_ref, u_plus_ref, k_plus_ref, eps_plus_ref = None, None, None, None
        if self.has_ref_data and self.ref_data_utau:
            try:
                # Filter reference data near the chosen x-slice and top wall (y>=0)
                # Use a tolerance for x matching
                ref_wall_data = self.ref_data[
                    (np.isclose(self.ref_data['x'], actual_x_pinn, rtol=0.05, atol=0.1)) &
                    (self.ref_data['y'] >= -self.config.EPS_SMALL) # Include y=0
                ].copy()

                if not ref_wall_data.empty:
                    y_dist_wall_ref = np.maximum(h - ref_wall_data['y'].values, self.config.EPS_SMALL * h)
                    utau_ref_safe = max(self.ref_data_utau, self.config.EPS_SMALL)
                    y_plus_ref = y_dist_wall_ref * utau_ref_safe / nu
                    # Check if required columns exist before accessing
                    if 'u_ref' in ref_wall_data.columns: u_plus_ref = ref_wall_data['u_ref'].values / utau_ref_safe
                    if 'k_ref' in ref_wall_data.columns: k_plus_ref = ref_wall_data['k_ref'].values / max(utau_ref_safe**2, self.config.EPS_SMALL)
                    if 'eps_ref' in ref_wall_data.columns: eps_plus_ref = ref_wall_data['eps_ref'].values * nu / max(utau_ref_safe**4, self.config.EPS_SMALL)

                    # Sort reference data by y+
                    sort_idx_ref = np.argsort(y_plus_ref)
                    y_plus_ref = y_plus_ref[sort_idx_ref]
                    if u_plus_ref is not None: u_plus_ref = u_plus_ref[sort_idx_ref]
                    if k_plus_ref is not None: k_plus_ref = k_plus_ref[sort_idx_ref]
                    if eps_plus_ref is not None: eps_plus_ref = eps_plus_ref[sort_idx_ref]
                    logging.info(f"Processed {len(y_plus_ref)} reference points for wall unit comparison.")
                else:
                    logging.warning(f"No reference data found near x={actual_x_pinn:.2f}, y>=0 for wall unit plots.")
            except Exception as e:
                logging.error(f"Error processing reference data for wall units: {e}", exc_info=True)
                # Prevent plotting bad ref data
                y_plus_ref, u_plus_ref, k_plus_ref, eps_plus_ref = None, None, None, None

        # --- Create Wall Unit Plots ---
        plt.figure(figsize=(18, 6)) # Figure for U+, k+, eps+

        # Determine plot limits dynamically
        y_plus_max_plot = 1.1 * max(
            np.max(y_plus_pinn) if len(y_plus_pinn) > 0 else 100,
            np.max(y_plus_ref) if y_plus_ref is not None and len(y_plus_ref) > 0 else 100,
            self.config.YP_PLUS_TARGET * 1.5 # Ensure target y+ is visible
        )
        u_plus_max_plot = 1.1 * max(
             np.max(u_plus_pinn) if len(u_plus_pinn)>0 else 25,
             np.max(u_plus_ref) if u_plus_ref is not None and len(u_plus_ref)>0 else 25
        )
        k_plus_max_plot = 1.1 * max(
             np.max(k_plus_pinn) if len(k_plus_pinn)>0 else 5,
             np.max(k_plus_ref) if k_plus_ref is not None and len(k_plus_ref)>0 else 5,
             (self.config.K_TARGET_WF / max(self.pinn_data_utau**2, 1e-9)) * 1.2 if self.pinn_data_utau else 5 # Include target k+ if estimable
        )
        # Determine eps+ limits carefully (can vary wildly)
        min_eps_plus_data = min(
            np.min(eps_plus_pinn[eps_plus_pinn > 0]) if np.any(eps_plus_pinn > 0) else 1e-4,
            np.min(eps_plus_ref[eps_plus_ref > 0]) if eps_plus_ref is not None and np.any(eps_plus_ref > 0) else 1e-4
        )
        max_eps_plus_data = max(
            np.max(eps_plus_pinn) if len(eps_plus_pinn)>0 else 1,
            np.max(eps_plus_ref) if eps_plus_ref is not None and len(eps_plus_ref)>0 else 1
        )


        # 1. U+ vs y+ plot
        ax1 = plt.subplot(1, 3, 1)
        ax1.semilogx(y_plus_pinn, u_plus_pinn, 'r.', ms=4, label=f'PINN ($u_\\tau \\approx {self.pinn_data_utau:.3f}$)')
        if y_plus_ref is not None and u_plus_ref is not None:
            ax1.semilogx(y_plus_ref, u_plus_ref, 'bo', mfc='none', ms=5, label=f'Reference ($u_\\tau \\approx {self.ref_data_utau:.3f}$)' if self.ref_data_utau else 'Reference (u_tau N/A)')
        # Theoretical laws
        y_plus_theory_log = np.logspace(np.log10(max(11, 1)), np.log10(y_plus_max_plot*1.1), 100) # Extend slightly beyond max y+
        u_plus_loglaw = (1 / kappa) * np.log(y_plus_theory_log) + E_wall_plot
        y_plus_theory_vis = np.linspace(0.1, 20, 50) # Viscous sublayer range
        u_plus_viscous = y_plus_theory_vis # U+ = y+
        ax1.semilogx(y_plus_theory_log, u_plus_loglaw, 'k:', lw=1.5, label=f'Log-Law ($\\kappa={kappa}, E={E_wall_plot}$)')
        ax1.semilogx(y_plus_theory_vis, u_plus_viscous, 'k--', lw=1.5, label='Viscous ($U^+=y^+$)')
        ax1.set_xlabel('$y^+$')
        ax1.set_ylabel('$U^+$')
        ax1.set_title(f'$U^+$ vs $y^+$ (x={actual_x_pinn:.2f}m)')
        ax1.legend(fontsize=9)
        ax1.grid(True, which='both', ls=':')
        ax1.set_ylim(bottom=0, top=u_plus_max_plot)
        ax1.set_xlim(left=0.1, right=y_plus_max_plot) # Start x-axis slightly > 0 for log scale


        # 2. k+ vs y+ plot
        ax2 = plt.subplot(1, 3, 2)
        ax2.semilogx(y_plus_pinn, k_plus_pinn, 'r.', ms=4, label='PINN')
        if y_plus_ref is not None and k_plus_ref is not None:
            ax2.semilogx(y_plus_ref, k_plus_ref, 'bo', mfc='none', ms=5, label='Reference')
        # Add line for target y+ location
        ax2.axvline(self.config.YP_PLUS_TARGET, color='g', ls='-.', lw=1, label=f'Target $y_p^+ \\approx {self.config.YP_PLUS_TARGET:.1f}$')
        ax2.set_xlabel('$y^+$')
        ax2.set_ylabel('$k^+$')
        ax2.set_title('$k^+$ vs $y^+$')
        ax2.legend(fontsize=9)
        ax2.grid(True, which='both', ls=':')
        ax2.set_ylim(bottom=0, top=k_plus_max_plot) # k+ starts from 0
        ax2.set_xlim(left=0.1, right=y_plus_max_plot)


        # 3. eps+ vs y+ plot (log-log scale often used)
        ax3 = plt.subplot(1, 3, 3)
        ax3.loglog(y_plus_pinn, eps_plus_pinn, 'r.', ms=4, label='PINN')
        if y_plus_ref is not None and eps_plus_ref is not None:
            ax3.loglog(y_plus_ref, eps_plus_ref, 'bo', mfc='none', ms=5, label='Reference')
        # Theoretical trend near wall: eps+ ~ 1 / (kappa * y+) -> C / y+
        y_plus_theory_eps = np.logspace(np.log10(max(1,0.1)), np.log10(y_plus_max_plot*1.1), 100)
        # Adjust constant C for better visual fit if needed, 1/kappa is theoretical
        eps_plus_target_theory = 1.0 / (kappa * y_plus_theory_eps)
        ax3.loglog(y_plus_theory_eps, eps_plus_target_theory, 'k:', lw=1.5, label='$\\epsilon^+ \\propto 1/y^+$')
        ax3.axvline(self.config.YP_PLUS_TARGET, color='g', ls='-.', lw=1, label=f'Target $y_p^+ \\approx {self.config.YP_PLUS_TARGET:.1f}$')
        ax3.set_xlabel('$y^+$')
        ax3.set_ylabel('$\\epsilon^+$')
        ax3.set_title('$\\epsilon^+$ vs $y^+$ (log-log)')
        ax3.legend(fontsize=9)
        ax3.grid(True, which='both', ls=':')
        # Set reasonable y-limits for eps+
        ax3.set_ylim(bottom=max(min_eps_plus_data * 0.5, 1e-5), top=max_eps_plus_data * 2)
        ax3.set_xlim(left=0.1, right=y_plus_max_plot)


        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
        plt.suptitle(f'Wall Unit Profiles Comparison (Top Wall, x ≈ {actual_x_pinn:.2f}m)', fontsize=16)
        plt.subplots_adjust(top=0.88) # Adjust top margin for suptitle

        save_path = os.path.join(self.plots_dir, "profile_comparison_wall_units.png")
        try:
            plt.savefig(save_path, dpi=200, bbox_inches='tight')
            plt.close()
            logging.info(f"Wall unit comparison plots saved to {save_path}")
        except Exception as e:
            logging.error(f"Failed to save wall unit comparison plot: {e}")

    def plot_pressure_gradient_comparison(self):
        """Plots the streamwise pressure gradient dp/dx along the centerline."""
        if self.p_pred is None:
            logging.warning("PINN pressure data unavailable. Skipping pressure gradient plot.")
            return

        logging.info("Generating centerline pressure gradient comparison plot...")
        x_coords_pinn = self.X_grid[0, :] # Streamwise coordinates
        y_coords_pinn = self.Y_grid[:, 0] # Transverse coordinates
        # Find index closest to centerline y=0
        center_idx_pinn = np.argmin(np.abs(y_coords_pinn - 0.0))
        actual_y_center = y_coords_pinn[center_idx_pinn]

        # Extract PINN kinematic pressure along centerline
        p_centerline_pinn = self.p_pred[center_idx_pinn, :]
        # Calculate gradient using numpy.gradient
        dp_dx_pinn = np.gradient(p_centerline_pinn, x_coords_pinn)

        # --- Calculate Reference Pressure Gradient (if data available) ---
        dp_dx_ref = None
        x_coords_ref = None
        if self.has_ref_data and 'p_ref' in self.ref_data.columns:
            try:
                # Filter reference data near centerline (allow some tolerance)
                ref_centerline = self.ref_data[
                    np.isclose(self.ref_data['y'], 0.0, atol=0.02 * self.config.CHANNEL_HALF_HEIGHT)
                ].copy()
                ref_centerline.sort_values(by='x', inplace=True)

                if len(ref_centerline) > 1:
                    # If multiple y-values close to center, average pressure at each x
                    centerline_grouped = ref_centerline.groupby('x')['p_ref'].mean()
                    if len(centerline_grouped) > 5: # Need sufficient points for gradient
                        x_coords_ref = centerline_grouped.index.values
                        p_centerline_ref = centerline_grouped.values
                        if len(x_coords_ref) > 1:
                            dp_dx_ref = np.gradient(p_centerline_ref, x_coords_ref)
                            logging.info(f"Calculated reference pressure gradient from {len(x_coords_ref)} centerline points.")
                        else: logging.warning("Not enough unique x-points in reference centerline data for gradient.")
                    else: logging.warning("Not enough grouped x-points in reference centerline data for gradient.")
                else: logging.warning("Not enough points near centerline in reference data for gradient.")
            except Exception as e:
                logging.warning(f"Could not calculate reference pressure gradient: {e}")
        elif self.has_ref_data:
            logging.warning("Reference data loaded, but 'p_ref' column missing. Cannot plot reference pressure gradient.")

        # --- Create Plot ---
        plt.figure(figsize=(10, 6))
        plt.plot(x_coords_pinn, dp_dx_pinn, 'r-', lw=2, label='PINN $dp/dx$')
        if dp_dx_ref is not None and x_coords_ref is not None:
            plt.plot(x_coords_ref, dp_dx_ref, 'b--', lw=1.5, label='Reference $dp/dx$ (CSV)')

        plt.xlabel('x (m)')
        plt.ylabel('$dp/dx$ $(m/s^2)$') # Assuming kinematic pressure p
        plt.title(f'Streamwise Kinematic Pressure Gradient along Centerline (y ≈ {actual_y_center:.3f}m)')
        plt.legend()
        plt.grid(True, ls=':')

        # Optional: Set y-limits based on expected range (often negative and near constant)
        try:
           # Focus on the developed region (e.g., middle half)
           focus_start_idx = len(dp_dx_pinn) // 4
           focus_end_idx = -len(dp_dx_pinn) // 4
           if focus_start_idx < focus_end_idx : # Ensure valid slice
                focus_region_pinn = dp_dx_pinn[focus_start_idx:focus_end_idx]
                if len(focus_region_pinn) > 0:
                     mean_dpdx = np.mean(focus_region_pinn)
                     std_dpdx = np.std(focus_region_pinn)
                     # Set ylim to mean +/- a few std deviations, or fixed range
                     plt.ylim(mean_dpdx - 4*max(std_dpdx, abs(mean_dpdx)*0.1), mean_dpdx + 4*max(std_dpdx, abs(mean_dpdx)*0.1))
        except Exception:
             pass # Ignore errors in ylim setting

        plt.tight_layout()
        save_path = os.path.join(self.plots_dir, "pressure_gradient_comparison.png")
        try:
            plt.savefig(save_path, dpi=200)
            plt.close()
            logging.info(f"Pressure gradient comparison plot saved to {save_path}")
        except Exception as e:
            logging.error(f"Failed to save pressure gradient plot: {e}")

    def run_post_processing(self):
        """Runs the full post-processing sequence."""
        logging.info("--- Starting Full Post-Processing ---")
        self.plot_loss_history()

        # Prediction is required for all field plots
        prediction_successful = self.predict_pinn_fields()

        if prediction_successful:
            self.plot_contour_fields()
            # Load reference data only if prediction was successful
            self.load_reference_data()
            if self.has_ref_data:
                logging.info("Proceeding with PINN vs Reference CSV comparisons...")
                self.plot_profile_comparison()
                self.plot_wall_unit_comparison()
                self.plot_pressure_gradient_comparison()
            else:
                logging.warning("Skipping comparison plots as reference data is unavailable or failed to load.")
        else:
            logging.error("PINN field prediction failed. Aborting further post-processing.")

        logging.info("--- Post-Processing Finished ---")
# --- End Plotter Class ---


# =============================
# ===== Main Execution Block =====
# =============================
if __name__ == "__main__":
    main_start_time = time.time()

    # --- 1. Initial Setup ---
    main_cfg = Config() # Instantiate default config
    main_plot_cfg = PlotterConfig()

    # --- Google Drive Mount (if applicable) ---
    # This will update main_cfg paths if running in Colab and mount is successful
    mount_drive(main_cfg.DRIVE_MOUNT_POINT)

    # --- Setup Output Dirs and Logging ---
    # These must run AFTER mount_drive potentially updates main_cfg.OUTPUT_DIR etc.
    setup_output_directories(main_cfg)
    setup_logging(main_cfg.LOG_FILE) # Configure logging

    logging.info("="*60); logging.info(" PINN RANS k-epsilon Channel Flow Simulation Start "); logging.info("="*60)
    log_configuration(main_cfg, main_plot_cfg) # Log the final configuration used

    # --- 2. Define Boundaries ---
    try:
        # Pass the config object to the function
        bcs, anchor_points = get_boundary_conditions(main_cfg)
        logging.info(f"Defined {len(bcs)} boundary conditions.")
        if anchor_points is not None and len(anchor_points) > 0:
            logging.info(f"Using {anchor_points.shape[0]} anchor points for wall functions.")
        else:
             # Should have anchor points if using wall functions as defined
             logging.warning("No anchor points generated for wall functions, check get_boundary_conditions.")
    except Exception as e:
        logging.error(f"Failed to define boundary conditions: {e}", exc_info=True)
        sys.exit(1) # Exit if BCs cannot be defined

    # --- 3. Training Phase ---
    model_trained, history_trained, state_trained = None, None, None
    training_successful = False
    try:
        # Pass the config object to the Trainer
        trainer = Trainer(main_cfg)
        trainer.build_model(bcs, anchor_points)
        # Execute the updated training method
        model_trained, history_trained, state_trained = trainer.train()

        # Check if training outputs seem valid
        # Basic check: ensure they are not None
        if model_trained is not None and history_trained is not None and state_trained is not None:
             # Could add more checks, e.g., on final loss
             training_successful = True
             logging.info("Training phase returned valid model, history, and state objects.")
        else:
             logging.error("Training phase finished but returned an invalid state (model, losshistory, or train_state is None).")

    except Exception as e:
         logging.error(f"A critical error occurred during model building or the training phase: {e}", exc_info=True)
         # model_trained, etc., will remain None, training_successful will be False

    # --- 4. Post-processing and Plotting Phase ---
    if training_successful:
        logging.info("Proceeding to post-processing.")
        try:
            # Pass the final trained state to the plotter
            plotter = Plotter(main_cfg, main_plot_cfg, model_trained, history_trained, state_trained)
            plotter.run_post_processing()
        except Exception as e:
             logging.error(f"An error occurred during post-processing: {e}", exc_info=True)
    else:
        # This message will be logged if training failed, returned None, or hit a critical error
        logging.error("Training did not complete successfully or produced an invalid state. Skipping post-processing.")

    main_end_time = time.time()
    logging.info("="*60); logging.info(f" Script Execution Finished in {main_end_time - main_start_time:.2f} seconds"); logging.info("="*60)

Using backend: pytorch
Other supported backends: tensorflow.compat.v1, tensorflow, jax, paddle.
paddle supports more examples now and is recommended.


DeepXDE Backend requested: pytorch
DeepXDE Backend actual: pytorch
CUDA available.
PyTorch CUDA device detected by DDE: 0 (NVIDIA GeForce RTX 3050 Laptop GPU)
PyTorch version: 2.6.0+cu126
Number of GPUs: 1
PyTorch default dtype: torch.float32
2025-04-14 23:38:14 [INFO] Logging configured.
2025-04-14 23:38:14 [INFO]  PINN RANS k-epsilon Channel Flow Simulation Start 
2025-04-14 23:38:14 [INFO] Simulation Configuration:
2025-04-14 23:38:14 [INFO]   Output Directory: PINN_RANS_ChannelFlow
2025-04-14 23:38:14 [INFO]   Re_H: 10000
2025-04-14 23:38:14 [INFO]   Wall Function y_p: 0.04 (Target y+: 14.00)
2025-04-14 23:38:14 [INFO]   Network: 8 layers, 64 neurons
2025-04-14 23:38:14 [INFO]   Inlet k (log): -5.5860, Inlet eps (log): -7.5257
2025-04-14 23:38:14 [INFO]   Target WF U: 0.8402, k (log): -4.1145, eps (log): -3.8673
2025-04-14 23:38:14 [INFO]   Adam Iterations: 50000, LR: 0.001
2025-04-14 23:38:14 [INFO]   L-BFGS Iterations: 20000
2025-04-14 23:38:14 [INFO]   Checkpoint Interval: 1000


KeyboardInterrupt: 