## Fisher–KPP inverse problem



#### Install Libraries

In [1]:
!pip install pinnstorch
!pip install lightning



#### Import Libraries

In [None]:
from typing import Dict, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import lightning.pytorch as pl
import pinndtorch




ModuleNotFoundError: No module named 'pinnstorch'

In [None]:
class KPPDataset(Dataset):
    """
    Dataset of scattered observations (t, x) -> u.

    CSV is assumed to have columns:
        't', 'x', 'u_exact', 'u_noisy'
    We will use 'u_noisy' for the inverse problem.
    """

    def __init__(self, csv_path: str):
        super().__init__()

        # Read data from CSV
        df = pd.read_csv(csv_path)

        # N x 1 arrays
        t = df["t"].values.astype(np.float32).reshape(-1, 1)
        x = df["x"].values.astype(np.float32).reshape(-1, 1)
        u = df["u_noisy"].values.astype(np.float32).reshape(-1, 1)

        # Stack as [x, t] so that:
        #   X[:, 0] = x,  X[:, 1] = t
        # This matches the logic in training_step:
        #   x_data = X_data[:, 0:1]
        #   t_data = X_data[:, 1:2]
        self.X = torch.from_numpy(np.concatenate([x, t], axis=1))  # (N, 2)
        self.u = torch.from_numpy(u)                               # (N, 1)

        # Store domain bounds for the PINN
        self.t_min = float(t.min())
        self.t_max = float(t.max())
        self.x_min = float(x.min())
        self.x_max = float(x.max())

    def __len__(self):
        # Number of observation points
        return self.X.shape[0]

    def __getitem__(self, idx):
        # Return one sample: (X_i, u_i)
        return self.X[idx], self.u[idx]


In [None]:
def sample_collocation_points(
    N_f: int,
    t_min: float,
    t_max: float,
    x_min: float,
    x_max: float,
    device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample collocation points (t_f, x_f) uniformly in the domain.

    These points are used only to enforce the PDE (physics loss).
    """
    t_f = np.random.uniform(t_min, t_max, size=(N_f, 1)).astype(np.float32)
    x_f = np.random.uniform(x_min, x_max, size=(N_f, 1)).astype(np.float32)

    t_f = torch.from_numpy(t_f).to(device)
    x_f = torch.from_numpy(x_f).to(device)

    return t_f, x_f




In [None]:
class FisherKPPNet(nn.Module):
    """
    Neural network u(x,t) + learnable PDE parameters D and r.
    Receives lb, ub as 2D tensors: [x_min, t_min], [x_max, t_max].
    """
    def __init__(self, lb, ub):
        super().__init__()

        self.lb = lb
        self.ub = ub

        # FCN from pinnstorch
        self.net = pinnstorch.models.FCN(
            layers=[2, 100, 100, 100, 100, 1],
            output_names=["u"],
            lb=lb,
            ub=ub,
        )

        # Trainable PDE parameters
        self.D = nn.Parameter(torch.tensor(0.1, dtype=torch.float32))
        self.r = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))

    def forward(self, x, t):
        """
        pinnstorch FCN expects 'spatial' to be a tuple.
        For 1D x, we must wrap it as (x,)
        """
        outputs = self.net((x,), t)
        outputs["D"] = self.D
        outputs["r"] = self.r
        return outputs



In [None]:
def fisher_kpp_pde(outputs, x, t):
    """
    Fisher–KPP PDE:
        u_t = D u_xx + r u (1 - u)
    """

    u = outputs["u"]

    # First-order derivatives
    u_x, u_t = pinnstorch.utils.gradient(u, [x, t])

    # Second derivative in x
    u_xx = pinnstorch.utils.gradient(u_x, x)[0]

    D = outputs["D"]
    r = outputs["r"]

    # Physics residual
    f = u_t - D * u_xx - r * u * (1 - u)
    return f


In [None]:
class FisherKPP_PINN(pl.LightningModule):
    """
    Physics-Informed Neural Network for the Fisher–KPP inverse problem.

    Loss = data_loss + lambda_pde * pde_loss

    - data_loss: match u(t, x) to noisy observations
    - pde_loss:  enforce Fisher–KPP PDE at collocation points
    """

    def __init__(
        self,
        dataset: KPPDataset,
        n_f: int = 10000,
        lambda_pde: float = 1.0,
        lr: float = 1e-3,
    ):
        super().__init__()

        self.save_hyperparameters(ignore=["dataset"])

        # Domain bounds for normalization (as in Schrodinger mesh)
        lb = torch.tensor([dataset.x_min, dataset.t_min], dtype=torch.float32)
        ub = torch.tensor([dataset.x_max, dataset.t_max], dtype=torch.float32)

        self.net = FisherKPPNet(lb=lb, ub=ub)

        self.dataset = dataset
        self.lambda_pde = lambda_pde
        self.lr = lr
        self.n_f = n_f  # number of collocation points per epoch

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        return self.net(t, x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        X_data, u_data = batch
        X_data = X_data.to(self.device)
        u_data = u_data.to(self.device)

        # Correct mapping (same as Schrodinger):
        # X_data[:,0] = t, X_data[:,1] = x
        x_data = X_data[:, 0:1]
        t_data = X_data[:, 1:2]

        # -----------------------
        # 1. Data
        # -----------------------
        outputs_data = self(x_data, t_data)   # VERY IMPORTANT: FCN(x, t)
        loss_data = torch.mean((outputs_data["u"] - u_data)**2)

        # -----------------------
        # 2. PDE
        # -----------------------
        x_f, t_f = sample_collocation_points(
            self.n_f,
            self.dataset.x_min, self.dataset.x_max,
            self.dataset.t_min, self.dataset.t_max,
            self.device
        )

        outputs_f = self(x_f, t_f)            # FCN(x, t)
        f = fisher_kpp_pde(outputs_f, x_f, t_f)
        loss_pde = torch.mean(f**2)

        loss = loss_data + self.lambda_pde * loss_pde

        self.log("loss", loss, prog_bar=True)
        self.log("loss_data", loss_data)
        self.log("loss_pde", loss_pde)
        self.log("D", self.net.D)
        self.log("r", self.net.r)

        return loss



In [None]:
# ----- Build Fisher–KPP observation dataset -----
# Adjust the file name if your csv has a different name
csv_path = "data/kpp_training_data.csv"
dataset = KPPDataset(csv_path)

print("Dataset created.")
print(f"t in [{dataset.t_min:.4f}, {dataset.t_max:.4f}]")
print(f"x in [{dataset.x_min:.4f}, {dataset.x_max:.4f}]")
print(f"Number of samples: {len(dataset)}")


Dataset created.
t in [0.0000, 5.0000]
x in [0.0000, 10.0000]
Number of samples: 2000


In [None]:
# Training hyper-parameters
batch_size = 64
n_f = 5000          # number of collocation points per epoch
lambda_pde = 1.0
lr = 1e-3
max_epochs = 5000   # you can reduce for quick tests

# DataLoader for observation data
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Lightning model
model = FisherKPP_PINN(
    dataset=dataset,
    n_f=n_f,
    lambda_pde=lambda_pde,
    lr=lr,
)

# Trainer
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=max_epochs,
    log_every_n_steps=50,
)

trainer.fit(model=model, train_dataloaders=train_loader)



TypeError: FisherKPPNet.__init__() got an unexpected keyword argument 'lb'

In [None]:
print("Training finished.")
print(f"Estimated D: {model.net.D.item():.4f}")
print(f"Estimated r: {model.net.r.item():.4f}")

# Optional: quick visualization on a coarse grid
import matplotlib.pyplot as plt

# Build a small regular grid for plotting
Nt_plot, Nx_plot = 100, 100
t_plot = np.linspace(dataset.t_min, dataset.t_max, Nt_plot, dtype=np.float32)
x_plot = np.linspace(dataset.x_min, dataset.x_max, Nx_plot, dtype=np.float32)
T_grid, X_grid = np.meshgrid(t_plot, x_plot, indexing="ij")

t_flat = torch.from_numpy(T_grid.reshape(-1, 1)).to(model.device)
x_flat = torch.from_numpy(X_grid.reshape(-1, 1)).to(model.device)

with torch.no_grad():
    preds = model(t_flat, x_flat)["u"].cpu().numpy().reshape(Nt_plot, Nx_plot)

plt.figure(figsize=(7, 4))
plt.imshow(
    preds,
    extent=[dataset.x_min, dataset.x_max, dataset.t_max, dataset.t_min],
    aspect="auto",
    origin="upper",
)
plt.colorbar(label="u_pred(t, x)")
plt.xlabel("x")
plt.ylabel("t")
plt.title("Fisher–KPP PINN prediction")
plt.tight_layout()
plt.show()
