In [13]:
# prompt: write the code to mount the gdrive and list the files and folders in the drive , list also using the os

from google.colab import drive
import os

drive.mount('/content/drive')

# List files and folders using Google Colab's drive API
!ls '/content/drive/My Drive'

# List files and folders using the os module
print("\nListing files and folders using os module:")
for root, dirs, files in os.walk('/content/drive/My Drive'):
    level = root.replace('/content/drive/My Drive', '').count(os.sep)
    indent = ' ' * 4 * (level)
    print('{}{}/'.format(indent, os.path.basename(root)))
    subindent = ' ' * 4 * (level + 1)
    for f in files:
        print('{}{}'.format(subindent, f))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
'Colab Notebooks'   PINN_RANS_ChannelFlow

Listing files and folders using os module:
My Drive/
    Colab Notebooks/
        Untitled0.ipynb
        Untitled1.ipynb
    PINN_RANS_ChannelFlow/
        model_checkpoints/
            rans_channel_wf-30000.pt-17000.pt
            rans_channel_wf-49000.pt-8000.pt
            rans_channel_wf.ckpt-7000.pt
            rans_channel_wf-27000.pt-14000.pt
            rans_channel_wf-40000.pt-27000.pt
            rans_channel_wf-33000.pt-20000.pt
            rans_channel_wf.ckpt-6000.pt
            rans_channel_wf.ckpt-10000.pt
            rans_channel_wf-20000.pt-7000.pt
            rans_channel_wf-42000.pt-1000.pt
            rans_channel_wf-43000.pt-2000.pt
            rans_channel_wf-15000.pt-2000.pt
            rans_channel_wf-50000.pt-9000.pt
            rans_channel_wf.ckpt-9000.pt
            rans_channel_wf.ckpt-

In [8]:
import os
# Set environment variable *before* importing deepxde or torch
os.environ["DDE_BACKEND"] = "pytorch"
try:
  import torch
except ImportError:
  print("Installing torch...")
  !pip install torch -q
try:
  import deepxde
except ImportError:
  print("Installing deepxde...")
  !pip install deepxde -q
try:
  import pandas
except ImportError:
  print("Installing pandas...")
  !pip install pandas -q
try:
  import matplotlib
except ImportError:
  print("Installing matplotlib...")
  !pip install matplotlib -q

import sys
import time
import logging
import numpy as np
import torch # Now import torch after potentially installing
import deepxde as dde # Import deepxde after setting backend
import matplotlib.pyplot as plt
import pandas as pd
from scipy.interpolate import griddata
import re # <<<--- IMPORT REGEX MODULE


# --- Attempt to explicitly set backend (optional but good practice) ---
try:
    # This might still fail on older DeepXDE versions, but the env var is primary
    dde.config.set_default_backend("pytorch")
    print("Attempted to explicitly set DeepXDE backend to PyTorch.")
except AttributeError:
    print(f"Warning: Could not explicitly set backend via dde.config (likely older DeepXDE version). Relied on environment variable DDE_BACKEND={os.environ.get('DDE_BACKEND')}.")
except Exception as e:
    print(f"Warning: Could not explicitly set backend via dde.config: {e}")

print(f"DeepXDE Backend requested: {os.environ.get('DDE_BACKEND', 'Not Set')}")

# --- Check actual backend and setup device/dtype ---
if "deepxde" in sys.modules and hasattr(dde, 'backend'):
    print(f"DeepXDE Backend actual: {dde.backend.backend_name}")
    if dde.backend.backend_name == "pytorch":
        if torch.cuda.is_available():
            print("CUDA available.")
            try:
                # Use float32 as it's common for PINNs and avoids potential double precision issues
                torch.set_default_dtype(torch.float32)
                current_device = torch.cuda.current_device()
                print(f"PyTorch CUDA device detected by DDE: {current_device} ({torch.cuda.get_device_name(current_device)})")
                print(f"PyTorch version: {torch.__version__}")
                print(f"Number of GPUs: {torch.cuda.device_count()}")
            except Exception as e:
                print(f"Warning: Error during PyTorch device setup: {e}")
        else:
            print("CUDA not available. Using CPU.")
            try:
                torch.set_default_dtype(torch.float32)
                print(f"PyTorch default device set to: cpu")
            except Exception as e:
                print(f"Warning: Failed to set default PyTorch device to CPU: {e}")
        print(f"PyTorch default dtype: {torch.get_default_dtype()}")
    else:
        print(f"Warning: Backend is '{dde.backend.backend_name}', not PyTorch. PyTorch-specific device setup skipped.")
else:
    print("Warning: deepxde module or dde.backend not fully available for backend check.")


# =============================
# ===== Configuration Classes =====
# =============================

class PlotterConfig:
    """Stores configuration parameters specifically for plotting."""
    NX_PRED = 200
    NY_PRED = 100
    CMAP_VELOCITY = 'viridis'
    CMAP_PRESSURE = 'coolwarm'
    CMAP_TURBULENCE = 'plasma'


class Config:
    """Stores configuration parameters for the simulation."""
    DRIVE_MOUNT_POINT = '/content/drive'
      # Adjust path if necessary
    GDRIVE_BASE_FOLDER = '/content/drive/MyDrive/PINN_RANS_ChannelFlow'
    OUTPUT_DIR = GDRIVE_BASE_FOLDER # Default output dir (can be overwritten if not on Colab)
    MODEL_DIR = os.path.join(OUTPUT_DIR, "model_checkpoints")
    LOG_DIR = os.path.join(OUTPUT_DIR, "logs")
    PLOT_DIR = os.path.join(OUTPUT_DIR, "plots")
    DATA_DIR = os.path.join(OUTPUT_DIR, "data")
    LOG_FILE = os.path.join(LOG_DIR, "training_log.log")
    REFERENCE_DATA_FILE = os.path.join(DATA_DIR, "reference_output_data.csv")

    # Checkpoint filename base (without step number or extension)
    CHECKPOINT_FILENAME_BASE  = "rans_channel_wf"

    # --- Fluid and Geometry Parameters ---
    NU = 0.0002 # Kinematic viscosity
    RHO = 1.0 # Density (often set to 1 for incompressible flow)
    MU = RHO * NU # Dynamic viscosity
    U_INLET = 1.0 # Inlet velocity
    H = 2.0 # Full channel height
    CHANNEL_HALF_HEIGHT = H / 2.0
    L = 10.0 # Channel length
    RE_H = U_INLET * H / NU # Reynolds number based on full height
    EPS_SMALL = 1e-10 # Small epsilon for numerical stability (avoid log(0), division by zero)

    # --- k-epsilon Model Constants ---
    CMU = 0.09
    CEPS1 = 1.44
    CEPS2 = 1.92
    SIGMA_K = 1.0
    SIGMA_EPS = 1.3
    KAPPA = 0.41 # Von Karman constant

    # --- Wall Function Parameters ---
    E_WALL = 9.8 # Log-law constant for smooth walls
    Y_P = 0.04 # Distance from wall for applying wall functions (y_p)
    # Target values for deriving wall function BCs (can be based on desired Re_tau)
    RE_TAU_TARGET = 350 # Target friction Reynolds number
    U_TAU_TARGET = RE_TAU_TARGET * NU / CHANNEL_HALF_HEIGHT # Target friction velocity
    YP_PLUS_TARGET = Y_P * U_TAU_TARGET / NU # Target y+ at y_p
    # Target values at y_p based on log-law and turbulence equilibrium
    U_TARGET_WF = (U_TAU_TARGET / KAPPA) * np.log(max(E_WALL * YP_PLUS_TARGET, EPS_SMALL))
    K_TARGET_WF = U_TAU_TARGET**2 / np.sqrt(CMU)
    EPS_TARGET_WF = U_TAU_TARGET**3 / max(KAPPA * Y_P, EPS_SMALL)

    # --- Inlet Turbulence Parameters ---
    TURBULENCE_INTENSITY = 0.05 # Typical value for channel flow inlet
    MIXING_LENGTH_SCALE = 0.07 * CHANNEL_HALF_HEIGHT # Estimate based on boundary layer thickness
    # Inlet k and epsilon based on intensity and length scale
    K_INLET = 1.5 * (U_INLET * TURBULENCE_INTENSITY)**2
    EPS_INLET = (CMU**0.75) * (K_INLET**1.5) / MIXING_LENGTH_SCALE
    # Transformed values (log) for network prediction/BCs
    K_INLET_TRANSFORMED = np.log(max(K_INLET, EPS_SMALL))
    EPS_INLET_TRANSFORMED = np.log(max(EPS_INLET, EPS_SMALL))
    K_TARGET_WF_TRANSFORMED = np.log(max(K_TARGET_WF, EPS_SMALL))
    EPS_TARGET_WF_TRANSFORMED = np.log(max(EPS_TARGET_WF, EPS_SMALL))

    # --- Domain Geometry ---
    GEOM = dde.geometry.Rectangle(xmin=[0, -CHANNEL_HALF_HEIGHT], xmax=[L, CHANNEL_HALF_HEIGHT])

    # --- Network Architecture ---
    NUM_LAYERS = 8
    NUM_NEURONS = 64
    ACTIVATION = "tanh"
    INITIALIZER = "Glorot normal"
    NETWORK_INPUTS = 2 # x, y
    NETWORK_OUTPUTS = 5 # u, v, p', log(k), log(eps)

    # --- Training Parameters ---
    NUM_DOMAIN_POINTS = 20000
    NUM_BOUNDARY_POINTS = 4000 # For physical boundaries (inlet, outlet, walls)
    NUM_TEST_POINTS = 5000 # For evaluating PDE residuals during training/testing
    NUM_WF_POINTS_PER_WALL = 200 # Anchor points for wall function BCs (per wall)
    LEARNING_RATE_ADAM = 1e-3
    ADAM_ITERATIONS = 50000 # Number of iterations for Adam optimizer
    LBFGS_ITERATIONS = 20000 # Max iterations for L-BFGS optimizer
    # Loss weights: [PDE_cont, PDE_mom_x, PDE_mom_y, PDE_k, PDE_eps, BC_u_in, BC_v_in, BC_k_in, BC_eps_in, BC_p_out, BC_u_wall, BC_v_wall, BC_u_wf, BC_k_wf, BC_eps_wf]
    PDE_WEIGHTS = [1, 1, 1, 1, 1] # Weights for the 5 PDE residuals
    BC_WEIGHTS = [10, 10, 10, 10, 10, 10, 10, 20, 20, 20] # Weights for the 10 BCs (adjust as needed)
    LOSS_WEIGHTS = PDE_WEIGHTS + BC_WEIGHTS
    SAVE_INTERVAL = 1000 # Checkpoint saving interval (steps)
    DISPLAY_EVERY = 1000 # Loss display interval (steps)


# Instantiate config objects
cfg = Config()
plotter_cfg = PlotterConfig()


# ==============================================
# ===== Custom Checkpoint Callback Class ========
# ==============================================
class CustomModelCheckpoint(dde.callbacks.Callback):
    """Custom checkpoint callback that saves based on global step."""
    def __init__(self, filepath_base, period, verbose=1):
        super().__init__()
        self.filepath_base = filepath_base # e.g., /path/to/model_checkpoints/rans_channel_wf-
        self.period = period
        self.verbose = verbose
        self._saved_steps = set() # Tracks steps saved in the current trainer.train() call

    def on_epoch_end(self):
        """Check step at the end of each epoch."""
        self._save_checkpoint()

    # on_batch_end can also be used for finer control if needed, but on_epoch_end is common
    # def on_batch_end(self):
    #     self._save_checkpoint()

    def _save_checkpoint(self):
        """Internal method to check step and save."""
        if not hasattr(self, 'model') or not self.model or not hasattr(self.model, 'train_state') or not self.model.train_state:
            logging.debug("Model or train_state not ready for checkpoint.")
            return # Model not ready

        if dde.backend.backend_name != "pytorch":
             logging.warning("CustomModelCheckpoint requires PyTorch backend for model.save behavior.")
             return

        # Access the global step count directly from train_state
        step = self.model.train_state.step

        # Check if the current step is a multiple of the saving period,
        # is positive (avoid saving at step 0 unnecessarily),
        # and hasn't already been saved during this train() call.
        if step > 0 and step % self.period == 0 and step not in self._saved_steps:
            filepath = f"{self.filepath_base}{step}.pt" # Construct filename using global step
            if self.verbose > 0:
                logging.info(f"Step {step}: saving model to {filepath} ...")
            try:
                # Use DeepXDE's save method which handles backend specifics
                # Important: Pass save_optimizer_state=True if you need to restore the optimizer too (crucial for resuming)
                # Pass save_best_only=False as we are saving periodically based on steps
                self.model.save(filepath, verbose=0) # verbose=0 prevents double logging
                self._saved_steps.add(step) # Mark this step as saved for this run
            except Exception as e:
                logging.error(f"Error saving checkpoint at step {step} to {filepath}: {e}", exc_info=True)

# ==============================================
# ===== END: Custom Checkpoint Callback ========
# ==============================================


# ==========================
# ===== Utility Functions =====
# ==========================
def setup_logging(log_file):
    """Configures logging to file and console."""
    log_dir = os.path.dirname(log_file)
    ensure_dir(log_dir)
    root_logger = logging.getLogger()
    # Clear existing handlers to avoid duplicate messages if run multiple times in notebook
    if root_logger.hasHandlers():
        root_logger.handlers.clear()
    log_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt='%Y-%m-%d %H:%M:%S')
    root_logger.setLevel(logging.INFO) # Set root logger level

    # File handler
    try:
        file_handler = logging.FileHandler(log_file, mode='a') # Append mode
        file_handler.setFormatter(log_formatter)
        root_logger.addHandler(file_handler)
    except Exception as e:
        print(f"Error setting up file logger at {log_file}: {e}")


    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(log_formatter)
    root_logger.addHandler(console_handler)
    logging.info("Logging configured.")

def ensure_dir(directory):
    """Creates a directory if it doesn't exist."""
    if not os.path.exists(directory):
        try:
            os.makedirs(directory)
            logging.info(f"Created directory: {directory}")
        except OSError as e:
            logging.error(f"Failed to create directory {directory}: {e}")


def mount_drive(mount_point):
    """Mounts Google Drive if running in Colab."""
    if 'google.colab' in sys.modules:
        if not os.path.exists(os.path.join(mount_point, 'MyDrive')):
            try:
                from google.colab import drive
                logging.info(f"Mounting Google Drive at {mount_point}...")
                drive.mount(mount_point, force_remount=True)
                logging.info("Google Drive mounted successfully.")
                # Verify base folder access after mount
                # IMPORTANT: Use os.path.join correctly. Assumes GDRIVE_BASE_FOLDER starts relative to MyDrive
                gdrive_output_path = os.path.join(mount_point, 'MyDrive', cfg.GDRIVE_BASE_FOLDER.lstrip('/')) # Remove leading / if present
                cfg.OUTPUT_DIR = gdrive_output_path # IMPORTANT: Update config path
                cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
                cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
                cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
                cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
                cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "training_log.log")
                cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
                logging.info(f"Output paths updated to Google Drive: {cfg.OUTPUT_DIR}")
                ensure_dir(cfg.OUTPUT_DIR) # Create base dir on GDrive if it doesn't exist
                if os.path.exists(cfg.OUTPUT_DIR):
                    logging.info(f"Base folder exists: {cfg.OUTPUT_DIR}")
                else:
                    logging.warning(f"Configured base folder NOT found after mount attempt: {cfg.OUTPUT_DIR}")
            except Exception as e:
                logging.error(f"Error mounting Google Drive or accessing path: {e}")
                # Fallback to local directory if mount fails
                logging.warning("Falling back to local directory structure.")
                cfg.OUTPUT_DIR = cfg.GDRIVE_BASE_FOLDER # Use base folder name locally
                cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
                cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
                cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
                cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
                cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "training_log.log")
                cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
        else:
            logging.info("Google Drive already mounted.")
            # Still update paths if already mounted
            gdrive_output_path = os.path.join(mount_point, 'MyDrive', cfg.GDRIVE_BASE_FOLDER.lstrip('/'))
            cfg.OUTPUT_DIR = gdrive_output_path # IMPORTANT: Update config path
            cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
            cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
            cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
            cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
            cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "training_log.log")
            cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
            logging.info(f"Output paths point to Google Drive: {cfg.OUTPUT_DIR}")
    else:
        logging.info("Not running in Google Colab. Using local directory structure.")
        # Ensure local paths are based on the script location or CWD
        cfg.OUTPUT_DIR = cfg.GDRIVE_BASE_FOLDER # Use base folder name locally
        cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
        cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
        cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
        cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
        cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "training_log.log")
        cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")


def setup_output_directories(config):
    """Creates all necessary output directories."""
    logging.info("Setting up output directories...")
    ensure_dir(config.OUTPUT_DIR)
    ensure_dir(config.MODEL_DIR)
    ensure_dir(config.LOG_DIR)
    ensure_dir(config.PLOT_DIR)
    ensure_dir(config.DATA_DIR)
    logging.info("Output directories verified/created.")

def log_configuration(config, plotter_config):
    """Logs the simulation and plotter configuration."""
    logging.info("=" * 50)
    logging.info("Simulation Configuration:")
    logging.info(f"  Output Directory: {config.OUTPUT_DIR}")
    logging.info(f"  Re_H: {config.RE_H:.0f}")
    logging.info(f"  Wall Function y_p: {config.Y_P} (Target y+: {config.YP_PLUS_TARGET:.2f})")
    logging.info(f"  Network: {config.NUM_LAYERS} layers, {config.NUM_NEURONS} neurons")
    logging.info(f"  Inlet k (log): {config.K_INLET_TRANSFORMED:.4f}, Inlet eps (log): {config.EPS_INLET_TRANSFORMED:.4f}")
    logging.info(f"  Target WF U: {config.U_TARGET_WF:.4f}, k (log): {config.K_TARGET_WF_TRANSFORMED:.4f}, eps (log): {config.EPS_TARGET_WF_TRANSFORMED:.4f}")
    logging.info(f"  Adam Iterations: {config.ADAM_ITERATIONS}, LR: {config.LEARNING_RATE_ADAM}")
    logging.info(f"  L-BFGS Iterations: {config.LBFGS_ITERATIONS}")
    logging.info(f"  Checkpoint Interval: {config.SAVE_INTERVAL}")
    logging.info(f"  Reference Data File (CSV): {config.REFERENCE_DATA_FILE}")
    logging.info("Plotter Configuration:")
    logging.info(f"  Prediction Grid Nx: {plotter_config.NX_PRED}, Ny: {plotter_config.NY_PRED}")
    logging.info("=" * 50)
# --- End Utility Functions ---


# ===============================
# ===== PDE System Definition =====
# ===============================
def pde(x, y, config): # Pass config explicitly
    """Defines the RANS k-epsilon PDE system."""
    # Ensure backend is PyTorch as autograd syntax is used
    if dde.backend.backend_name != "pytorch":
        raise RuntimeError("PDE function relies on PyTorch autograd. Backend mismatch.")

    nu = config.NU; Cmu = config.CMU; Ceps1 = config.CEPS1; Ceps2 = config.CEPS2
    sigma_k = config.SIGMA_K; sigma_eps = config.SIGMA_EPS; eps_small = config.EPS_SMALL

    # Network outputs: u, v, p', log(k), log(eps)
    u, v, p_prime, k_raw, eps_raw = y[:, 0:1], y[:, 1:2], y[:, 2:3], y[:, 3:4], y[:, 4:5]

    # --- Apply inverse transformation and enforce positivity ---
    # Add eps_small *after* exp to ensure positivity even if raw output is very small
    k = torch.exp(k_raw) + eps_small
    eps = torch.exp(eps_raw) + eps_small

    # --- Calculate Gradients using PyTorch Autograd via DeepXDE wrappers ---
    try:
        # Gradients of primitive variables (u, v, p') directly from network output 'y'
        u_x = dde.grad.jacobian(y, x, i=0, j=0); u_y = dde.grad.jacobian(y, x, i=0, j=1)
        v_x = dde.grad.jacobian(y, x, i=1, j=0); v_y = dde.grad.jacobian(y, x, i=1, j=1)
        p_prime_x = dde.grad.jacobian(y, x, i=2, j=0); p_prime_y = dde.grad.jacobian(y, x, i=2, j=1)

        u_xx = dde.grad.hessian(y, x, component=0, i=0, j=0); u_yy = dde.grad.hessian(y, x, component=0, i=1, j=1)
        v_xx = dde.grad.hessian(y, x, component=1, i=0, j=0); v_yy = dde.grad.hessian(y, x, component=1, i=1, j=1)
        # Mixed derivatives (needed for momentum diffusion terms)
        u_xy = dde.grad.hessian(y, x, component=0, i=0, j=1)
        v_xy = dde.grad.hessian(y, x, component=1, i=0, j=1) # Order matters: d/dx(d/dy(...))

        # --- Gradients of transformed k, eps using PyTorch autograd ---
        # We need gradients of k and eps themselves, not log(k), log(eps)
        # Ensure x requires grad if DeepXDE doesn't handle it automatically
        if isinstance(x, torch.Tensor) and not x.requires_grad:
            # This might be needed if x comes from data loading without gradients enabled
            x.requires_grad_(True)

        # Calculate grads for k
        grad_k = torch.autograd.grad(k, x, grad_outputs=torch.ones_like(k), create_graph=True)[0]
        k_x, k_y = grad_k[:, 0:1], grad_k[:, 1:2]
        # Calculate grads for eps
        grad_eps = torch.autograd.grad(eps, x, grad_outputs=torch.ones_like(eps), create_graph=True)[0]
        eps_x, eps_y = grad_eps[:, 0:1], grad_eps[:, 1:2]

        # --- Hessians of transformed k, eps using PyTorch autograd ---
        # Calculate hessians for k (laplacian components)
        # Need create_graph=True on first grad to compute second grad
        grad_kx = torch.autograd.grad(k_x, x, grad_outputs=torch.ones_like(k_x), create_graph=True)[0]
        k_xx = grad_kx[:, 0:1]
        grad_ky = torch.autograd.grad(k_y, x, grad_outputs=torch.ones_like(k_y), create_graph=True)[0]
        k_yy = grad_ky[:, 1:2]

        # Calculate hessians for eps (laplacian components)
        grad_epsx = torch.autograd.grad(eps_x, x, grad_outputs=torch.ones_like(eps_x), create_graph=True)[0]
        eps_xx = grad_epsx[:, 0:1]
        grad_epsy = torch.autograd.grad(eps_y, x, grad_outputs=torch.ones_like(eps_y), create_graph=True)[0]
        eps_yy = grad_epsy[:, 1:2]

    except RuntimeError as grad_e:
        # Catch common autograd errors like trying to backward twice without retain_graph
        logging.error(f"PyTorch Autograd RuntimeError calculating gradients in PDE: {grad_e}. Ensure create_graph=True is used correctly for higher-order derivatives.", exc_info=True)
        # Return tensors of zeros with the correct shape and device matching input 'y'
        zero_tensor = torch.zeros_like(y[:, 0:1])
        return [zero_tensor] * 5 # Match the number of expected PDE residual outputs
    except Exception as grad_e:
        logging.error(f"General error calculating gradients in PDE function: {grad_e}", exc_info=True)
        zero_tensor = torch.zeros_like(y[:, 0:1])
        return [zero_tensor] * 5

    # --- Turbulent Viscosity ---
    # Use the transformed k, eps which are guaranteed positive
    k_safe = k # k already has eps_small added after exp
    eps_safe = eps # eps already has eps_small added after exp
    # Add eps_small to denominator for extra safety, although eps_safe should be positive
    nu_t = Cmu * torch.square(k_safe) / (eps_safe + eps_small) # Eq: nu_t = Cmu * k^2 / eps
    nu_eff = nu + nu_t # Effective viscosity

    # --- Gradients of nu_eff (Needed for diffusion terms in momentum eqns) ---
    # Using chain rule: d(nu_eff)/dx = d(nu_t)/dk * dk/dx + d(nu_t)/deps * deps/dx
    # Ensure denominators are safe
    dnut_dk = 2.0 * Cmu * k_safe / (eps_safe + eps_small)
    dnut_deps = -Cmu * torch.square(k_safe) / torch.square(eps_safe + eps_small)

    nu_eff_x = dnut_dk * k_x + dnut_deps * eps_x
    nu_eff_y = dnut_dk * k_y + dnut_deps * eps_y

    # --- RANS Equation Residuals ---

    # 1. Continuity Equation: d(u)/dx + d(v)/dy = 0
    eq_continuity = u_x + v_y

    # 2. X-Momentum Equation:
    # d(u)/dt + u*du/dx + v*du/dy = -dp'/dx + d/dx[nu_eff * (2*du/dx)] + d/dy[nu_eff * (du/dy + dv/dx)]
    # Steady state: u*du/dx + v*du/dy + dp'/dx - [d/dx(...) + d/dy(...)] = 0
    adv_u = u * u_x + v * u_y # Advection
    # Diffusion terms (expanded using product rule)
    # d/dx[nu_eff * (2*du/dx)] = d(nu_eff)/dx * (2*du/dx) + nu_eff * (2*d^2u/dx^2)
    diff_u_term1 = nu_eff_x * (2 * u_x) + nu_eff * (2 * u_xx)
    # d/dy[nu_eff * (du/dy + dv/dx)] = d(nu_eff)/dy * (du/dy + dv/dx) + nu_eff * (d^2u/dy^2 + d^2v/dxdy)
    # Assuming v_xy = d/dx(d/dy(v)) = d^2v/(dx dy)
    diff_u_term2 = nu_eff_y * (u_y + v_x) + nu_eff * (u_yy + v_xy)
    eq_mom_x = adv_u + p_prime_x - (diff_u_term1 + diff_u_term2)

    # 3. Y-Momentum Equation:
    # d(v)/dt + u*dv/dx + v*dv/dy = -dp'/dy + d/dx[nu_eff * (dv/dx + du/dy)] + d/dy[nu_eff * (2*dv/dy)]
    # Steady state: u*dv/dx + v*dv/dy + dp'/dy - [d/dx(...) + d/dy(...)] = 0
    adv_v = u * v_x + v * v_y # Advection
    # Diffusion terms (expanded)
    # d/dx[nu_eff * (dv/dx + du/dy)] = d(nu_eff)/dx * (dv/dx + du/dy) + nu_eff * (d^2v/dx^2 + d^2u/dxdy)
    # Assuming u_xy = d/dx(d/dy(u)) = d^2u/(dx dy)
    diff_v_term1 = nu_eff_x * (v_x + u_y) + nu_eff * (v_xx + u_xy)
    # d/dy[nu_eff * (2*dv/dy)] = d(nu_eff)/dy * (2*dv/dy) + nu_eff * (2*d^2v/dy^2)
    diff_v_term2 = nu_eff_y * (2 * v_y) + nu_eff * (2 * v_yy)
    eq_mom_y = adv_v + p_prime_y - (diff_v_term1 + diff_v_term2)

    # --- Turbulence Model Equations ---

    # Production term P_k = nu_t * S^2, where S is the modulus of the mean strain rate tensor
    # S^2 = 2*((du/dx)^2 + (dv/dy)^2) + (du/dy + dv/dx)^2  (for 2D)
    S_squared = 2 * (torch.square(u_x) + torch.square(v_y)) + torch.square(u_y + v_x)
    # Ensure P_k is non-negative (though theoretically it should be if nu_t >= 0 and S^2 >= 0)
    P_k = torch.relu(nu_t * S_squared) # Using ReLU for safety, or just nu_t * S_squared if confident

    # 4. k-Equation:
    # d(k)/dt + u*dk/dx + v*dk/dy = d/dx[(nu + nu_t/sigma_k)*dk/dx] + d/dy[(nu + nu_t/sigma_k)*dk/dy] + P_k - eps
    # Steady state: u*dk/dx + v*dk/dy - [Diffusion] - P_k + eps = 0
    adv_k = u * k_x + v * k_y # Advection
    # Diffusion term: div[ (nu + nu_t/sigma_k) * grad(k) ]
    diffusivity_k = nu + nu_t / sigma_k
    # Gradient of diffusivity: d(diff_k)/dx = (1/sigma_k) * d(nu_t)/dx
    # Note: nu_eff_x = d(nu_t)/dx, nu_eff_y = d(nu_t)/dy
    d_diffk_dx = (1 / sigma_k) * nu_eff_x
    d_diffk_dy = (1 / sigma_k) * nu_eff_y
    laplacian_k = k_xx + k_yy
    # Expand divergence using product rule: div(D*grad(k)) = grad(D).grad(k) + D*laplacian(k)
    diffusion_k = d_diffk_dx * k_x + d_diffk_dy * k_y + diffusivity_k * laplacian_k
    # Use eps_safe which is guaranteed positive
    eq_k = adv_k - diffusion_k - P_k + eps_safe

    # 5. ε-Equation:
    # d(eps)/dt + u*deps/dx + v*deps/dy = d/dx[(nu + nu_t/sigma_eps)*deps/dx] + d/dy[(nu + nu_t/sigma_eps)*deps/dy] + Ceps1*(eps/k)*P_k - Ceps2*(eps^2/k)
    # Steady state: u*deps/dx + v*deps/dy - [Diffusion] - Source + Sink = 0
    adv_eps = u * eps_x + v * eps_y # Advection
    # Diffusion term: div[ (nu + nu_t/sigma_eps) * grad(eps) ]
    diffusivity_eps = nu + nu_t / sigma_eps
    d_diffeps_dx = (1 / sigma_eps) * nu_eff_x # Gradients of nu_t
    d_diffeps_dy = (1 / sigma_eps) * nu_eff_y
    laplacian_eps = eps_xx + eps_yy
    diffusion_eps = d_diffeps_dx * eps_x + d_diffeps_dy * eps_y + diffusivity_eps * laplacian_eps
    # Source/Sink terms (use safe, positive k and eps)
    # Add eps_small to denominators for robustness, even if k_safe/eps_safe have it
    source_eps = Ceps1 * (eps_safe / (k_safe + eps_small)) * P_k
    sink_eps = Ceps2 * (torch.square(eps_safe) / (k_safe + eps_small))
    eq_eps = adv_eps - diffusion_eps - source_eps + sink_eps

    # Return the residuals of the 5 equations
    return [eq_continuity, eq_mom_x, eq_mom_y, eq_k, eq_eps]
# --- End PDE function ---


# =============================
# ===== Boundary Conditions =====
# =============================
def get_boundary_conditions(config):
    """Defines all boundary conditions for the channel flow problem."""
    geom = config.GEOM
    h = config.CHANNEL_HALF_HEIGHT
    L = config.L
    y_p = config.Y_P # Wall function distance
    n_wf_points = config.NUM_WF_POINTS_PER_WALL

    # --- Boundary Definition Functions ---
    def boundary_inlet(x, on_boundary):
        return on_boundary and np.isclose(x[0], 0)

    def boundary_outlet(x, on_boundary):
        return on_boundary and np.isclose(x[0], L)

    def boundary_bottom_wall_physical(x, on_boundary):
        # Physical wall at y = -h
        return on_boundary and np.isclose(x[1], -h)

    def boundary_top_wall_physical(x, on_boundary):
        # Physical wall at y = +h
        return on_boundary and np.isclose(x[1], h)

    def boundary_walls_physical(x, on_boundary):
        # Combined physical walls
        return boundary_bottom_wall_physical(x, on_boundary) or boundary_top_wall_physical(x, on_boundary)

    # --- Inlet BCs (Dirichlet) ---
    # u = U_INLET, v = 0, k = k_inlet_transformed, eps = eps_inlet_transformed
    bc_u_inlet = dde.DirichletBC(geom, lambda x: config.U_INLET, boundary_inlet, component=0) # component=0 -> u
    bc_v_inlet = dde.DirichletBC(geom, lambda x: 0, boundary_inlet, component=1) # component=1 -> v
    bc_k_inlet = dde.DirichletBC(geom, lambda x: config.K_INLET_TRANSFORMED, boundary_inlet, component=3) # component=3 -> log(k)
    bc_eps_inlet = dde.DirichletBC(geom, lambda x: config.EPS_INLET_TRANSFORMED, boundary_inlet, component=4) # component=4 -> log(eps)

    # --- Outlet BC (Dirichlet) ---
    # p' = 0 (gauge pressure relative to outlet)
    bc_p_outlet = dde.DirichletBC(geom, lambda x: 0, boundary_outlet, component=2) # component=2 -> p'

    # --- Physical Wall BCs (No-slip) ---
    # u = 0, v = 0
    bc_u_walls = dde.DirichletBC(geom, lambda x: 0, boundary_walls_physical, component=0) # u=0 on walls
    bc_v_walls = dde.DirichletBC(geom, lambda x: 0, boundary_walls_physical, component=1) # v=0 on walls

    # --- Wall Function BCs (PointSetBC near walls) ---
    # Define anchor points slightly away from the walls at y = +/- (h - y_p)
    # Avoid placing points exactly at inlet/outlet for stability
    x_wf_coords = np.linspace(0 + L * 0.01, L - L * 0.01, n_wf_points)[:, None] # Exclude exact corners
    points_bottom_wf = np.hstack((x_wf_coords, np.full_like(x_wf_coords, -h + y_p)))
    points_top_wf = np.hstack((x_wf_coords, np.full_like(x_wf_coords, h - y_p)))
    anchor_points_wf = np.vstack((points_bottom_wf, points_top_wf))
    logging.info(f"Generated {anchor_points_wf.shape[0]} anchor points for wall functions.")

    # Target values at these anchor points (using pre-calculated config values)
    U_target_vals = np.full((anchor_points_wf.shape[0], 1), config.U_TARGET_WF)
    k_target_vals = np.full((anchor_points_wf.shape[0], 1), config.K_TARGET_WF_TRANSFORMED)
    eps_target_vals = np.full((anchor_points_wf.shape[0], 1), config.EPS_TARGET_WF_TRANSFORMED)

    # Define PointSetBCs for u, k (log), eps (log) at the anchor points
    bc_u_wf = dde.PointSetBC(anchor_points_wf, U_target_vals, component=0) # Target u at WF points
    bc_k_wf = dde.PointSetBC(anchor_points_wf, k_target_vals, component=3) # Target log(k) at WF points
    bc_eps_wf = dde.PointSetBC(anchor_points_wf, eps_target_vals, component=4) # Target log(eps) at WF points

    # --- Collect all BCs ---
    all_bcs = [
        bc_u_inlet, bc_v_inlet, bc_k_inlet, bc_eps_inlet, # Inlet (4 BCs)
        bc_p_outlet, # Outlet (1 BC)
        bc_u_walls, bc_v_walls, # Physical Walls (2 BCs)
        bc_u_wf, bc_k_wf, bc_eps_wf # Wall Functions (3 BCs)
    ]
    # Note: The anchor_points_wf array is also needed by the Data object later.
    return all_bcs, anchor_points_wf
# --- End Boundary Conditions ---


# =======================
# ===== Trainer Class =====
# =======================
class Trainer:
    """Handles the setup, training, and checkpointing of the PINN model."""
    def __init__(self, config):
        self.config = config
        self.model = None
        self.losshistory = None
        self.train_state = None
        self.pde = pde # Assign the PDE function

    def build_model(self, bcs, anchor_points):
        """Builds the DeepXDE model including network and data."""
        logging.info("Building the PINN model...")
        if dde.backend.backend_name != "pytorch":
             raise RuntimeError("This code relies on the PyTorch backend.")

        # Define the neural network
        net = dde.maps.FNN(
            layer_sizes=[self.config.NETWORK_INPUTS] + [self.config.NUM_NEURONS] * self.config.NUM_LAYERS + [self.config.NETWORK_OUTPUTS],
            activation=self.config.ACTIVATION,
            kernel_initializer=self.config.INITIALIZER
        )

        # Wrap PDE to include config
        pde_with_config = lambda x, y: self.pde(x, y, config=self.config)

        # Define the PDE data object
        # Use 'anchor_points' argument if available in your DeepXDE version, otherwise rely on BC sampling
        try:
            data = dde.data.PDE(
                geometry=self.config.GEOM,
                pde=pde_with_config,
                bcs=bcs, # List of boundary conditions including PointSetBCs
                num_domain=self.config.NUM_DOMAIN_POINTS,
                num_boundary=self.config.NUM_BOUNDARY_POINTS, # Sample points on physical boundaries
                num_test=self.config.NUM_TEST_POINTS, # Points for testing PDE residual during training
                anchors=anchor_points # Explicitly provide wall function anchor points here
            )
            logging.info(f"Using {anchor_points.shape[0]} anchor points for wall functions (passed to anchors).")
        except TypeError: # Handle older DeepXDE versions that might not have 'anchors'
            logging.warning("DeepXDE version might not support 'anchors' argument in PDE. Relying on PointSetBC sampling.")
            data = dde.data.PDE(
                geometry=self.config.GEOM,
                pde=pde_with_config,
                bcs=bcs, # PointSetBCs are still included here
                num_domain=self.config.NUM_DOMAIN_POINTS,
                num_boundary=self.config.NUM_BOUNDARY_POINTS,
                num_test=self.config.NUM_TEST_POINTS
                # num_anchors is deprecated/removed in newer versions, do not use
            )


        self.model = dde.Model(data, net)
        logging.info("Model built successfully.")

    # ==============================================================
    # ========    UPDATED Trainer.train Method            ==========
    # ==============================================================
    def train(self):
        """Compiles and trains the model, handles checkpointing and optimizer switching."""
        if self.model is None:
            logging.error("Model not built. Call build_model first.")
            return None, None, None

        # --- Define Checkpoint Paths and Callback ---
        filepath_base = os.path.join(self.config.MODEL_DIR, f"{self.config.CHECKPOINT_FILENAME_BASE}-")
        logging.info(f"Checkpoint filename base: {filepath_base}")
        custom_checkpointer = CustomModelCheckpoint(
            filepath_base=filepath_base,
            period=self.config.SAVE_INTERVAL,
            verbose=1
        )

        # --- Check for Latest Checkpoint ---
        latest_checkpoint = None
        restored_step = 0
        if os.path.exists(self.config.MODEL_DIR):
            try:
                # Use a regex that matches the base name, a hyphen, digits, and the .pt extension
                filename_pattern = re.compile(rf"^{re.escape(os.path.basename(filepath_base))}(\d+)\.pt$")
                checkpoint_files = []
                logging.info(f"Searching for checkpoints in: {self.config.MODEL_DIR}")
                logging.info(f"Using pattern: {filename_pattern.pattern}")

                for f in os.listdir(self.config.MODEL_DIR):
                    full_path = os.path.join(self.config.MODEL_DIR, f)
                    if os.path.isfile(full_path): # Ensure it's a file
                        match = filename_pattern.match(f)
                        if match:
                            step_num = int(match.group(1))
                            checkpoint_files.append((step_num, full_path))
                            logging.debug(f"Found potential checkpoint: {f} (Step: {step_num})") # Debug log

                if checkpoint_files:
                    checkpoint_files.sort(key=lambda item: item[0], reverse=True) # Sort descending by step
                    restored_step, latest_checkpoint = checkpoint_files[0] # Get the one with the highest step
                    logging.info(f"Found latest valid checkpoint: {latest_checkpoint} at step {restored_step}")
                else:
                    logging.info("No valid checkpoints found matching the pattern.")
            except Exception as e:
                logging.error(f"Error finding/parsing checkpoint filenames: {e}", exc_info=True)
                latest_checkpoint = None; restored_step = 0

        restore_path = latest_checkpoint if (latest_checkpoint and os.path.isfile(latest_checkpoint)) else None

        # --- Determine Initial Optimizer Based on Restored Step ---
        initial_optimizer = "adam"
        initial_lr = self.config.LEARNING_RATE_ADAM
        if restore_path and restored_step >= self.config.ADAM_ITERATIONS:
            # If restoring from a step within or after the L-BFGS phase, compile L-BFGS first.
            initial_optimizer = "L-BFGS"
            initial_lr = None # L-BFGS doesn't use LR in compile signature
            logging.info(f"Restored step ({restored_step}) >= Adam iterations ({self.config.ADAM_ITERATIONS}). Will compile with L-BFGS initially.")
        else:
            # Start with Adam (from scratch or resuming within Adam phase).
            logging.info(f"Restored step ({restored_step}) < Adam iterations ({self.config.ADAM_ITERATIONS}) or no checkpoint. Will compile with Adam initially.")

        # --- Compile with the Determined Initial Optimizer ---
        logging.info(f"Compiling model with {initial_optimizer} optimizer initially.")
        if dde.backend.backend_name != "pytorch":
             raise RuntimeError("Cannot compile model, backend is not PyTorch.")

        compile_args = {"optimizer": initial_optimizer, "loss_weights": self.config.LOSS_WEIGHTS}
        if initial_lr is not None:
            compile_args["lr"] = initial_lr
        try:
            # Compile *before* restoring state
            self.model.compile(**compile_args)
            logging.info(f"Model compiled with {initial_optimizer}.")
        except Exception as e:
            logging.error(f"Failed to compile model with {initial_optimizer}: {e}", exc_info=True)
            return None, None, None # Stop if compilation fails


        # --- Explicit Restore (if applicable) ---
        if restore_path:
            try:
                logging.info(f"Explicitly restoring model state from: {restore_path}")
                # Restore should work because the compiled optimizer matches the saved state (or will be handled by restore)
                # Pass ignore_hyperparameters=True if you only want weights/biases and not optimizer state/step etc.
                # However, for resuming, we *do* want the optimizer state and step.
                self.model.restore(restore_path, verbose=1)

                # DeepXDE's restore *should* update train_state.step internally. Let's verify.
                current_step_after_restore = self.model.train_state.step if self.model.train_state else -1
                logging.info(f"Model state restored. Internal DDE step count *after* restore: {current_step_after_restore}")

                # Force set the step counter based on the filename step for consistency, IF different.
                if self.model.train_state:
                    if current_step_after_restore != restored_step:
                         logging.warning(f"Mismatch between expected restored step ({restored_step}) from filename and DDE internal step ({current_step_after_restore}) after restore! Forcing DDE step to match filename.")
                         self.model.train_state.step = restored_step
                         # Log the step count again after forcing it
                         current_step_after_manual_set = self.model.train_state.step
                         logging.info(f"Internal step count after manual setting: {current_step_after_manual_set}")
                    else:
                         logging.info("DDE internal step count matches restored step from filename.")
                else:
                    # This should not happen if restore was successful with a valid state
                    logging.error("Cannot verify/set step count: model.train_state is None after restore. Restore might have failed silently or checkpoint was incomplete.")
                    self.model = None # Indicate failure
                    return None, None, None # Stop if restore fails

            except Exception as e:
                # The specific KeyError: 'step' might occur if checkpoint is old or missing optimizer state.
                # Other errors (file corruption, backend mismatch during save/load) might still occur.
                logging.error(f"Failed during explicit model restore from {restore_path}: {e}", exc_info=True)
                logging.error("Common causes: Incomplete checkpoint (missing optimizer state?), version mismatch, file corruption.")
                self.model = None # Indicate failure
                return None, None, None # Stop if restore fails
        else:
            logging.info("No suitable checkpoint found. Starting training from scratch.")
            restored_step = 0 # Ensure this is 0 if starting fresh
            if self.model.train_state:
                self.model.train_state.step = 0 # Ensure internal step is 0 for new training
            else:
                 # This can happen if the initial compile failed
                 logging.error("model.train_state is None when starting from scratch. Compile likely failed earlier.")
                 return None, None, None


        # --- Adam Training Phase ---
        run_adam_phase = False
        adam_iters_to_run = 0
        # Use the step count that's been synchronized after potential restore
        current_step_synced = self.model.train_state.step if self.model.train_state else restored_step

        if initial_optimizer == "adam":
            # Only run Adam if we started with Adam and haven't finished its iterations
            if current_step_synced < self.config.ADAM_ITERATIONS:
                run_adam_phase = True
                adam_iters_to_run = self.config.ADAM_ITERATIONS - current_step_synced
                logging.info(f"Starting Adam training phase from step {current_step_synced} for {adam_iters_to_run} iterations...")
            else:
                logging.info(f"Adam phase already completed (current step {current_step_synced} >= {self.config.ADAM_ITERATIONS}).")
        else:
            # Adam is skipped if we restored directly into the L-BFGS phase
            logging.info("Skipping Adam phase as initial compilation was L-BFGS.")

        adam_start_time = time.time()
        if run_adam_phase and adam_iters_to_run > 0:
            try:
                # Ensure model is compiled with Adam before training Adam
                if self.model.opt_name != "adam": # Safety check
                     logging.warning("Model optimizer is not Adam before Adam training. Recompiling with Adam.")
                     self.model.compile("adam", lr=self.config.LEARNING_RATE_ADAM, loss_weights=self.config.LOSS_WEIGHTS)
                     # Restore step count if recompile reset it (might happen)
                     if self.model.train_state and self.model.train_state.step != current_step_synced:
                          logging.info(f"Restoring step count to {current_step_synced} after Adam recompile.")
                          self.model.train_state.step = current_step_synced

                self.losshistory, self.train_state = self.model.train(
                    iterations=adam_iters_to_run,
                    display_every=self.config.DISPLAY_EVERY,
                    callbacks=[custom_checkpointer] # Use the custom callback
                )
                adam_time = time.time() - adam_start_time
                # Update current step after Adam phase completes
                current_step_synced = self.model.train_state.step if self.model.train_state else -1
                # Check if losshistory is valid before accessing attributes
                if self.losshistory and hasattr(self.losshistory, 'loss_train') and self.losshistory.loss_train:
                    final_loss_adam = self.losshistory.loss_train[-1] if self.losshistory.loss_train else "N/A"
                    logging.info(f"Adam training ({adam_iters_to_run} iterations) finished in {adam_time:.2f}s. Final loss: {final_loss_adam}. Current step: {current_step_synced}")
                else:
                    logging.error(f"Adam training finished in {adam_time:.2f}s but loss history is empty/invalid. Current step: {current_step_synced}")
                    # Consider whether to proceed to L-BFGS or return failure

            except Exception as e:
                logging.error(f"Error occurred during Adam training: {e}", exc_info=True)
                return self.model, self.losshistory, self.train_state # Exit on error

        # --- L-BFGS Training Phase ---
        run_lbfgs_phase = False
        lbfgs_iters_to_run = 0
        needs_lbfgs_compile = False

        # Get the step count *after* any Adam training that might have occurred
        current_step_after_adam = self.model.train_state.step if self.model.train_state else current_step_synced

        # Calculate the nominal total steps required by the configuration
        total_target_steps = self.config.ADAM_ITERATIONS + self.config.LBFGS_ITERATIONS

        # ===> MODIFICATION START: Check if total steps already completed <===
        if current_step_after_adam >= total_target_steps:
            logging.info(f"Current step ({current_step_after_adam}) meets or exceeds total configured steps ({total_target_steps}). Skipping further training.")
            run_lbfgs_phase = False # Explicitly prevent L-BFGS phase
        # ===> MODIFICATION END <===
        elif self.config.LBFGS_ITERATIONS > 0: # Original check: L-BFGS is configured to run
             # Proceed with the logic to potentially run L-BFGS
            if current_step_after_adam < self.config.ADAM_ITERATIONS:
                 # This case should ideally not happen if Adam ran correctly, but acts as a safeguard
                 logging.warning(f"Attempting L-BFGS, but current step {current_step_after_adam} is less than Adam target {self.config.ADAM_ITERATIONS}. Check logic.")
                 # Decide if L-BFGS should run anyway or be skipped. Here we skip it.
                 run_lbfgs_phase = False
            else:
                 # We are at or past the Adam iterations, L-BFGS is configured, and haven't hit the total target yet.
                 run_lbfgs_phase = True
                 # Simplification: Run the full configured L-BFGS iterations when entering this phase.
                 lbfgs_iters_to_run = self.config.LBFGS_ITERATIONS
                 logging.info(f"Proceeding to L-BFGS phase (current step {current_step_after_adam}). Target iterations for this phase: {lbfgs_iters_to_run}")

                 # Check if we need to compile L-BFGS (i.e., if the *current* optimizer is Adam)
                 if self.model.opt_name == "adam": # Check the actual current optimizer
                     needs_lbfgs_compile = True
                     logging.info("Switching optimizer: Will compile for L-BFGS.")
                 elif self.model.opt_name == "L-BFGS":
                      logging.info("Optimizer is already L-BFGS (likely resumed). No recompilation needed.")
                 else:
                      # Should ideally be either 'adam' or 'L-BFGS' at this point
                      logging.warning(f"Unexpected optimizer '{self.model.opt_name}' before L-BFGS phase. Forcing L-BFGS compile.")
                      needs_lbfgs_compile = True
        else:
            # L-BFGS iterations set to 0 in config
            logging.info("L-BFGS iterations set to 0 in config, skipping L-BFGS training phase.")
            run_lbfgs_phase = False


        # --- Execute L-BFGS if needed ---
        if run_lbfgs_phase and lbfgs_iters_to_run > 0:
             if self.model is not None and self.model.net is not None:
                 lbfgs_start_time = time.time()
                 try:
                     # Re-compile for L-BFGS only if necessary
                     if needs_lbfgs_compile:
                         self.model.compile("L-BFGS", loss_weights=self.config.LOSS_WEIGHTS)
                         logging.info(f"Model re-compiled with L-BFGS.")
                         # Restore step count if recompile reset it
                         if self.model.train_state and self.model.train_state.step != current_step_after_adam:
                              logging.info(f"Restoring step count to {current_step_after_adam} after L-BFGS recompile.")
                              self.model.train_state.step = current_step_after_adam

                     # Use the same CUSTOM checkpointer instance for L-BFGS saves
                     # The 'iterations' for L-BFGS in DeepXDE usually means max_iter for the optimizer's internal loop.
                     # The global step advancement depends on the callbacks and how DeepXDE handles L-BFGS steps.
                     self.losshistory, self.train_state = self.model.train(
                         iterations=lbfgs_iters_to_run, # Let L-BFGS run its course
                         display_every=self.config.DISPLAY_EVERY, # Use display_every for LBFGS too
                         callbacks=[custom_checkpointer] # Pass the custom callback instance
                     )
                     lbfgs_time = time.time() - lbfgs_start_time
                     # Update current step after L-BFGS phase completes
                     current_step_synced = self.model.train_state.step if self.model.train_state else -1
                     if self.losshistory and hasattr(self.losshistory, 'loss_train') and self.losshistory.loss_train:
                         final_loss_lbfgs = self.losshistory.loss_train[-1] if self.losshistory.loss_train else "N/A"
                         logging.info(f"L-BFGS training (max {lbfgs_iters_to_run} internal iterations) finished in {lbfgs_time:.2f}s. Final loss: {final_loss_lbfgs}. Final global step recorded: {current_step_synced}")
                     else:
                         logging.error(f"L-BFGS training finished in {lbfgs_time:.2f}s but loss history is invalid. Final global step: {current_step_synced}")
                         # Decide whether to continue based on Adam success

                 except Exception as e:
                     logging.error(f"Error during L-BFGS compilation or training: {e}", exc_info=True)
                     # Decide whether to continue based on Adam success
             else:
                 logging.warning("Skipping L-BFGS training because the model state is not valid (e.g., restore failed or Adam failed).")

        # --- Post-training Validation (Using Corrected Gradient Calculation) ---
        if self.model and self.model.net:
            try:
                self._post_training_checks()
            except Exception as e:
                 logging.error(f"Error during post-training checks: {e}", exc_info=True)

        logging.info("Training sequence complete.")
        # Make sure to return the potentially updated model, losshistory, train_state
        return self.model, self.losshistory, self.train_state
    # ==============================================================
    #========== END OF UPDATED Trainer.train Method      ===========
    # ==============================================================


    # --- Optional Post Training Checks ---
    def _post_training_checks(self):
        """Perform physics validation checks compatible with PyTorch backend."""
        logging.info("Performing PyTorch-based post-training checks...")
        if dde.backend.backend_name != "pytorch":
             logging.warning("Post-training checks skipped, requires PyTorch backend.")
             return
        if self.model is None or self.model.net is None:
             logging.error("Model or network not available for post-training checks.")
             return
        try:
            # Run the turbulence production check
            self._check_turbulence_production()
            # Add other checks here if needed
        except Exception as e:
            # Catch errors specifically from the check functions
            logging.error(f"Error during post-training check execution: {e}", exc_info=True)


    def _check_turbulence_production(self):
        """Ensure turbulence production term P_k >= 0 using dde.grad.jacobian."""
        logging.info("Checking turbulence production term Pk...")
        if not hasattr(self.model.data, 'test_x') or self.model.data.test_x is None or len(self.model.data.test_x) == 0:
            logging.warning("Test points not available, using training points for P_k check.")
            if not hasattr(self.model.data, 'train_x') or self.model.data.train_x is None or len(self.model.data.train_x) == 0:
                 logging.error("No points available (train or test) for P_k check.")
                 return
            X = self.model.data.train_x # Fallback to training points
        else:
             X = self.model.data.test_x # Prefer test points

        if self.model.net is None or not list(self.model.net.parameters()):
             logging.error("Model network not available or has no parameters for P_k check.")
             return

        # Get device from model parameters
        try:
            # Ensure model parameters exist and get device
            device = next(iter(self.model.net.parameters())).device
        except StopIteration:
             logging.error("Model network has no parameters.")
             return
        except AttributeError:
             logging.error("Model or network object structure unexpected.")
             return

        # Ensure X is a tensor on the correct device with requires_grad=True
        try:
            x_tensor = torch.tensor(X, dtype=torch.float32, device=device, requires_grad=True)
        except Exception as e:
            logging.error(f"Failed to convert points to tensor for P_k check: {e}")
            return

        # Forward pass with gradient tracking enabled by requires_grad=True on x_tensor
        try:
            y_tensor = self.model.net(x_tensor)
        except Exception as e:
            logging.error(f"Failed during network forward pass for P_k check: {e}")
            return

        if y_tensor is None or y_tensor.shape[1] != self.config.NETWORK_OUTPUTS:
             logging.error(f"Network output has unexpected shape {y_tensor.shape if y_tensor is not None else 'None'} for P_k check.")
             return

        # Extract variables (these are tensors now)
        u = y_tensor[:, 0:1]; v = y_tensor[:, 1:2]
        k_raw = y_tensor[:, 3:4]; eps_raw = y_tensor[:, 4:5]

        # Use the same transformation as in PDE
        k_check = torch.exp(k_raw) + self.config.EPS_SMALL
        eps_check = torch.exp(eps_raw) + self.config.EPS_SMALL

        # Compute gradients using DeepXDE's jacobian function
        try:
            # Calculate individual gradient components using dde.grad.jacobian
            u_x = dde.grad.jacobian(y_tensor, x_tensor, i=0, j=0) # d(output 0)/d(input 0) = du/dx
            u_y = dde.grad.jacobian(y_tensor, x_tensor, i=0, j=1) # d(output 0)/d(input 1) = du/dy
            v_x = dde.grad.jacobian(y_tensor, x_tensor, i=1, j=0) # d(output 1)/d(input 0) = dv/dx
            v_y = dde.grad.jacobian(y_tensor, x_tensor, i=1, j=1) # d(output 1)/d(input 1) = dv/dy

        except Exception as grad_e:
             # Catch potential errors during gradient computation
             logging.error(f"Error computing velocity gradients via dde.grad.jacobian for P_k check: {grad_e}", exc_info=True)
             return

        # Compute nu_t using safe k, eps from transformation
        # Use torch.maximum for safe division, ensure tensors are on same device
        eps_safe_check = torch.maximum(eps_check, torch.tensor(self.config.EPS_SMALL**2, device=device))
        nu_t_check = self.config.CMU * torch.square(k_check) / eps_safe_check

        # Strain rate tensor squared (S^2)
        S_squared = 2*(torch.square(u_x) + torch.square(v_y)) + torch.square(u_y + v_x)
        # Production term P_k = nu_t * S^2
        P_k = nu_t_check * S_squared

        # Check for negative production (detach before converting to numpy/item)
        try:
            P_k_detached = P_k.detach() # Detach from graph before summary stats
            min_Pk = torch.min(P_k_detached).item()
            max_Pk = torch.max(P_k_detached).item()
            num_negative = torch.sum(P_k_detached < 0).item()
            # Allow small negative values due to numerical precision
            negative_threshold = -self.config.EPS_SMALL * 100
            if min_Pk < negative_threshold:
                logging.warning(f"Negative turbulence production detected! Min P_k = {min_Pk:.3e}. ({num_negative}/{len(P_k)} points < {negative_threshold:.1e})")
            else:
                logging.info(f"Turbulence production check passed (min P_k = {min_Pk:.3e}, max P_k = {max_Pk:.3e})")
        except Exception as check_e:
             logging.error(f"Error checking P_k value statistics: {check_e}", exc_info=True)

# --- End Trainer Class ---


# ========================
# ===== Plotter Class =====
# ========================
class Plotter:
    """Handles post-processing and plotting of simulation results."""
    def __init__(self, config, plotter_config, model, losshistory, train_state):
        self.config = config
        self.plotter_config = plotter_config
        self.model = model
        self.losshistory = losshistory
        self.train_state = train_state
        self.ref_data_path = config.REFERENCE_DATA_FILE
        self.plots_dir = config.PLOT_DIR
        self.ref_data = None
        self.has_ref_data = False
        self.ref_data_utau = None # Estimated friction velocity from reference data
        self.pinn_data_utau = None # Estimated friction velocity from PINN data

        # Predicted fields (initialized to None)
        self.X_grid, self.Y_grid = None, None
        self.u_pred, self.v_pred, self.p_prime_pred = None, None, None
        self.k_pred, self.eps_pred, self.nu_t_pred = None, None, None
        self.p_pred = None # Kinematic pressure

        # Ensure plot directory exists before plotting
        os.makedirs(self.plots_dir, exist_ok=True)
        logging.info(f"Plotter initialized. Plots will be saved in: {self.plots_dir}")
        if self.ref_data_path:
             if os.path.exists(self.ref_data_path):
                 logging.info(f"Reference CSV data path found: {self.ref_data_path}")
             else:
                 logging.warning(f"Reference CSV file not found: {self.ref_data_path}. Comparisons will be skipped.")
        else:
             logging.info("No reference data path provided in config. Comparisons will be skipped.")

    def plot_loss_history(self):
        """Plots and saves the training loss history."""
        if self.losshistory and self.train_state:
            logging.info("Saving loss history plot...")
            try:
                os.makedirs(self.plots_dir, exist_ok=True) # Ensure dir exists
                # Use isplot=False to prevent showing plot in non-interactive envs like Colab background runs
                dde.saveplot(self.losshistory, self.train_state, issave=True, isplot=False, output_dir=self.plots_dir)
                # Check if the standard 'loss.png' was created and rename it
                default_loss_file = os.path.join(self.plots_dir, "loss.png")
                target_loss_file = os.path.join(self.plots_dir, "training_loss_history.png")
                if os.path.exists(default_loss_file):
                    try:
                        # Use os.replace for atomic rename if possible, fallback to rename
                        os.replace(default_loss_file, target_loss_file)
                        logging.info(f"Loss history plot saved as '{os.path.basename(target_loss_file)}'.")
                    except OSError as rename_err: # Fallback for cross-device links etc.
                         logging.warning(f"Could not replace/rename loss plot, using os.rename: {rename_err}")
                         os.rename(default_loss_file, target_loss_file)
                         logging.info(f"Loss history plot saved as '{os.path.basename(target_loss_file)}'.")

                else:
                     # Check if the loss file was created with a different name pattern potentially
                     potential_files = [f for f in os.listdir(self.plots_dir) if f.lower().endswith('.png') and 'loss' in f.lower()]
                     if potential_files:
                          logging.warning(f"dde.saveplot might not have produced 'loss.png'. Found: {potential_files}. Check DeepXDE version/behavior.")
                     else:
                          logging.warning("dde.saveplot did not produce 'loss.png'. Check file permissions, DeepXDE version, or if training actually ran.")
            except ImportError:
                logging.error("Matplotlib might be needed by dde.saveplot but is not installed or importable.")
            except Exception as e:
                 logging.error(f"Could not save loss history plot: {e}", exc_info=True)
        else:
            logging.warning("Loss history or train state not available, skipping loss plot.")

    def load_reference_data(self):
        """Loads and preprocesses reference data from a CSV file."""
        self.has_ref_data = False # Assume false until successfully loaded
        if not self.ref_data_path:
            logging.info("No reference data path provided. Skipping load.")
            return
        if not os.path.exists(self.ref_data_path):
            logging.warning(f"Reference CSV file not found: '{self.ref_data_path}'. Skipping load.")
            return

        logging.info(f"Loading reference data from: {self.ref_data_path}")
        try:
            df_ref = pd.read_csv(self.ref_data_path)
            logging.info(f"Loaded reference data: {df_ref.shape[0]} rows, {df_ref.shape[1]} cols. Initial columns: {df_ref.columns.tolist()}")

            # --- Data Filtering (Optional, customize as needed) ---
            # Example: Filter for latest time step if applicable
            time_col = None
            if 'Time' in df_ref.columns: time_col = 'Time'
            elif 'TimeStep' in df_ref.columns: time_col = 'TimeStep' # Adapt to actual column name
            if time_col:
                latest_time = df_ref[time_col].max()
                df_ref = df_ref[df_ref[time_col] == latest_time].copy()
                logging.info(f"Filtered for latest time/step ({time_col}={latest_time}): {df_ref.shape[0]} rows remaining.")

            # Identify coordinate columns (handle variations in naming robustly)
            x_col, y_col, z_col = None, None, None
            # Prioritize common exact names, then case-insensitive, then containing keywords
            potential_x = ['x', 'Points:0', 'X', 'x-coordinate']
            potential_y = ['y', 'Points:1', 'Y', 'y-coordinate']
            potential_z = ['z', 'Points:2', 'Z', 'z-coordinate']

            for p_x in potential_x:
                if p_x in df_ref.columns: x_col = p_x; break
            if not x_col: # Fallback: case-insensitive and keyword check
                for col in df_ref.columns:
                    if col.lower() in ['x', 'points:0', 'x-coordinate']: x_col = col; break

            for p_y in potential_y:
                if p_y in df_ref.columns: y_col = p_y; break
            if not y_col:
                 for col in df_ref.columns:
                    if col.lower() in ['y', 'points:1', 'y-coordinate']: y_col = col; break

            for p_z in potential_z:
                 if p_z in df_ref.columns: z_col = p_z; break
            if not z_col:
                 for col in df_ref.columns:
                    if col.lower() in ['z', 'points:2', 'z-coordinate']: z_col = col; break

            if not x_col or not y_col:
                 raise ValueError(f"Could not identify x/y coordinates in reference columns: {df_ref.columns.tolist()}")
            logging.info(f"Identified coordinate columns: x='{x_col}', y='{y_col}'" + (f", z='{z_col}'" if z_col else ""))

            # Filter for specific Z-plane if data is 3D and multiple Z values exist
            if z_col and len(df_ref[z_col].unique()) > 1:
                target_z = 0.0 # Target the center plane
                unique_z = df_ref[z_col].unique()
                nearest_z_idx = np.argmin(np.abs(unique_z - target_z))
                nearest_z = unique_z[nearest_z_idx]
                # Use a tolerance for floating point comparison
                df_ref = df_ref[np.isclose(df_ref[z_col], nearest_z)].copy()
                logging.info(f"Filtered for z-plane near {target_z} (actual: {nearest_z:.4f}): {df_ref.shape[0]} rows remaining.")

            # --- Variable Renaming (Handle variations) ---
            # Map potential CSV column names (lowercase) to consistent internal names
            var_map = {
                'u:0':'u_ref', 'u_x':'u_ref', 'velocity:0':'u_ref', 'velocity_x':'u_ref', 'u':'u_ref', 'velocityu':'u_ref',
                'u:1':'v_ref', 'u_y':'v_ref', 'velocity:1':'v_ref', 'velocity_y':'v_ref', 'v':'v_ref', 'velocityv':'v_ref',
                'p':'p_ref', 'pressure':'p_ref', 'kinematicpressure':'p_ref', 'kinematic_pressure':'p_ref', # Assume kinematic pressure if 'p'
                'k':'k_ref', 'turbulentkinetienergy':'k_ref', 'turbulentkineticenergy':'k_ref', 'tke':'k_ref',
                'epsilon':'eps_ref', 'turbulencedissipationrate':'eps_ref', 'dissipationrate':'eps_ref', 'dissipation':'eps_ref',
                'nut':'nut_ref', 'turbulentviscosity':'nut_ref', 'eddyviscosity':'nut_ref', 'nutilda':'nut_ref' # Check nuTilda if using Spalart-Allmaras
            }
            rename_dict = {}
            processed_cols = set() # Track columns already mapped
            # Apply mapping based on lowercase, stripped column names
            for col in df_ref.columns:
                col_lower = col.lower().strip().replace('_','').replace('-','').replace(' ','') # Normalize heavily
                if col_lower in var_map and col not in processed_cols:
                    rename_dict[col] = var_map[col_lower]
                    processed_cols.add(col) # Mark as processed

            # Add coordinate renaming (original column name -> standard name)
            rename_dict[x_col] = 'x'
            rename_dict[y_col] = 'y'
            if z_col: rename_dict[z_col] = 'z'
            processed_cols.update([x_col, y_col, z_col] if z_col else [x_col, y_col])

            # Check for unmapped columns that might be important
            unmapped_cols = [col for col in df_ref.columns if col not in processed_cols]
            if unmapped_cols:
                 logging.debug(f"Unmapped columns in reference data: {unmapped_cols}")

            df_ref.rename(columns=rename_dict, inplace=True)
            logging.info(f"Renamed reference columns based on mapping: {rename_dict}")
            logging.info(f"Columns after rename: {df_ref.columns.tolist()}")


            # --- Check for Required Columns (after renaming) ---
            # Define which columns are absolutely essential for comparison plots
            required_cols_for_plots = ['x', 'y', 'u_ref'] # Minimal for velocity profile
            # Add others based on which plots you intend to generate
            if 'plot_profile_comparison' in dir(self): required_cols_for_plots.extend(['p_ref', 'k_ref', 'eps_ref'])
            if 'plot_wall_unit_comparison' in dir(self): required_cols_for_plots.extend(['k_ref', 'eps_ref'])
            # Add nut_ref if needed
            required_cols_for_plots = list(set(required_cols_for_plots)) # Remove duplicates

            missing_cols = [col for col in required_cols_for_plots if col not in df_ref.columns]
            if missing_cols:
                # Raise error only if essential columns like x, y, u are missing
                if any(c in missing_cols for c in ['x', 'y', 'u_ref']):
                    raise ValueError(f"Missing essential columns after renaming in reference data: {missing_cols}. Available: {df_ref.columns.tolist()}")
                else:
                    logging.warning(f"Missing some optional columns for plots: {missing_cols}. Comparison plots might be incomplete.")

            # Keep only necessary columns + optional ones if present
            cols_to_keep = ['x', 'y'] + [col for col in ['u_ref', 'v_ref', 'p_ref', 'k_ref', 'eps_ref', 'nut_ref'] if col in df_ref.columns]
            if z_col and 'z' in df_ref.columns: cols_to_keep.append('z') # Keep z if it was present
            df_ref = df_ref[list(set(cols_to_keep))] # Ensure unique columns

            # Sort and reset index
            df_ref.sort_values(by=['x', 'y'], inplace=True)
            df_ref.reset_index(drop=True, inplace=True)
            self.ref_data = df_ref
            self.has_ref_data = True # Set flag only on successful load and processing
            logging.info(f"Successfully loaded and preprocessed reference CSV data. Final columns: {df_ref.columns.tolist()}")

        except FileNotFoundError:
            # Already logged warning, just pass
            pass
        except ValueError as ve: # Catch specific errors like missing columns
            logging.error(f"ValueError processing reference CSV: {ve}")
        except Exception as e:
            logging.error(f"Unexpected error loading or processing reference CSV: {e}", exc_info=True)
            self.ref_data = None
            self.has_ref_data = False

    def predict_pinn_fields(self):
        """Predicts flow fields using the trained PINN model on a grid."""
        if self.model is None or self.model.net is None:
             logging.error("PINN Model or network not available for prediction.")
             return False

        logging.info("Predicting PINN flow fields on evaluation grid...")
        nx = self.plotter_config.NX_PRED
        ny = self.plotter_config.NY_PRED
        x_coords = np.linspace(0, self.config.L, nx)
        y_coords = np.linspace(-self.config.CHANNEL_HALF_HEIGHT, self.config.CHANNEL_HALF_HEIGHT, ny)
        self.X_grid, self.Y_grid = np.meshgrid(x_coords, y_coords)
        # Create prediction points (N, 2) array
        pred_points = np.vstack((np.ravel(self.X_grid), np.ravel(self.Y_grid))).T

        try:
            # Use model.predict for inference
            # Ensure input is float32, as model was likely trained with it
            # Convert to numpy explicitly if it's a tensor
            if isinstance(pred_points, torch.Tensor):
                pred_points_np = pred_points.cpu().numpy()
            else:
                pred_points_np = np.array(pred_points, dtype=np.float32)

            # Check model state before predicting
            if not hasattr(self.model, 'sess') and dde.backend.backend_name == "tensorflow.compat.v1":
                 logging.warning("TensorFlow v1 backend detected, but model session (sess) seems unavailable. Prediction might fail.")
            elif not hasattr(self.model, 'net') or self.model.net is None:
                 logging.error("Model network attribute is missing or None. Cannot predict.")
                 return False

            predictions_raw = self.model.predict(pred_points_np) # Predict expects numpy

            if predictions_raw is None or not isinstance(predictions_raw, np.ndarray) or predictions_raw.shape[1] != self.config.NETWORK_OUTPUTS:
                logging.error(f"Prediction shape mismatch or invalid type. Expected {self.config.NETWORK_OUTPUTS} outputs, got shape {predictions_raw.shape if predictions_raw is not None else 'None'} and type {type(predictions_raw)}.")
                return False

        except AttributeError as ae:
             logging.error(f"AttributeError during PINN prediction (check model state/backend compatibility): {ae}", exc_info=True)
             return False
        except Exception as e:
            logging.error(f"Error during PINN prediction: {e}", exc_info=True)
            return False

        # --- Process Raw Predictions ---
        # Reshape predictions back to grid format (ny, nx)
        try:
            self.u_pred = predictions_raw[:, 0].reshape(ny, nx)
            self.v_pred = predictions_raw[:, 1].reshape(ny, nx)
            self.p_prime_pred = predictions_raw[:, 2].reshape(ny, nx)
            k_raw_pred = predictions_raw[:, 3].reshape(ny, nx)
            eps_raw_pred = predictions_raw[:, 4].reshape(ny, nx)

            # Apply inverse transform (exp) and add epsilon for positivity
            # Use np.exp for numpy arrays
            self.k_pred = np.exp(k_raw_pred) + self.config.EPS_SMALL
            self.eps_pred = np.exp(eps_raw_pred) + self.config.EPS_SMALL

            # Calculate kinematic pressure p = p' - (2/3)*k (assuming isotropic normal stress contribution)
            # This definition assumes p' in the RANS equations represents p_kinematic + (2/3)k
            # Verify this assumption based on the specific RANS formulation used.
            self.p_pred = self.p_prime_pred - (2.0 / 3.0) * self.k_pred

            # Calculate turbulent viscosity nu_t = Cmu * k^2 / eps
            # Use np.maximum for safe division with numpy arrays
            eps_safe_pred = np.maximum(self.eps_pred, self.config.EPS_SMALL**2) # Use squared epsilon to match units if needed
            self.nu_t_pred = self.config.CMU * np.square(self.k_pred) / eps_safe_pred

            logging.info("PINN field prediction and processing complete.")
            return True
        except Exception as proc_e:
            logging.error(f"Error processing raw predictions: {proc_e}", exc_info=True)
            return False


    def plot_contour_fields(self):
        """Plots contour fields of the predicted PINN variables."""
        if self.u_pred is None: # Check if prediction data exists
            logging.warning("PINN data unavailable for plotting. Run predict_pinn_fields first. Skipping contours.")
            return

        logging.info("Generating PINN contour plots...")
        try:
            # Create a figure with subplots
            fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # Adjust figure size as needed
            axes = axes.ravel() # Flatten the axes array for easy indexing

            cmap_vel = self.plotter_config.CMAP_VELOCITY
            cmap_p = self.plotter_config.CMAP_PRESSURE
            cmap_turb = self.plotter_config.CMAP_TURBULENCE

            # Data to plot: (data_array, title, label, cmap, optional: use_log)
            plot_data_list = [
                (self.u_pred, 'PINN Streamwise Velocity (u)', 'u (m/s)', cmap_vel),
                (self.v_pred, 'PINN Transverse Velocity (v)', 'v (m/s)', cmap_vel),
                (self.p_pred, "PINN Kinematic Pressure (p)", r'$p/\rho$ ($m^2/s^2$)', cmap_p),
                (self.k_pred, 'PINN TKE (k)', r'$k$ ($m^2/s^2$)', cmap_turb),
                (self.eps_pred, 'PINN Dissipation ($\epsilon$)', r'$\epsilon$ ($m^2/s^3$)', cmap_turb),
                (self.nu_t_pred / self.config.NU, 'PINN Eddy Viscosity Ratio', r'$\nu_t / \nu$', cmap_turb, True) # Plot ratio, optionally log scale
            ]

            for i, (data, title, label, cmap, *log_flag) in enumerate(plot_data_list):
                ax = axes[i]
                plot_values = data
                cbar_label = label
                levels = 50 # Number of contour levels
                use_log = log_flag[0] if log_flag else False # Check if log flag is provided

                # Optional log scale for positive quantities
                if use_log and np.nanmin(data) > self.config.EPS_SMALL: # Check if data is positive before log
                    try:
                        # Floor values slightly above zero before taking log10
                        min_positive = np.nanmin(data[data > self.config.EPS_SMALL*10])
                        plot_values = np.log10(np.maximum(data, min_positive * 0.01)) # Use nanmin
                        cbar_label = f'log10({label})'
                        levels = np.logspace(np.log10(min_positive*0.01), np.log10(np.nanmax(data)), levels) # Log spaced levels
                        logging.debug(f"Using log scale for {title}")
                    except Exception as log_err:
                         logging.warning(f"Could not apply log scale for {title}: {log_err}. Using linear scale.")
                         use_log = False # Revert to linear scale on error
                         plot_values = data # Ensure plot_values is reset

                # Use contourf for filled contours
                try:
                    if use_log: # Use log levels if calculated
                        cf = ax.contourf(self.X_grid, self.Y_grid, plot_values, levels=levels, cmap=cmap, extend='both', locator=plt.LogLocator())
                    else: # Linear levels
                         cf = ax.contourf(self.X_grid, self.Y_grid, plot_values, levels=levels, cmap=cmap, extend='both')

                    fig.colorbar(cf, ax=ax, label=cbar_label)
                    ax.set_title(title)
                    ax.set_xlabel('x (m)')
                    ax.set_ylabel('y (m)')
                    ax.set_aspect('equal', adjustable='box') # Make aspect ratio equal
                except ValueError as ve:
                     logging.error(f"ValueError during contour plot for {title} (check data range/levels): {ve}")
                except Exception as e:
                    logging.error(f"Error plotting contour for {title}: {e}")

            # Hide any unused subplots if necessary (e.g., if plot_data_list has < 6 items)
            for j in range(i + 1, len(axes)):
                 fig.delaxes(axes[j])

            plt.tight_layout()
            save_path = os.path.join(self.plots_dir, "pinn_field_contours.png")
            plt.savefig(save_path, dpi=200, bbox_inches='tight')
            plt.close(fig) # Close the specific figure
            logging.info(f"PINN contour field plots saved to {os.path.basename(save_path)}")

        except Exception as e:
            logging.error(f"Failed to generate or save contour plots: {e}", exc_info=True)
            # Ensure figure is closed if error occurs mid-plotting
            if 'fig' in locals() and plt.fignum_exists(fig.number):
                plt.close(fig)

    def _estimate_utau(self, data_source='pinn', x_slice_loc=None):
        """Estimates friction velocity u_tau from data near the wall using gradient and log-law."""
        if x_slice_loc is None:
            # Choose a location away from inlet/outlet, e.g., 80% downstream
            x_slice_loc = self.config.L * 0.8

        h = self.config.CHANNEL_HALF_HEIGHT
        y_p = self.config.Y_P # Use the wall function distance y_p for context, but sample closer
        nu = self.config.NU
        rho = self.config.RHO # Density needed for stress calculation

        # Define two points very near the physical wall (e.g., y+ of ~1 and ~5 if possible)
        # This requires an initial guess of u_tau or an iterative approach.
        # Simpler: use fixed small distances from the wall.
        y_dist_1 = 0.001 * h # Very close to wall
        y_dist_2 = 0.01 * h  # Still close, but further than y_dist_1

        # Calculate y-coordinates relative to centerline for top wall
        y_eval_top_1 = h - y_dist_1
        y_eval_top_2 = h - y_dist_2
        eval_points_top = np.array([[x_slice_loc, y_eval_top_1], [x_slice_loc, y_eval_top_2]])

        # Calculate y-coordinates relative to centerline for bottom wall
        y_eval_bot_1 = -h + y_dist_1
        y_eval_bot_2 = -h + y_dist_2
        eval_points_bot = np.array([[x_slice_loc, y_eval_bot_1], [x_slice_loc, y_eval_bot_2]])

        # Average distance from the *nearest* wall (used for y+ estimate later)
        y_dist_wall_avg_top = (y_dist_1 + y_dist_2) / 2.0
        y_dist_wall_avg_bot = (y_dist_1 + y_dist_2) / 2.0


        u1_top, k1_top, eps1_top, u2_top, k2_top, eps2_top = [None]*6
        u1_bot, k1_bot, eps1_bot, u2_bot, k2_bot, eps2_bot = [None]*6

        try:
            # --- Get u, k, eps at the evaluation points ---
            interp_method = 'linear' # Start with linear interpolation for reference data

            if data_source == 'pinn':
                if self.model is None: return None
                # Predict for both top and bottom points together
                eval_points_all = np.vstack((eval_points_top, eval_points_bot))
                pred_raw = self.model.predict(eval_points_all)
                if pred_raw is None or pred_raw.shape[0] < 4:
                     logging.error(f"PINN prediction failed or returned insufficient points for u_tau estimate at x={x_slice_loc:.2f}")
                     return None
                # Extract and transform
                u_all = pred_raw[:, 0]
                k_raw_all = pred_raw[:, 3]
                eps_raw_all = pred_raw[:, 4]
                k_all = np.exp(k_raw_all) + self.config.EPS_SMALL
                eps_all = np.exp(eps_raw_all) + self.config.EPS_SMALL

                # Split results
                u1_top, u2_top, u1_bot, u2_bot = u_all
                k1_top, k2_top, k1_bot, k2_bot = k_all
                eps1_top, eps2_top, eps1_bot, eps2_bot = eps_all

            elif data_source == 'reference' and self.has_ref_data:
                if self.ref_data is None: return None
                points_ref = self.ref_data[['x', 'y']].values
                req_cols = ['u_ref', 'k_ref', 'eps_ref'] # Need these for nu_eff calculation
                if not all(col in self.ref_data.columns for col in req_cols):
                    logging.warning(f"Reference data missing required columns {req_cols} for u_tau estimation from gradient.")
                    return None

                # Interpolate required values for top and bottom points
                values_to_interp = {}
                for col in req_cols:
                    values_to_interp[col] = self.ref_data[col].values

                interp_results = {}
                eval_points_all = np.vstack((eval_points_top, eval_points_bot))

                for col, values in values_to_interp.items():
                    interp_vals = griddata(points_ref, values, eval_points_all, method=interp_method)
                    nan_mask = np.isnan(interp_vals)
                    if np.any(nan_mask):
                        logging.debug(f"Linear interpolation failed for '{col}' ({data_source}) at x={x_slice_loc:.2f}. Trying 'nearest'.")
                        interp_nearest = griddata(points_ref, values, eval_points_all[nan_mask], method='nearest')
                        interp_vals[nan_mask] = interp_nearest
                        if np.any(np.isnan(interp_vals)):
                             logging.error(f"Interpolation (linear & nearest) failed for '{col}' ({data_source}) at x={x_slice_loc:.2f}. Cannot estimate u_tau.")
                             return None
                    interp_results[col] = interp_vals

                # Split results
                u1_top, u2_top, u1_bot, u2_bot = interp_results['u_ref']
                k1_top, k2_top, k1_bot, k2_bot = interp_results['k_ref']
                eps1_top, eps2_top, eps1_bot, eps2_bot = interp_results['eps_ref']

            else:
                logging.warning(f"Invalid data_source '{data_source}' or missing data for u_tau estimation.")
                return None

            # --- Estimate Gradients and u_tau for Top Wall ---
            # Gradient du/dy = (u_further - u_closer) / (y_further - y_closer)
            # y_eval_top_2 is further from center, closer to wall than y_eval_top_1
            du_dy_top = (u2_top - u1_top) / (y_eval_top_2 - y_eval_top_1) # Should be negative
            # Effective viscosity near wall (average of the two points)
            k_avg_top = (k1_top + k2_top) / 2.0
            eps_avg_top = (eps1_top + eps2_top) / 2.0
            nu_t_avg_top = self.config.CMU * k_avg_top**2 / max(eps_avg_top, self.config.EPS_SMALL**2)
            nu_eff_avg_top = nu + nu_t_avg_top
            # Wall shear stress tau_w = rho * nu_eff * |du/dy| (note the absolute value)
            tau_w_top = rho * nu_eff_avg_top * abs(du_dy_top)
            u_tau_top = np.sqrt(max(tau_w_top / rho, self.config.EPS_SMALL)) # Ensure non-negative sqrt arg

             # --- Estimate Gradients and u_tau for Bottom Wall ---
            # y_eval_bot_2 is further from center, closer to wall than y_eval_bot_1
            du_dy_bot = (u2_bot - u1_bot) / (y_eval_bot_2 - y_eval_bot_1) # Should be positive
            k_avg_bot = (k1_bot + k2_bot) / 2.0
            eps_avg_bot = (eps1_bot + eps2_bot) / 2.0
            nu_t_avg_bot = self.config.CMU * k_avg_bot**2 / max(eps_avg_bot, self.config.EPS_SMALL**2)
            nu_eff_avg_bot = nu + nu_t_avg_bot
            tau_w_bot = rho * nu_eff_avg_bot * abs(du_dy_bot)
            u_tau_bot = np.sqrt(max(tau_w_bot / rho, self.config.EPS_SMALL))

            # --- Average or Choose ---
            # Average the top and bottom estimates for a single channel value
            u_tau_estimated = (u_tau_top + u_tau_bot) / 2.0

            # Optional: Log-law refinement (less reliable if points are deep in viscous sublayer)
            # y_plus_est_top = y_dist_wall_avg_top * u_tau_top / nu
            # y_plus_est_bot = y_dist_wall_avg_bot * u_tau_bot / nu
            # logging.debug(f"Intermediate u_tau ({data_source}): Top={u_tau_top:.4f} (y+ ~{y_plus_est_top:.1f}), Bottom={u_tau_bot:.4f} (y+ ~{y_plus_est_bot:.1f})")

            logging.info(f"Estimated u_tau ({data_source}) at x={x_slice_loc:.2f} m: {u_tau_estimated:.4f} m/s (avg of top/bottom grad estimates)")
            return u_tau_estimated

        except Exception as e:
            logging.error(f"Error estimating u_tau for {data_source} at x={x_slice_loc:.2f}: {e}", exc_info=True)
            return None

    def plot_profile_comparison(self):
        """Plots profiles of PINN vs Reference data at a channel cross-section."""
        if self.u_pred is None: # Check if prediction data exists
            logging.warning("PINN data unavailable for plotting. Skipping profile comparison.")
            return
        if not self.has_ref_data:
            logging.warning("Reference CSV data not loaded or failed processing. Skipping profile comparison.")
            return

        logging.info("Generating profile comparison plots...")
        # Define slice location (e.g., channel midpoint or further downstream)
        x_slice_loc = self.config.L * 0.8 # Use same location as u_tau estimate for consistency
        ny_pinn = self.plotter_config.NY_PRED # Number of points in y from PINN grid

        # Find the closest x-coordinate in the PINN grid
        y_coords_pinn = self.Y_grid[:, 0] # y-coordinates from PINN grid
        x_coords_pinn = self.X_grid[0, :] # x-coordinates from PINN grid
        try:
             x_slice_idx_pinn = np.argmin(np.abs(x_coords_pinn - x_slice_loc))
             actual_x_pinn = x_coords_pinn[x_slice_idx_pinn] # Actual x used for slicing
        except IndexError:
             logging.error("PINN grid coordinates seem invalid. Cannot find x-slice.")
             return

        # Extract PINN data slice at the chosen x-index
        pinn_slice = {
            'y': y_coords_pinn,
            'u': self.u_pred[:, x_slice_idx_pinn],
            'v': self.v_pred[:, x_slice_idx_pinn],
            'p': self.p_pred[:, x_slice_idx_pinn], # Use calculated kinematic p
            'k': self.k_pred[:, x_slice_idx_pinn],
            'eps': self.eps_pred[:, x_slice_idx_pinn],
            'nut': self.nu_t_pred[:, x_slice_idx_pinn]
        }

        # --- Interpolate Reference Data onto PINN y-coordinates at the same x ---
        ref_slice = {'y': y_coords_pinn} # Initialize dict for interpolated ref data
        interpolation_successful = False
        try:
            if self.ref_data is None: raise ValueError("Reference data frame is None.")

            # Points from reference data (x, y)
            points_ref = self.ref_data[['x', 'y']].values
            # Target points for interpolation (same x, PINN y-coords)
            target_points = np.vstack((np.full(ny_pinn, actual_x_pinn), y_coords_pinn)).T

            logging.info(f"Interpolating reference data onto {ny_pinn} points at x={actual_x_pinn:.3f}...")

            # Interpolate each variable present in the reference data
            variables_to_interpolate = [
                ('u_ref', 'u'), ('v_ref', 'v'), ('p_ref', 'p'),
                ('k_ref', 'k'), ('eps_ref', 'eps'), ('nut_ref', 'nut')
            ]
            missing_ref_vars = []
            for var_ref, var_pinn in variables_to_interpolate:
                if var_ref in self.ref_data.columns:
                    values_ref = self.ref_data[var_ref].values
                    # Linear interpolation
                    interp_values = griddata(points_ref, values_ref, target_points, method='linear')
                    # Handle NaNs with nearest neighbor fallback
                    nan_mask = np.isnan(interp_values)
                    num_nans = np.sum(nan_mask)
                    if num_nans > 0:
                        logging.debug(f"{num_nans} NaNs found for '{var_ref}' after linear interp. Trying nearest neighbor.")
                        # Only interpolate NaN points with nearest
                        interp_nearest = griddata(points_ref, values_ref, target_points[nan_mask], method='nearest')
                        interp_values[nan_mask] = interp_nearest
                        if np.any(np.isnan(interp_values)): # Check if nearest also failed
                             logging.warning(f"Interpolation (linear & nearest) failed for '{var_ref}' at some points. Profile will be incomplete.")
                    ref_slice[var_pinn] = interp_values
                else:
                    logging.debug(f"Reference variable '{var_ref}' not found in CSV data. Skipping interpolation.")
                    ref_slice[var_pinn] = np.full(ny_pinn, np.nan) # Fill with NaN if missing
                    missing_ref_vars.append(var_ref)

            interpolation_successful = True # Mark as successful if loop completes
            if missing_ref_vars:
                 logging.warning(f"Could not interpolate reference variables: {missing_ref_vars}")

        except ValueError as ve:
             logging.error(f"ValueError during reference data interpolation: {ve}")
        except Exception as e:
            logging.error(f"Error interpolating reference data for profiles: {e}", exc_info=True)
            # If interpolation fails, mark ref data as unavailable for this plot
            interpolation_successful = False

        # --- Create Plots ---
        try:
            fig, axes = plt.subplots(3, 2, figsize=(12, 15)) # Adjust size
            axes = axes.ravel()
            plot_idx = 0
            h = self.config.CHANNEL_HALF_HEIGHT
            plot_vars = [ # Variables to plot, their names, and units
                ('u', 'Velocity u', 'm/s'),
                ('v', 'Velocity v', 'm/s'),
                ('p', 'Kinematic Pressure p', r'$m^2/s^2$'),
                ('k', 'TKE k', r'$m^2/s^2$'),
                ('eps', 'Dissipation eps', r'$m^2/s^3$'),
                ('nut', 'Eddy Viscosity nu_t', r'$m^2/s$')
            ]

            for key, name, unit in plot_vars:
                if plot_idx >= len(axes): break # Avoid index error if fewer plots than axes
                ax = axes[plot_idx]

                # Plot PINN data
                ax.plot(pinn_slice[key], pinn_slice['y'] / h, 'r-', linewidth=2, label='PINN')

                # Plot Reference data if interpolation was successful and data exists
                if interpolation_successful and key in ref_slice and not np.all(np.isnan(ref_slice[key])):
                    ax.plot(ref_slice[key], ref_slice['y'] / h, 'b--', linewidth=1.5, label='Reference (CSV)')
                elif not interpolation_successful and key != 'y': # Add empty placeholder if interp failed
                     ax.plot([], [], 'b--', label='Reference (Failed)')

                ax.set_xlabel(f'{name} ({unit})')
                ax.set_ylabel('y/h') # Normalize y by half-height
                ax.set_title(f'{name} Profile') # Title moved to suptitle
                ax.legend(fontsize=8)
                ax.grid(True, linestyle=':')

                # Use log scale for x-axis for turbulence quantities (k, eps, nut) if values are positive
                if key in ['k', 'eps', 'nut']:
                     try:
                         # Check if plotted values are sufficiently positive
                         min_val_for_log = self.config.EPS_SMALL
                         pinn_valid = np.nanmin(pinn_slice[key]) > min_val_for_log
                         ref_valid = False
                         if interpolation_successful and key in ref_slice and not np.all(np.isnan(ref_slice[key])):
                             ref_valid = np.nanmin(ref_slice[key]) > min_val_for_log

                         if pinn_valid and (ref_valid or not interpolation_successful): # Allow log if ref failed but PINN is ok
                              ax.set_xscale('log')
                              ax.grid(True, which='both', linestyle=':') # Add minor grid for log scale
                              logging.debug(f"Using log scale for {key} profile.")
                     except ValueError: # Handle cases where data might be exactly zero or negative
                           logging.warning(f"Could not apply log scale for {key} (likely non-positive values).")
                     except Exception as log_e:
                          logging.warning(f"Error applying log scale for {key}: {log_e}")

                plot_idx += 1

            # Hide unused axes
            for j in range(plot_idx, len(axes)):
                fig.delaxes(axes[j])

            plt.suptitle(f'Profile Comparison at x ≈ {actual_x_pinn:.3f} m', fontsize=16)
            plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap

            save_path = os.path.join(self.plots_dir, "profile_comparison_pinn_vs_csv.png")
            plt.savefig(save_path, dpi=200, bbox_inches='tight')
            plt.close(fig)
            logging.info(f"Profile comparison plot saved to {os.path.basename(save_path)}")

        except Exception as e:
            logging.error(f"Failed to generate or save profile comparison plot: {e}", exc_info=True)
            if 'fig' in locals() and plt.fignum_exists(fig.number):
                 plt.close(fig)


    def plot_wall_unit_comparison(self):
        """Plots profiles in wall units (y+, U+, k+, eps+) vs reference/theory."""
        if self.u_pred is None: # Check if prediction data exists
            logging.warning("PINN data unavailable for plotting. Skipping wall unit plots.")
            return

        logging.info("Generating wall unit comparison plots...")

        # --- Estimate Friction Velocity (u_tau) ---
        # Use a location away from inlet/outlet for potentially more developed flow
        x_slice_loc_utau = self.config.L * 0.8
        self.pinn_data_utau = self._estimate_utau(data_source='pinn', x_slice_loc=x_slice_loc_utau)
        if self.has_ref_data:
            self.ref_data_utau = self._estimate_utau(data_source='reference', x_slice_loc=x_slice_loc_utau)
        else:
            self.ref_data_utau = None

        # Proceed only if PINN u_tau could be estimated
        if not self.pinn_data_utau:
            logging.error("Could not estimate PINN u_tau. Skipping wall unit plots.")
            return
        if self.has_ref_data and not self.ref_data_utau:
            logging.warning("Could not estimate reference u_tau. Plotting PINN wall units only vs theory.")

        # --- Prepare Data for Wall Units ---
        nu = self.config.NU
        h = self.config.CHANNEL_HALF_HEIGHT
        kappa = self.config.KAPPA
        # Use E+ (often denoted B or C+ in literature) instead of E_wall for log-law plot
        # E+ = E_wall * exp(kappa * B_offset) if B_offset is used, but often just a constant ~5.0-5.5
        # Let's use a typical value, adjust if needed based on expected Re_tau
        B_const = 5.2 # Typical log-law intercept constant for smooth walls (E+ or C+)

        # Use the same x-slice as the profile plots or the u_tau estimate location
        x_slice_loc_plot = x_slice_loc_utau # Use u_tau location for consistency
        y_coords_pinn = self.Y_grid[:, 0]
        x_coords_pinn = self.X_grid[0, :]
        try:
            x_slice_idx_pinn = np.argmin(np.abs(x_coords_pinn - x_slice_loc_plot))
            actual_x_pinn = x_coords_pinn[x_slice_idx_pinn]
        except IndexError:
             logging.error("PINN grid coordinates seem invalid. Cannot find x-slice for wall units.")
             return


        # Extract PINN data near one wall (e.g., top wall: y >= 0) for clarity
        # Could potentially average top/bottom walls if symmetry is expected
        wall_indices_pinn = y_coords_pinn >= -self.config.EPS_SMALL # Include centerline point if present
        y_wall_pinn = y_coords_pinn[wall_indices_pinn]
        # Distance from the NEAREST wall (top wall if y>=0, bottom wall if y<0)
        # For top wall (y>=0): distance = h - y
        y_dist_wall_pinn = np.maximum(h - y_wall_pinn, self.config.EPS_SMALL * h) # Avoid zero distance

        # Get corresponding u, k, eps from the PINN slice
        u_wall_pinn = self.u_pred[wall_indices_pinn, x_slice_idx_pinn]
        k_wall_pinn = self.k_pred[wall_indices_pinn, x_slice_idx_pinn]
        eps_wall_pinn = self.eps_pred[wall_indices_pinn, x_slice_idx_pinn]

        # Calculate wall units for PINN data
        utau_pinn_safe = max(self.pinn_data_utau, self.config.EPS_SMALL) # Avoid division by zero
        y_plus_pinn = y_dist_wall_pinn * utau_pinn_safe / nu
        u_plus_pinn = u_wall_pinn / utau_pinn_safe
        # k+ = k / u_tau^2
        k_plus_pinn = k_wall_pinn / max(utau_pinn_safe**2, self.config.EPS_SMALL**2)
        # eps+ = eps * nu / u_tau^4
        eps_plus_pinn = eps_wall_pinn * nu / max(utau_pinn_safe**4, self.config.EPS_SMALL**4)

        # Sort PINN data by y+ for plotting lines correctly
        sort_idx_pinn = np.argsort(y_plus_pinn)
        y_plus_pinn = y_plus_pinn[sort_idx_pinn]
        u_plus_pinn = u_plus_pinn[sort_idx_pinn]
        k_plus_pinn = k_plus_pinn[sort_idx_pinn]
        eps_plus_pinn = eps_plus_pinn[sort_idx_pinn]

        # --- Prepare Reference Data (if available and u_tau estimated) ---
        y_plus_ref, u_plus_ref, k_plus_ref, eps_plus_ref = None, None, None, None
        ref_processed = False
        if self.has_ref_data and self.ref_data_utau:
            try:
                # Filter reference data near the chosen x-slice and top wall (y>=0)
                # Use a tolerance for x matching
                ref_wall_data = self.ref_data[
                    (np.isclose(self.ref_data['x'], actual_x_pinn, rtol=0.05, atol=0.1*self.config.L)) & # Wider tolerance for x
                    (self.ref_data['y'] >= -self.config.EPS_SMALL) # Include y=0
                ].copy()

                if not ref_wall_data.empty:
                    y_wall_ref = ref_wall_data['y'].values
                    y_dist_wall_ref = np.maximum(h - y_wall_ref, self.config.EPS_SMALL * h)
                    utau_ref_safe = max(self.ref_data_utau, self.config.EPS_SMALL)
                    y_plus_ref = y_dist_wall_ref * utau_ref_safe / nu

                    # Check if required columns exist before accessing
                    if 'u_ref' in ref_wall_data.columns: u_plus_ref = ref_wall_data['u_ref'].values / utau_ref_safe
                    if 'k_ref' in ref_wall_data.columns: k_plus_ref = ref_wall_data['k_ref'].values / max(utau_ref_safe**2, self.config.EPS_SMALL**2)
                    if 'eps_ref' in ref_wall_data.columns: eps_plus_ref = ref_wall_data['eps_ref'].values * nu / max(utau_ref_safe**4, self.config.EPS_SMALL**4)

                    # Sort reference data by y+
                    sort_idx_ref = np.argsort(y_plus_ref)
                    y_plus_ref = y_plus_ref[sort_idx_ref]
                    if u_plus_ref is not None: u_plus_ref = u_plus_ref[sort_idx_ref]
                    if k_plus_ref is not None: k_plus_ref = k_plus_ref[sort_idx_ref]
                    if eps_plus_ref is not None: eps_plus_ref = eps_plus_ref[sort_idx_ref]
                    logging.info(f"Processed {len(y_plus_ref)} reference points for wall unit comparison.")
                    ref_processed = True
                else:
                    logging.warning(f"No reference data found near x={actual_x_pinn:.3f}, y>=0 for wall unit plots.")
            except KeyError as ke:
                 logging.error(f"Missing column in reference data needed for wall units: {ke}")
            except Exception as e:
                logging.error(f"Error processing reference data for wall units: {e}", exc_info=True)
                # Prevent plotting bad ref data
                y_plus_ref, u_plus_ref, k_plus_ref, eps_plus_ref = None, None, None, None

        # --- Create Wall Unit Plots ---
        try:
            fig, axes = plt.subplots(1, 3, figsize=(18, 5.5)) # Figure for U+, k+, eps+

            # Determine plot limits dynamically
            y_plus_max_pinn = np.max(y_plus_pinn) if len(y_plus_pinn) > 0 else 100
            y_plus_max_ref = np.max(y_plus_ref) if ref_processed and y_plus_ref is not None and len(y_plus_ref) > 0 else y_plus_max_pinn
            y_plus_max_plot = 1.1 * max(y_plus_max_pinn, y_plus_max_ref, self.config.YP_PLUS_TARGET * 1.5) # Ensure target y+ visible, extend slightly

            u_plus_max_pinn = np.max(u_plus_pinn) if len(u_plus_pinn) > 0 else 25
            u_plus_max_ref = np.max(u_plus_ref) if ref_processed and u_plus_ref is not None and len(u_plus_ref) > 0 else u_plus_max_pinn
            u_plus_max_plot = 1.1 * max(u_plus_max_pinn, u_plus_max_ref)

            k_plus_max_pinn = np.max(k_plus_pinn) if len(k_plus_pinn) > 0 else 5
            k_plus_max_ref = np.max(k_plus_ref) if ref_processed and k_plus_ref is not None and len(k_plus_ref) > 0 else k_plus_max_pinn
            k_plus_max_plot = 1.1 * max(k_plus_max_pinn, k_plus_max_ref)


            # 1. U+ vs y+ plot
            ax = axes[0]
            ax.semilogx(y_plus_pinn, u_plus_pinn, 'r.', ms=4, label=f'PINN ($u_\\tau \\approx {self.pinn_data_utau:.3f}$)')
            if ref_processed and y_plus_ref is not None and u_plus_ref is not None:
                ax.semilogx(y_plus_ref, u_plus_ref, 'bo', mfc='none', ms=5, label=f'Ref ($u_\\tau \\approx {self.ref_data_utau:.3f}$)' if self.ref_data_utau else 'Ref (u_tau N/A)')
            # Theoretical laws
            y_plus_log_min = 11 # Start log law around y+=11
            y_plus_theory_log = np.logspace(np.log10(max(y_plus_log_min, 1)), np.log10(y_plus_max_plot*1.1), 100) # Extend slightly beyond max y+
            u_plus_loglaw = (1 / kappa) * np.log(y_plus_theory_log) + B_const
            y_plus_theory_vis = np.linspace(0.1, 30, 50) # Viscous sublayer range (y+ < 5), buffer (y+ 5-30)
            u_plus_viscous = y_plus_theory_vis # U+ = y+
            ax.semilogx(y_plus_theory_log, u_plus_loglaw, 'k:', lw=1.5, label=f'Log Law ($\\kappa={kappa}, B={B_const}$)')
            ax.semilogx(y_plus_theory_vis, u_plus_viscous, 'k--', lw=1.5, label='Viscous ($U^+=y^+$)')
            ax.set_xlabel('$y^+$')
            ax.set_ylabel('$U^+$')
            ax.set_title(f'$U^+$ vs $y^+$ Profile')
            ax.legend(fontsize=9)
            ax.grid(True, which='both', ls=':')
            ax.set_ylim(bottom=0, top=u_plus_max_plot)
            ax.set_xlim(left=0.1, right=y_plus_max_plot) # Start x-axis slightly > 0 for log scale


            # 2. k+ vs y+ plot
            ax = axes[1]
            ax.semilogx(y_plus_pinn, k_plus_pinn, 'r.', ms=4, label='PINN')
            if ref_processed and y_plus_ref is not None and k_plus_ref is not None:
                ax.semilogx(y_plus_ref, k_plus_ref, 'bo', mfc='none', ms=5, label='Reference')
            # Add line for target y+ location used in wall function (if defined)
            if hasattr(self.config, 'YP_PLUS_TARGET'):
                 ax.axvline(self.config.YP_PLUS_TARGET, color='g', ls='-.', lw=1, label=f'WF $y_p^+ \\approx {self.config.YP_PLUS_TARGET:.1f}$')
            ax.set_xlabel('$y^+$')
            ax.set_ylabel('$k^+$')
            ax.set_title('$k^+$ vs $y^+$ Profile')
            ax.legend(fontsize=9)
            ax.grid(True, which='both', ls=':')
            ax.set_ylim(bottom=0, top=k_plus_max_plot) # k+ starts from 0
            ax.set_xlim(left=0.1, right=y_plus_max_plot)


            # 3. eps+ vs y+ plot (log-log scale often used)
            ax = axes[2]
            # Filter out non-positive values before plotting on log-log scale
            valid_idx_pinn = eps_plus_pinn > self.config.EPS_SMALL**2 # Use small threshold
            if np.any(valid_idx_pinn):
                 ax.loglog(y_plus_pinn[valid_idx_pinn], eps_plus_pinn[valid_idx_pinn], 'r.', ms=4, label='PINN')
            else: ax.plot([],[], 'r.', label='PINN (No positive data)') # Placeholder

            if ref_processed and y_plus_ref is not None and eps_plus_ref is not None:
                 valid_idx_ref = eps_plus_ref > self.config.EPS_SMALL**2
                 if np.any(valid_idx_ref):
                     ax.loglog(y_plus_ref[valid_idx_ref], eps_plus_ref[valid_idx_ref], 'bo', mfc='none', ms=5, label='Reference')
                 else: ax.plot([],[], 'bo', mfc='none', label='Reference (No positive data)')

            # Theoretical trend near wall (local equilibrium): eps+ ~ 1 / (kappa * y+) -> C / y+
            # Valid mainly away from y+=0
            y_plus_theory_eps = np.logspace(np.log10(max(1, 0.1)), np.log10(y_plus_max_plot*1.1), 100)
            eps_plus_target_theory = 1.0 / (kappa * y_plus_theory_eps) # Simplified equilibrium scaling
            ax.loglog(y_plus_theory_eps, eps_plus_target_theory, 'k:', lw=1.5, label='$\\epsilon^+ \\propto 1/y^+$')

            if hasattr(self.config, 'YP_PLUS_TARGET'):
                ax.axvline(self.config.YP_PLUS_TARGET, color='g', ls='-.', lw=1, label=f'WF $y_p^+ \\approx {self.config.YP_PLUS_TARGET:.1f}$')

            ax.set_xlabel('$y^+$')
            ax.set_ylabel('$\\epsilon^+$')
            ax.set_title('$\\epsilon^+$ vs $y^+$ Profile (log-log)')
            ax.legend(fontsize=9)
            ax.grid(True, which='both', ls=':')

            # Set reasonable y-limits for eps+ based on observed data
            min_eps_plus_data = np.min(eps_plus_pinn[valid_idx_pinn]) if np.any(valid_idx_pinn) else 1e-5
            max_eps_plus_data = np.max(eps_plus_pinn[valid_idx_pinn]) if np.any(valid_idx_pinn) else 1
            if ref_processed and eps_plus_ref is not None and np.any(valid_idx_ref):
                  min_eps_plus_data = min(min_eps_plus_data, np.min(eps_plus_ref[valid_idx_ref]))
                  max_eps_plus_data = max(max_eps_plus_data, np.max(eps_plus_ref[valid_idx_ref]))

            ax.set_ylim(bottom=max(min_eps_plus_data * 0.1, 1e-6), top=max_eps_plus_data * 10)
            ax.set_xlim(left=0.1, right=y_plus_max_plot)


            plt.suptitle(f'Wall Unit Profiles (Top Wall, x ≈ {actual_x_pinn:.3f}m)', fontsize=16)
            plt.tight_layout(rect=[0, 0.03, 1, 0.93]) # Adjust layout for suptitle

            save_path = os.path.join(self.plots_dir, "profile_comparison_wall_units.png")
            plt.savefig(save_path, dpi=200, bbox_inches='tight')
            plt.close(fig)
            logging.info(f"Wall unit comparison plots saved to {os.path.basename(save_path)}")

        except Exception as e:
            logging.error(f"Failed to generate or save wall unit comparison plot: {e}", exc_info=True)
            if 'fig' in locals() and plt.fignum_exists(fig.number):
                 plt.close(fig)

    def plot_pressure_gradient_comparison(self):
        """Plots the streamwise pressure gradient dp/dx along the centerline."""
        if self.p_pred is None: # Check if PINN kinematic pressure is available
            logging.warning("PINN pressure data unavailable. Skipping pressure gradient plot.")
            return

        logging.info("Generating centerline pressure gradient comparison plot...")
        try:
            x_coords_pinn = self.X_grid[0, :] # Streamwise coordinates from PINN grid
            y_coords_pinn = self.Y_grid[:, 0] # Transverse coordinates
            # Find index closest to centerline y=0
            center_idx_pinn = np.argmin(np.abs(y_coords_pinn - 0.0))
            actual_y_center = y_coords_pinn[center_idx_pinn]

            # Extract PINN kinematic pressure along centerline
            p_centerline_pinn = self.p_pred[center_idx_pinn, :]
            # Calculate gradient using numpy.gradient (central difference)
            dp_dx_pinn = np.gradient(p_centerline_pinn, x_coords_pinn)

            # --- Calculate Reference Pressure Gradient (if data available) ---
            dp_dx_ref = None
            x_coords_ref = None
            ref_grad_calculated = False
            if self.has_ref_data and 'p_ref' in self.ref_data.columns:
                try:
                    # Filter reference data near centerline (allow some tolerance)
                    centerline_tol = 0.05 * self.config.CHANNEL_HALF_HEIGHT
                    ref_centerline = self.ref_data[
                        np.abs(self.ref_data['y']) <= centerline_tol
                    ].copy()

                    if not ref_centerline.empty:
                        # If multiple y-values close to center, average pressure at each unique x
                        # Group by x and calculate the mean pressure
                        centerline_grouped = ref_centerline.groupby('x')['p_ref'].mean()
                        # Sort by x just in case grouping changed order
                        centerline_grouped = centerline_grouped.sort_index()

                        if len(centerline_grouped) > 5: # Need sufficient points for reliable gradient
                            x_coords_ref = centerline_grouped.index.values
                            p_centerline_ref = centerline_grouped.values
                            if len(x_coords_ref) > 1:
                                dp_dx_ref = np.gradient(p_centerline_ref, x_coords_ref)
                                logging.info(f"Calculated reference pressure gradient from {len(x_coords_ref)} centerline points.")
                                ref_grad_calculated = True
                            else: logging.warning("Not enough unique x-points in reference centerline data for gradient.")
                        else: logging.warning(f"Not enough grouped x-points ({len(centerline_grouped)}) near centerline in reference data for gradient.")
                    else: logging.warning("No points found near centerline in reference data.")
                except KeyError as ke:
                     logging.warning(f"Could not calculate reference pressure gradient due to missing column: {ke}")
                except Exception as e:
                    logging.warning(f"Could not calculate reference pressure gradient: {e}")
            elif self.has_ref_data:
                logging.warning("Reference data loaded, but 'p_ref' column missing. Cannot plot reference pressure gradient.")

            # --- Create Plot ---
            fig, ax = plt.subplots(figsize=(10, 6))
            ax.plot(x_coords_pinn, dp_dx_pinn, 'r-', lw=2, label='PINN $dp/dx$')
            if ref_grad_calculated and dp_dx_ref is not None and x_coords_ref is not None:
                ax.plot(x_coords_ref, dp_dx_ref, 'b--', lw=1.5, label='Reference $dp/dx$ (CSV)')
            elif self.has_ref_data: # Add placeholder if ref data exists but grad failed
                ax.plot([], [], 'b--', label='Reference $dp/dx$ (Failed)')


            ax.set_xlabel('x / L') # Normalize x-axis by channel length
            ax.set_ylabel(r'$dp/dx$ $(m/s^2)$') # Assuming kinematic pressure p
            ax.set_title(f'Streamwise Kinematic Pressure Gradient along Centerline (y ≈ {actual_y_center:.3f}m)')
            ax.legend()
            ax.grid(True, ls=':')
            # Normalize x coordinates for axis limits
            ax.set_xlim(0, 1) # Range from 0 to L

            # Optional: Set y-limits based on expected range (often negative and near constant in developed region)
            try:
               # Focus on the developed region (e.g., latter half, excluding outlet proximity)
               focus_start_idx = len(dp_dx_pinn) // 2
               focus_end_idx = int(len(dp_dx_pinn) * 0.95) # Exclude last 5%
               if focus_start_idx < focus_end_idx : # Ensure valid slice
                    focus_region_pinn = dp_dx_pinn[focus_start_idx:focus_end_idx]
                    if len(focus_region_pinn) > 0:
                         mean_dpdx = np.mean(focus_region_pinn)
                         std_dpdx = np.std(focus_region_pinn)
                         # Set ylim to mean +/- a few std deviations, or use padding
                         pad = 5 * max(std_dpdx, abs(mean_dpdx)*0.1, 1e-4) # Ensure pad is reasonable
                         ax.set_ylim(mean_dpdx - pad, mean_dpdx + pad)
                         logging.debug(f"Adjusted pressure gradient plot ylim based on developed region: [{mean_dpdx-pad:.2e}, {mean_dpdx+pad:.2e}]")
            except Exception as ylim_e:
                 logging.warning(f"Could not automatically set y-limits for pressure gradient plot: {ylim_e}")

            # Normalize x-axis ticks
            ax.set_xticks(np.linspace(0, self.config.L, 6))
            ax.set_xticklabels([f"{x/self.config.L:.1f}" for x in np.linspace(0, self.config.L, 6)])
            ax.set_xlabel('x / L')


            plt.tight_layout()
            save_path = os.path.join(self.plots_dir, "pressure_gradient_comparison.png")
            plt.savefig(save_path, dpi=200)
            plt.close(fig)
            logging.info(f"Pressure gradient comparison plot saved to {os.path.basename(save_path)}")

        except Exception as e:
            logging.error(f"Failed to generate or save pressure gradient plot: {e}", exc_info=True)
            if 'fig' in locals() and plt.fignum_exists(fig.number):
                 plt.close(fig)

    def run_post_processing(self):
        """Runs the full post-processing sequence."""
        if self.model is None:
             logging.error("Model object is None. Cannot run post-processing.")
             return

        logging.info("--- Starting Full Post-Processing ---")
        # Plot loss history regardless of prediction success
        self.plot_loss_history()

        # Prediction is required for all field/profile plots
        prediction_successful = self.predict_pinn_fields()

        if prediction_successful:
            # Plot contours based on prediction
            self.plot_contour_fields()

            # Load reference data only if prediction was successful (needed for comparisons)
            self.load_reference_data() # Sets self.has_ref_data flag

            # Proceed with comparisons if both prediction and ref data are available
            if self.has_ref_data:
                logging.info("Proceeding with PINN vs Reference CSV comparisons...")
                self.plot_profile_comparison()
                self.plot_wall_unit_comparison() # Depends on _estimate_utau which needs ref data
                self.plot_pressure_gradient_comparison() # Depends on ref data pressure
            else:
                logging.warning("Skipping comparison plots as reference data is unavailable or failed to load.")
        else:
            logging.error("PINN field prediction failed. Aborting further post-processing that depends on predictions.")

        logging.info("--- Post-Processing Finished ---")
# --- End Plotter Class ---


# =============================
# ===== Main Execution Block =====
# =============================
if __name__ == "__main__":
    main_start_time = time.time()

    # --- 1. Initial Setup ---
    main_cfg = Config() # Instantiate default config
    main_plot_cfg = PlotterConfig()

    # --- Google Drive Mount (if applicable) ---
    # This will update main_cfg paths if running in Colab and mount is successful
    # Run this *before* setting up logging/dirs based on potentially updated paths
    mount_drive(main_cfg.DRIVE_MOUNT_POINT)

    # --- Setup Output Dirs and Logging ---
    # These must run AFTER mount_drive potentially updates main_cfg.OUTPUT_DIR etc.
    setup_output_directories(main_cfg)
    setup_logging(main_cfg.LOG_FILE) # Configure logging (will log to updated path)

    logging.info("="*60); logging.info(" PINN RANS k-epsilon Channel Flow Simulation Start "); logging.info("="*60)
    log_configuration(main_cfg, main_plot_cfg) # Log the final configuration used

    # --- 2. Define Boundaries ---
    try:
        # Pass the config object to the function
        bcs, anchor_points = get_boundary_conditions(main_cfg)
        logging.info(f"Defined {len(bcs)} boundary conditions.")
        if anchor_points is not None and len(anchor_points) > 0:
            # Logging moved to build_model where 'anchors' is used
            pass
            # logging.info(f"Using {anchor_points.shape[0]} anchor points for wall functions.")
        else:
             # Should have anchor points if using wall functions as defined
             logging.warning("No anchor points generated for wall functions, check get_boundary_conditions.")
             # Ensure anchor_points is None or empty list if not generated, for build_model compatibility
             anchor_points = None
    except Exception as e:
        logging.error(f"Failed to define boundary conditions: {e}", exc_info=True)
        sys.exit(1) # Exit if BCs cannot be defined

    # --- 3. Training Phase ---
    model_trained, history_trained, state_trained = None, None, None
    training_successful = False
    try:
        # Pass the config object to the Trainer
        trainer = Trainer(main_cfg)
        # Pass anchor points to build_model
        trainer.build_model(bcs, anchor_points)
        # Execute the updated training method
        model_trained, history_trained, state_trained = trainer.train()

        # Check if training outputs seem valid
        if model_trained is not None and history_trained is not None and state_trained is not None:
             # Add more checks? e.g., check if final loss is reasonable (not NaN/inf)
             final_loss = history_trained.loss_train[-1] if history_trained.loss_train else float('inf')
             if np.isfinite(np.sum(final_loss)): # Check if final loss is finite
                 training_successful = True
                 logging.info("Training phase returned valid model, history, and state objects.")
             else:
                  logging.error(f"Training phase finished but final loss is invalid: {final_loss}. Considering training failed.")
                  training_successful = False # Mark as failed if loss is bad
        else:
             logging.error("Training phase finished but returned an invalid state (model, losshistory, or train_state is None).")
             training_successful = False

    except Exception as e:
         logging.error(f"A critical error occurred during model building or the training phase: {e}", exc_info=True)
         # Ensure training_successful remains False
         training_successful = False

    # --- 4. Post-processing and Plotting Phase ---
    if training_successful:
        logging.info("Proceeding to post-processing.")
        try:
            # Pass the final trained state to the plotter
            plotter = Plotter(main_cfg, main_plot_cfg, model_trained, history_trained, state_trained)
            plotter.run_post_processing()
        except Exception as e:
             logging.error(f"An error occurred during post-processing: {e}", exc_info=True)
    else:
        # This message will be logged if training failed, returned None, or hit a critical error
        logging.error("Training did not complete successfully or produced an invalid state. Skipping post-processing.")

    main_end_time = time.time()
    logging.info("="*60); logging.info(f" Script Execution Finished in {main_end_time - main_start_time:.2f} seconds"); logging.info("="*60)


DeepXDE Backend requested: pytorch
DeepXDE Backend actual: pytorch
CUDA available.
PyTorch CUDA device detected by DDE: 0 (Tesla T4)
PyTorch version: 2.6.0+cu124
Number of GPUs: 1
PyTorch default dtype: torch.float32
2025-04-14 20:39:08 [INFO] Google Drive already mounted.
2025-04-14 20:39:08 [INFO] Output paths point to Google Drive: /content/drive/MyDrive/content/drive/MyDrive/PINN_RANS_ChannelFlow
2025-04-14 20:39:08 [INFO] Setting up output directories...
2025-04-14 20:39:08 [INFO] Output directories verified/created.
2025-04-14 20:39:08 [INFO] Logging configured.
2025-04-14 20:39:08 [INFO]  PINN RANS k-epsilon Channel Flow Simulation Start 
2025-04-14 20:39:08 [INFO] Simulation Configuration:
2025-04-14 20:39:08 [INFO]   Output Directory: /content/drive/MyDrive/PINN_RANS_ChannelFlow
2025-04-14 20:39:08 [INFO]   Re_H: 10000
2025-04-14 20:39:08 [INFO]   Wall Function y_p: 0.04 (Target y+: 14.00)
2025-04-14 20:39:08 [INFO]   Network: 8 layers, 64 neurons
2025-04-14 20:39:08 [INFO]   

In [12]:
# import os
# # Set environment variable *before* importing deepxde or torch
# os.environ["DDE_BACKEND"] = "pytorch"
# try:
#   import torch
# except ImportError:
#   print("Installing torch...")
#   !pip install torch -q
# try:
#   import deepxde
# except ImportError:
#   print("Installing deepxde...")
#   !pip install deepxde -q
# try:
#   import pandas
# except ImportError:
#   print("Installing pandas...")
#   !pip install pandas -q
# try:
#   import matplotlib
# except ImportError:
#   print("Installing matplotlib...")
#   !pip install matplotlib -q

# import sys
# import time
# import logging
# import numpy as np
# import torch # Now import torch after potentially installing
# import deepxde as dde # Import deepxde after setting backend
# import matplotlib.pyplot as plt
# import pandas as pd
# from scipy.interpolate import griddata
# import re # <<<--- IMPORT REGEX MODULE


# # --- Attempt to explicitly set backend (optional but good practice) ---
# try:
#     # This might still fail on older DeepXDE versions, but the env var is primary
#     dde.config.set_default_backend("pytorch")
#     print("Attempted to explicitly set DeepXDE backend to PyTorch.")
# except AttributeError:
#     print(f"Warning: Could not explicitly set backend via dde.config (likely older DeepXDE version). Relied on environment variable DDE_BACKEND={os.environ.get('DDE_BACKEND')}.")
# except Exception as e:
#     print(f"Warning: Could not explicitly set backend via dde.config: {e}")

# print(f"DeepXDE Backend requested: {os.environ.get('DDE_BACKEND', 'Not Set')}")

# # --- Check actual backend and setup device/dtype ---
# if "deepxde" in sys.modules and hasattr(dde, 'backend'):
#     print(f"DeepXDE Backend actual: {dde.backend.backend_name}")
#     if dde.backend.backend_name == "pytorch":
#         if torch.cuda.is_available():
#             print("CUDA available.")
#             try:
#                 # Use float32 as it's common for PINNs and avoids potential double precision issues
#                 torch.set_default_dtype(torch.float32)
#                 current_device = torch.cuda.current_device()
#                 print(f"PyTorch CUDA device detected by DDE: {current_device} ({torch.cuda.get_device_name(current_device)})")
#                 print(f"PyTorch version: {torch.__version__}")
#                 print(f"Number of GPUs: {torch.cuda.device_count()}")
#             except Exception as e:
#                 print(f"Warning: Error during PyTorch device setup: {e}")
#         else:
#             print("CUDA not available. Using CPU.")
#             try:
#                 torch.set_default_dtype(torch.float32)
#                 print(f"PyTorch default device set to: cpu")
#             except Exception as e:
#                 print(f"Warning: Failed to set default PyTorch device to CPU: {e}")
#         print(f"PyTorch default dtype: {torch.get_default_dtype()}")
#     else:
#         print(f"Warning: Backend is '{dde.backend.backend_name}', not PyTorch. PyTorch-specific device setup skipped.")
# else:
#     print("Warning: deepxde module or dde.backend not fully available for backend check.")


# # =============================
# # ===== Configuration Classes =====
# # =============================

# class PlotterConfig:
#     """Stores configuration parameters specifically for plotting."""
#     NX_PRED = 200
#     NY_PRED = 100
#     CMAP_VELOCITY = 'viridis'
#     CMAP_PRESSURE = 'coolwarm'
#     CMAP_TURBULENCE = 'plasma'


# class Config:
#     """Stores configuration parameters for the simulation."""
#     DRIVE_MOUNT_POINT = '/content/drive'
#       # Adjust path if necessary
#     # === PATH CORRECTION ===
#     GDRIVE_BASE_FOLDER = 'PINN_RANS_ChannelFlow' # Path relative to MyDrive
#     # =======================
#     OUTPUT_DIR = os.path.join(DRIVE_MOUNT_POINT, 'MyDrive', GDRIVE_BASE_FOLDER) # Default if not overwritten
#     MODEL_DIR = os.path.join(OUTPUT_DIR, "model_checkpoints")
#     LOG_DIR = os.path.join(OUTPUT_DIR, "logs")
#     PLOT_DIR = os.path.join(OUTPUT_DIR, "plots")
#     DATA_DIR = os.path.join(OUTPUT_DIR, "data")
#     LOG_FILE = os.path.join(LOG_DIR, "plotting_log.log") # Changed log file name
#     REFERENCE_DATA_FILE = os.path.join(DATA_DIR, "reference_output_data.csv")

#     # Checkpoint filename base (without step number or extension)
#     CHECKPOINT_FILENAME_BASE  = "rans_channel_wf"

#     # --- Fluid and Geometry Parameters ---
#     NU = 0.0002 # Kinematic viscosity
#     RHO = 1.0 # Density (often set to 1 for incompressible flow)
#     MU = RHO * NU # Dynamic viscosity
#     U_INLET = 1.0 # Inlet velocity - STILL NEEDED FOR BC DEFINITION
#     H = 2.0 # Full channel height
#     CHANNEL_HALF_HEIGHT = H / 2.0
#     L = 10.0 # Channel length
#     RE_H = U_INLET * H / NU # Reynolds number based on full height
#     EPS_SMALL = 1e-10 # Small epsilon for numerical stability (avoid log(0), division by zero)

#     # --- k-epsilon Model Constants ---
#     CMU = 0.09
#     CEPS1 = 1.44
#     CEPS2 = 1.92
#     SIGMA_K = 1.0
#     SIGMA_EPS = 1.3
#     KAPPA = 0.41 # Von Karman constant

#     # --- Wall Function Parameters (Needed for BCs) ---
#     E_WALL = 9.8 # Log-law constant for smooth walls
#     Y_P = 0.04 # Distance from wall for applying wall functions (y_p)
#     RE_TAU_TARGET = 350 # Target friction Reynolds number
#     U_TAU_TARGET = RE_TAU_TARGET * NU / CHANNEL_HALF_HEIGHT # Target friction velocity
#     YP_PLUS_TARGET = Y_P * U_TAU_TARGET / NU # Target y+ at y_p
#     U_TARGET_WF = (U_TAU_TARGET / KAPPA) * np.log(max(E_WALL * YP_PLUS_TARGET, EPS_SMALL))
#     K_TARGET_WF = U_TAU_TARGET**2 / np.sqrt(CMU)
#     EPS_TARGET_WF = U_TAU_TARGET**3 / max(KAPPA * Y_P, EPS_SMALL)

#     # --- Inlet Turbulence Parameters (Needed for BCs) ---
#     TURBULENCE_INTENSITY = 0.05 # Typical value for channel flow inlet
#     MIXING_LENGTH_SCALE = 0.07 * CHANNEL_HALF_HEIGHT # Estimate based on boundary layer thickness
#     K_INLET = 1.5 * (U_INLET * TURBULENCE_INTENSITY)**2
#     EPS_INLET = (CMU**0.75) * (K_INLET**1.5) / MIXING_LENGTH_SCALE
#     K_INLET_TRANSFORMED = np.log(max(K_INLET, EPS_SMALL))
#     EPS_INLET_TRANSFORMED = np.log(max(EPS_INLET, EPS_SMALL))
#     K_TARGET_WF_TRANSFORMED = np.log(max(K_TARGET_WF, EPS_SMALL))
#     EPS_TARGET_WF_TRANSFORMED = np.log(max(EPS_TARGET_WF, EPS_SMALL))

#     # --- Domain Geometry ---
#     GEOM = dde.geometry.Rectangle(xmin=[0, -CHANNEL_HALF_HEIGHT], xmax=[L, CHANNEL_HALF_HEIGHT])

#     # --- Network Architecture (MUST MATCH SAVED MODEL) ---
#     NUM_LAYERS = 8
#     NUM_NEURONS = 64
#     ACTIVATION = "tanh"
#     INITIALIZER = "Glorot normal" # Initializer doesn't matter for loading weights
#     NETWORK_INPUTS = 2 # x, y
#     NETWORK_OUTPUTS = 5 # u, v, p', log(k), log(eps)

#     # --- Training Parameters (Not used for training, but needed for Data object) ---
#     NUM_DOMAIN_POINTS = 1 # Minimal points needed just to build Data object
#     NUM_BOUNDARY_POINTS = 1 # Minimal points
#     NUM_TEST_POINTS = 1 # Minimal points
#     NUM_WF_POINTS_PER_WALL = 200 # Needed for get_boundary_conditions
#     # Other training params (LR, iterations, weights) are irrelevant now
#     LOSS_WEIGHTS = None # Not needed for prediction


# # Instantiate config objects
# cfg = Config()
# plotter_cfg = PlotterConfig()


# # ==============================================
# # ===== Custom Checkpoint Callback Class ========
# # ==============================================
# # --- Callback is NOT needed for loading/plotting ---


# # ==========================
# # ===== Utility Functions =====
# # ==========================
# def setup_logging(log_file):
#     """Configures logging to file and console."""
#     log_dir = os.path.dirname(log_file)
#     ensure_dir(log_dir)
#     root_logger = logging.getLogger()
#     # Clear existing handlers to avoid duplicate messages if run multiple times in notebook
#     if root_logger.hasHandlers():
#         root_logger.handlers.clear()
#     log_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt='%Y-%m-%d %H:%M:%S')
#     root_logger.setLevel(logging.INFO) # Set root logger level

#     # File handler
#     try:
#         file_handler = logging.FileHandler(log_file, mode='a') # Append mode
#         file_handler.setFormatter(log_formatter)
#         root_logger.addHandler(file_handler)
#     except Exception as e:
#         print(f"Error setting up file logger at {log_file}: {e}")


#     # Console handler
#     console_handler = logging.StreamHandler(sys.stdout)
#     console_handler.setFormatter(log_formatter)
#     root_logger.addHandler(console_handler)
#     logging.info("Logging configured.")

# def ensure_dir(directory):
#     """Creates a directory if it doesn't exist."""
#     if not os.path.exists(directory):
#         try:
#             os.makedirs(directory)
#             logging.info(f"Created directory: {directory}")
#         except OSError as e:
#             logging.error(f"Failed to create directory {directory}: {e}")


# def mount_drive(mount_point):
#     """Mounts Google Drive if running in Colab."""
#     # === Uses Corrected Config Path ===
#     if 'google.colab' in sys.modules:
#         if not os.path.exists(os.path.join(mount_point, 'MyDrive')):
#             try:
#                 from google.colab import drive
#                 logging.info(f"Mounting Google Drive at {mount_point}...")
#                 drive.mount(mount_point, force_remount=True)
#                 logging.info("Google Drive mounted successfully.")
#                 # Construct path correctly using relative GDRIVE_BASE_FOLDER
#                 gdrive_output_path = os.path.join(mount_point, 'MyDrive', cfg.GDRIVE_BASE_FOLDER)
#                 cfg.OUTPUT_DIR = gdrive_output_path
#                 cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
#                 cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
#                 cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
#                 cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
#                 cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "plotting_log.log") # Use new log file name
#                 cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
#                 logging.info(f"Output paths updated to Google Drive: {cfg.OUTPUT_DIR}")
#                 ensure_dir(cfg.OUTPUT_DIR)
#                 if not os.path.exists(cfg.OUTPUT_DIR):
#                     logging.warning(f"Configured base folder NOT found after mount attempt: {cfg.OUTPUT_DIR}")
#             except Exception as e:
#                 logging.error(f"Error mounting Google Drive or accessing path: {e}")
#                 logging.warning("Falling back to local directory structure.")
#                 # Use relative path locally too if mount fails
#                 cfg.OUTPUT_DIR = cfg.GDRIVE_BASE_FOLDER
#                 cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
#                 cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
#                 cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
#                 cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
#                 cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "plotting_log.log")
#                 cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
#         else:
#             logging.info("Google Drive already mounted.")
#             # Update paths correctly if already mounted
#             gdrive_output_path = os.path.join(mount_point, 'MyDrive', cfg.GDRIVE_BASE_FOLDER)
#             cfg.OUTPUT_DIR = gdrive_output_path
#             cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
#             cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
#             cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
#             cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
#             cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "plotting_log.log")
#             cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")
#             logging.info(f"Output paths point to Google Drive: {cfg.OUTPUT_DIR}")
#     else:
#         logging.info("Not running in Google Colab. Using local directory structure.")
#         # Use relative path locally
#         cfg.OUTPUT_DIR = cfg.GDRIVE_BASE_FOLDER
#         cfg.MODEL_DIR = os.path.join(cfg.OUTPUT_DIR, "model_checkpoints")
#         cfg.LOG_DIR = os.path.join(cfg.OUTPUT_DIR, "logs")
#         cfg.PLOT_DIR = os.path.join(cfg.OUTPUT_DIR, "plots")
#         cfg.DATA_DIR = os.path.join(cfg.OUTPUT_DIR, "data")
#         cfg.LOG_FILE = os.path.join(cfg.LOG_DIR, "plotting_log.log")
#         cfg.REFERENCE_DATA_FILE = os.path.join(cfg.DATA_DIR, "reference_output_data.csv")

# def setup_output_directories(config):
#     """Creates all necessary output directories."""
#     logging.info("Setting up output directories...")
#     ensure_dir(config.OUTPUT_DIR)
#     ensure_dir(config.MODEL_DIR)
#     ensure_dir(config.LOG_DIR)
#     ensure_dir(config.PLOT_DIR)
#     ensure_dir(config.DATA_DIR)
#     logging.info("Output directories verified/created.")

# def log_configuration(config, plotter_config):
#     """Logs the simulation and plotter configuration (relevant parts)."""
#     logging.info("=" * 50)
#     logging.info("Plotting Configuration:")
#     logging.info(f"  Output Directory: {config.OUTPUT_DIR}")
#     logging.info(f"  Model Directory: {config.MODEL_DIR}")
#     logging.info(f"  Plot Directory: {config.PLOT_DIR}")
#     logging.info(f"  Ref Data File: {config.REFERENCE_DATA_FILE}")
#     logging.info(f"  Network Expected: {config.NUM_LAYERS} layers, {config.NUM_NEURONS} neurons")
#     logging.info(f"  Prediction Grid Nx: {plotter_config.NX_PRED}, Ny: {plotter_config.NY_PRED}")
#     logging.info("=" * 50)
# # --- End Utility Functions ---


# # ===============================
# # ===== PDE System Definition =====
# # ===============================
# # --- PDE function is needed for Data object creation, keep as is ---
# def pde(x, y, config):
#     if dde.backend.backend_name != "pytorch":
#         raise RuntimeError("PDE function relies on PyTorch autograd. Backend mismatch.")
#     nu = config.NU; Cmu = config.CMU; Ceps1 = config.CEPS1; Ceps2 = config.CEPS2
#     sigma_k = config.SIGMA_K; sigma_eps = config.SIGMA_EPS; eps_small = config.EPS_SMALL
#     u, v, p_prime, k_raw, eps_raw = y[:, 0:1], y[:, 1:2], y[:, 2:3], y[:, 3:4], y[:, 4:5]
#     k = torch.exp(k_raw) + eps_small
#     eps = torch.exp(eps_raw) + eps_small
#     try:
#         u_x = dde.grad.jacobian(y, x, i=0, j=0); u_y = dde.grad.jacobian(y, x, i=0, j=1)
#         v_x = dde.grad.jacobian(y, x, i=1, j=0); v_y = dde.grad.jacobian(y, x, i=1, j=1)
#         p_prime_x = dde.grad.jacobian(y, x, i=2, j=0); p_prime_y = dde.grad.jacobian(y, x, i=2, j=1)
#         u_xx = dde.grad.hessian(y, x, component=0, i=0, j=0); u_yy = dde.grad.hessian(y, x, component=0, i=1, j=1)
#         v_xx = dde.grad.hessian(y, x, component=1, i=0, j=0); v_yy = dde.grad.hessian(y, x, component=1, i=1, j=1)
#         u_xy = dde.grad.hessian(y, x, component=0, i=0, j=1)
#         v_xy = dde.grad.hessian(y, x, component=1, i=0, j=1)
#         if isinstance(x, torch.Tensor) and not x.requires_grad: x.requires_grad_(True)
#         grad_k = torch.autograd.grad(k, x, grad_outputs=torch.ones_like(k), create_graph=True)[0]
#         k_x, k_y = grad_k[:, 0:1], grad_k[:, 1:2]
#         grad_eps = torch.autograd.grad(eps, x, grad_outputs=torch.ones_like(eps), create_graph=True)[0]
#         eps_x, eps_y = grad_eps[:, 0:1], grad_eps[:, 1:2]
#         grad_kx = torch.autograd.grad(k_x, x, grad_outputs=torch.ones_like(k_x), create_graph=True)[0]
#         k_xx = grad_kx[:, 0:1]
#         grad_ky = torch.autograd.grad(k_y, x, grad_outputs=torch.ones_like(k_y), create_graph=True)[0]
#         k_yy = grad_ky[:, 1:2]
#         grad_epsx = torch.autograd.grad(eps_x, x, grad_outputs=torch.ones_like(eps_x), create_graph=True)[0]
#         eps_xx = grad_epsx[:, 0:1]
#         grad_epsy = torch.autograd.grad(eps_y, x, grad_outputs=torch.ones_like(eps_y), create_graph=True)[0]
#         eps_yy = grad_epsy[:, 1:2]
#     except RuntimeError as grad_e:
#         logging.error(f"PyTorch Autograd RuntimeError calculating gradients in PDE: {grad_e}. Ensure create_graph=True is used correctly for higher-order derivatives.", exc_info=True)
#         zero_tensor = torch.zeros_like(y[:, 0:1])
#         return [zero_tensor] * 5
#     except Exception as grad_e:
#         logging.error(f"General error calculating gradients in PDE function: {grad_e}", exc_info=True)
#         zero_tensor = torch.zeros_like(y[:, 0:1])
#         return [zero_tensor] * 5
#     k_safe = k; eps_safe = eps
#     nu_t = Cmu * torch.square(k_safe) / (eps_safe + eps_small)
#     nu_eff = nu + nu_t
#     dnut_dk = 2.0 * Cmu * k_safe / (eps_safe + eps_small)
#     dnut_deps = -Cmu * torch.square(k_safe) / torch.square(eps_safe + eps_small)
#     nu_eff_x = dnut_dk * k_x + dnut_deps * eps_x
#     nu_eff_y = dnut_dk * k_y + dnut_deps * eps_y
#     eq_continuity = u_x + v_y
#     adv_u = u * u_x + v * u_y
#     diff_u_term1 = nu_eff_x * (2 * u_x) + nu_eff * (2 * u_xx)
#     diff_u_term2 = nu_eff_y * (u_y + v_x) + nu_eff * (u_yy + v_xy)
#     eq_mom_x = adv_u + p_prime_x - (diff_u_term1 + diff_u_term2)
#     adv_v = u * v_x + v * v_y
#     diff_v_term1 = nu_eff_x * (v_x + u_y) + nu_eff * (v_xx + u_xy)
#     diff_v_term2 = nu_eff_y * (2 * v_y) + nu_eff * (2 * v_yy)
#     eq_mom_y = adv_v + p_prime_y - (diff_v_term1 + diff_v_term2)
#     S_squared = 2 * (torch.square(u_x) + torch.square(v_y)) + torch.square(u_y + v_x)
#     P_k = torch.relu(nu_t * S_squared)
#     adv_k = u * k_x + v * k_y
#     diffusivity_k = nu + nu_t / sigma_k
#     d_diffk_dx = (1 / sigma_k) * nu_eff_x
#     d_diffk_dy = (1 / sigma_k) * nu_eff_y
#     laplacian_k = k_xx + k_yy
#     diffusion_k = d_diffk_dx * k_x + d_diffk_dy * k_y + diffusivity_k * laplacian_k
#     eq_k = adv_k - diffusion_k - P_k + eps_safe
#     adv_eps = u * eps_x + v * eps_y
#     diffusivity_eps = nu + nu_t / sigma_eps
#     d_diffeps_dx = (1 / sigma_eps) * nu_eff_x
#     d_diffeps_dy = (1 / sigma_eps) * nu_eff_y
#     laplacian_eps = eps_xx + eps_yy
#     diffusion_eps = d_diffeps_dx * eps_x + d_diffeps_dy * eps_y + diffusivity_eps * laplacian_eps
#     source_eps = Ceps1 * (eps_safe / (k_safe + eps_small)) * P_k
#     sink_eps = Ceps2 * (torch.square(eps_safe) / (k_safe + eps_small))
#     eq_eps = adv_eps - diffusion_eps - source_eps + sink_eps
#     return [eq_continuity, eq_mom_x, eq_mom_y, eq_k, eq_eps]


# # =============================
# # ===== Boundary Conditions =====
# # =============================
# # --- BC function is needed for Data object creation, keep as is ---
# def get_boundary_conditions(config):
#     geom = config.GEOM; h = config.CHANNEL_HALF_HEIGHT; L = config.L
#     y_p = config.Y_P; n_wf_points = config.NUM_WF_POINTS_PER_WALL
#     def boundary_inlet(x, on_boundary): return on_boundary and np.isclose(x[0], 0)
#     def boundary_outlet(x, on_boundary): return on_boundary and np.isclose(x[0], L)
#     def boundary_bottom_wall_physical(x, on_boundary): return on_boundary and np.isclose(x[1], -h)
#     def boundary_top_wall_physical(x, on_boundary): return on_boundary and np.isclose(x[1], h)
#     def boundary_walls_physical(x, on_boundary): return boundary_bottom_wall_physical(x, on_boundary) or boundary_top_wall_physical(x, on_boundary)
#     bc_u_inlet = dde.DirichletBC(geom, lambda x: config.U_INLET, boundary_inlet, component=0)
#     bc_v_inlet = dde.DirichletBC(geom, lambda x: 0, boundary_inlet, component=1)
#     bc_k_inlet = dde.DirichletBC(geom, lambda x: config.K_INLET_TRANSFORMED, boundary_inlet, component=3)
#     bc_eps_inlet = dde.DirichletBC(geom, lambda x: config.EPS_INLET_TRANSFORMED, boundary_inlet, component=4)
#     bc_p_outlet = dde.DirichletBC(geom, lambda x: 0, boundary_outlet, component=2)
#     bc_u_walls = dde.DirichletBC(geom, lambda x: 0, boundary_walls_physical, component=0)
#     bc_v_walls = dde.DirichletBC(geom, lambda x: 0, boundary_walls_physical, component=1)
#     x_wf_coords = np.linspace(0 + L * 0.01, L - L * 0.01, n_wf_points)[:, None]
#     points_bottom_wf = np.hstack((x_wf_coords, np.full_like(x_wf_coords, -h + y_p)))
#     points_top_wf = np.hstack((x_wf_coords, np.full_like(x_wf_coords, h - y_p)))
#     anchor_points_wf = np.vstack((points_bottom_wf, points_top_wf))
#     logging.info(f"Generated {anchor_points_wf.shape[0]} anchor points for wall functions (needed for BC setup).")
#     U_target_vals = np.full((anchor_points_wf.shape[0], 1), config.U_TARGET_WF)
#     k_target_vals = np.full((anchor_points_wf.shape[0], 1), config.K_TARGET_WF_TRANSFORMED)
#     eps_target_vals = np.full((anchor_points_wf.shape[0], 1), config.EPS_TARGET_WF_TRANSFORMED)
#     bc_u_wf = dde.PointSetBC(anchor_points_wf, U_target_vals, component=0)
#     bc_k_wf = dde.PointSetBC(anchor_points_wf, k_target_vals, component=3)
#     bc_eps_wf = dde.PointSetBC(anchor_points_wf, eps_target_vals, component=4)
#     all_bcs = [bc_u_inlet, bc_v_inlet, bc_k_inlet, bc_eps_inlet, bc_p_outlet, bc_u_walls, bc_v_walls, bc_u_wf, bc_k_wf, bc_eps_wf]
#     return all_bcs, anchor_points_wf


# # =======================
# # ===== Trainer Class =====
# # =======================
# # --- Trainer class is only used for build_model ---
# class Trainer:
#     """Handles the setup of the PINN model structure for loading/plotting."""
#     def __init__(self, config):
#         self.config = config
#         self.model = None
#         self.pde = pde # Assign the PDE function

#     def build_model(self, bcs, anchor_points):
#         """Builds the DeepXDE model structure (network and data)."""
#         logging.info("Building the PINN model structure...")
#         if dde.backend.backend_name != "pytorch":
#              raise RuntimeError("This code relies on the PyTorch backend.")

#         # Define the neural network (MUST MATCH SAVED ARCHITECTURE)
#         net = dde.maps.FNN(
#             layer_sizes=[self.config.NETWORK_INPUTS] + [self.config.NUM_NEURONS] * self.config.NUM_LAYERS + [self.config.NETWORK_OUTPUTS],
#             activation=self.config.ACTIVATION,
#             kernel_initializer=self.config.INITIALIZER # Initializer doesn't matter for loading
#         )

#         # Wrap PDE to include config
#         pde_with_config = lambda x, y: self.pde(x, y, config=self.config)

#         # Define the PDE data object (minimal points needed)
#         try:
#             data = dde.data.PDE(
#                 geometry=self.config.GEOM,
#                 pde=pde_with_config,
#                 bcs=bcs, # BCs are needed for model structure
#                 num_domain=self.config.NUM_DOMAIN_POINTS,
#                 num_boundary=self.config.NUM_BOUNDARY_POINTS,
#                 num_test=self.config.NUM_TEST_POINTS, # Use test points for potential evaluation later
#                 anchors=anchor_points
#             )
#             logging.info(f"Using {len(anchor_points) if anchor_points is not None else 0} anchor points in Data object.")
#         except TypeError:
#             logging.warning("DeepXDE version might not support 'anchors'.")
#             data = dde.data.PDE(
#                 geometry=self.config.GEOM,
#                 pde=pde_with_config,
#                 bcs=bcs,
#                 num_domain=self.config.NUM_DOMAIN_POINTS,
#                 num_boundary=self.config.NUM_BOUNDARY_POINTS,
#                 num_test=self.config.NUM_TEST_POINTS
#             )

#         self.model = dde.Model(data, net)
#         logging.info("Model structure built successfully.")
#         return self.model # Return the built model

#     # --- train() and _post_training_checks() are removed ---

# # --- End Trainer Class ---


# # ========================
# # ===== Plotter Class =====
# # ========================
# # --- Plotter class remains the same ---
# class Plotter:
#     """Handles post-processing and plotting of simulation results."""
#     def __init__(self, config, plotter_config, model, losshistory, train_state):
#         self.config = config
#         self.plotter_config = plotter_config
#         self.model = model
#         self.losshistory = losshistory if losshistory else None
#         self.train_state = train_state if train_state else None
#         self.ref_data_path = config.REFERENCE_DATA_FILE
#         self.plots_dir = config.PLOT_DIR
#         self.ref_data = None
#         self.has_ref_data = False
#         self.ref_data_utau = None
#         self.pinn_data_utau = None
#         self.X_grid, self.Y_grid = None, None
#         self.u_pred, self.v_pred, self.p_prime_pred = None, None, None
#         self.k_pred, self.eps_pred, self.nu_t_pred = None, None, None
#         self.p_pred = None
#         os.makedirs(self.plots_dir, exist_ok=True)
#         logging.info(f"Plotter initialized. Plots will be saved in: {self.plots_dir}")
#         if self.ref_data_path:
#              if os.path.exists(self.ref_data_path):
#                  logging.info(f"Reference CSV data path found: {self.ref_data_path}")
#              else:
#                  logging.warning(f"Reference CSV file not found: {self.ref_data_path}. Comparisons will be skipped.")
#         else:
#              logging.info("No reference data path provided in config. Comparisons will be skipped.")

#     def plot_loss_history(self):
#         if self.losshistory and self.train_state:
#             logging.info("Saving loss history plot...")
#             try:
#                 os.makedirs(self.plots_dir, exist_ok=True)
#                 dde.saveplot(self.losshistory, self.train_state, issave=True, isplot=False, output_dir=self.plots_dir)
#                 default_loss_file = os.path.join(self.plots_dir, "loss.png")
#                 target_loss_file = os.path.join(self.plots_dir, "training_loss_history.png")
#                 if os.path.exists(default_loss_file):
#                     try:
#                         os.replace(default_loss_file, target_loss_file)
#                         logging.info(f"Loss history plot saved as '{os.path.basename(target_loss_file)}'.")
#                     except OSError as rename_err:
#                          logging.warning(f"Could not replace/rename loss plot, using os.rename: {rename_err}")
#                          os.rename(default_loss_file, target_loss_file)
#                          logging.info(f"Loss history plot saved as '{os.path.basename(target_loss_file)}'.")
#                 else:
#                      potential_files = [f for f in os.listdir(self.plots_dir) if f.lower().endswith('.png') and 'loss' in f.lower()]
#                      if potential_files: logging.warning(f"dde.saveplot might not have produced 'loss.png'. Found: {potential_files}.")
#                      else: logging.warning("dde.saveplot did not produce 'loss.png'.")
#             except ImportError: logging.error("Matplotlib might be needed by dde.saveplot but is not installed.")
#             except Exception as e: logging.error(f"Could not save loss history plot: {e}", exc_info=True)
#         else:
#             logging.info("Loss history or train state not available (likely loaded model). Skipping loss plot.")

#     def load_reference_data(self):
#         self.has_ref_data = False
#         if not self.ref_data_path: return
#         if not os.path.exists(self.ref_data_path):
#             logging.warning(f"Reference CSV file not found: '{self.ref_data_path}'. Skipping load.")
#             return
#         logging.info(f"Loading reference data from: {self.ref_data_path}")
#         try:
#             df_ref = pd.read_csv(self.ref_data_path)
#             logging.info(f"Loaded reference data: {df_ref.shape[0]} rows, {df_ref.shape[1]} cols.")
#             time_col = None
#             if 'Time' in df_ref.columns: time_col = 'Time'
#             elif 'TimeStep' in df_ref.columns: time_col = 'TimeStep'
#             if time_col:
#                 latest_time = df_ref[time_col].max()
#                 df_ref = df_ref[df_ref[time_col] == latest_time].copy()
#                 logging.info(f"Filtered for latest time/step ({time_col}={latest_time}): {df_ref.shape[0]} rows remaining.")
#             x_col, y_col, z_col = None, None, None
#             potential_x = ['x', 'Points:0', 'X', 'x-coordinate']
#             potential_y = ['y', 'Points:1', 'Y', 'y-coordinate']
#             potential_z = ['z', 'Points:2', 'Z', 'z-coordinate']
#             for p_x in potential_x:
#                 if p_x in df_ref.columns: x_col = p_x; break
#             if not x_col:
#                 for col in df_ref.columns:
#                     if col.lower() in ['x', 'points:0', 'x-coordinate']: x_col = col; break
#             for p_y in potential_y:
#                 if p_y in df_ref.columns: y_col = p_y; break
#             if not y_col:
#                  for col in df_ref.columns:
#                     if col.lower() in ['y', 'points:1', 'y-coordinate']: y_col = col; break
#             for p_z in potential_z:
#                  if p_z in df_ref.columns: z_col = p_z; break
#             if not z_col:
#                  for col in df_ref.columns:
#                     if col.lower() in ['z', 'points:2', 'z-coordinate']: z_col = col; break
#             if not x_col or not y_col: raise ValueError(f"Could not identify x/y coordinates in reference columns: {df_ref.columns.tolist()}")
#             logging.info(f"Identified coordinate columns: x='{x_col}', y='{y_col}'" + (f", z='{z_col}'" if z_col else ""))
#             if z_col and len(df_ref[z_col].unique()) > 1:
#                 target_z = 0.0
#                 unique_z = df_ref[z_col].unique()
#                 nearest_z_idx = np.argmin(np.abs(unique_z - target_z))
#                 nearest_z = unique_z[nearest_z_idx]
#                 df_ref = df_ref[np.isclose(df_ref[z_col], nearest_z)].copy()
#                 logging.info(f"Filtered for z-plane near {target_z} (actual: {nearest_z:.4f}): {df_ref.shape[0]} rows remaining.")
#             var_map = {
#                 'u:0':'u_ref', 'u_x':'u_ref', 'velocity:0':'u_ref', 'velocity_x':'u_ref', 'u':'u_ref', 'velocityu':'u_ref',
#                 'u:1':'v_ref', 'u_y':'v_ref', 'velocity:1':'v_ref', 'velocity_y':'v_ref', 'v':'v_ref', 'velocityv':'v_ref',
#                 'p':'p_ref', 'pressure':'p_ref', 'kinematicpressure':'p_ref', 'kinematic_pressure':'p_ref',
#                 'k':'k_ref', 'turbulentkinetienergy':'k_ref', 'turbulentkineticenergy':'k_ref', 'tke':'k_ref',
#                 'epsilon':'eps_ref', 'turbulencedissipationrate':'eps_ref', 'dissipationrate':'eps_ref', 'dissipation':'eps_ref',
#                 'nut':'nut_ref', 'turbulentviscosity':'nut_ref', 'eddyviscosity':'nut_ref', 'nutilda':'nut_ref'
#             }
#             rename_dict = {}
#             processed_cols = set()
#             for col in df_ref.columns:
#                 col_lower = col.lower().strip().replace('_','').replace('-','').replace(' ','')
#                 if col_lower in var_map and col not in processed_cols:
#                     rename_dict[col] = var_map[col_lower]
#                     processed_cols.add(col)
#             rename_dict[x_col] = 'x'; rename_dict[y_col] = 'y'
#             if z_col: rename_dict[z_col] = 'z'
#             processed_cols.update([x_col, y_col, z_col] if z_col else [x_col, y_col])
#             unmapped_cols = [col for col in df_ref.columns if col not in processed_cols]
#             if unmapped_cols: logging.debug(f"Unmapped columns in reference data: {unmapped_cols}")
#             df_ref.rename(columns=rename_dict, inplace=True)
#             logging.info(f"Renamed reference columns based on mapping. New columns: {df_ref.columns.tolist()}")
#             required_cols_for_plots = list(set(['x', 'y', 'u_ref', 'p_ref', 'k_ref', 'eps_ref']))
#             missing_cols = [col for col in required_cols_for_plots if col not in df_ref.columns]
#             if missing_cols:
#                 if any(c in missing_cols for c in ['x', 'y', 'u_ref']):
#                     raise ValueError(f"Missing essential columns after renaming in reference data: {missing_cols}.")
#                 else: logging.warning(f"Missing some optional columns for plots: {missing_cols}.")
#             cols_to_keep = ['x', 'y'] + [col for col in ['u_ref', 'v_ref', 'p_ref', 'k_ref', 'eps_ref', 'nut_ref'] if col in df_ref.columns]
#             if z_col and 'z' in df_ref.columns: cols_to_keep.append('z')
#             df_ref = df_ref[list(set(cols_to_keep))]
#             df_ref.sort_values(by=['x', 'y'], inplace=True)
#             df_ref.reset_index(drop=True, inplace=True)
#             self.ref_data = df_ref
#             self.has_ref_data = True
#             logging.info(f"Successfully loaded and preprocessed reference CSV data. Final columns: {df_ref.columns.tolist()}")
#         except FileNotFoundError: pass
#         except ValueError as ve: logging.error(f"ValueError processing reference CSV: {ve}")
#         except Exception as e:
#             logging.error(f"Unexpected error loading or processing reference CSV: {e}", exc_info=True)
#             self.ref_data = None; self.has_ref_data = False

#     def predict_pinn_fields(self):
#         if self.model is None or self.model.net is None:
#              logging.error("PINN Model or network not available for prediction.")
#              return False
#         logging.info("Predicting PINN flow fields on evaluation grid...")
#         nx = self.plotter_config.NX_PRED; ny = self.plotter_config.NY_PRED
#         x_coords = np.linspace(0, self.config.L, nx); y_coords = np.linspace(-self.config.CHANNEL_HALF_HEIGHT, self.config.CHANNEL_HALF_HEIGHT, ny)
#         self.X_grid, self.Y_grid = np.meshgrid(x_coords, y_coords)
#         pred_points = np.vstack((np.ravel(self.X_grid), np.ravel(self.Y_grid))).T
#         try:
#             if isinstance(pred_points, torch.Tensor): pred_points_np = pred_points.cpu().numpy()
#             else: pred_points_np = np.array(pred_points, dtype=np.float32)
#             if not hasattr(self.model, 'net') or self.model.net is None:
#                  logging.error("Model network attribute is missing or None. Cannot predict.")
#                  return False
#             predictions_raw = self.model.predict(pred_points_np)
#             if predictions_raw is None or not isinstance(predictions_raw, np.ndarray) or predictions_raw.shape[1] != self.config.NETWORK_OUTPUTS:
#                 logging.error(f"Prediction shape mismatch or invalid type. Expected {self.config.NETWORK_OUTPUTS} outputs, got shape {predictions_raw.shape if predictions_raw is not None else 'None'} and type {type(predictions_raw)}.")
#                 return False
#         except AttributeError as ae:
#              logging.error(f"AttributeError during PINN prediction: {ae}", exc_info=True)
#              return False
#         except Exception as e:
#             logging.error(f"Error during PINN prediction: {e}", exc_info=True)
#             return False
#         try:
#             self.u_pred = predictions_raw[:, 0].reshape(ny, nx)
#             self.v_pred = predictions_raw[:, 1].reshape(ny, nx)
#             self.p_prime_pred = predictions_raw[:, 2].reshape(ny, nx)
#             k_raw_pred = predictions_raw[:, 3].reshape(ny, nx); eps_raw_pred = predictions_raw[:, 4].reshape(ny, nx)
#             self.k_pred = np.exp(k_raw_pred) + self.config.EPS_SMALL
#             self.eps_pred = np.exp(eps_raw_pred) + self.config.EPS_SMALL
#             self.p_pred = self.p_prime_pred - (2.0 / 3.0) * self.k_pred
#             eps_safe_pred = np.maximum(self.eps_pred, self.config.EPS_SMALL**2)
#             self.nu_t_pred = self.config.CMU * np.square(self.k_pred) / eps_safe_pred
#             logging.info("PINN field prediction and processing complete.")
#             return True
#         except Exception as proc_e:
#             logging.error(f"Error processing raw predictions: {proc_e}", exc_info=True)
#             return False

#     def plot_contour_fields(self):
#         if self.u_pred is None:
#             logging.warning("PINN data unavailable for plotting. Run predict_pinn_fields first. Skipping contours.")
#             return
#         logging.info("Generating PINN contour plots...")
#         try:
#             fig, axes = plt.subplots(2, 3, figsize=(18, 10))
#             axes = axes.ravel()
#             cmap_vel = self.plotter_config.CMAP_VELOCITY; cmap_p = self.plotter_config.CMAP_PRESSURE; cmap_turb = self.plotter_config.CMAP_TURBULENCE
#             plot_data_list = [
#                 (self.u_pred, 'PINN Streamwise Velocity (u)', 'u (m/s)', cmap_vel),
#                 (self.v_pred, 'PINN Transverse Velocity (v)', 'v (m/s)', cmap_vel),
#                 (self.p_pred, "PINN Kinematic Pressure (p)", r'$p/\rho$ ($m^2/s^2$)', cmap_p),
#                 (self.k_pred, 'PINN TKE (k)', r'$k$ ($m^2/s^2$)', cmap_turb),
#                 (self.eps_pred, 'PINN Dissipation ($\epsilon$)', r'$\epsilon$ ($m^2/s^3$)', cmap_turb),
#                 (self.nu_t_pred / self.config.NU, 'PINN Eddy Viscosity Ratio', r'$\nu_t / \nu$', cmap_turb, True)
#             ]
#             for i, (data, title, label, cmap, *log_flag) in enumerate(plot_data_list):
#                 ax = axes[i]; plot_values = data; cbar_label = label; levels = 50
#                 use_log = log_flag[0] if log_flag else False
#                 if use_log and np.nanmin(data) > self.config.EPS_SMALL:
#                     try:
#                         min_positive = np.nanmin(data[data > self.config.EPS_SMALL*10])
#                         plot_values = np.log10(np.maximum(data, min_positive * 0.01))
#                         cbar_label = f'log10({label})'
#                         levels = np.logspace(np.log10(min_positive*0.01), np.log10(np.nanmax(data)), levels)
#                         logging.debug(f"Using log scale for {title}")
#                     except Exception as log_err:
#                          logging.warning(f"Could not apply log scale for {title}: {log_err}. Using linear scale.")
#                          use_log = False; plot_values = data
#                 try:
#                     if use_log: cf = ax.contourf(self.X_grid, self.Y_grid, plot_values, levels=levels, cmap=cmap, extend='both', locator=plt.LogLocator())
#                     else: cf = ax.contourf(self.X_grid, self.Y_grid, plot_values, levels=levels, cmap=cmap, extend='both')
#                     fig.colorbar(cf, ax=ax, label=cbar_label)
#                     ax.set_title(title); ax.set_xlabel('x (m)'); ax.set_ylabel('y (m)')
#                     ax.set_aspect('equal', adjustable='box')
#                 except ValueError as ve: logging.error(f"ValueError during contour plot for {title}: {ve}")
#                 except Exception as e: logging.error(f"Error plotting contour for {title}: {e}")
#             for j in range(i + 1, len(axes)): fig.delaxes(axes[j])
#             plt.tight_layout()
#             save_path = os.path.join(self.plots_dir, "pinn_field_contours.png")
#             plt.savefig(save_path, dpi=200, bbox_inches='tight')
#             plt.close(fig)
#             logging.info(f"PINN contour field plots saved to {os.path.basename(save_path)}")
#         except Exception as e:
#             logging.error(f"Failed to generate or save contour plots: {e}", exc_info=True)
#             if 'fig' in locals() and plt.fignum_exists(fig.number): plt.close(fig)

#     def _estimate_utau(self, data_source='pinn', x_slice_loc=None):
#         if x_slice_loc is None: x_slice_loc = self.config.L * 0.8
#         h = self.config.CHANNEL_HALF_HEIGHT; nu = self.config.NU; rho = self.config.RHO
#         y_dist_1 = 0.001 * h; y_dist_2 = 0.01 * h
#         y_eval_top_1 = h - y_dist_1; y_eval_top_2 = h - y_dist_2
#         eval_points_top = np.array([[x_slice_loc, y_eval_top_1], [x_slice_loc, y_eval_top_2]])
#         y_eval_bot_1 = -h + y_dist_1; y_eval_bot_2 = -h + y_dist_2
#         eval_points_bot = np.array([[x_slice_loc, y_eval_bot_1], [x_slice_loc, y_eval_bot_2]])
#         u1_top, k1_top, eps1_top, u2_top, k2_top, eps2_top = [None]*6
#         u1_bot, k1_bot, eps1_bot, u2_bot, k2_bot, eps2_bot = [None]*6
#         try:
#             interp_method = 'linear'
#             if data_source == 'pinn':
#                 if self.model is None: return None
#                 eval_points_all = np.vstack((eval_points_top, eval_points_bot))
#                 pred_raw = self.model.predict(eval_points_all)
#                 if pred_raw is None or pred_raw.shape[0] < 4:
#                      logging.error(f"PINN prediction failed or returned insufficient points for u_tau estimate at x={x_slice_loc:.2f}")
#                      return None
#                 u_all = pred_raw[:, 0]; k_raw_all = pred_raw[:, 3]; eps_raw_all = pred_raw[:, 4]
#                 k_all = np.exp(k_raw_all) + self.config.EPS_SMALL; eps_all = np.exp(eps_raw_all) + self.config.EPS_SMALL
#                 u1_top, u2_top, u1_bot, u2_bot = u_all
#                 k1_top, k2_top, k1_bot, k2_bot = k_all
#                 eps1_top, eps2_top, eps1_bot, eps2_bot = eps_all
#             elif data_source == 'reference' and self.has_ref_data:
#                 if self.ref_data is None: return None
#                 points_ref = self.ref_data[['x', 'y']].values
#                 req_cols = ['u_ref', 'k_ref', 'eps_ref']
#                 if not all(col in self.ref_data.columns for col in req_cols):
#                     logging.warning(f"Reference data missing required columns {req_cols} for u_tau estimation from gradient.")
#                     return None
#                 values_to_interp = {col: self.ref_data[col].values for col in req_cols}
#                 interp_results = {}
#                 eval_points_all = np.vstack((eval_points_top, eval_points_bot))
#                 for col, values in values_to_interp.items():
#                     interp_vals = griddata(points_ref, values, eval_points_all, method=interp_method)
#                     nan_mask = np.isnan(interp_vals)
#                     if np.any(nan_mask):
#                         logging.debug(f"Linear interpolation failed for '{col}' ({data_source}) at x={x_slice_loc:.2f}. Trying 'nearest'.")
#                         interp_nearest = griddata(points_ref, values, eval_points_all[nan_mask], method='nearest')
#                         interp_vals[nan_mask] = interp_nearest
#                         if np.any(np.isnan(interp_vals)):
#                              logging.error(f"Interpolation (linear & nearest) failed for '{col}' ({data_source}) at x={x_slice_loc:.2f}. Cannot estimate u_tau.")
#                              return None
#                     interp_results[col] = interp_vals
#                 u1_top, u2_top, u1_bot, u2_bot = interp_results['u_ref']
#                 k1_top, k2_top, k1_bot, k2_bot = interp_results['k_ref']
#                 eps1_top, eps2_top, eps1_bot, eps2_bot = interp_results['eps_ref']
#             else:
#                 logging.warning(f"Invalid data_source '{data_source}' or missing data for u_tau estimation.")
#                 return None
#             du_dy_top = (u2_top - u1_top) / (y_eval_top_2 - y_eval_top_1)
#             k_avg_top = (k1_top + k2_top) / 2.0; eps_avg_top = (eps1_top + eps2_top) / 2.0
#             nu_t_avg_top = self.config.CMU * k_avg_top**2 / max(eps_avg_top, self.config.EPS_SMALL**2)
#             nu_eff_avg_top = nu + nu_t_avg_top
#             tau_w_top = rho * nu_eff_avg_top * abs(du_dy_top)
#             u_tau_top = np.sqrt(max(tau_w_top / rho, self.config.EPS_SMALL))
#             du_dy_bot = (u2_bot - u1_bot) / (y_eval_bot_2 - y_eval_bot_1)
#             k_avg_bot = (k1_bot + k2_bot) / 2.0; eps_avg_bot = (eps1_bot + eps2_bot) / 2.0
#             nu_t_avg_bot = self.config.CMU * k_avg_bot**2 / max(eps_avg_bot, self.config.EPS_SMALL**2)
#             nu_eff_avg_bot = nu + nu_t_avg_bot
#             tau_w_bot = rho * nu_eff_avg_bot * abs(du_dy_bot)
#             u_tau_bot = np.sqrt(max(tau_w_bot / rho, self.config.EPS_SMALL))
#             u_tau_estimated = (u_tau_top + u_tau_bot) / 2.0
#             logging.info(f"Estimated u_tau ({data_source}) at x={x_slice_loc:.2f} m: {u_tau_estimated:.4f} m/s (avg of top/bottom grad estimates)")
#             return u_tau_estimated
#         except Exception as e:
#             logging.error(f"Error estimating u_tau for {data_source} at x={x_slice_loc:.2f}: {e}", exc_info=True)
#             return None

#     def plot_profile_comparison(self):
#         if self.u_pred is None: return
#         if not self.has_ref_data: return
#         logging.info("Generating profile comparison plots...")
#         x_slice_loc = self.config.L * 0.8; ny_pinn = self.plotter_config.NY_PRED
#         y_coords_pinn = self.Y_grid[:, 0]; x_coords_pinn = self.X_grid[0, :]
#         try: x_slice_idx_pinn = np.argmin(np.abs(x_coords_pinn - x_slice_loc)); actual_x_pinn = x_coords_pinn[x_slice_idx_pinn]
#         except IndexError: logging.error("PINN grid coords invalid."); return
#         pinn_slice = {'y': y_coords_pinn, 'u': self.u_pred[:, x_slice_idx_pinn], 'v': self.v_pred[:, x_slice_idx_pinn],
#                       'p': self.p_pred[:, x_slice_idx_pinn], 'k': self.k_pred[:, x_slice_idx_pinn],
#                       'eps': self.eps_pred[:, x_slice_idx_pinn], 'nut': self.nu_t_pred[:, x_slice_idx_pinn]}
#         ref_slice = {'y': y_coords_pinn}; interpolation_successful = False
#         try:
#             if self.ref_data is None: raise ValueError("Ref data is None.")
#             points_ref = self.ref_data[['x', 'y']].values
#             target_points = np.vstack((np.full(ny_pinn, actual_x_pinn), y_coords_pinn)).T
#             logging.info(f"Interpolating ref data at x={actual_x_pinn:.3f}...")
#             variables_to_interpolate = [('u_ref', 'u'), ('v_ref', 'v'), ('p_ref', 'p'), ('k_ref', 'k'), ('eps_ref', 'eps'), ('nut_ref', 'nut')]
#             missing_ref_vars = []
#             for var_ref, var_pinn in variables_to_interpolate:
#                 if var_ref in self.ref_data.columns:
#                     values_ref = self.ref_data[var_ref].values
#                     interp_values = griddata(points_ref, values_ref, target_points, method='linear')
#                     nan_mask = np.isnan(interp_values)
#                     if np.any(nan_mask):
#                         interp_nearest = griddata(points_ref, values_ref, target_points[nan_mask], method='nearest')
#                         interp_values[nan_mask] = interp_nearest
#                         if np.any(np.isnan(interp_values)): logging.warning(f"Interp failed for '{var_ref}'.")
#                     ref_slice[var_pinn] = interp_values
#                 else:
#                     ref_slice[var_pinn] = np.full(ny_pinn, np.nan); missing_ref_vars.append(var_ref)
#             interpolation_successful = True
#             if missing_ref_vars: logging.warning(f"Could not interp ref vars: {missing_ref_vars}")
#         except ValueError as ve: logging.error(f"ValueError during ref interp: {ve}")
#         except Exception as e: logging.error(f"Error interpolating ref data: {e}", exc_info=True); interpolation_successful = False
#         try:
#             fig, axes = plt.subplots(3, 2, figsize=(12, 15)); axes = axes.ravel(); plot_idx = 0; h = self.config.CHANNEL_HALF_HEIGHT
#             plot_vars = [('u', 'Velocity u', 'm/s'), ('v', 'Velocity v', 'm/s'), ('p', 'Kinematic Pressure p', r'$m^2/s^2$'),
#                          ('k', 'TKE k', r'$m^2/s^2$'), ('eps', 'Dissipation eps', r'$m^2/s^3$'), ('nut', 'Eddy Viscosity nu_t', r'$m^2/s$')]
#             for key, name, unit in plot_vars:
#                 if plot_idx >= len(axes): break
#                 ax = axes[plot_idx]
#                 ax.plot(pinn_slice[key], pinn_slice['y'] / h, 'r-', linewidth=2, label='PINN')
#                 if interpolation_successful and key in ref_slice and not np.all(np.isnan(ref_slice[key])):
#                     ax.plot(ref_slice[key], ref_slice['y'] / h, 'b--', linewidth=1.5, label='Reference (CSV)')
#                 elif not interpolation_successful and key != 'y': ax.plot([], [], 'b--', label='Reference (Failed)')
#                 ax.set_xlabel(f'{name} ({unit})'); ax.set_ylabel('y/h'); ax.set_title(f'{name} Profile')
#                 ax.legend(fontsize=8); ax.grid(True, linestyle=':')
#                 if key in ['k', 'eps', 'nut']:
#                      try:
#                          min_val_for_log = self.config.EPS_SMALL; pinn_valid = np.nanmin(pinn_slice[key]) > min_val_for_log; ref_valid = False
#                          if interpolation_successful and key in ref_slice and not np.all(np.isnan(ref_slice[key])): ref_valid = np.nanmin(ref_slice[key]) > min_val_for_log
#                          if pinn_valid and (ref_valid or not interpolation_successful):
#                               ax.set_xscale('log'); ax.grid(True, which='both', linestyle=':')
#                      except ValueError: logging.warning(f"Could not apply log scale for {key}.")
#                      except Exception as log_e: logging.warning(f"Error log scale for {key}: {log_e}")
#                 plot_idx += 1
#             for j in range(plot_idx, len(axes)): fig.delaxes(axes[j])
#             plt.suptitle(f'Profile Comparison at x ≈ {actual_x_pinn:.3f} m', fontsize=16)
#             plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#             save_path = os.path.join(self.plots_dir, "profile_comparison_pinn_vs_csv.png")
#             plt.savefig(save_path, dpi=200, bbox_inches='tight'); plt.close(fig)
#             logging.info(f"Profile comparison plot saved to {os.path.basename(save_path)}")
#         except Exception as e:
#             logging.error(f"Failed to generate profile comparison plot: {e}", exc_info=True)
#             if 'fig' in locals() and plt.fignum_exists(fig.number): plt.close(fig)

#     def plot_wall_unit_comparison(self):
#         if self.u_pred is None: return
#         logging.info("Generating wall unit comparison plots...")
#         x_slice_loc_utau = self.config.L * 0.8
#         self.pinn_data_utau = self._estimate_utau(data_source='pinn', x_slice_loc=x_slice_loc_utau)
#         if self.has_ref_data: self.ref_data_utau = self._estimate_utau(data_source='reference', x_slice_loc=x_slice_loc_utau)
#         else: self.ref_data_utau = None
#         if not self.pinn_data_utau: logging.error("Could not estimate PINN u_tau."); return
#         if self.has_ref_data and not self.ref_data_utau: logging.warning("Could not estimate ref u_tau.")
#         nu = self.config.NU; h = self.config.CHANNEL_HALF_HEIGHT; kappa = self.config.KAPPA; B_const = 5.2
#         x_slice_loc_plot = x_slice_loc_utau
#         y_coords_pinn = self.Y_grid[:, 0]; x_coords_pinn = self.X_grid[0, :]
#         try: x_slice_idx_pinn = np.argmin(np.abs(x_coords_pinn - x_slice_loc_plot)); actual_x_pinn = x_coords_pinn[x_slice_idx_pinn]
#         except IndexError: logging.error("PINN grid coords invalid."); return
#         wall_indices_pinn = y_coords_pinn >= -self.config.EPS_SMALL; y_wall_pinn = y_coords_pinn[wall_indices_pinn]
#         y_dist_wall_pinn = np.maximum(h - y_wall_pinn, self.config.EPS_SMALL * h)
#         u_wall_pinn = self.u_pred[wall_indices_pinn, x_slice_idx_pinn]; k_wall_pinn = self.k_pred[wall_indices_pinn, x_slice_idx_pinn]; eps_wall_pinn = self.eps_pred[wall_indices_pinn, x_slice_idx_pinn]
#         utau_pinn_safe = max(self.pinn_data_utau, self.config.EPS_SMALL)
#         y_plus_pinn = y_dist_wall_pinn * utau_pinn_safe / nu; u_plus_pinn = u_wall_pinn / utau_pinn_safe
#         k_plus_pinn = k_wall_pinn / max(utau_pinn_safe**2, self.config.EPS_SMALL**2); eps_plus_pinn = eps_wall_pinn * nu / max(utau_pinn_safe**4, self.config.EPS_SMALL**4)
#         sort_idx_pinn = np.argsort(y_plus_pinn)
#         y_plus_pinn = y_plus_pinn[sort_idx_pinn]; u_plus_pinn = u_plus_pinn[sort_idx_pinn]; k_plus_pinn = k_plus_pinn[sort_idx_pinn]; eps_plus_pinn = eps_plus_pinn[sort_idx_pinn]
#         y_plus_ref, u_plus_ref, k_plus_ref, eps_plus_ref = None, None, None, None; ref_processed = False
#         if self.has_ref_data and self.ref_data_utau:
#             try:
#                 ref_wall_data = self.ref_data[(np.isclose(self.ref_data['x'], actual_x_pinn, rtol=0.05, atol=0.1*self.config.L)) & (self.ref_data['y'] >= -self.config.EPS_SMALL)].copy()
#                 if not ref_wall_data.empty:
#                     y_wall_ref = ref_wall_data['y'].values; y_dist_wall_ref = np.maximum(h - y_wall_ref, self.config.EPS_SMALL * h)
#                     utau_ref_safe = max(self.ref_data_utau, self.config.EPS_SMALL); y_plus_ref = y_dist_wall_ref * utau_ref_safe / nu
#                     if 'u_ref' in ref_wall_data.columns: u_plus_ref = ref_wall_data['u_ref'].values / utau_ref_safe
#                     if 'k_ref' in ref_wall_data.columns: k_plus_ref = ref_wall_data['k_ref'].values / max(utau_ref_safe**2, self.config.EPS_SMALL**2)
#                     if 'eps_ref' in ref_wall_data.columns: eps_plus_ref = ref_wall_data['eps_ref'].values * nu / max(utau_ref_safe**4, self.config.EPS_SMALL**4)
#                     sort_idx_ref = np.argsort(y_plus_ref); y_plus_ref = y_plus_ref[sort_idx_ref]
#                     if u_plus_ref is not None: u_plus_ref = u_plus_ref[sort_idx_ref]
#                     if k_plus_ref is not None: k_plus_ref = k_plus_ref[sort_idx_ref]
#                     if eps_plus_ref is not None: eps_plus_ref = eps_plus_ref[sort_idx_ref]
#                     logging.info(f"Processed {len(y_plus_ref)} reference points for wall units.")
#                     ref_processed = True
#                 else: logging.warning(f"No ref data found near x={actual_x_pinn:.3f}, y>=0 for wall units.")
#             except KeyError as ke: logging.error(f"Missing column in ref data for wall units: {ke}")
#             except Exception as e: logging.error(f"Error processing ref data for wall units: {e}", exc_info=True)
#         try:
#             fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))
#             y_plus_max_pinn = np.max(y_plus_pinn) if len(y_plus_pinn) > 0 else 100; y_plus_max_ref = np.max(y_plus_ref) if ref_processed and y_plus_ref is not None and len(y_plus_ref) > 0 else y_plus_max_pinn
#             y_plus_max_plot = 1.1 * max(y_plus_max_pinn, y_plus_max_ref, self.config.YP_PLUS_TARGET * 1.5)
#             u_plus_max_pinn = np.max(u_plus_pinn) if len(u_plus_pinn) > 0 else 25; u_plus_max_ref = np.max(u_plus_ref) if ref_processed and u_plus_ref is not None and len(u_plus_ref) > 0 else u_plus_max_pinn
#             u_plus_max_plot = 1.1 * max(u_plus_max_pinn, u_plus_max_ref)
#             k_plus_max_pinn = np.max(k_plus_pinn) if len(k_plus_pinn) > 0 else 5; k_plus_max_ref = np.max(k_plus_ref) if ref_processed and k_plus_ref is not None and len(k_plus_ref) > 0 else k_plus_max_pinn
#             k_plus_max_plot = 1.1 * max(k_plus_max_pinn, k_plus_max_ref)
#             ax = axes[0]; ax.semilogx(y_plus_pinn, u_plus_pinn, 'r.', ms=4, label=f'PINN ($u_\\tau \\approx {self.pinn_data_utau:.3f}$)')
#             if ref_processed and y_plus_ref is not None and u_plus_ref is not None: ax.semilogx(y_plus_ref, u_plus_ref, 'bo', mfc='none', ms=5, label=f'Ref ($u_\\tau \\approx {self.ref_data_utau:.3f}$)' if self.ref_data_utau else 'Ref (u_tau N/A)')
#             y_plus_log_min = 11; y_plus_theory_log = np.logspace(np.log10(max(y_plus_log_min, 1)), np.log10(y_plus_max_plot*1.1), 100)
#             u_plus_loglaw = (1 / kappa) * np.log(y_plus_theory_log) + B_const; y_plus_theory_vis = np.linspace(0.1, 30, 50); u_plus_viscous = y_plus_theory_vis
#             ax.semilogx(y_plus_theory_log, u_plus_loglaw, 'k:', lw=1.5, label=f'Log Law ($\\kappa={kappa}, B={B_const}$)'); ax.semilogx(y_plus_theory_vis, u_plus_viscous, 'k--', lw=1.5, label='Viscous ($U^+=y^+$)')
#             ax.set_xlabel('$y^+$'); ax.set_ylabel('$U^+$'); ax.set_title(f'$U^+$ vs $y^+$ Profile'); ax.legend(fontsize=9); ax.grid(True, which='both', ls=':'); ax.set_ylim(bottom=0, top=u_plus_max_plot); ax.set_xlim(left=0.1, right=y_plus_max_plot)
#             ax = axes[1]; ax.semilogx(y_plus_pinn, k_plus_pinn, 'r.', ms=4, label='PINN')
#             if ref_processed and y_plus_ref is not None and k_plus_ref is not None: ax.semilogx(y_plus_ref, k_plus_ref, 'bo', mfc='none', ms=5, label='Reference')
#             if hasattr(self.config, 'YP_PLUS_TARGET'): ax.axvline(self.config.YP_PLUS_TARGET, color='g', ls='-.', lw=1, label=f'WF $y_p^+ \\approx {self.config.YP_PLUS_TARGET:.1f}$')
#             ax.set_xlabel('$y^+$'); ax.set_ylabel('$k^+$'); ax.set_title('$k^+$ vs $y^+$ Profile'); ax.legend(fontsize=9); ax.grid(True, which='both', ls=':'); ax.set_ylim(bottom=0, top=k_plus_max_plot); ax.set_xlim(left=0.1, right=y_plus_max_plot)
#             ax = axes[2]; valid_idx_pinn = eps_plus_pinn > self.config.EPS_SMALL**2
#             if np.any(valid_idx_pinn): ax.loglog(y_plus_pinn[valid_idx_pinn], eps_plus_pinn[valid_idx_pinn], 'r.', ms=4, label='PINN')
#             else: ax.plot([],[], 'r.', label='PINN (No positive data)')
#             if ref_processed and y_plus_ref is not None and eps_plus_ref is not None:
#                  valid_idx_ref = eps_plus_ref > self.config.EPS_SMALL**2
#                  if np.any(valid_idx_ref): ax.loglog(y_plus_ref[valid_idx_ref], eps_plus_ref[valid_idx_ref], 'bo', mfc='none', ms=5, label='Reference')
#                  else: ax.plot([],[], 'bo', mfc='none', label='Reference (No positive data)')
#             y_plus_theory_eps = np.logspace(np.log10(max(1, 0.1)), np.log10(y_plus_max_plot*1.1), 100); eps_plus_target_theory = 1.0 / (kappa * y_plus_theory_eps)
#             ax.loglog(y_plus_theory_eps, eps_plus_target_theory, 'k:', lw=1.5, label='$\\epsilon^+ \\propto 1/y^+$')
#             if hasattr(self.config, 'YP_PLUS_TARGET'): ax.axvline(self.config.YP_PLUS_TARGET, color='g', ls='-.', lw=1, label=f'WF $y_p^+ \\approx {self.config.YP_PLUS_TARGET:.1f}$')
#             ax.set_xlabel('$y^+$'); ax.set_ylabel('$\\epsilon^+$'); ax.set_title('$\\epsilon^+$ vs $y^+$ Profile (log-log)'); ax.legend(fontsize=9); ax.grid(True, which='both', ls=':')
#             min_eps_plus_data = np.min(eps_plus_pinn[valid_idx_pinn]) if np.any(valid_idx_pinn) else 1e-5; max_eps_plus_data = np.max(eps_plus_pinn[valid_idx_pinn]) if np.any(valid_idx_pinn) else 1
#             if ref_processed and eps_plus_ref is not None and np.any(valid_idx_ref): min_eps_plus_data = min(min_eps_plus_data, np.min(eps_plus_ref[valid_idx_ref])); max_eps_plus_data = max(max_eps_plus_data, np.max(eps_plus_ref[valid_idx_ref]))
#             ax.set_ylim(bottom=max(min_eps_plus_data * 0.1, 1e-6), top=max_eps_plus_data * 10); ax.set_xlim(left=0.1, right=y_plus_max_plot)
#             plt.suptitle(f'Wall Unit Profiles (Top Wall, x ≈ {actual_x_pinn:.3f}m)', fontsize=16)
#             plt.tight_layout(rect=[0, 0.03, 1, 0.93])
#             save_path = os.path.join(self.plots_dir, "profile_comparison_wall_units.png")
#             plt.savefig(save_path, dpi=200, bbox_inches='tight'); plt.close(fig)
#             logging.info(f"Wall unit comparison plots saved to {os.path.basename(save_path)}")
#         except Exception as e:
#             logging.error(f"Failed to generate wall unit comparison plot: {e}", exc_info=True)
#             if 'fig' in locals() and plt.fignum_exists(fig.number): plt.close(fig)

#     def plot_pressure_gradient_comparison(self):
#         if self.p_pred is None: return
#         logging.info("Generating centerline pressure gradient comparison plot...")
#         try:
#             x_coords_pinn = self.X_grid[0, :]; y_coords_pinn = self.Y_grid[:, 0]
#             center_idx_pinn = np.argmin(np.abs(y_coords_pinn - 0.0)); actual_y_center = y_coords_pinn[center_idx_pinn]
#             p_centerline_pinn = self.p_pred[center_idx_pinn, :]; dp_dx_pinn = np.gradient(p_centerline_pinn, x_coords_pinn)
#             dp_dx_ref = None; x_coords_ref = None; ref_grad_calculated = False
#             if self.has_ref_data and 'p_ref' in self.ref_data.columns:
#                 try:
#                     centerline_tol = 0.05 * self.config.CHANNEL_HALF_HEIGHT; ref_centerline = self.ref_data[np.abs(self.ref_data['y']) <= centerline_tol].copy()
#                     if not ref_centerline.empty:
#                         centerline_grouped = ref_centerline.groupby('x')['p_ref'].mean().sort_index()
#                         if len(centerline_grouped) > 5:
#                             x_coords_ref = centerline_grouped.index.values; p_centerline_ref = centerline_grouped.values
#                             if len(x_coords_ref) > 1: dp_dx_ref = np.gradient(p_centerline_ref, x_coords_ref); ref_grad_calculated = True
#                             else: logging.warning("Not enough unique x for ref grad.")
#                         else: logging.warning(f"Not enough grouped x ({len(centerline_grouped)}) for ref grad.")
#                     else: logging.warning("No points near centerline in ref data.")
#                 except KeyError as ke: logging.warning(f"Missing ref col for pressure grad: {ke}")
#                 except Exception as e: logging.warning(f"Could not calc ref pressure grad: {e}")
#             elif self.has_ref_data: logging.warning("Ref data loaded, but 'p_ref' missing.")
#             fig, ax = plt.subplots(figsize=(10, 6))
#             ax.plot(x_coords_pinn, dp_dx_pinn, 'r-', lw=2, label='PINN $dp/dx$')
#             if ref_grad_calculated and dp_dx_ref is not None and x_coords_ref is not None: ax.plot(x_coords_ref, dp_dx_ref, 'b--', lw=1.5, label='Reference $dp/dx$ (CSV)')
#             elif self.has_ref_data: ax.plot([], [], 'b--', label='Reference $dp/dx$ (Failed)')
#             ax.set_xlabel('x / L'); ax.set_ylabel(r'$dp/dx$ $(m/s^2)$'); ax.set_title(f'Streamwise Kinematic Pressure Gradient along Centerline (y ≈ {actual_y_center:.3f}m)')
#             ax.legend(); ax.grid(True, ls=':'); ax.set_xlim(0, 1)
#             try:
#                focus_start_idx = len(dp_dx_pinn) // 2; focus_end_idx = int(len(dp_dx_pinn) * 0.95)
#                if focus_start_idx < focus_end_idx :
#                     focus_region_pinn = dp_dx_pinn[focus_start_idx:focus_end_idx]
#                     if len(focus_region_pinn) > 0:
#                          mean_dpdx = np.mean(focus_region_pinn); std_dpdx = np.std(focus_region_pinn)
#                          pad = 5 * max(std_dpdx, abs(mean_dpdx)*0.1, 1e-4)
#                          ax.set_ylim(mean_dpdx - pad, mean_dpdx + pad)
#             except Exception as ylim_e: logging.warning(f"Could not set y-limits for pressure grad plot: {ylim_e}")
#             ax.set_xticks(np.linspace(0, self.config.L, 6)); ax.set_xticklabels([f"{x/self.config.L:.1f}" for x in np.linspace(0, self.config.L, 6)]); ax.set_xlabel('x / L')
#             plt.tight_layout(); save_path = os.path.join(self.plots_dir, "pressure_gradient_comparison.png")
#             plt.savefig(save_path, dpi=200); plt.close(fig)
#             logging.info(f"Pressure gradient comparison plot saved to {os.path.basename(save_path)}")
#         except Exception as e:
#             logging.error(f"Failed to generate pressure grad plot: {e}", exc_info=True)
#             if 'fig' in locals() and plt.fignum_exists(fig.number): plt.close(fig)

#     def run_post_processing(self):
#         """Runs the full post-processing sequence (plotting only)."""
#         if self.model is None:
#              logging.error("Model object is None. Cannot run post-processing.")
#              return
#         logging.info("--- Starting Full Post-Processing (Plotting Only) ---")
#         self.plot_loss_history() # Will be skipped if history is None
#         prediction_successful = self.predict_pinn_fields()
#         if prediction_successful:
#             self.plot_contour_fields()
#             self.load_reference_data()
#             if self.has_ref_data:
#                 logging.info("Proceeding with PINN vs Reference CSV comparisons...")
#                 self.plot_profile_comparison()
#                 self.plot_wall_unit_comparison()
#                 self.plot_pressure_gradient_comparison()
#             else:
#                 logging.warning("Skipping comparison plots as reference data is unavailable or failed to load.")
#         else:
#             logging.error("PINN field prediction failed. Aborting further post-processing that depends on predictions.")
#         logging.info("--- Post-Processing Finished ---")
# # --- End Plotter Class ---


# # =============================
# # ===== Main Execution Block =====
# # =============================
# if __name__ == "__main__":
#     main_start_time = time.time()

#     # --- 1. Initial Setup ---
#     main_cfg = Config()
#     main_plot_cfg = PlotterConfig()
#     mount_drive(main_cfg.DRIVE_MOUNT_POINT) # Mount drive first
#     setup_output_directories(main_cfg)    # Setup dirs based on potentially updated path
#     setup_logging(main_cfg.LOG_FILE)       # Setup logging to the correct file

#     logging.info("="*60); logging.info(" PINN RANS k-epsilon PLOTTING Start "); logging.info("="*60)
#     log_configuration(main_cfg, main_plot_cfg) # Log config

#     # --- 2. Define Boundaries (Still needed for model structure) ---
#     try:
#         bcs, anchor_points = get_boundary_conditions(main_cfg)
#         logging.info(f"Defined {len(bcs)} boundary conditions for model structure.")
#         if anchor_points is None: anchor_points = [] # Ensure it's iterable
#     except Exception as e:
#         logging.error(f"Failed to define boundary conditions: {e}", exc_info=True)
#         sys.exit(1)

#     # --- 3. Build Model Structure and Load Checkpoint ---
#     model_loaded = None
#     load_successful = False
#     try:
#         # Instantiate Trainer just to use its build_model method
#         builder = Trainer(main_cfg)
#         model_loaded = builder.build_model(bcs, anchor_points)

#         if model_loaded is None:
#              raise RuntimeError("Model building failed.")

#         # --- Find the latest checkpoint ---
#         filepath_base = os.path.join(main_cfg.MODEL_DIR, f"{main_cfg.CHECKPOINT_FILENAME_BASE}-")
#         latest_checkpoint = None
#         restored_step = 0
#         if os.path.exists(main_cfg.MODEL_DIR):
#             filename_pattern = re.compile(rf"^{re.escape(os.path.basename(filepath_base))}(\d+)\.pt$")
#             checkpoint_files = []
#             logging.info(f"Searching for checkpoints in: {main_cfg.MODEL_DIR} with pattern {filename_pattern.pattern}")
#             for f in os.listdir(main_cfg.MODEL_DIR):
#                 full_path = os.path.join(main_cfg.MODEL_DIR, f)
#                 if os.path.isfile(full_path):
#                     match = filename_pattern.match(f)
#                     if match:
#                         step_num = int(match.group(1))
#                         checkpoint_files.append((step_num, full_path))
#             if checkpoint_files:
#                 checkpoint_files.sort(key=lambda item: item[0], reverse=True)
#                 restored_step, latest_checkpoint = checkpoint_files[0]
#                 logging.info(f"Found latest checkpoint to load: {latest_checkpoint} at step {restored_step}")
#             else:
#                 logging.error("No valid checkpoints found matching the pattern. Cannot proceed with plotting.")
#                 sys.exit(1) # Exit if no checkpoint found
#         else:
#              logging.error(f"Model directory not found: {main_cfg.MODEL_DIR}. Cannot load model.")
#              sys.exit(1)

#         restore_path = latest_checkpoint

#         # --- Compile model (Still recommended for DeepXDE structure before prediction) ---
#         # Use a dummy optimizer/lr as it won't be used for training
#         logging.info("Compiling model structure (recommended before prediction)...")
#         # Ensure the loss_weights argument is not passed if it's None in Config
#         compile_args = {"optimizer": "adam", "lr": 1e-4}
#         if main_cfg.LOSS_WEIGHTS is not None:
#              compile_args["loss_weights"] = main_cfg.LOSS_WEIGHTS
#         else:
#              logging.debug("Compiling without loss_weights as it's None in Config.")
#         model_loaded.compile(**compile_args) # Optimizer/LR choice is arbitrary here
#         logging.info("Model compiled.")

#         # ===> KEY CHANGE: Manual Weight Loading <===
#         logging.info(f"Manually loading network weights from: {restore_path}")
#         # Determine device to load onto (CPU or GPU if available)
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         logging.info(f"Loading checkpoint onto device: {device}")
#         # Load the entire checkpoint dictionary
#         checkpoint = torch.load(restore_path, map_location=device)

#         # Check if 'model_state_dict' key exists (standard PyTorch/DeepXDE practice)
#         if 'model_state_dict' in checkpoint:
#             # Access the underlying network (assuming it's stored in model.net)
#             if hasattr(model_loaded, 'net') and isinstance(model_loaded.net, torch.nn.Module):
#                 # Load only the model's parameters
#                 model_loaded.net.load_state_dict(checkpoint['model_state_dict'])
#                 logging.info(f"Model weights restored manually from step {restored_step}.")
#                 load_successful = True
#             else:
#                 logging.error("Model object does not have a 'net' attribute or it's not a torch.nn.Module. Cannot load state dict.")
#                 load_successful = False
#         # Fallback: Check for 'state_dict' (might be used in some cases)
#         elif 'state_dict' in checkpoint:
#              logging.warning("Checkpoint using 'state_dict' key instead of 'model_state_dict'. Attempting load.")
#              if hasattr(model_loaded, 'net') and isinstance(model_loaded.net, torch.nn.Module):
#                 model_loaded.net.load_state_dict(checkpoint['state_dict'])
#                 logging.info(f"Model weights restored manually from step {restored_step} using 'state_dict'.")
#                 load_successful = True
#              else:
#                 logging.error("Model object does not have a 'net' attribute or it's not a torch.nn.Module. Cannot load state dict.")
#                 load_successful = False
#         else:
#             logging.error("Checkpoint file does not contain 'model_state_dict' or 'state_dict'. Cannot load weights.")
#             load_successful = False
#         # ===> END OF KEY CHANGE <===

#     except Exception as e:
#          logging.error(f"A critical error occurred during model build or load: {e}", exc_info=True)
#          load_successful = False

#     # --- 4. Post-processing and Plotting Phase ---
#     if load_successful and model_loaded is not None:
#         logging.info("Proceeding to post-processing (plotting).")
#         try:
#             # Instantiate Plotter with the loaded model
#             # Pass None for history and state as they weren't generated
#             plotter = Plotter(main_cfg, main_plot_cfg, model_loaded, None, None)
#             plotter.run_post_processing()
#         except Exception as e:
#              logging.error(f"An error occurred during post-processing/plotting: {e}", exc_info=True)
#     else:
#         logging.error("Model loading failed or model is invalid. Skipping post-processing.")

#     main_end_time = time.time()
#     logging.info("="*60); logging.info(f" Script Execution (Plotting Only) Finished in {main_end_time - main_start_time:.2f} seconds"); logging.info("="*60)

DeepXDE Backend requested: pytorch
DeepXDE Backend actual: pytorch
CUDA available.
PyTorch CUDA device detected by DDE: 0 (Tesla T4)
PyTorch version: 2.6.0+cu124
Number of GPUs: 1
PyTorch default dtype: torch.float32
2025-04-14 20:50:54 [INFO] Google Drive already mounted.
2025-04-14 20:50:54 [INFO] Output paths point to Google Drive: /content/drive/MyDrive/PINN_RANS_ChannelFlow
2025-04-14 20:50:54 [INFO] Setting up output directories...
2025-04-14 20:50:54 [INFO] Output directories verified/created.
2025-04-14 20:50:54 [INFO] Logging configured.
2025-04-14 20:50:54 [INFO]  PINN RANS k-epsilon PLOTTING Start 
2025-04-14 20:50:54 [INFO] Plotting Configuration:
2025-04-14 20:50:54 [INFO]   Output Directory: /content/drive/MyDrive/PINN_RANS_ChannelFlow
2025-04-14 20:50:54 [INFO]   Model Directory: /content/drive/MyDrive/PINN_RANS_ChannelFlow/model_checkpoints
2025-04-14 20:50:54 [INFO]   Plot Directory: /content/drive/MyDrive/PINN_RANS_ChannelFlow/plots
2025-04-14 20:50:54 [INFO]   Ref Da

  if use_log: cf = ax.contourf(self.X_grid, self.Y_grid, plot_values, levels=levels, cmap=cmap, extend='both', locator=plt.LogLocator())


2025-04-14 20:50:59 [INFO] PINN contour field plots saved to pinn_field_contours.png
2025-04-14 20:50:59 [INFO] Loading reference data from: /content/drive/MyDrive/PINN_RANS_ChannelFlow/data/reference_output_data.csv
2025-04-14 20:50:59 [INFO] Loaded reference data: 414060 rows, 12 cols.
2025-04-14 20:50:59 [INFO] Filtered for latest time/step (Time=1000): 41406 rows remaining.
2025-04-14 20:50:59 [INFO] Identified coordinate columns: x='Points:0', y='Points:1', z='Points:2'
2025-04-14 20:50:59 [INFO] Filtered for z-plane near 0.0 (actual: 0.0000): 20703 rows remaining.
2025-04-14 20:50:59 [INFO] Renamed reference columns based on mapping. New columns: ['TimeStep', 'Time', 'x', 'y', 'z', 'u_ref', 'v_ref', 'U:2', 'eps_ref', 'k_ref', 'nut_ref', 'p_ref']
2025-04-14 20:50:59 [INFO] Successfully loaded and preprocessed reference CSV data. Final columns: ['y', 'u_ref', 'eps_ref', 'p_ref', 'k_ref', 'nut_ref', 'z', 'v_ref', 'x']
2025-04-14 20:50:59 [INFO] Proceeding with PINN vs Reference CSV 