## Solution of the Heat Equation on the Line using a PINN-PCE Method

We solve the following heat equation:

$$
u_t - \frac{a}{k^2} u_{xx} = 0
$$

with initial and boundary conditions:

$$
u(0, x) = \sin(k x) \\
u(t, 0) = 0 \\
u(t, \pi) = e^{-a t} \sin(\pi k)
$$


In [3]:
%matplotlib widget

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from matplotlib.animation import FuncAnimation
from matplotlib.lines import Line2D
from torch import tensor

from examples.equations.heat_1d.pde import reference_solution
from measure_uq.models import PINN, PINN_PCE
from measure_uq.pde import PDE
from measure_uq.trainers.trainer import Trainer
from measure_uq.utilities import cartesian_product_of_rows, to_numpy

mpl.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "STIX",  # or 'Times New Roman' if available
        "mathtext.fontset": "stix",
        "axes.labelsize": 16,
        "font.size": 14,
        "legend.fontsize": 14,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
    }
)


if 0:
    model_type = PINN_PCE
    model_path = "data/best_model_pinn_pce.pickle"
    pde_path = "data/pde_pinn_pce.pickle"
    trainer_path = "data/trainer_pinn_pce.pickle"
else:
    model_type = PINN
    model_path = "data/best_model_pinn.pickle"
    pde_path = "data/pde_pinn.pickle"
    trainer_path = "data/trainer_pinn.pickle"

model = model_type.load(model_path)
pde = PDE.load(pde_path)

trainer = Trainer.load(trainer_path)

T = pde.conditions_train.conditions[0].T
X = pde.conditions_train.conditions[0].X

t = tensor(np.linspace(0, T, 101)[:, None])
x = tensor(np.linspace(0, X, 51)[:, None])

parameters = pde.parameters_test.values.detach().cpu()

Nx = x.shape[0]
Nt = t.shape[0]
Np = parameters.shape[0]

ModuleNotFoundError: No module named 'matplotlib'

### Evaluate the solution of the PINN for all the parameters

In [None]:
tx = cartesian_product_of_rows(t, x).float()

_, y_tmp = model.forward(tx, parameters)

y = np.array(
    [y_tmp[k::Np, :].reshape(Nt, Nx).detach().numpy().T for k in range(Np)],
)  # shape: (samples, Nx, Nt)

mean_y = y.mean(axis=0)
std_y = y.std(axis=0)

### Evaluate the reference solution for all the parameters


In [None]:
tt = to_numpy(t).squeeze()
xx = to_numpy(x).squeeze()

s_tmp = [reference_solution(tt, xx, to_numpy(p)) for p in parameters]

solutions = np.stack(s_tmp, axis=0)  # shape: (samples, Nx, Nt)
mean_solution = solutions.mean(axis=0)
std_solution = solutions.std(axis=0)

### Plot the PINN and the reference solution

In [None]:
model_name = model_type.__name__.upper()

fig, axs = plt.subplots(2, 1, figsize=(16, 10))
axs = axs.flatten()

axs[0].set_ylim(1.05 * y.min(), 1.05 * y.max())
axs[1].set_ylim(
    1.05 * (mean_y.min() - std_y.max()), 1.05 * (mean_y.max() + std_y.max())
)

i = 0

ax0_lines1 = axs[0].plot(x, solutions[:, :, i].T)
ax0_lines2 = axs[0].plot(x, y[:, :, i].T, "--")

# Manual legend for line styles
reference_line = Line2D([0], [0], color="black", linestyle="-", label="Reference")
pinn_line = Line2D([0], [0], color="black", linestyle="--", label=f"{model_name}")
axs[0].legend(
    handles=[reference_line, pinn_line],
    loc="lower left",
    fontsize=16,
)


(ax1_sol_mean,) = axs[1].plot(
    x,
    mean_solution[:, i],
    "-",
    color="black",
    label="Ref. mean",
    linewidth=2,
)
(ax1_sol_std_plus,) = axs[1].plot(
    x,
    mean_solution[:, i] + std_solution[:, i],
    "--",
    color="black",
    linewidth=2,
)
(ax1_sol_std_minus,) = axs[1].plot(
    x,
    mean_solution[:, i] - std_solution[:, i],
    "--",
    color="black",
    label=r"Ref. mean $\pm$ std",
    linewidth=2,
)
(ax1_y_mean,) = axs[1].plot(
    x,
    mean_y[:, i],
    ".",
    color="blue",
    label=f"{model_name} mean",
    linewidth=2,
    markersize=12,
)

ax1_fill = axs[1].fill_between(
    x.squeeze(),
    mean_y[:, i] - std_y[:, i],
    mean_y[:, i] + std_y[:, i],
    color="blue",
    alpha=0.3,
    label=f"{model_name} uncertainty band (1 std)",
)

axs[0].set_ylabel("u", fontsize=16)
axs[1].set_ylabel("u", fontsize=16)
axs[1].set_xlabel("x", fontsize=16)

for a in axs:
    a.grid()

axs[1].legend(loc="lower left", fontsize=16)
fig.suptitle(f"t = {t[i].item():.2f}", fontsize=18)


def animate(i: int):
    """
    Animate the solution of the PDE.

    Parameters
    ----------
    i : int
        The index of the frame.

    Returns
    -------
    tuple
        Updated matplotlib artists.
    """
    fig.suptitle(f"t = {t[i].item():.2f}")

    for k, line in enumerate(ax0_lines1):
        line.set_ydata(solutions[k, :, i].T)

    for k, line in enumerate(ax0_lines2):
        line.set_ydata(y[k, :, i].T)

    ax1_sol_mean.set_ydata(mean_solution[:, i])
    ax1_sol_std_plus.set_ydata(mean_solution[:, i] + std_solution[:, i])
    ax1_sol_std_minus.set_ydata(mean_solution[:, i] - std_solution[:, i])
    ax1_y_mean.set_ydata(mean_y[:, i])

    if axs[1].collections:
        axs[1].collections[0].remove()

    ax1_fill = axs[1].fill_between(
        x.squeeze(),
        mean_y[:, i] - std_y[:, i],
        mean_y[:, i] + std_y[:, i],
        color="blue",
        alpha=0.3,
    )

    return (
        *ax0_lines1,
        *ax0_lines2,
        ax1_y_mean,
        ax1_sol_mean,
        ax1_sol_std_plus,
        ax1_sol_std_minus,
        ax1_fill,
    )


anime = FuncAnimation(
    fig,
    animate,
    interval=200,
    blit=False,
    frames=t.shape[0],
)

plt.close()
HTML(anime.to_html5_video())

In [None]:
path = f"diffusion_{model_type.__name__.lower()}.mp4"

anime.save(
    path,
    dpi=200,  # Higher resolution
    bitrate=3000,  # Higher quality (in kbps)
    fps=5,  # Match your desired speed
)

### Plot the train and test loss

In [None]:
_, ax = plt.subplots()

x1 = trainer.trainer_data.losses_train.i
y1 = trainer.trainer_data.losses_train.v
ax.plot(x1, y1, label="Training loss")

x2 = trainer.trainer_data.losses_test.i
y2 = trainer.trainer_data.losses_test.v
ax.plot(x2, y2, label="Testing loss")

ax.set_yscale("log")
ax.grid(True)


ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.legend()

plt.show()