In [None]:
# ============================================================
# BOX 1/3 — Reading CSV snapshots + consistent crop + (re)gridding for FFT/FNO
# ============================================================
from __future__ import annotations

from flamekit.io_fields import field_path
from flamekit.io_fronts import Case
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.tri as mtri

# ----------------------------
# USER SETTINGS
# ----------------------------
TIME_STEP_START = 200
TIME_STEP_END   = 269

PHI      = 0.40
LAT_SIZE = "025"
POST     = True

BASE_DIR  = Path("../isocontours")
VAR_NAME  = "T"
SORT_COLS = ["x", "y"]
COORD_TOL = 0.0

X_THESHOLD = 300  # keep only x > threshold (consistent)

# FNO requires a uniform grid for FFT.
# If your cropped points are not on a uniform tensor grid, we interpolate to a uniform grid.
GRID_NX = 256
GRID_NY = 128

# ----------------------------
# Helpers (same style as your template)
# ----------------------------
def field_csv_path(base_dir: Path, phi: float, lat_size: str, time_step: int, post: bool) -> Path:
    case = Case(
        base_dir=base_dir,
        phi=phi,
        lat_size=lat_size,
        time_step=time_step,
        post=post,
    )
    return field_path(case)

def read_field_sorted(path: Path, var_name: str, sort_cols: list[str]) -> tuple[np.ndarray, np.ndarray]:
    if not path.exists():
        raise FileNotFoundError(f"Missing file:\n  {path}")
    df = pd.read_csv(path)

    missing = [c for c in (sort_cols + [var_name]) if c not in df.columns]
    if missing:
        raise ValueError(f"{path.name}: missing columns {missing}")

    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=sort_cols + [var_name])
    df = df.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)

    coords = df[sort_cols].to_numpy(dtype=np.float64)
    values = df[var_name].to_numpy(dtype=np.float64)
    return coords, values

def coords_match(a: np.ndarray, b: np.ndarray, atol: float) -> bool:
    if a.shape != b.shape:
        return False
    if atol == 0.0:
        return np.array_equal(a, b)
    return np.allclose(a, b, atol=atol, rtol=0.0)

# ----------------------------
# Build X on CROPPED point list (for later saving/plots)
# ----------------------------
times = list(range(TIME_STEP_START, TIME_STEP_END + 1))

ref_path = field_csv_path(BASE_DIR, PHI, LAT_SIZE, times[0], POST)
coords_ref_full, snap0_full = read_field_sorted(ref_path, VAR_NAME, SORT_COLS)

mask_x = coords_ref_full[:, 0] > X_THESHOLD
coords_ref = coords_ref_full[mask_x]
snap0 = snap0_full[mask_x]

snapshots = [snap0]
for t in times[1:]:
    p = field_csv_path(BASE_DIR, PHI, LAT_SIZE, t, POST)
    coords_t_full, snap_t_full = read_field_sorted(p, VAR_NAME, SORT_COLS)

    if coords_t_full.shape[0] != coords_ref_full.shape[0]:
        raise ValueError(f"Full point count changed at t={t}: {coords_t_full.shape[0]} vs {coords_ref_full.shape[0]}")

    if not coords_match(coords_t_full, coords_ref_full, COORD_TOL):
        raise ValueError(f"Full coordinates mismatch at t={t}. Sorting/interpolation/mesh changed.")

    snapshots.append(snap_t_full[mask_x])

# X_points: (n_points_cropped, n_snaps)
X_points = np.stack(snapshots, axis=1).astype(np.float64)
n_points, n_snaps = X_points.shape
print(f"Read snapshots: n_points_cropped={n_points}, n_snaps={n_snaps}, x>{X_THESHOLD}")

# ----------------------------
# Regrid to uniform (ny,nx) for FFT/FNO
# ----------------------------
# Define uniform grid on bounding box of CROPPED coordinates
x_min, x_max = float(coords_ref[:, 0].min()), float(coords_ref[:, 0].max())
y_min, y_max = float(coords_ref[:, 1].min()), float(coords_ref[:, 1].max())

xs_u = np.linspace(x_min, x_max, GRID_NX, dtype=np.float64)
ys_u = np.linspace(y_min, y_max, GRID_NY, dtype=np.float64)
Xg, Yg = np.meshgrid(xs_u, ys_u)  # shapes (ny,nx)

# Triangulation of cropped points for interpolation
tri = mtri.Triangulation(coords_ref[:, 0], coords_ref[:, 1])

def points_to_uniform_grid(vals_points: np.ndarray) -> np.ndarray:
    """Linear triangulation interpolation onto (ys_u,xs_u) grid."""
    interp = mtri.LinearTriInterpolator(tri, vals_points.astype(np.float64))
    Zm = interp(Xg, Yg)  # masked array
    Z = np.asarray(Zm.filled(np.nan), dtype=np.float32)
    if np.isnan(Z).any():
        fill = np.nanmean(Z)
        if not np.isfinite(fill):
            fill = 0.0
        Z = np.nan_to_num(Z, nan=float(fill)).astype(np.float32)
    return Z  # (ny,nx)

# Build gridded time sequence: U[t] = (ny,nx)
U = np.stack([points_to_uniform_grid(X_points[:, k]) for k in range(n_snaps)], axis=0).astype(np.float32)
# U: (T, ny, nx)
T_total, ny, nx = U.shape
print(f"Uniform grid sequence: U.shape={U.shape} (T,ny,nx)")

# Normalize per-pixel over time (helps training stability)
U_mean = U.mean(axis=0, keepdims=True)            # (1,ny,nx)
U_std  = U.std(axis=0, keepdims=True) + 1e-6      # (1,ny,nx)
U_n = (U - U_mean) / U_std                        # (T,ny,nx)

# Add channel dim for CNN/FNO: (T,1,ny,nx)
U_n = U_n[:, None, :, :].astype(np.float32)

# Convenience: map uniform-grid field back to your cropped point list (for saving/metrics)
def uniform_grid_to_points_bilinear(Z: np.ndarray, coords_xy: np.ndarray) -> np.ndarray:
    """
    Z: (ny,nx) on ys_u/xs_u
    coords_xy: (n_points,2)
    returns values at coords via bilinear interpolation
    """
    x = coords_xy[:, 0].astype(np.float64)
    y = coords_xy[:, 1].astype(np.float64)

    # fractional indices
    ix = np.searchsorted(xs_u, x) - 1
    iy = np.searchsorted(ys_u, y) - 1
    ix = np.clip(ix, 0, nx - 2)
    iy = np.clip(iy, 0, ny - 2)

    x0 = xs_u[ix]
    x1 = xs_u[ix + 1]
    y0 = ys_u[iy]
    y1 = ys_u[iy + 1]

    tx = (x - x0) / (x1 - x0 + 1e-30)
    ty = (y - y0) / (y1 - y0 + 1e-30)

    z00 = Z[iy, ix]
    z10 = Z[iy, ix + 1]
    z01 = Z[iy + 1, ix]
    z11 = Z[iy + 1, ix + 1]

    z0 = (1 - tx) * z00 + tx * z10
    z1 = (1 - tx) * z01 + tx * z11
    return ((1 - ty) * z0 + ty * z1).astype(np.float64)


In [None]:
from flamekit.io_fields import field_path
from flamekit.io_fronts import Case
# ============================================================
# BOX 2/3 — kFNO (paper-style: L -> H -> K* (Koopman-like A iterated) -> Q* -> P*)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# ----------------------------
# kFNO SETTINGS
# ----------------------------
DEVICE = "cuda"
SEED = 0
torch.manual_seed(SEED)
np.random.seed(SEED)

if DEVICE == "cuda" and not torch.cuda.is_available():
    DEVICE = "cpu"
device = torch.device(DEVICE)
print("Using device:", device)

N_PRED = 4              # n-step operator output length (t+1..t+n)
EPOCHS = 800
BATCH_SIZE = 6
BATCHES_PER_EPOCH = 60

LR = 2e-3
WEIGHT_DECAY = 1e-7

WIDTH = 48              # hidden channels
N_H_LAYERS = 4           # layers in H (baseline FNO trunk)
N_Q_LAYERS = 1           # layers in Q*
MODES_X = 20
MODES_Y = 20

ADD_COORDS = True        # common in FNO; helps (still OK with periodic BCs)
ALPHA_SKIP = 1.0         # paper has alpha controlling skip connection in Fourier layer; 0=no skip, 1=full skip :contentReference[oaicite:1]{index=1}

# Koopman-like advancement operator A in hidden space:
# "linear"  -> one Fourier layer w/o nonlinearity
# "nonlinear" -> two Fourier layers with nonlinearity between (more expressive) :contentReference[oaicite:2]{index=2}
A_TYPE = "linear"        # or "nonlinear"

# Loss in normalized space
mse = nn.MSELoss()

U_torch = torch.from_numpy(U_n).to(device)  # (T,1,ny,nx)

if N_PRED >= T_total:
    raise ValueError(f"N_PRED={N_PRED} must be < number of snapshots T={T_total}.")

# ----------------------------
# Core FNO building blocks
# ----------------------------
class SpectralConv2d(nn.Module):
    """
    2D spectral convolution: keep low Fourier modes and apply learned complex weights.
    """
    def __init__(self, in_channels: int, out_channels: int, modes_x: int, modes_y: int) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes_x = modes_x
        self.modes_y = modes_y

        # complex weights for truncated modes
        scale = 1.0 / (in_channels * out_channels)
        self.w = nn.Parameter(scale * torch.randn(in_channels, out_channels, modes_y, modes_x, dtype=torch.cfloat))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,C,ny,nx)
        B, C, ny_, nx_ = x.shape
        x_ft = torch.fft.rfft2(x, norm="ortho")  # (B,C,ny, nx//2+1)

        out_ft = torch.zeros(B, self.out_channels, ny_, nx_ // 2 + 1, dtype=torch.cfloat, device=x.device)

        my = min(self.modes_y, ny_)
        mx = min(self.modes_x, nx_ // 2 + 1)

        # low modes (top-left block)
        out_ft[:, :, :my, :mx] = torch.einsum("bcyx,co yx->boyx", x_ft[:, :, :my, :mx], self.w[:, :, :my, :mx])

        out = torch.fft.irfft2(out_ft, s=(ny_, nx_), norm="ortho")
        return out  # (B,outC,ny,nx)

class FourierLayer(nn.Module):
    """
    Paper-like Fourier layer:
      u_{l+1} = σ( α * W(u_l) + (1-α) * F^{-1}(R(F(u_l))) + b )
    where α controls skip connection contribution. :contentReference[oaicite:3]{index=3}
    """
    def __init__(self, width: int, modes_x: int, modes_y: int, alpha: float, activation: bool = True) -> None:
        super().__init__()
        self.alpha = float(alpha)
        self.spectral = SpectralConv2d(width, width, modes_x, modes_y)
        self.pointwise = nn.Conv2d(width, width, kernel_size=1)
        self.bias = nn.Parameter(torch.zeros(1, width, 1, 1))
        self.activation = activation

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y_spec = self.spectral(x)
        y_lin = self.pointwise(x)
        y = self.alpha * y_lin + (1.0 - self.alpha) * y_spec + self.bias
        if self.activation:
            return F.gelu(y)
        return y

class FNOBlock(nn.Module):
    def __init__(self, width: int, modes_x: int, modes_y: int, n_layers: int, alpha: float, activation_last: bool = True) -> None:
        super().__init__()
        layers = []
        for i in range(n_layers):
            act = True
            if i == n_layers - 1 and not activation_last:
                act = False
            layers.append(FourierLayer(width, modes_x, modes_y, alpha=alpha, activation=act))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

# ----------------------------
# kFNO model
# ----------------------------
class kFNO(nn.Module):
    """
    Outputs n-step predictions in one forward pass:
      input: u(t)  -> outputs: [u(t+1), ..., u(t+n)]
    Architecture (paper):
      L -> H -> K* (iterate A) -> Q* -> P*  :contentReference[oaicite:4]{index=4}
    """
    def __init__(
        self,
        width: int,
        modes_x: int,
        modes_y: int,
        n_h_layers: int,
        n_q_layers: int,
        n_pred: int,
        add_coords: bool,
        alpha_skip: float,
        a_type: str = "linear",
    ) -> None:
        super().__init__()
        self.n_pred = int(n_pred)
        self.add_coords = bool(add_coords)

        in_ch = 1 + (2 if self.add_coords else 0)
        self.lift = nn.Conv2d(in_ch, width, kernel_size=1)

        # H: baseline FNO trunk
        self.H = FNOBlock(width, modes_x, modes_y, n_layers=n_h_layers, alpha=alpha_skip, activation_last=True)

        # A: Koopman-like advancement operator in hidden space (shared across steps) :contentReference[oaicite:5]{index=5}
        if a_type == "linear":
            self.A = FourierLayer(width, modes_x, modes_y, alpha=alpha_skip, activation=False)
        elif a_type == "nonlinear":
            self.A = nn.Sequential(
                FourierLayer(width, modes_x, modes_y, alpha=alpha_skip, activation=True),
                FourierLayer(width, modes_x, modes_y, alpha=alpha_skip, activation=False),
            )
        else:
            raise ValueError("a_type must be 'linear' or 'nonlinear'.")

        # Q*: refinement on advanced hidden states
        self.Q = FNOBlock(width, modes_x, modes_y, n_layers=n_q_layers, alpha=alpha_skip, activation_last=True) if n_q_layers > 0 else nn.Identity()

        # P*: project to physical field
        self.proj = nn.Sequential(
            nn.Conv2d(width, width, 1),
            nn.GELU(),
            nn.Conv2d(width, 1, 1),
        )

    def forward(self, u0: torch.Tensor) -> torch.Tensor:
        """
        u0: (B,1,ny,nx)
        returns: (B,n_pred,1,ny,nx) predictions for steps 1..n_pred
        """
        B, C, ny_, nx_ = u0.shape

        if self.add_coords:
            # normalized coordinate channels in [0,1]
            xs = torch.linspace(0, 1, nx_, device=u0.device).view(1, 1, 1, nx_).expand(B, 1, ny_, nx_)
            ys = torch.linspace(0, 1, ny_, device=u0.device).view(1, 1, ny_, 1).expand(B, 1, ny_, nx_)
            u_in = torch.cat([u0, xs, ys], dim=1)
        else:
            u_in = u0

        e = self.lift(u_in)   # L
        e = self.H(e)         # H

        outs = []
        for _k in range(self.n_pred):
            e = self.A(e)     # K*: advance hidden state iteratively (Koopman-like A) :contentReference[oaicite:6]{index=6}
            ek = self.Q(e)    # Q*: refine
            uk = self.proj(ek)  # P*: project
            outs.append(uk)

        return torch.stack(outs, dim=1)

# ----------------------------
# Training utilities (learn extended n-step operator)
# ----------------------------
def sample_batch_starts(T: int, n_pred: int, batch_size: int) -> np.ndarray:
    # need u(t) and targets u(t+1..t+n_pred)
    max_start = T - (n_pred + 1)
    if max_start < 0:
        raise ValueError("Not enough snapshots for chosen N_PRED.")
    return np.random.randint(0, max_start + 1, size=batch_size)

def rel_l2(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    # relative L2 over all pixels + channels + batch
    num = torch.linalg.norm(a - b)
    den = torch.linalg.norm(b) + eps
    return num / den

# ----------------------------
# Train
# ----------------------------
model = kFNO(
    width=WIDTH,
    modes_x=MODES_X,
    modes_y=MODES_Y,
    n_h_layers=N_H_LAYERS,
    n_q_layers=N_Q_LAYERS,
    n_pred=N_PRED,
    add_coords=ADD_COORDS,
    alpha_skip=ALPHA_SKIP,
    a_type=A_TYPE,
).to(device)

opt = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

print(f"kFNO: WIDTH={WIDTH}, modes=({MODES_Y},{MODES_X}), N_PRED={N_PRED}, A_TYPE={A_TYPE}")

for epoch in range(1, EPOCHS + 1):
    model.train()
    losses = []
    rerrs = []

    for _ in range(BATCHES_PER_EPOCH):
        idx0 = sample_batch_starts(T_total, N_PRED, BATCH_SIZE)

        u0 = U_torch[idx0, :, :, :]                              # (B,1,ny,nx)
        tgt = torch.stack([U_torch[idx0 + j, :, :, :] for j in range(1, N_PRED + 1)], dim=1)  # (B,N,1,ny,nx)

        pred = model(u0)                                         # (B,N,1,ny,nx)

        loss = mse(pred, tgt)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        losses.append(loss.item())
        rerrs.append(float(rel_l2(pred, tgt).detach().cpu().item()))

    if epoch == 1 or epoch % 100 == 0:
        print(f"Epoch {epoch:4d}/{EPOCHS} | MSE={np.mean(losses):.3e} | relL2={np.mean(rerrs):.3e}")

# ----------------------------
# One-step forecast at next timestep (t_next = TIME_STEP_END + 1)
# ----------------------------
model.eval()
t_next = TIME_STEP_END + 1

with torch.no_grad():
    u_last = U_torch[-1:, :, :, :]       # normalized, (1,1,ny,nx)
    pred_seq = model(u_last)             # (1,N_PRED,1,ny,nx)
    u1_pred_n = pred_seq[:, 0, 0, :, :].cpu().numpy()[0]  # first-step prediction on uniform grid, normalized

# Undo normalization to physical T on uniform grid
T_pred_grid = (u1_pred_n * U_std[0] + U_mean[0]).astype(np.float64)  # (ny,nx)

# Map back to cropped point list for saving/metrics
T_pred_points = uniform_grid_to_points_bilinear(T_pred_grid, coords_ref)

# Save prediction alongside coords (cropped)
out_dir = field_path(Case(base_dir=BASE_DIR, phi=PHI, lat_size=LAT_SIZE, time_step=0, post=POST)).parent
out_dir.mkdir(parents=True, exist_ok=True)

out = pd.DataFrame(coords_ref, columns=SORT_COLS)
out[f"{VAR_NAME}_pred"] = T_pred_points
out_path = out_dir / f"kfno_pred_{VAR_NAME}_{t_next}_xgt{int(X_THESHOLD)}.csv"
out.to_csv(out_path, index=False)
print("Wrote:", out_path)


In [None]:
# ============================================================
# BOX 3/3 — Plotting (pred/true fields + error maps + Ttrue vs Tpred)
# ============================================================

# Read true next snapshot (if exists), validate full coords, then crop
path_true = field_csv_path(BASE_DIR, PHI, LAT_SIZE, t_next, POST)

if path_true.exists():
    coords_true_full, snap_true_full = read_field_sorted(path_true, VAR_NAME, SORT_COLS)

    if coords_true_full.shape[0] != coords_ref_full.shape[0]:
        raise ValueError(
            f"True next-step has different full point count: {coords_true_full.shape[0]} vs {coords_ref_full.shape[0]}"
        )
    if not coords_match(coords_true_full, coords_ref_full, COORD_TOL):
        raise ValueError("True next-step full coordinates changed; cannot compare directly.")

    T_true_points = snap_true_full[mask_x].astype(np.float64)

    # Errors on points
    err = T_pred_points - T_true_points
    abs_err = np.abs(err)

    rmse = float(np.sqrt(np.mean(err**2)))
    rel_l2 = float(np.linalg.norm(err) / (np.linalg.norm(T_true_points) + 1e-12))
    print(f"Next-step RMSE: {rmse:.6e}")
    print(f"Next-step relative L2 error: {rel_l2:.6e}")

    # Save error CSV (cropped)
    out_err = pd.DataFrame(coords_ref, columns=SORT_COLS)
    out_err[f"{VAR_NAME}_true"] = T_true_points
    out_err[f"{VAR_NAME}_pred"] = T_pred_points
    out_err["err"] = err
    out_err["abs_err"] = abs_err
    err_path = out_dir / f"kfno_err_{VAR_NAME}_{t_next}_xgt{int(X_THESHOLD)}.csv"
    out_err.to_csv(err_path, index=False)
    print("Wrote:", err_path)

    # Triangulation for XY plots
    x = coords_ref[:, 0].astype(float)
    y = coords_ref[:, 1].astype(float)
    tri_xy = mtri.Triangulation(x, y)
    try:
        analyzer = mtri.TriAnalyzer(tri_xy)
        tri_xy.set_mask(analyzer.get_flat_tri_mask(min_circle_ratio=0.02))
    except Exception:
        pass

    def tricontour_field(vals: np.ndarray, title: str, cbar_label: str, vmin=None, vmax=None) -> None:
        fig = plt.figure(figsize=(7.2, 5.8))
        ax = fig.add_subplot(111)
        cf = ax.tricontourf(tri_xy, vals, levels=60, vmin=vmin, vmax=vmax)
        ax.set_aspect("equal", adjustable="box")
        ax.set_xlabel("x"); ax.set_ylabel("y")
        ax.set_title(title)
        cbar = fig.colorbar(cf, ax=ax, orientation="horizontal", pad=0.08, fraction=0.06)
        cbar.set_label(cbar_label)
        plt.tight_layout()
        plt.show()

    # Shared scale for true/pred fields
    vmin = float(min(np.percentile(T_true_points, 0.5), np.percentile(T_pred_points, 0.5)))
    vmax = float(max(np.percentile(T_true_points, 99.5), np.percentile(T_pred_points, 99.5)))

    tricontour_field(
        T_pred_points,
        title=f"{VAR_NAME} PRED (kFNO) at t={t_next} (x > {X_THESHOLD})",
        cbar_label=f"{VAR_NAME} (pred)",
        vmin=vmin, vmax=vmax,
    )
    tricontour_field(
        T_true_points,
        title=f"{VAR_NAME} TRUE at t={t_next} (x > {X_THESHOLD})",
        cbar_label=f"{VAR_NAME} (true)",
        vmin=vmin, vmax=vmax,
    )

    # Requested: Ttrue vs Tpred scatter
    plt.figure(figsize=(5, 5))
    plt.scatter(T_true_points, T_pred_points, s=2)
    lo = float(min(T_true_points.min(), T_pred_points.min()))
    hi = float(max(T_true_points.max(), T_pred_points.max()))
    plt.plot([lo, hi], [lo, hi], linestyle="--")
    plt.xlabel(f"{VAR_NAME}_true")
    plt.ylabel(f"{VAR_NAME}_pred")
    plt.title(f"{VAR_NAME}_true vs {VAR_NAME}_pred at t={t_next} (x > {X_THESHOLD})")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    # Error maps
    vmax_signed = float(np.percentile(np.abs(err), 99.0)) + 1e-30
    vmax_abs = float(np.percentile(abs_err, 99.0)) + 1e-30

    tricontour_field(
        err,
        title=f"{VAR_NAME} signed error (pred - true) at t={t_next} (x > {X_THESHOLD})",
        cbar_label="Signed error",
        vmin=-vmax_signed, vmax=vmax_signed,
    )
    tricontour_field(
        abs_err,
        title=f"{VAR_NAME} absolute error |pred - true| at t={t_next} (x > {X_THESHOLD})",
        cbar_label="Absolute error",
        vmin=0.0, vmax=vmax_abs,
    )

else:
    print("True next-step file does not exist; plotting only prediction.")
    # Plot prediction anyway
    x = coords_ref[:, 0].astype(float)
    y = coords_ref[:, 1].astype(float)
    tri_xy = mtri.Triangulation(x, y)
    fig = plt.figure(figsize=(7.2, 5.8))
    ax = fig.add_subplot(111)
    cf = ax.tricontourf(tri_xy, T_pred_points, levels=60)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("x"); ax.set_ylabel("y")
    ax.set_title(f"{VAR_NAME} PRED (kFNO) at t={t_next} (x > {X_THESHOLD})")
    cbar = fig.colorbar(cf, ax=ax, orientation="horizontal", pad=0.08, fraction=0.06)
    cbar.set_label(f"{VAR_NAME} (pred)")
    plt.tight_layout()
    plt.show()
