
# PINN Demo — Heat Diffusion + Ortho/Para Kinetics

This notebook demonstrates the workflow used in the repository:

1. Generate reference data with a **finite-difference (FD)** solver (heat + kinetics).
2. Train a small **differentiable surrogate** for the equilibrium ortho-fraction $ f_{\rm eq}(T) $ derived from statistical mechanics.
3. Train a **Physics-Informed Neural Network (PINN)** that predicts $T(x,t)$ and $f(x,t)$ and enforces the coupled physics through residual losses.
4. Compare **PINN vs FD** and visualize residual fields.


In [None]:

# Force CPU + silence CUDA warnings (useful on some systems)
import os, warnings
os.environ["CUDA_VISIBLE_DEVICES"] = ""
warnings.filterwarnings("ignore", message=".*CUDA initialization.*")

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy.interpolate import RectBivariateSpline

np.random.seed(42)
torch.manual_seed(42)
device = torch.device("cpu")
print("Using device:", device)



## Physics model and parameters

We solve, in 1D on \(x\in[0,L]\) and \(t\in[0,t_{\max}]\):

- Heat equation with a source from conversion heat,  
\[
T_t = \alpha T_{xx} + S, \qquad S = -\frac{\Delta H}{\rho c_p}\frac{\partial f}{\partial t}
\]

- Local kinetics,  
\[
\frac{\partial f}{\partial t} = -k(T)\big(f - f_{\rm eq}(T)\big), \quad k(T)=k_0 e^{-E_{\rm act}/T}.
\]

Below we use placeholder values for demonstration. Replace with literature values for your application.


In [None]:

# --- Physical parameters (demo values) ---
L = 0.1                 # m
t_max = 60.0            # s
alpha = 1e-5            # m^2/s
rho_cp = 1.0e6          # J/(m^3 K)
DeltaH = 1e5            # J per converted "unit" (toy)
k0 = 1e-3               # s^-1
E_act = 5.0             # dimensionless in this toy Arrhenius

# Boltzmann and rotational temperature (for f_eq reference)
kB = 1.380649e-23
Theta_rot = 85.4

def E_J(J):
    return kB * Theta_rot * J * (J + 1)

def f_ortho_eq_numpy(T, Jmax=40):
    """Equilibrium ortho fraction from rotational partition function (numpy; slow but precise)."""
    T = max(float(T), 1.0)
    Js = np.arange(0, Jmax+1)
    energies = np.exp(-np.array([E_J(J) for J in Js]) / (kB*T + 1e-30))
    g_ns = np.where(Js % 2 == 0, 1, 3)   # nuclear spin degeneracy (para even=1, ortho odd=3)
    degeneracy = (2*Js + 1) * g_ns
    Z = np.sum(degeneracy * energies)
    ortho_sum = np.sum(degeneracy[Js % 2 == 1] * energies[Js % 2 == 1])
    return float(ortho_sum / Z)



## Finite-difference (FD) reference solver

We use an explicit scheme with an automatic adjustment of \(\Delta t\) to satisfy the CFL condition.


In [None]:

def generate_synthetic(L=0.1, Nx=81, t_max=60.0, Nt=301):
    dx = L/(Nx-1)
    x = np.linspace(0, L, Nx)

    max_dt = 0.5 * dx * dx / alpha
    dt_user = t_max/(Nt-1)
    if dt_user > max_dt:
        dt = max_dt
        Nt = int(np.ceil(t_max / dt)) + 1
        dt = t_max / (Nt - 1)
        print(f"[FD] Adjusted Nt to {Nt} and dt to {dt:.3e} for stability (max_dt {max_dt:.3e})")
    else:
        dt = dt_user

    t = np.linspace(0, t_max, Nt)

    # Initial conditions
    T = np.ones(Nx) * 20.0
    f = np.ones(Nx) * 0.75
    # Gaussian bump in T at center
    T += 5.0 * np.exp(-((x - L/2)**2)/(2*(0.01)**2))

    T_all = np.zeros((Nt, Nx)); T_all[0,:] = T.copy()
    f_all = np.zeros((Nt, Nx)); f_all[0,:] = f.copy()

    for n in range(1, Nt):
        T_for = np.maximum(T, 1e-3)
        k_vals = k0 * np.exp(-E_act / (T_for + 1e-12))
        f_eq_vals = np.array([f_ortho_eq_numpy(Ti) for Ti in np.maximum(T, 1.0)])
        dfdt = -k_vals * (f - f_eq_vals)
        f_new = f + dt * dfdt

        S = -DeltaH / rho_cp * dfdt
        T_new = T.copy()
        T_new[1:-1] = T[1:-1] + alpha * dt / dx**2 * (T[2:] - 2*T[1:-1] + T[:-2]) + dt * S[1:-1]
        T_new[0] = 20.0; T_new[-1] = 20.0

        if (np.isnan(T_new).any() or np.isnan(f_new).any() or
            np.isinf(T_new).any() or np.isinf(f_new).any()):
            print(f"[FD] NaN/Inf at step {n} — truncating.")
            T_all = T_all[:n,:]; f_all = f_all[:n,:]; t = t[:n]; break

        T = np.maximum(T_new, 1e-6)
        f = np.clip(f_new, 0.0, 1.0)
        T_all[n,:] = T; f_all[n,:] = f

    print(f"[FD] Generated: Nt={T_all.shape[0]}, Nx={T_all.shape[1]}, T range [{T_all.min():.3e}, {T_all.max():.3e}]")
    return x, t, T_all, f_all

# Run FD once for demo
x_grid, t_grid, T_all, f_all = generate_synthetic(L=L, Nx=81, t_max=t_max, Nt=301)

# Quick look at centerline signals
center_idx = np.argmin(np.abs(x_grid - L/2))
plt.figure()
plt.plot(t_grid, T_all[:, center_idx])
plt.xlabel("t (s)"); plt.ylabel("T (K)"); plt.title("FD: Center temperature evolution")
plt.grid(True); plt.show()

plt.figure()
plt.plot(t_grid, f_all[:, center_idx])
plt.xlabel("t (s)"); plt.ylabel("f"); plt.title("FD: Center ortho-fraction evolution")
plt.grid(True); plt.show()



## Differentiable surrogate for \( f_{\rm eq}(T) \)

We train a small neural network on a dense temperature grid to reproduce the equilibrium ortho-fraction curve. The surrogate is **frozen** during PINN training.


In [None]:

class FEqSurrogate(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 64), nn.Tanh(),
            nn.Linear(64, 64), nn.Tanh(),
            nn.Linear(64, 1)
        )
    def forward(self, T):
        return self.net(T)

def build_feq_surrogate():
    T_grid = np.linspace(1.0, 300.0, 2000).astype(np.float32)
    f_grid = np.array([f_ortho_eq_numpy(Ti) for Ti in T_grid], dtype=np.float32)
    X = torch.tensor(T_grid.reshape(-1,1), dtype=torch.float32, device=device)
    Y = torch.tensor(f_grid.reshape(-1,1), dtype=torch.float32, device=device)
    sur = FEqSurrogate().to(device)
    opt = torch.optim.Adam(sur.parameters(), lr=1e-3)
    for i in range(1200):
        opt.zero_grad()
        y = sur(X)
        loss = ((y - Y)**2).mean()
        loss.backward(); opt.step()
        if i % 300 == 0:
            print(f"[surrogate] iter {i:4d} loss {loss.item():.3e}")
    for p in sur.parameters(): p.requires_grad = False
    return sur

sur = build_feq_surrogate()

def f_eq_torch(T_tensor):
    """T_tensor: (N,1) -> (N,1) using frozen surrogate."""
    return sur(T_tensor)



## PINN model

One network predicts both fields \(T(x,t)\) and \(f(x,t)\). We normalize inputs to \([-1,1]\) and use the chain rule to compute derivatives in physical coordinates.


In [None]:

class PINN(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.net = nn.ModuleList([nn.Linear(layers[i], layers[i+1]) for i in range(len(layers)-1)])
        self.act = nn.Tanh()
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x):
        y = x
        for layer in self.net[:-1]:
            y = self.act(layer(y))
        return self.net[-1](y)   # [T, f]

# Normalization + chain rule factors
x_min, x_max = 0.0, L
t_min, t_max_local = 0.0, t_max
sx = 2.0 / (x_max - x_min)         # dx_norm/dx
st = 2.0 / (t_max_local - t_min)   # dt_norm/dt

def normalize_X_np(X):
    Xn = X.copy()
    Xn[:,0] = 2*(X[:,0] - x_min)/(x_max - x_min) - 1.0
    Xn[:,1] = 2*(X[:,1] - t_min)/(t_max_local - t_min) - 1.0
    return Xn



## Train the PINN (short demo)

We use moderate settings so the demo runs quickly. For the repository script you can increase the model size, number of epochs, and enable LBFGS refinement.


In [None]:

import time

def train_pinn_demo(num_epochs=600, collocation_N=1800, ic_N=200, bc_N=200, data_N=250):
    # Weights (emphasize data modestly)
    W_RESID, W_IC, W_BC, W_DATA = 1.0, 10.0, 10.0, 10.0

    # Build data arrays
    Xmesh, Ymesh = np.meshgrid(x_grid, t_grid)
    XT = np.vstack([Xmesh.ravel(), Ymesh.ravel()]).T
    T_flat = T_all.ravel(); f_flat = f_all.ravel()

    # Sparse measurements
    rand_idx = np.random.choice(XT.shape[0], size=data_N, replace=False)
    X_data = XT[rand_idx]; T_data = T_flat[rand_idx]; f_data = f_flat[rand_idx]

    # Collocation (uniform)
    x_coll = np.random.rand(collocation_N)*L
    t_coll = np.random.rand(collocation_N)*t_max
    X_coll = np.vstack([x_coll, t_coll]).T

    # Focused collocation near center/early times
    N_focus = collocation_N // 3
    x_focus = (L/2) + 0.015*np.random.randn(N_focus)
    t_focus = (0.3*t_max)*np.random.rand(N_focus)
    X_focus = np.vstack([np.clip(x_focus, 0.0, L), np.clip(t_focus, 0.0, t_max)]).T
    X_coll = np.vstack([X_coll, X_focus])

    # IC & BC
    x_ic = np.random.rand(ic_N)*L; t_ic = np.zeros(ic_N)
    X_ic = np.vstack([x_ic, t_ic]).T
    T_ic = np.interp(x_ic, x_grid, T_all[0,:])
    f_ic = np.interp(x_ic, x_grid, f_all[0,:])

    t_bc = np.random.rand(bc_N)*t_max
    X_bc = np.vstack([np.concatenate([np.zeros(bc_N), np.ones(bc_N)*L]),
                      np.concatenate([t_bc, t_bc])]).T
    T_bc_vals = np.full(2*bc_N, 20.0)
    f_bc_vals = np.concatenate([np.interp(t_bc, t_grid, f_all[:,0]),
                                np.interp(t_bc, t_grid, f_all[:,-1])])

    # Normalize coords
    X_coll_n = normalize_X_np(X_coll)
    X_ic_n = normalize_X_np(X_ic)
    X_bc_n = normalize_X_np(X_bc)
    X_data_n = normalize_X_np(X_data)

    # Torch tensors
    def to_t(arr, req=False):
        t = torch.tensor(arr, dtype=torch.float32, device=device)
        t.requires_grad = req
        return t

    X_coll_t = to_t(X_coll_n, req=True)
    X_ic_t   = to_t(X_ic_n,   req=True)
    X_bc_t   = to_t(X_bc_n,   req=True)
    X_data_t = to_t(X_data_n, req=False)

    T_ic_t = torch.tensor(T_ic, dtype=torch.float32, device=device)[:,None]
    f_ic_t = torch.tensor(f_ic, dtype=torch.float32, device=device)[:,None]
    T_bc_t = torch.tensor(T_bc_vals, dtype=torch.float32, device=device)[:,None]
    f_bc_t = torch.tensor(f_bc_vals, dtype=torch.float32, device=device)[:,None]
    T_data_t = torch.tensor(T_data, dtype=torch.float32, device=device)[:,None]
    f_data_t = torch.tensor(f_data, dtype=torch.float32, device=device)[:,None]

    # Model (moderate size for demo)
    layers = [2, 64, 64, 64, 2]
    model = PINN(layers).to(device)
    with torch.no_grad():
        final = model.net[-1]
        if isinstance(final, nn.Linear):
            final.bias[0].fill_(20.0); final.bias[1].fill_(0.75)

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    mse = nn.MSELoss()

    t0 = time.time()
    for epoch in range(num_epochs):
        model.train(); opt.zero_grad()

        pred_c = model(X_coll_t); T_c = pred_c[:,0:1]; f_c = pred_c[:,1:2]
        gT = torch.autograd.grad(T_c, X_coll_t, torch.ones_like(T_c), create_graph=True, retain_graph=True)[0]
        T_xn, T_tn = gT[:,0:1], gT[:,1:2]
        T_xxn = torch.autograd.grad(T_xn, X_coll_t, torch.ones_like(T_xn), create_graph=True, retain_graph=True)[0][:,0:1]
        T_t = T_tn*st; T_xx = T_xxn*(sx**2)

        gf = torch.autograd.grad(f_c, X_coll_t, torch.ones_like(f_c), create_graph=True, retain_graph=True)[0]
        f_t = gf[:,1:2]*st

        T_c_clamp = T_c.clamp(min=1e-3, max=1e5)
        f_eq_vals = f_eq_torch(T_c_clamp)
        k_vals = k0 * torch.exp(-E_act / (T_c_clamp + 1e-8))
        S_c = -DeltaH / rho_cp * f_t

        r_T = torch.nan_to_num(T_t - alpha*T_xx - S_c, nan=1e6, posinf=1e6, neginf=-1e6)
        r_f = torch.nan_to_num(f_t + k_vals*(f_c - f_eq_vals), nan=1e6, posinf=1e6, neginf=-1e6)
        loss_r = mse(r_T, torch.zeros_like(r_T)) + mse(r_f, torch.zeros_like(r_f))

        pred_ic = model(X_ic_t)
        loss_ic = mse(pred_ic[:,0:1], T_ic_t) + mse(pred_ic[:,1:2], f_ic_t)
        pred_bc = model(X_bc_t)
        loss_bc = mse(pred_bc[:,0:1], T_bc_t) + mse(pred_bc[:,1:2], f_bc_t)

        pred_d = model(X_data_t)
        loss_data = mse(pred_d[:,0:1], T_data_t) + mse(pred_d[:,1:2], f_data_t)

        loss = W_RESID*loss_r + W_IC*loss_ic + W_BC*loss_bc + W_DATA*loss_data
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        opt.step()

        if epoch % 100 == 0 or epoch == num_epochs-1:
            print(f"Epoch {epoch:4d}: Loss {loss.item():.3e}, resid {loss_r.item():.3e}, ic {loss_ic.item():.3e}, data {loss_data.item():.3e}")

    print(f"[Adam] done in {time.time()-t0:.1f}s")
    return model

model = train_pinn_demo()



## Evaluation: PINN vs FD


In [None]:

# Build visualization grid
Nx_vis = 101; Nt_vis = 101
xv = np.linspace(0, L, Nx_vis); tv = np.linspace(0, t_max, Nt_vis)
X_vis = np.array([[xi, ti] for ti in tv for xi in xv])
X_vis_n = normalize_X_np(X_vis)
X_vis_t = torch.tensor(X_vis_n, dtype=torch.float32, device=device)

with torch.no_grad():
    pred = model(X_vis_t).cpu().numpy()
T_pred = pred[:,0].reshape(Nt_vis, Nx_vis)
f_pred = pred[:,1].reshape(Nt_vis, Nx_vis)

# Interpolate FD to same grid
spline_T = RectBivariateSpline(t_grid, x_grid, T_all)
spline_f = RectBivariateSpline(t_grid, x_grid, f_all)
T_ref = np.array([spline_T(ti, xv) for ti in tv]).reshape(Nt_vis, Nx_vis)
f_ref = np.array([spline_f(ti, xv) for ti in tv]).reshape(Nt_vis, Nx_vis)

rmse_T = np.sqrt(np.mean((T_pred - T_ref)**2))
rmse_f = np.sqrt(np.mean((f_pred - f_ref)**2))
print(f"RMSE T: {rmse_T:.4e} K, RMSE f: {rmse_f:.4e}")

# Centerline traces
center_idx = Nx_vis//2
plt.figure()
plt.plot(tv, T_pred[:, center_idx], label="PINN T center")
T_ref_center = np.interp(tv, t_grid, T_all[:, np.argmin(np.abs(x_grid - xv[center_idx]))])
plt.plot(tv, T_ref_center, "--", label="FD ref T center")
plt.xlabel("t (s)"); plt.ylabel("T (K)"); plt.title("Center temperature evolution"); plt.legend(); plt.grid(True)
plt.show()

plt.figure()
plt.plot(tv, f_pred[:, center_idx], label="PINN f center")
f_ref_center = np.interp(tv, t_grid, f_all[:, np.argmin(np.abs(x_grid - xv[center_idx]))])
plt.plot(tv, f_ref_center, "--", label="FD ref f center")
plt.xlabel("t (s)"); plt.ylabel("f"); plt.title("Center ortho-fraction evolution"); plt.legend(); plt.grid(True)
plt.show()



## PDE residual fields

We visualize the magnitudes of residuals \( r_T = T_t - \alpha T_{xx} - S \) and \( r_f = f_t + k(T)(f-f_{\rm eq}) \) over the space-time grid.


In [None]:

# Compute residuals on the visualization grid
X_vis_t.requires_grad = True
pred_vis = model(X_vis_t)
T_v = pred_vis[:,0:1]
f_v = pred_vis[:,1:2]

gT = torch.autograd.grad(T_v, X_vis_t, torch.ones_like(T_v), create_graph=True)[0]
T_xn, T_tn = gT[:,0:1], gT[:,1:2]
T_xxn = torch.autograd.grad(T_xn, X_vis_t, torch.ones_like(T_xn), create_graph=True)[0][:,0:1]
T_t = T_tn*st; T_xx = T_xxn*(sx**2)

gf = torch.autograd.grad(f_v, X_vis_t, torch.ones_like(f_v), create_graph=True)[0]
f_t = gf[:,1:2]*st

T_vc = T_v.clamp(min=1e-3)
kvals = k0 * torch.exp(-E_act / (T_vc + 1e-8))
f_eq_vals = f_eq_torch(T_vc)
S_v = -DeltaH / rho_cp * f_t

rT = (T_t - alpha*T_xx - S_v).detach().cpu().numpy().reshape(Nt_vis, Nx_vis)
rf = (f_t + kvals*(f_v - f_eq_vals)).detach().cpu().numpy().reshape(Nt_vis, Nx_vis)

# Heatmaps
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.title("abs(r_T)")
plt.imshow(np.abs(rT), extent=[0, L, tv[-1], tv[0]], aspect="auto")
plt.colorbar()
plt.subplot(1,2,2)
plt.title("abs(r_f)")
plt.imshow(np.abs(rf), extent=[0, L, tv[-1], tv[0]], aspect="auto")
plt.colorbar()
plt.tight_layout(); plt.show()



## Save figures for README (optional)

Run this cell to save key plots under `figures/`.


In [None]:

import os
os.makedirs("figures", exist_ok=True)

# Recreate and save the two main plots
plt.figure()
plt.plot(tv, T_pred[:, center_idx], label="PINN T center")
plt.plot(tv, T_ref_center, "--", label="FD ref T center")
plt.xlabel("t (s)"); plt.ylabel("T (K)"); plt.title("Center temperature evolution"); plt.legend(); plt.grid(True)
plt.savefig("figures/center_T_vs_time.png", dpi=150, bbox_inches="tight")
plt.show()

plt.figure()
plt.plot(tv, f_pred[:, center_idx], label="PINN f center")
plt.plot(tv, f_ref_center, "--", label="FD ref f center")
plt.xlabel("t (s)"); plt.ylabel("f"); plt.title("Center ortho-fraction evolution"); plt.legend(); plt.grid(True)
plt.savefig("figures/center_f_vs_time.png", dpi=150, bbox_inches="tight")
plt.show()

# Residuals
plt.figure(figsize=(10,4))
plt.subplot(1,2,1); plt.title("abs(r_T)"); plt.imshow(np.abs(rT), extent=[0, L, tv[-1], tv[0]], aspect="auto"); plt.colorbar()
plt.subplot(1,2,2); plt.title("abs(r_f)"); plt.imshow(np.abs(rf), extent=[0, L, tv[-1], tv[0]], aspect="auto"); plt.colorbar()
plt.tight_layout()
plt.savefig("figures/residuals.png", dpi=150, bbox_inches="tight")
plt.show()

print("Saved figures to ./figures")
