# Bistable ODE Simulation
This notebook simulates the bistable ODE:

$$
\frac{dy}{dt} = -r(y - 1) (2 - y) (y - 3) \\
y(0) = y_0
$$
where $r\sim \mathcal{U}(0.8, 1.2)$ and $y_0 \sim \mathcal{U}(0,4)$.


In [None]:
%matplotlib widget

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from matplotlib import animation
from torch import tensor

from examples.equations.bistable_ode.ode import reference_solution
from measure_uq.models import PINN, PINN_PCE
from measure_uq.pde import PDE
from measure_uq.trainers.trainer import Trainer
from measure_uq.utilities import to_numpy

mpl.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "STIX",  # or 'Times New Roman' if available
        "mathtext.fontset": "stix",
        "axes.labelsize": 16,
        "font.size": 14,
        "legend.fontsize": 14,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
    }
)

if 0:
    model_type = PINN_PCE
    model_path = "data/best_model_pinn_pce.pickle"
    pde_path = "data/pde_pinn_pce.pickle"
    trainer_path = "data/trainer_pinn_pce.pickle"
else:
    model_type = PINN
    model_path = "data/best_model_pinn.pickle"
    pde_path = "data/pde_pinn.pickle"
    trainer_path = "data/trainer_pinn.pickle"


model = model_type.load(model_path)
pde = PDE.load(pde_path)
trainer = Trainer.load(trainer_path)

T = pde.conditions_train.conditions[0].T

t = tensor(np.linspace(0, T, 101)[:, None]).float()

pde.parameters_test.N = 10000
pde.parameters_test.sample_values()
parameters = pde.parameters_test.values.detach().cpu()

Nt = t.shape[0]
Np = parameters.shape[0]

### Evaluate the solution of the PINN for all the parameters

In [None]:
_, y_tmp = model.forward(t, parameters)

y = np.array(
    [y_tmp[k::Np, :].detach().cpu().numpy() for k in range(Np)],
)  # shape: (samples, Nt, 1)

y = y.squeeze(-1)

mean_y = y.mean(axis=0)
std_y = y.std(axis=0)

### Evaluate the reference solution for all the parameters


In [None]:
tt = to_numpy(t).squeeze()

s_tmp = [reference_solution(tt, to_numpy(p)) for p in parameters]

solutions = np.stack(s_tmp, axis=0)  # shape: (samples, Nt)
mean_solution = solutions.mean(axis=0)
std_solution = solutions.std(axis=0)

### Plot the PINN and the reference solution

In [None]:
plt.close("all")

fig, ax = plt.subplots(1, 1, figsize=(10, 5))

ax.plot(t, mean_solution, label="MC mean")
ax.plot(t, mean_solution - std_solution, color="black", label="MC mean - std")
ax.plot(t, mean_solution + std_solution, color="black", label="MC mean + std")

ax.plot(t, mean_y, "--", label="PINN mean")
ax.fill_between(
    t.squeeze(),
    mean_y - std_y,
    mean_y + std_y,
    alpha=0.3,
    label="Standard deviation of PINN",
)

ax.set_xlabel("time")
ax.set_ylabel("y")

ax.legend()
ax.grid()

fig.savefig("bistable_ode_mean_and_std.pdf", bbox_inches="tight", dpi=300)

In [None]:
plt.close("all")

t = to_numpy(t).squeeze()
bins = np.linspace(0, 4, 100)

fig, ax = plt.subplots(figsize=(8, 4))


def update(frame):
    """Update the plot for each frame."""
    ax.clear()

    ax.hist(
        solutions[:, frame],
        bins=bins,
        density=True,
        alpha=0.6,
        color="skyblue",
        label="Reference",
    )
    ax.hist(y[:, frame], bins=bins, density=True, alpha=0.6, color="red", label="PINN")

    ax.set_xlim(0, 4)
    ax.set_ylim(0, 2.0)
    ax.set_ylabel("Frequency")
    ax.set_xlabel("u")
    ax.set_title(f"t = {t[frame]:.2f}")
    ax.legend()

    fig.tight_layout()  # ensures labels fit in the frame


anime = animation.FuncAnimation(fig, update, frames=len(t), interval=100)
plt.close()

HTML(anime.to_html5_video())

In [None]:
# Save the video with good quality
path = f"bistable_ode_{model_type.__name__.lower()}.mp4"
anime.save(path, dpi=200, bitrate=3000, fps=5)

### PLot Wasserstein distance of the histograms

In [None]:
plt.close("all")

# Example: MC vs PINN solutions (shape: [n_samples, n_times])
# Replace with your actual data arrays
# solutions: (N_samples, T)
# y: (N_samples, T)
assert solutions.shape == y.shape

n_times = solutions.shape[1]
wasserstein_distances = []

for i in range(n_times):
    data1 = np.sort(solutions[:, i])
    data2 = np.sort(y[:, i])

    # Match sizes
    n = min(len(data1), len(data2))
    d = np.mean(np.abs(data1[:n] - data2[:n]))
    wasserstein_distances.append(d)

# Plot Wasserstein distance as function of time
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(tt, wasserstein_distances, linewidth=3, color="black")
ax.set_xlabel("t")
ax.set_ylabel("Wasserstein distance")
ax.grid(True)
fig.tight_layout()
fig.savefig("bistable_ode_wasserstein_distance.pdf", bbox_inches="tight", dpi=300)

plt.show()

### Plot the histogram at a time

In [None]:
plt.close("all")

t = to_numpy(t).squeeze()
bins = np.linspace(0, 4, 100)

fig, ax = plt.subplots(figsize=(8, 4))

frame = 25

ax.hist(
    solutions[:, frame],
    bins=bins,
    density=True,
    alpha=0.6,
    color="skyblue",
    label="Reference",
)
ax.hist(y[:, frame], bins=bins, density=True, alpha=0.6, color="red", label="PINN")

ax.set_xlim(0, 4)
ax.set_ylim(0, 2.0)
ax.set_ylabel("Frequency")
ax.set_xlabel("u")
ax.set_title(f"t = {t[frame]:.2f}")
ax.legend()

ax.set_xlim(0.9, 3.1)

fig.tight_layout()

fig.savefig("bistable_ode_histograms.pdf", bbox_inches="tight", dpi=300)

### Plot the train and test loss

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))

x1 = trainer.trainer_data.losses_train.i
y1 = trainer.trainer_data.losses_train.v
ax.plot(x1, y1, label="Training loss", linewidth=3)

x2 = trainer.trainer_data.losses_test.i
y2 = trainer.trainer_data.losses_test.v
ax.plot(x2, y2, label="Testing loss", linewidth=3)

ax.set_yscale("log")
ax.grid(True)


ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.legend()

fig.tight_layout()

fig.savefig("bistable_ode_loss.pdf", bbox_inches="tight", dpi=300)