## Solution of the Heat Equation on the Line using a PINN-PCE Method

We solve the following heat equation:

$$
u_t - \frac{a}{k^2} u_{xx} = 0
$$

with initial and boundary conditions:

$$
u(0, x) = \sin(k x) \\
u(t, 0) = 0 \\
u(t, \pi) = e^{-a t} \sin(\pi k)
$$


In [None]:
%matplotlib widget

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from scipy.ndimage import gaussian_filter
from torch import tensor

from examples.equations.heat_1d.pde import reference_solution
from measure_uq.models import PINN, PINN_PCE
from measure_uq.pde import PDE
from measure_uq.trainers.trainer import Trainer
from measure_uq.utilities import cartesian_product_of_rows, to_numpy

mpl.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "STIX",  # or 'Times New Roman' if available
        "mathtext.fontset": "stix",
        "axes.labelsize": 16,
        "font.size": 14,
        "legend.fontsize": 14,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
    }
)


model_type_pinn_pce = PINN_PCE
model_path_pinn_pce = "data/best_model_pinn_pce.pickle"
pde_path_pinn_pce = "data/pde_pinn_pce.pickle"
trainer_path_pinn_pce = "data/trainer_pinn_pce.pickle"

model_type_pinn = PINN
model_path_pinn = "data/best_model_pinn.pickle"
pde_path_pinn = "data/pde_pinn.pickle"
trainer_path_pinn = "data/trainer_pinn.pickle"

model_pinn_pce = model_type_pinn_pce.load(model_path_pinn_pce)
pde_pinn_pce = PDE.load(pde_path_pinn_pce)
trainer_pinn_pce = Trainer.load(trainer_path_pinn_pce)

model_pinn = model_type_pinn.load(model_path_pinn)
pde_pinn = PDE.load(pde_path_pinn)
trainer_pinn = Trainer.load(trainer_path_pinn)

T = pde_pinn_pce.conditions_train.conditions[0].T
X = pde_pinn_pce.conditions_train.conditions[0].X

t = tensor(np.linspace(0, T, 101)[:, None])
x = tensor(np.linspace(0, X, 51)[:, None])

pde_pinn_pce.parameters_test.N = 10000
pde_pinn_pce.parameters_test.sample_values()
parameters = pde_pinn_pce.parameters_test.values.detach().cpu()

Nx = x.shape[0]
Nt = t.shape[0]
Np = parameters.shape[0]

### Evaluate model and reference solutions

In [None]:
tx = cartesian_product_of_rows(t, x).float()

# PINN_PCE solution
_, y_tmp_pinn_pce = model_pinn_pce.forward(tx, parameters)

y_pinn_pce = np.array(
    [y_tmp_pinn_pce[k::Np, :].reshape(Nt, Nx).detach().numpy().T for k in range(Np)],
)  # shape: (samples, Nx, Nt)

mean_y_pinn_pce = y_pinn_pce.mean(axis=0)
std_y_pinn_pce = y_pinn_pce.std(axis=0)


# PINN solution
_, y_tmp_pinn = model_pinn.forward(tx, parameters)

y_pinn = np.array(
    [y_tmp_pinn[k::Np, :].reshape(Nt, Nx).detach().numpy().T for k in range(Np)],
)  # shape: (samples, Nx, Nt)

mean_y_pinn = y_pinn.mean(axis=0)
std_y_pinn = y_pinn.std(axis=0)

# Reference solution
tt = to_numpy(t).squeeze()
xx = to_numpy(x).squeeze()

s_tmp = [reference_solution(tt, xx, to_numpy(p)) for p in parameters]

solutions = np.stack(s_tmp, axis=0)  # shape: (samples, Nx, Nt)
mean_solution = solutions.mean(axis=0)
std_solution = solutions.std(axis=0)

### Plot the density of the solution

In [None]:
plt.close("all")

methods = {
    "Reference": solutions,
    "PINN": y_pinn,
    "PINN-PCE": y_pinn_pce,
}

x_vals = x.squeeze() if x.ndim > 1 else x
t_vals = t.squeeze() if t.ndim > 1 else t
middle_time_index = t.shape[0] // 2

n_xbins = 100
n_ubins = 100
threshold = 0.1

fig, axs = plt.subplots(1, 3, figsize=(18, 5), sharey=True)

for ax_idx, (ax, (label, sol)) in enumerate(zip(axs, methods.items(), strict=False)):
    u_vals = sol[:, :, middle_time_index].reshape(-1)
    x_grid = np.tile(x_vals, sol.shape[0])

    hist, x_edges, u_edges = np.histogram2d(
        x_grid,
        u_vals,
        bins=[n_xbins, n_ubins],
        range=[[x_vals.min(), x_vals.max()], [u_vals.min(), u_vals.max()]],
    )

    hist_smooth = gaussian_filter(hist.T, sigma=1.0)
    hist_masked = np.ma.masked_less(hist_smooth, threshold)

    cmap = plt.colormaps.get_cmap("plasma").copy()
    cmap.set_bad(color="white")
    cmap.set_under(color="white")

    norm = LogNorm(vmin=threshold, vmax=hist_smooth.max())

    im = ax.imshow(
        hist_masked,
        aspect="auto",
        origin="lower",
        extent=[x_edges[0], x_edges[-1], u_edges[0], u_edges[-1]],
        cmap=cmap,
        norm=norm,
    )

    ax.set_xlabel("x", fontsize=12)
    ax.set_title(label, fontsize=14)  # ← method name as subtitle

    if ax_idx == 0:
        ax.set_ylabel("log(Frequency)", fontsize=12)
        fig.colorbar(im, ax=ax)
    else:
        fig.colorbar(im, ax=ax)

plt.tight_layout()
fig.savefig("diffusion_histograms.pdf", bbox_inches="tight", dpi=300)
plt.show()

### Plot mean heatmap

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assumes: solutions = Reference, y_pinn = PINN, y_pinn_pce = PINN-PCE
methods = {
    "PINN": y_pinn,
    "PINN-PCE": y_pinn_pce,
}

x_vals = x.squeeze() if x.ndim > 1 else x
t_vals = t.squeeze() if t.ndim > 1 else t

mean_reference = np.mean(solutions, axis=0)  # shape: (n_x, n_t)

fig, axs = plt.subplots(1, 2, figsize=(12, 5), sharey=True, constrained_layout=True)

for ax, (label, sol) in zip(axs, methods.items(), strict=False):
    mean_model = np.mean(sol, axis=0)
    abs_diff = np.abs(mean_model - mean_reference)

    im = ax.imshow(
        abs_diff.T,
        aspect="auto",
        origin="lower",
        extent=[x_vals.min(), x_vals.max(), t_vals.min(), t_vals.max()],
        cmap="viridis",
    )

    ax.set_xlabel("x", fontsize=12)
    ax.set_title(f"{label} vs. Reference", fontsize=14)

axs[0].set_ylabel("t", fontsize=12)
fig.colorbar(im, ax=axs, location="right", shrink=0.8)

fig.savefig("diffusion_absolute_error_heatmaps.pdf", bbox_inches="tight", dpi=300)
plt.show()

### Plot the train and test loss

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))

# --- PINN (solid lines) ---
x1 = trainer_pinn.trainer_data.losses_train.i
y1 = trainer_pinn.trainer_data.losses_train.v
ax.plot(x1, y1, label="PINN – Training", linestyle="-", linewidth=2)

x2 = trainer_pinn.trainer_data.losses_test.i
y2 = trainer_pinn.trainer_data.losses_test.v
ax.plot(x2, y2, label="PINN – Testing", linestyle="-", linewidth=2)

# --- PINN-PCE (dashed lines) ---
x3 = trainer_pinn_pce.trainer_data.losses_train.i
y3 = trainer_pinn_pce.trainer_data.losses_train.v
ax.plot(
    x3, y3, label="PINN-PCE – Training", linestyle="--", linewidth=2, color="tab:red"
)

x4 = trainer_pinn_pce.trainer_data.losses_test.i
y4 = trainer_pinn_pce.trainer_data.losses_test.v
ax.plot(
    x4, y4, label="PINN-PCE – Testing", linestyle="--", linewidth=2, color="tab:cyan"
)

# --- Formatting ---
ax.set_yscale("log")
ax.set_ylim(top=1)  # Set maximum y to 10^0
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.grid(True, which="both", linestyle="--", alpha=0.6)
ax.legend(fontsize=12)
fig.tight_layout()

fig.savefig("diffusion_loss.pdf", bbox_inches="tight", dpi=300)

plt.show()