# Solve a Nonlinear 1D Reaction-Diffusion Equation with Random Coefficients
This notebook solves the PDE:

$$
\frac{\partial u}{\partial t} - D \frac{\partial^2 u}{\partial x^2} + g(x) u^3 = f(x), \quad t \in [0, T],\ x \in [-1, 1]
$$

with:

### Initial condition:
$$
u(0, x) = 0.5 \cos^2(\pi x)
$$

### Boundary conditions:
$$
u(t, -1) = u(t, 1) = 0.5
$$

### Reaction coefficient:
$$
g(x) = 0.2 + e^{r_1 x} \cos^2(r_2 x), \quad r_1 \sim \mathcal{U}(0.5, 1), \quad r_2 \sim \mathcal{U}(3, 4)
$$

### Forcing term:
$$
f(x) = \exp\left( -\frac{(x - 0.25)^2}{2 k_1^2} \right) \sin^2(k_2 x), \quad k_1 \sim \mathcal{U}(0.2, 0.8), \quad k_2 \sim \mathcal{U}(1, 4)
$$

In [None]:
%matplotlib widget

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from torch import tensor

from examples.equations.reaction_diffusion_1d.pde import reference_solution
from measure_uq.models import PINN
from measure_uq.pde import PDE
from measure_uq.trainers.trainer import Trainer
from measure_uq.utilities import cartesian_product_of_rows, to_numpy

mpl.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "STIX",
        "mathtext.fontset": "stix",
        "font.size": 16,
        "axes.titlesize": 16,
        "axes.labelsize": 16,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "legend.fontsize": 16,
        "figure.titlesize": 16,
    }
)

# model_type_pinn_pce = PINN_PCE
# model_path_pinn_pce = "data/best_model_pinn_pce.pickle"
# pde_path_pinn_pce = "data/pde_pinn_pce.pickle"
# trainer_path_pinn_pce = "data/trainer_pinn_pce.pickle"

model_type_pinn = PINN
model_path_pinn = "data/best_model_pinn.pickle"
pde_path_pinn = "data/pde_pinn.pickle"
trainer_path_pinn = "data/trainer_pinn.pickle"

# model_pinn_pce = model_type_pinn_pce.load(model_path_pinn_pce)
# pde_pinn_pce = PDE.load(pde_path_pinn_pce)
# trainer_pinn_pce = Trainer.load(trainer_path_pinn_pce)

model_pinn = model_type_pinn.load(model_path_pinn)
pde_pinn = PDE.load(pde_path_pinn)
trainer_pinn = Trainer.load(trainer_path_pinn)

T = pde_pinn.conditions_train.conditions[0].T
D = pde_pinn.conditions_train.conditions[0].D

t = tensor(np.linspace(0, T, 101)[:, None])
x = tensor(np.linspace(-1, 1, 51)[:, None])

pde_pinn.parameters_test.N = 10000
pde_pinn.parameters_test.sample_values()
parameters = pde_pinn.parameters_test.values.detach().cpu()

Nx = x.shape[0]
Nt = t.shape[0]
Np = parameters.shape[0]

### Evaluate model and reference solutions

In [None]:
tx = cartesian_product_of_rows(t, x).float()

# PINN_PCE solution
# _, y_tmp_pinn_pce = model_pinn_pce.forward(tx, parameters)

# y_pinn_pce = np.array(
#     [y_tmp_pinn_pce[k::Np, :].reshape(Nt, Nx).detach().numpy().T for k in range(Np)],
# )  # shape: (samples, Nx, Nt)

# mean_y_pinn_pce = y_pinn_pce.mean(axis=0)
# std_y_pinn_pce = y_pinn_pce.std(axis=0)


# PINN solution
_, y_tmp_pinn = model_pinn.forward(tx, parameters)

y_pinn = np.array(
    [y_tmp_pinn[k::Np, :].reshape(Nt, Nx).detach().numpy().T for k in range(Np)],
)  # shape: (samples, Nx, Nt)

mean_y_pinn = y_pinn.mean(axis=0)
std_y_pinn = y_pinn.std(axis=0)

# Reference solution
tt = to_numpy(t).squeeze()
xx = to_numpy(x).squeeze()

s_tmp = [reference_solution(tt, xx, to_numpy(p), D) for p in parameters]

solutions = np.stack(s_tmp, axis=0)  # shape: (samples, Nx, Nt)
mean_solution = solutions.mean(axis=0)
std_solution = solutions.std(axis=0)

### Plot the density of the solution

In [None]:
plt.close("all")

methods = {
    "Reference": solutions,
    "PINN": y_pinn,
    # "PINN-PCE": y_pinn_pce,
}

x_vals = x.squeeze() if x.ndim > 1 else x
t_vals = t.squeeze() if t.ndim > 1 else t
middle_time_index = t.shape[0] // 2

n_xbins = 100
n_ubins = 100
threshold = 0.1

fig, axs = plt.subplots(1, len(methods), figsize=(18, 5), sharey=True)

for ax_idx, (ax, (label, sol)) in enumerate(zip(axs, methods.items(), strict=False)):
    u_vals = sol[:, :, middle_time_index].reshape(-1)
    x_grid = np.tile(x_vals, sol.shape[0])

    hist, x_edges, u_edges = np.histogram2d(
        x_grid,
        u_vals,
        bins=[n_xbins, n_ubins],
        range=[[x_vals.min(), x_vals.max()], [u_vals.min(), u_vals.max()]],
    )

    hist = hist.T

    cmap = plt.colormaps.get_cmap("plasma").copy()
    cmap.set_bad(color="white")
    cmap.set_under(color="white")

    norm = LogNorm(vmin=1, vmax=hist.max())

    im = ax.imshow(
        hist,
        aspect="auto",
        origin="lower",
        extent=[x_edges[0], x_edges[-1], u_edges[0], u_edges[-1]],
        cmap=cmap,
        norm=norm,
        interpolation="bicubic",
    )

    ax.set_xlabel("x")
    ax.set_title(label)  # ← method name as subtitle

    if ax_idx == 0:
        ax.set_ylabel("u")
        fig.colorbar(im, ax=ax)
    else:
        fig.colorbar(im, ax=ax)

plt.tight_layout()
fig.savefig("reaction_diffusion_histograms.pdf", bbox_inches="tight", dpi=300)
plt.show()

### Plot mean heatmap

In [None]:
plt.close("all")

# Assumes: solutions = Reference, y_pinn = PINN, y_pinn_pce = PINN-PCE
methods = {
    "PINN": y_pinn,
    # "PINN-PCE": y_pinn_pce,
}

x_vals = x.squeeze() if x.ndim > 1 else x
t_vals = t.squeeze() if t.ndim > 1 else t

mean_reference = np.mean(solutions, axis=0)  # shape: (n_x, n_t)

fig, axs = plt.subplots(
    1, len(methods), figsize=(12, 5), sharey=True, constrained_layout=True
)
axs = np.atleast_1d(axs)

for ax, (label, sol) in zip(axs, methods.items(), strict=False):
    mean_model = np.mean(sol, axis=0)
    abs_diff = np.abs(mean_model - mean_reference)

    im = ax.imshow(
        abs_diff.T,
        aspect="auto",
        origin="lower",
        extent=[x_vals.min(), x_vals.max(), t_vals.min(), t_vals.max()],
        cmap="viridis",
    )

    ax.set_xlabel("x")
    ax.set_title(f"{label} vs. Reference")

axs[0].set_ylabel("t")
fig.colorbar(im, ax=axs, location="right", shrink=0.8)

fig.savefig(
    "reaction_diffusion_absolute_error_heatmaps.pdf", bbox_inches="tight", dpi=300
)
plt.show()

### Plot mean and standard deviation at fixed times

In [None]:
plt.close("all")

model_name = model_type_pinn.__name__.upper()

fig, axs = plt.subplots(1, 2, figsize=(16, 5))
axs = axs.flatten()


def plot_mean_and_std(ax, i):
    ax.plot(
        x,
        mean_solution[:, i],
        "-",
        color="black",
        label="Ref. mean",
        linewidth=2,
    )
    ax.plot(
        x,
        mean_solution[:, i] + std_solution[:, i],
        "--",
        color="black",
        linewidth=2,
    )
    ax.plot(
        x,
        mean_solution[:, i] - std_solution[:, i],
        "--",
        color="black",
        label=r"Ref. mean $\pm$ std",
        linewidth=2,
    )
    ax.plot(
        x,
        mean_y_pinn[:, i],
        "o",
        color="blue",
        label=f"{model_name} mean",
        linewidth=2,
        markersize=4,
    )

    ax.fill_between(
        x.squeeze(),
        mean_y_pinn[:, i] - std_y_pinn[:, i],
        mean_y_pinn[:, i] + std_y_pinn[:, i],
        color="blue",
        alpha=0.3,
        label=f"{model_name} uncertainty band (1 std)",
    )


t_index = t.shape[0] // 4
plot_mean_and_std(axs[0], t_index)
axs[0].set_title(f"t = {t[t_index].item()}")

t_index = -1
plot_mean_and_std(axs[1], t_index)
axs[1].set_title(f"t = {t[t_index].item()}")

for ax in axs:
    ax.set_ylabel("u")
    ax.set_xlabel("x")
    ax.grid()

axs[0].legend()

plt.tight_layout()
fig.savefig("reaction_diffusion_mean_and_std.pdf", bbox_inches="tight", dpi=300)
plt.show()

### Plot the train and test loss

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))

# --- PINN (solid lines) ---
x1 = trainer_pinn.trainer_data.losses_train.i
y1 = trainer_pinn.trainer_data.losses_train.v
ax.plot(x1, y1, label="PINN – Training", linestyle="-", linewidth=2)

x2 = trainer_pinn.trainer_data.losses_test.i
y2 = trainer_pinn.trainer_data.losses_test.v
ax.plot(x2, y2, label="PINN – Testing", linestyle="-", linewidth=2)

# --- PINN-PCE (dashed lines) ---
# x3 = trainer_pinn_pce.trainer_data.losses_train.i
# y3 = trainer_pinn_pce.trainer_data.losses_train.v
# ax.plot(x3, y3, label='PINN-PCE – Training', linestyle='--', linewidth=2, color='tab:red')

# x4 = trainer_pinn_pce.trainer_data.losses_test.i
# y4 = trainer_pinn_pce.trainer_data.losses_test.v
# ax.plot(x4, y4, label='PINN-PCE – Testing', linestyle='--', linewidth=2, color='tab:cyan')

# --- Formatting ---
ax.set_yscale("log")
ax.set_ylim(top=1)  # Set maximum y to 10^0
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.grid(True, which="both", linestyle="--", alpha=0.6)
ax.legend()
fig.tight_layout()

fig.savefig("reaction_diffusion_loss.pdf", bbox_inches="tight", dpi=300)

plt.show()