In [None]:
# ============================================================
# BOX 1/3 — Reading CSV snapshots + consistent crop (NO gridding)
# ============================================================
from __future__ import annotations

from flamekit.io_fields import field_path
from flamekit.io_fronts import Case
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.tri as mtri

# ----------------------------
# USER SETTINGS
# ----------------------------
TIME_STEP_START = 200
TIME_STEP_END   = 269

PHI      = 0.40
LAT_SIZE = "025"
POST     = True

BASE_DIR  = Path("../isocontours")
VAR_NAME  = "T"
SORT_COLS = ["x", "y"]
COORD_TOL = 0.0

X_THESHOLD = 300  # keep only x > threshold (consistent)

# ----------------------------
# Helpers
# ----------------------------
def field_csv_path(base_dir: Path, phi: float, lat_size: str, time_step: int, post: bool) -> Path:
    case = Case(
        base_dir=base_dir,
        phi=phi,
        lat_size=lat_size,
        time_step=time_step,
        post=post,
    )
    return field_path(case)

def read_field_sorted(path: Path, var_name: str, sort_cols: list[str]) -> tuple[np.ndarray, np.ndarray]:
    if not path.exists():
        raise FileNotFoundError(f"Missing file:\n  {path}")
    df = pd.read_csv(path)

    missing = [c for c in (sort_cols + [var_name]) if c not in df.columns]
    if missing:
        raise ValueError(f"{path.name}: missing columns {missing}")

    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=sort_cols + [var_name])
    df = df.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)

    coords = df[sort_cols].to_numpy(dtype=np.float64)
    values = df[var_name].to_numpy(dtype=np.float64)
    return coords, values

def coords_match(a: np.ndarray, b: np.ndarray, atol: float) -> bool:
    if a.shape != b.shape:
        return False
    if atol == 0.0:
        return np.array_equal(a, b)
    return np.allclose(a, b, atol=atol, rtol=0.0)

# ----------------------------
# Build X_points: (n_points_cropped, n_snaps)
# ----------------------------
times = list(range(TIME_STEP_START, TIME_STEP_END + 1))

ref_path = field_csv_path(BASE_DIR, PHI, LAT_SIZE, times[0], POST)
coords_ref_full, snap0_full = read_field_sorted(ref_path, VAR_NAME, SORT_COLS)

mask_x = coords_ref_full[:, 0] > X_THESHOLD
coords_ref = coords_ref_full[mask_x]
snap0 = snap0_full[mask_x]

snapshots = [snap0]
for t in times[1:]:
    p = field_csv_path(BASE_DIR, PHI, LAT_SIZE, t, POST)
    coords_t_full, snap_t_full = read_field_sorted(p, VAR_NAME, SORT_COLS)

    if coords_t_full.shape[0] != coords_ref_full.shape[0]:
        raise ValueError(f"Full point count changed at t={t}.")

    if not coords_match(coords_t_full, coords_ref_full, COORD_TOL):
        raise ValueError(f"Full coordinates mismatch at t={t} (sorting/mesh changed).")

    snapshots.append(snap_t_full[mask_x])

X_points = np.stack(snapshots, axis=1).astype(np.float64)  # (N, T)
N, T_total = X_points.shape

print(f"Read {T_total} snapshots with N={N} cropped points (x>{X_THESHOLD}).")


In [None]:
from flamekit.io_fields import field_path
from flamekit.io_fronts import Case
# ============================================================
# BOX 2/3 — Train a local Graph-CNN (kNN message passing) to predict next step
#          NO grid interpolation; neighborhoods are defined on (x,y) points.
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# ----------------------------
# TRAINING SETTINGS
# ----------------------------
DEVICE = "cuda"
SEED = 0
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda" and not torch.cuda.is_available():
    DEVICE = "cpu"
device = torch.device(DEVICE)
print("Using device:", device)

KNN_K = 12                 # number of neighbors per point (excluding self)
LAYERS = 4                 # graph conv layers
WIDTH = 64                 # hidden channels
EPOCHS = 300
BATCH_TIMES = 4            # number of (t -> t+1) pairs per gradient step
STEPS_PER_EPOCH = 50
LR = 2e-3
WEIGHT_DECAY = 1e-8

# If N is large, you may need smaller WIDTH, smaller KNN_K, and BATCH_TIMES=1..2.

# ----------------------------
# Build kNN graph once (CPU), then move indices to GPU
# ----------------------------
from sklearn.neighbors import NearestNeighbors

coords = coords_ref.astype(np.float64)
nbrs = NearestNeighbors(n_neighbors=KNN_K + 1, algorithm="auto").fit(coords)
dist, idx = nbrs.kneighbors(coords)  # idx: (N, K+1), includes self at column 0

# drop self-neighbor
nbr_idx = idx[:, 1:].astype(np.int64)                  # (N, K)
nbr_dist = dist[:, 1:].astype(np.float32)              # (N, K)

# optional distance-based weights (normalized)
w = 1.0 / (nbr_dist + 1e-6)
w = w / (w.sum(axis=1, keepdims=True) + 1e-12)         # (N,K)

nbr_idx_t = torch.from_numpy(nbr_idx).to(device)
w_t = torch.from_numpy(w).to(device)                   # (N,K)

# ----------------------------
# Data tensors (time as batch)
# ----------------------------
# X_points: (N,T). We train on pairs (t -> t+1).
X_seq = X_points.T.astype(np.float32)  # (T,N)

# Normalize per-point over time (helps stability)
mu = X_seq.mean(axis=0, keepdims=True)
sd = X_seq.std(axis=0, keepdims=True) + 1e-6
Xn_seq = (X_seq - mu) / sd             # (T,N)

Xn_t = torch.from_numpy(Xn_seq).to(device)  # (T,N)

# ----------------------------
# Graph-CNN layers (message passing)
# ----------------------------
class GraphConv(nn.Module):
    """
    Simple GraphSAGE-style conv:
      m_i = mean_j (h_j) and max_j(h_j) over kNN
      optionally weighted mean via w
      h'_i = MLP([h_i, mean, max])
    """
    def __init__(self, c_in: int, c_out: int) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(3 * c_in, c_out),
            nn.ReLU(inplace=True),
            nn.Linear(c_out, c_out),
        )
        self.norm = nn.LayerNorm(c_out)

    def forward(self, h: torch.Tensor, nbr_idx: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        # h: (B,N,C)
        # nbr_idx: (N,K)
        # w: (N,K)
        B, N, C = h.shape
        # gather neighbors: (B,N,K,C)
        h_nei = h[:, nbr_idx, :]  # advanced indexing; nbr_idx is (N,K)

        # weighted mean: sum_k w * h_nei
        w_ = w.unsqueeze(0).unsqueeze(-1)  # (1,N,K,1)
        mean = (w_ * h_nei).sum(dim=2)     # (B,N,C)

        mx = h_nei.max(dim=2).values       # (B,N,C)

        x = torch.cat([h, mean, mx], dim=-1)   # (B,N,3C)
        out = self.mlp(x)
        out = self.norm(out)
        return out

class GraphCNNNextStep(nn.Module):
    """
    Input: u(t) as scalar per point.
    Output: u(t+1) scalar per point.
    """
    def __init__(self, width: int, layers: int) -> None:
        super().__init__()
        self.in_lin = nn.Linear(1, width)

        self.convs = nn.ModuleList([GraphConv(width, width) for _ in range(layers)])
        self.out_lin = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(width, 1),
        )

    def forward(self, u: torch.Tensor, nbr_idx: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        # u: (B,N) -> (B,N,1)
        h = self.in_lin(u.unsqueeze(-1))  # (B,N,width)
        for conv in self.convs:
            h_new = conv(h, nbr_idx, w)
            h = h + h_new  # residual
        y = self.out_lin(h).squeeze(-1)   # (B,N)
        return y

model = GraphCNNNextStep(width=WIDTH, layers=LAYERS).to(device)
opt = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
loss_fn = nn.MSELoss()

# ----------------------------
# Train
# ----------------------------
model.train()
for epoch in range(1, EPOCHS + 1):
    losses = []
    for _ in range(STEPS_PER_EPOCH):
        # sample random time indices for training pairs
        t0 = np.random.randint(0, T_total - 1, size=BATCH_TIMES)
        u0 = Xn_t[t0, :]           # (B,N)
        u1 = Xn_t[t0 + 1, :]       # (B,N)

        pred = model(u0, nbr_idx_t, w_t)
        loss = loss_fn(pred, u1)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        losses.append(loss.item())

    if epoch == 1 or epoch % 50 == 0:
        print(f"Epoch {epoch:4d}/{EPOCHS} | train MSE (norm) = {float(np.mean(losses)):.3e}")

# ----------------------------
# Forecast next step from last snapshot
# ----------------------------
model.eval()
t_next = TIME_STEP_END + 1

with torch.no_grad():
    u_last = Xn_t[-1:, :]                               # (1,N) normalized
    u_next_pred_n = model(u_last, nbr_idx_t, w_t)[0]     # (N,) normalized

# Undo normalization back to physical units
u_next_pred = (u_next_pred_n.cpu().numpy() * sd.reshape(-1) + mu.reshape(-1)).astype(np.float64)  # (N,)

# Save prediction
out_dir = field_path(Case(base_dir=BASE_DIR, phi=PHI, lat_size=LAT_SIZE, time_step=0, post=POST)).parent
out_dir.mkdir(parents=True, exist_ok=True)

out = pd.DataFrame(coords_ref, columns=SORT_COLS)
out[f"{VAR_NAME}_pred"] = u_next_pred
out_path = out_dir / f"cnn_graph_pred_{VAR_NAME}_{t_next}_xgt{int(X_THESHOLD)}.csv"
out.to_csv(out_path, index=False)
print("Wrote:", out_path)


In [None]:
# ============================================================
# BOX 3/3 — Plotting: pred/true fields + errors + Ttrue vs Tpred
# ============================================================
path_true = field_csv_path(BASE_DIR, PHI, LAT_SIZE, t_next, POST)

# Triangulation for spatial plots on the original point cloud
x = coords_ref[:, 0].astype(float)
y = coords_ref[:, 1].astype(float)
tri = mtri.Triangulation(x, y)
try:
    analyzer = mtri.TriAnalyzer(tri)
    tri.set_mask(analyzer.get_flat_tri_mask(min_circle_ratio=0.02))
except Exception:
    pass

def tricontour_field(vals: np.ndarray, title: str, cbar_label: str, vmin=None, vmax=None) -> None:
    fig = plt.figure(figsize=(7.2, 5.8))
    ax = fig.add_subplot(111)
    cf = ax.tricontourf(tri, vals, levels=60, vmin=vmin, vmax=vmax)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("x"); ax.set_ylabel("y")
    ax.set_title(title)
    cbar = fig.colorbar(cf, ax=ax, orientation="horizontal", pad=0.08, fraction=0.06)
    cbar.set_label(cbar_label)
    plt.tight_layout()
    plt.show()

if path_true.exists():
    coords_true_full, snap_true_full = read_field_sorted(path_true, VAR_NAME, SORT_COLS)

    if coords_true_full.shape[0] != coords_ref_full.shape[0]:
        raise ValueError("True next-step has different full point count; cannot compare directly.")

    if not coords_match(coords_true_full, coords_ref_full, COORD_TOL):
        raise ValueError("True next-step full coordinates changed; cannot compare directly.")

    u_true = snap_true_full[mask_x].astype(np.float64)

    err = u_next_pred - u_true
    abs_err = np.abs(err)

    rmse = float(np.sqrt(np.mean(err**2)))
    rel_l2 = float(np.linalg.norm(err) / (np.linalg.norm(u_true) + 1e-12))
    print(f"Next-step RMSE: {rmse:.6e}")
    print(f"Next-step relative L2 error: {rel_l2:.6e}")

    # Save error CSV
    out_err = pd.DataFrame(coords_ref, columns=SORT_COLS)
    out_err[f"{VAR_NAME}_true"] = u_true
    out_err[f"{VAR_NAME}_pred"] = u_next_pred
    out_err["err"] = err
    out_err["abs_err"] = abs_err
    err_path = out_dir / f"cnn_graph_err_{VAR_NAME}_{t_next}_xgt{int(X_THESHOLD)}.csv"
    out_err.to_csv(err_path, index=False)
    print("Wrote:", err_path)

    # Shared color scale for true/pred fields
    vmin = float(min(np.percentile(u_true, 0.5), np.percentile(u_next_pred, 0.5)))
    vmax = float(max(np.percentile(u_true, 99.5), np.percentile(u_next_pred, 99.5)))

    tricontour_field(
        u_next_pred,
        title=f"{VAR_NAME} PRED (Graph-CNN) at t={t_next} (x > {X_THESHOLD})",
        cbar_label=f"{VAR_NAME} (pred)",
        vmin=vmin, vmax=vmax,
    )
    tricontour_field(
        u_true,
        title=f"{VAR_NAME} TRUE at t={t_next} (x > {X_THESHOLD})",
        cbar_label=f"{VAR_NAME} (true)",
        vmin=vmin, vmax=vmax,
    )

    # Requested: Ttrue vs Tpred scatter
    plt.figure(figsize=(5, 5))
    plt.scatter(u_true, u_next_pred, s=2)
    lo = float(min(u_true.min(), u_next_pred.min()))
    hi = float(max(u_true.max(), u_next_pred.max()))
    plt.plot([lo, hi], [lo, hi], linestyle="--")
    plt.xlabel(f"{VAR_NAME}_true")
    plt.ylabel(f"{VAR_NAME}_pred")
    plt.title(f"{VAR_NAME}_true vs {VAR_NAME}_pred at t={t_next} (x > {X_THESHOLD})")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    # Error maps
    vmax_signed = float(np.percentile(np.abs(err), 99.0)) + 1e-30
    vmax_abs = float(np.percentile(abs_err, 99.0)) + 1e-30

    tricontour_field(
        err,
        title=f"{VAR_NAME} signed error (pred - true) at t={t_next} (x > {X_THESHOLD})",
        cbar_label="Signed error",
        vmin=-vmax_signed, vmax=vmax_signed,
    )
    tricontour_field(
        abs_err,
        title=f"{VAR_NAME} absolute error |pred - true| at t={t_next} (x > {X_THESHOLD})",
        cbar_label="Absolute error",
        vmin=0.0, vmax=vmax_abs,
    )

else:
    print("True next-step file does not exist; plotting only prediction.")
    tricontour_field(
        u_next_pred,
        title=f"{VAR_NAME} PRED (Graph-CNN) at t={t_next} (x > {X_THESHOLD})",
        cbar_label=f"{VAR_NAME} (pred)",
    )
