In [None]:
from __future__ import annotations

from flamekit.io_fields import field_path
from flamekit.io_fronts import Case
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.tri as mtri

# Optional MPI (safe in serial)
try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.rank
except Exception:
    comm = None
    rank = 0


# ============================================================
# USER SETTINGS
# ============================================================

TIME_STEP_START = 200
TIME_STEP_END   = 269

PHI      = 0.40
LAT_SIZE = "025"
POST     = True

BASE_DIR  = Path("../data/isocontours")
VAR_NAME  = "T"                 # field to model
SORT_COLS = ["x", "y"]
COORD_TOL = 0.0                 # 0 -> exact match; >0 -> np.allclose atol=COORD_TOL

# Keep only points with x > X_THESHOLD (applied consistently across all timesteps)
X_THESHOLD = 300


# ============================================================
# Koopman-CNN settings
# ============================================================

DEVICE = "cuda"   # "cuda" or "cpu"
DT = 1.0          # index-time step (kept for completeness)

Z_DIM = 16        # latent dim (8..64 typical)
EPOCHS = 600
BATCHES_PER_EPOCH = 50
BATCH_SIZE = 8

ROLLOUT_LEN = 4   # multi-step training horizon (<= n_snaps-1)

LR = 2e-3
WEIGHT_DECAY = 1e-7

# Loss weights
W_RECON = 1.0            # reconstruction of each snapshot
W_PRED  = 1.0            # multi-step prediction in observation space (decoded)
W_LAT   = 0.1            # latent-space multi-step consistency
W_STAB  = 1e-3           # spectral radius penalty

SEED = 0
np.random.seed(SEED)


# ============================================================
# File helpers
# ============================================================

def field_csv_path(base_dir: Path, phi: float, lat_size: str, time_step: int, post: bool) -> Path:
    case = Case(
        base_dir=base_dir,
        phi=phi,
        lat_size=lat_size,
        time_step=time_step,
        post=post,
    )
    return field_path(case)

def read_field_sorted(path: Path, var_name: str, sort_cols: list[str]) -> tuple[np.ndarray, np.ndarray]:
    if not path.exists():
        raise FileNotFoundError(f"Missing file:\n  {path}")
    df = pd.read_csv(path)

    missing = [c for c in (sort_cols + [var_name]) if c not in df.columns]
    if missing:
        raise ValueError(f"{path.name}: missing columns {missing}")

    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=sort_cols + [var_name])
    df = df.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)

    coords = df[sort_cols].to_numpy(dtype=np.float64)
    values = df[var_name].to_numpy(dtype=np.float64)
    return coords, values


# ============================================================
# Build snapshot matrix X (n_points, n_snaps) without interpolation
# ============================================================

times = list(range(TIME_STEP_START, TIME_STEP_END + 1))

ref_path = field_csv_path(BASE_DIR, PHI, LAT_SIZE, times[0], POST)
coords_ref_full, snap0_full = read_field_sorted(ref_path, VAR_NAME, SORT_COLS)

# Apply x-threshold mask ONCE (from reference coords), then apply to all timesteps
x_ref = coords_ref_full[:, 0]
mask_x = x_ref > X_THESHOLD

coords_ref = coords_ref_full[mask_x]
snap0 = snap0_full[mask_x]

n_points = coords_ref.shape[0]
n_snaps = len(times)

if rank == 0:
    print(f"Reference timestep: {times[0]}")
    print(f"X_THESHOLD={X_THESHOLD} -> keeping {n_points}/{coords_ref_full.shape[0]} points")
    print(f"n_points={n_points}, n_snapshots={n_snaps}")
    print(f"Reading: {ref_path}")

snapshots = [snap0]
for t in times[1:]:
    path = field_csv_path(BASE_DIR, PHI, LAT_SIZE, t, POST)
    coords_t_full, snap_t_full = read_field_sorted(path, VAR_NAME, SORT_COLS)

    # Coordinate match check on FULL coords
    if coords_t_full.shape[0] != coords_ref_full.shape[0]:
        raise ValueError(
            f"Inconsistent n_points (full) at timestep {t}: {coords_t_full.shape[0]} vs {coords_ref_full.shape[0]}"
        )

    if COORD_TOL == 0.0:
        same = np.array_equal(coords_t_full, coords_ref_full)
    else:
        same = np.allclose(coords_t_full, coords_ref_full, atol=COORD_TOL, rtol=0.0)

    if not same:
        raise ValueError(
            f"Coordinates mismatch at timestep {t}.\n"
            f"Set COORD_TOL>0 (if only small floating diffs) or interpolate/regrid."
        )

    snap_t = snap_t_full[mask_x]
    if snap_t.shape[0] != n_points:
        raise RuntimeError("Masking produced inconsistent point count. Check X_THESHOLD and sorting.")

    snapshots.append(snap_t)

# X: (n_points, n_snaps)
X = np.stack(snapshots, axis=1).astype(np.float64)


# ============================================================
# Grid inference / gridding utilities (for CNN)
# ============================================================

def infer_structured_grid(coords_xy: np.ndarray) -> tuple[bool, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Try to infer a structured (x,y) grid from scattered coords.
    Returns:
      structured_ok, xs, ys, ix, iy
    where xs are unique sorted x coords, ys unique sorted y coords,
    and ix,iy are integer indices per point such that grid[iy, ix] maps to point value.
    """
    x = coords_xy[:, 0]
    y = coords_xy[:, 1]
    xs = np.unique(x)
    ys = np.unique(y)
    nx = xs.size
    ny = ys.size

    if coords_xy.shape[0] != nx * ny:
        return False, xs, ys, None, None

    # map each point to grid indices
    ix = np.searchsorted(xs, x)
    iy = np.searchsorted(ys, y)

    # check that every (ix,iy) is present exactly once
    lin = iy * nx + ix
    if np.unique(lin).size != lin.size:
        return False, xs, ys, None, None

    return True, xs, ys, ix, iy


def points_to_grid_structured(values: np.ndarray, nx: int, ny: int, ix: np.ndarray, iy: np.ndarray) -> np.ndarray:
    grid = np.empty((ny, nx), dtype=np.float32)
    grid[iy, ix] = values.astype(np.float32)
    return grid


def points_to_grid_triangulated(coords_xy: np.ndarray, values: np.ndarray, xs: np.ndarray, ys: np.ndarray) -> np.ndarray:
    """
    Fallback: interpolate scattered values onto a regular grid defined by xs, ys using matplotlib tri interpolation.
    """
    x = coords_xy[:, 0].astype(np.float64)
    y = coords_xy[:, 1].astype(np.float64)
    tri = mtri.Triangulation(x, y)
    interp = mtri.LinearTriInterpolator(tri, values.astype(np.float64))
    Xg, Yg = np.meshgrid(xs, ys)
    Zg = interp(Xg, Yg)  # masked array
    Z = np.asarray(Zg.filled(np.nan), dtype=np.float32)

    # Fill NaNs (outside convex hull) with nearest-neighbor-ish fallback: use mean as safe default
    if np.isnan(Z).any():
        fill = np.nanmean(Z)
        if not np.isfinite(fill):
            fill = 0.0
        Z = np.nan_to_num(Z, nan=float(fill)).astype(np.float32)
    return Z


def grid_to_points_structured(grid: np.ndarray, ix: np.ndarray, iy: np.ndarray) -> np.ndarray:
    return grid[iy, ix].astype(np.float64)


# ============================================================
# Prepare image sequence for CNN: (T, 1, ny, nx)
# ============================================================

structured_ok, xs, ys, ix, iy = infer_structured_grid(coords_ref)
nx = xs.size
ny = ys.size

if rank == 0:
    print(f"Grid inference: structured_ok={structured_ok}, nx={nx}, ny={ny}, n_points={n_points}")

# Build gridded snapshots
imgs = []
for k in range(n_snaps):
    v = X[:, k]
    if structured_ok:
        g = points_to_grid_structured(v, nx=nx, ny=ny, ix=ix, iy=iy)
    else:
        # regular grid definition for fallback
        # (using inferred unique xs/ys still, but may be non-rectangular coverage)
        g = points_to_grid_triangulated(coords_ref, v, xs, ys)
    imgs.append(g)

# imgs_np: (T, ny, nx)
imgs_np = np.stack(imgs, axis=0).astype(np.float32)

# Normalize per-pixel over time (like your per-DOF normalization, but 2D)
mean_img = imgs_np.mean(axis=0, keepdims=True)                 # (1, ny, nx)
std_img  = imgs_np.std(axis=0, keepdims=True) + 1e-6           # (1, ny, nx)
imgs_n   = (imgs_np - mean_img) / std_img                      # (T, ny, nx)

# Add channel dim: (T, 1, ny, nx)
imgs_n = imgs_n[:, None, :, :]

T_total = imgs_n.shape[0]
if ROLLOUT_LEN >= T_total:
    raise ValueError(f"ROLLOUT_LEN={ROLLOUT_LEN} must be < number of snapshots T={T_total}.")


# ============================================================
# Koopman-CNN (PyTorch)
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(SEED)
if DEVICE == "cuda" and not torch.cuda.is_available():
    DEVICE = "cpu"
device = torch.device(DEVICE)

if rank == 0:
    print("Using device:", device)
    if device.type == "cuda":
        print("GPU:", torch.cuda.get_device_name(0))

X_torch = torch.from_numpy(imgs_n).to(device)  # (T,1,ny,nx)


def spectral_radius(K: torch.Tensor) -> torch.Tensor:
    eigvals = torch.linalg.eigvals(K)
    rho = torch.max(torch.abs(eigvals))
    return rho.real


def sample_batch_indices(T: int, rollout_len: int, batch_size: int) -> list[int]:
    max_start = T - (rollout_len + 1)
    if max_start < 0:
        raise ValueError(f"Not enough snapshots: T={T}, rollout_len={rollout_len}")
    idx = np.random.randint(0, max_start + 1, size=batch_size)
    return idx.tolist()


def make_batch(Xseq: torch.Tensor, idx0: list[int], rollout_len: int) -> torch.Tensor:
    # (B, L+1, 1, ny, nx)
    seqs = [Xseq[i : i + rollout_len + 1] for i in idx0]
    return torch.stack(seqs, dim=0)


class KoopmanCNN(nn.Module):
    """
    Convolutional autoencoder + linear latent dynamics:
      z_{k+1} = z_k K^T
    """
    def __init__(self, z_dim: int) -> None:
        super().__init__()

        # Encoder: downsample + pool to fixed 4x4
        self.enc_conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 4, stride=2, padding=1), nn.ReLU(inplace=True),  # /2
            nn.Conv2d(32, 64, 4, stride=2, padding=1), nn.ReLU(inplace=True),  # /4
            nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.ReLU(inplace=True), # /8
        )
        self.pool = nn.AdaptiveAvgPool2d((4, 4))
        self.enc_fc = nn.Linear(128 * 4 * 4, z_dim)

        # Decoder: z -> 128x4x4 then upsample to ~64x64 then interpolate to (ny,nx)
        self.dec_fc = nn.Linear(z_dim, 128 * 4 * 4)
        self.dec_deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(inplace=True),  # 8x8
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(inplace=True),   # 16x16
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), nn.ReLU(inplace=True),   # 32x32
            nn.ConvTranspose2d(16, 8, 4, stride=2, padding=1), nn.ReLU(inplace=True),    # 64x64
            nn.Conv2d(8, 1, 3, padding=1),
        )

        # Linear Koopman operator in latent space
        self.K = nn.Parameter(torch.eye(z_dim))

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,1,ny,nx)
        h = self.enc_conv(x)
        h = self.pool(h)
        h = h.flatten(1)
        z = self.enc_fc(h)
        return z

    def decode(self, z: torch.Tensor, out_hw: tuple[int, int]) -> torch.Tensor:
        # z: (B,z_dim)
        B = z.shape[0]
        h = self.dec_fc(z).view(B, 128, 4, 4)
        y = self.dec_deconv(h)  # (B,1,~64,~64)
        y = F.interpolate(y, size=out_hw, mode="bilinear", align_corners=False)
        return y

    def step_latent(self, z: torch.Tensor) -> torch.Tensor:
        # z: (B,z_dim) as row vectors
        return z @ self.K.T


model = KoopmanCNN(z_dim=Z_DIM).to(device)
opt = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
mse = nn.MSELoss()

if rank == 0:
    print(f"Training KoopmanCNN: z_dim={Z_DIM}, T={T_total}, rollout_len={ROLLOUT_LEN}, img=({ny},{nx})")


for epoch in range(1, EPOCHS + 1):
    model.train()
    losses = []

    for _ in range(BATCHES_PER_EPOCH):
        idx0 = sample_batch_indices(T_total, ROLLOUT_LEN, batch_size=BATCH_SIZE)
        Xb = make_batch(X_torch, idx0, ROLLOUT_LEN)  # (B,L+1,1,ny,nx)
        B, Lp1, C, H, W = Xb.shape
        L = Lp1 - 1

        # Encode all frames in batch
        Xflat = Xb.reshape(B * (L + 1), 1, H, W)
        z_true = model.encode(Xflat).reshape(B, L + 1, Z_DIM)

        # Reconstruction
        X_recon = model.decode(z_true.reshape(B * (L + 1), Z_DIM), out_hw=(H, W)).reshape(B, L + 1, 1, H, W)
        loss_recon = mse(X_recon, Xb)

        # Latent rollout from z0
        z_list = [z_true[:, 0, :]]
        for _k in range(L):
            z_list.append(model.step_latent(z_list[-1]))
        z_pred = torch.stack(z_list, dim=1)  # (B,L+1,Z_DIM)

        # Decode predicted sequence
        X_pred = model.decode(z_pred.reshape(B * (L + 1), Z_DIM), out_hw=(H, W)).reshape(B, L + 1, 1, H, W)

        # Prediction loss (exclude k=0)
        loss_pred = mse(X_pred[:, 1:, ...], Xb[:, 1:, ...])

        # Latent consistency (exclude k=0)
        loss_lat = mse(z_pred[:, 1:, :], z_true[:, 1:, :])

        # Stability penalty: penalize rho(K) > 1
        rho = spectral_radius(model.K)
        loss_stab = torch.relu(rho - 1.0) ** 2

        loss = W_RECON * loss_recon + W_PRED * loss_pred + W_LAT * loss_lat + W_STAB * loss_stab

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        losses.append(loss.item())

    if rank == 0 and (epoch == 1 or epoch % 50 == 0):
        avg = float(np.mean(losses)) if losses else float("nan")
        with torch.no_grad():
            rho_val = float(spectral_radius(model.K).cpu().item())
        print(f"Epoch {epoch:5d}/{EPOCHS} | loss={avg:.6e} | rho(K)={rho_val:.6f}")


# ============================================================
# One-step forecast: predict t_next = last timestep + 1
# ============================================================

model.eval()
t_next = times[-1] + 1

with torch.no_grad():
    x_last = X_torch[-1].unsqueeze(0)             # (1,1,ny,nx) normalized
    z_last = model.encode(x_last)                 # (1,z_dim)
    z_next = model.step_latent(z_last)            # (1,z_dim)
    x_next_n = model.decode(z_next, out_hw=(ny, nx))[0, 0].cpu().numpy()  # (ny,nx), normalized

# Undo normalization back to physical values
x_next_grid = (x_next_n * std_img[0] + mean_img[0]).astype(np.float64)   # (ny,nx)

# Convert grid back to per-point ordering (coords_ref)
if structured_ok:
    x_next_pred = grid_to_points_structured(x_next_grid, ix=ix, iy=iy)   # (n_points,)
else:
    # fallback (not structured): sample via triangulation grid mapping is ambiguous
    # We use nearest on the xs/ys grid indices as a pragmatic fallback.
    # (If you end up here often, you should regrid your data to a structured mesh upstream.)
    xi = np.searchsorted(xs, coords_ref[:, 0].astype(np.float64))
    yi = np.searchsorted(ys, coords_ref[:, 1].astype(np.float64))
    xi = np.clip(xi, 0, nx - 1)
    yi = np.clip(yi, 0, ny - 1)
    x_next_pred = x_next_grid[yi, xi].astype(np.float64)

# Save prediction to CSV
out_dir = field_path(Case(base_dir=BASE_DIR, phi=PHI, lat_size=LAT_SIZE, time_step=0, post=POST)).parent
out_dir.mkdir(parents=True, exist_ok=True)

out = pd.DataFrame(coords_ref, columns=SORT_COLS)
out[f"{VAR_NAME}_pred"] = x_next_pred
out_path = out_dir / f"koopmancnn_pred_{VAR_NAME}_{t_next}_xgt{int(X_THESHOLD)}.csv"
out.to_csv(out_path, index=False)

if rank == 0:
    print("Wrote:", out_path)


# ============================================================
# Compare to true next timestep if present + plots + XY error maps
# ============================================================

path_true = field_csv_path(BASE_DIR, PHI, LAT_SIZE, t_next, POST)

# Build triangulation once for plotting on scattered coords
x_sc = coords_ref[:, 0].astype(float)
y_sc = coords_ref[:, 1].astype(float)
tri_sc = mtri.Triangulation(x_sc, y_sc)
try:
    analyzer = mtri.TriAnalyzer(tri_sc)
    mask = analyzer.get_flat_tri_mask(min_circle_ratio=0.02)
    tri_sc.set_mask(mask)
except Exception:
    pass

def tricontour_field(vals_points: np.ndarray, title: str, cbar_label: str, vmin=None, vmax=None) -> None:
    fig = plt.figure(figsize=(7.2, 5.8))
    ax = fig.add_subplot(111)
    cf = ax.tricontourf(tri_sc, vals_points, levels=60, vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("x"); ax.set_ylabel("y")
    ax.set_title(title)
    cbar = fig.colorbar(cf, ax=ax, orientation="horizontal", pad=0.08, fraction=0.06)
    cbar.set_label(cbar_label)
    plt.tight_layout()
    plt.show()

if path_true.exists():
    coords_true_full, snap_true_full = read_field_sorted(path_true, VAR_NAME, SORT_COLS)

    # coordinate match check to reference FULL coords
    if coords_true_full.shape[0] != coords_ref_full.shape[0]:
        raise ValueError(
            f"True next-step has different full point count: {coords_true_full.shape[0]} vs {coords_ref_full.shape[0]}"
        )

    if COORD_TOL == 0.0:
        same = np.array_equal(coords_true_full, coords_ref_full)
    else:
        same = np.allclose(coords_true_full, coords_ref_full, atol=COORD_TOL, rtol=0.0)
    if not same:
        raise ValueError("True next-step coordinates do not match reference coordinates.")

    snap_true = snap_true_full[mask_x].astype(np.float64)

    err = x_next_pred - snap_true
    abs_err = np.abs(err)

    rmse = float(np.sqrt(np.mean(err**2)))
    rel_l2 = float(np.linalg.norm(err) / (np.linalg.norm(snap_true) + 1e-12))

    if rank == 0:
        print(f"Next-step compare at t={t_next}: RMSE={rmse:.6e}, relL2={rel_l2:.6e}")

    # Save error CSV (cropped)
    out_err = pd.DataFrame(coords_ref, columns=SORT_COLS)
    out_err[f"{VAR_NAME}_true"] = snap_true
    out_err[f"{VAR_NAME}_pred"] = x_next_pred
    out_err["err"] = err
    out_err["abs_err"] = abs_err
    err_path = out_dir / f"koopmancnn_err_{VAR_NAME}_{t_next}_xgt{int(X_THESHOLD)}.csv"
    out_err.to_csv(err_path, index=False)
    if rank == 0:
        print("Wrote:", err_path)

    # Shared color scale for pred/true fields
    vmin = float(min(np.percentile(snap_true, 0.5), np.percentile(x_next_pred, 0.5)))
    vmax = float(max(np.percentile(snap_true, 99.5), np.percentile(x_next_pred, 99.5)))

    tricontour_field(
        x_next_pred,
        title=f"{VAR_NAME} PRED (Koopman-CNN) at t={t_next} (x > {X_THESHOLD})",
        cbar_label=f"{VAR_NAME} (pred)",
        vmin=vmin, vmax=vmax,
    )
    tricontour_field(
        snap_true,
        title=f"{VAR_NAME} TRUE at t={t_next} (x > {X_THESHOLD})",
        cbar_label=f"{VAR_NAME} (true)",
        vmin=vmin, vmax=vmax,
    )

    # Side-by-side pred vs true (shared scale)
    fig, axes = plt.subplots(1, 2, figsize=(13.5, 5.6), dpi=110)
    cf0 = axes[0].tricontourf(tri_sc, x_next_pred, levels=60, vmin=vmin, vmax=vmax)
    axes[0].set_aspect("equal", adjustable="box")
    axes[0].set_xlabel("x"); axes[0].set_ylabel("y")
    axes[0].set_title(f"{VAR_NAME} PRED (t={t_next})")

    cf1 = axes[1].tricontourf(tri_sc, snap_true, levels=60, vmin=vmin, vmax=vmax)
    axes[1].set_aspect("equal", adjustable="box")
    axes[1].set_xlabel("x"); axes[1].set_ylabel("y")
    axes[1].set_title(f"{VAR_NAME} TRUE (t={t_next})")

    cbar = fig.colorbar(cf1, ax=axes.ravel().tolist(), orientation="horizontal", pad=0.10, fraction=0.06)
    cbar.set_label(f"{VAR_NAME} (shared scale)")
    plt.tight_layout()
    plt.show()

    # True vs pred scatter
    plt.figure(figsize=(5, 5))
    plt.scatter(snap_true, x_next_pred, s=2)
    lo = float(min(snap_true.min(), x_next_pred.min()))
    hi = float(max(snap_true.max(), x_next_pred.max()))
    plt.plot([lo, hi], [lo, hi], linestyle="--")
    plt.xlabel("True")
    plt.ylabel("Predicted (Koopman-CNN)")
    plt.title(f"{VAR_NAME}: true vs predicted at t={t_next} (x > {X_THESHOLD})")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    # Error maps
    vmax_signed = float(np.percentile(np.abs(err), 99.0)) + 1e-30
    vmax_abs = float(np.percentile(abs_err, 99.0)) + 1e-30

    tricontour_field(
        err,
        title=f"{VAR_NAME} signed error (pred - true) at t={t_next} (x > {X_THESHOLD})",
        cbar_label="Signed error",
        vmin=-vmax_signed, vmax=vmax_signed,
    )
    tricontour_field(
        abs_err,
        title=f"{VAR_NAME} absolute error |pred - true| at t={t_next} (x > {X_THESHOLD})",
        cbar_label="Absolute error",
        vmin=0.0, vmax=vmax_abs,
    )

    # Worst-points overlay
    k = 300
    idx_worst = np.argsort(abs_err)[-k:]
    fig = plt.figure(figsize=(7.2, 5.8))
    ax = fig.add_subplot(111)
    ax.scatter(x_sc, y_sc, s=1)
    ax.scatter(x_sc[idx_worst], y_sc[idx_worst], s=8)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"Top-{k} worst points by |error| at t={t_next} (x > {X_THESHOLD})")
    plt.tight_layout()
    plt.show()

else:
    if rank == 0:
        print("True next-step file does not exist; only the forecast was saved.")

    # Still plot predicted field
    tricontour_field(
        x_next_pred,
        title=f"{VAR_NAME} PRED (Koopman-CNN) at t={t_next} (x > {X_THESHOLD})",
        cbar_label=f"{VAR_NAME} (pred)",
    )


Reference timestep: 200
X_THESHOLD=300 -> keeping 255744/839680 points
n_points=255744, n_snapshots=70
Reading: ..\isocontours\phi0.40\h400x025_ref\extracted_field_post_200.csv
Grid inference: structured_ok=False, nx=875, ny=225, n_points=255744
