# SIR model: inverse problem
## A PINN approach

In this notebook, we will solve the inverse problem of the SIR model using a Physics-Informed Neural Network (PINN). The goal is to estimate the infection rate $\beta$ from the observed data of the infected population. To do this, we will train a PINN model, where we compute the residuals of the differential equation system with initial conditions and the data loss simultaneously.

The SIR model is governed by the following set of ordinary differential equations (ODEs):

$$
\begin{cases}
\frac{dS}{dt} &= -\frac{\beta}{N} I S, \\
\frac{dI}{dt} &= \frac{\beta}{N} I S - \delta I, \\
\frac{dR}{dt} &= \delta I,
\end{cases}
$$

where $t \in [0, 90]$ and with the initial conditions $S(0) = N - 1$, $I(0) = 1$, and $R(0) = 0$.

## Configuration

In [None]:
import json
import os

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn

# Create directories
figures_dir = "figures"
data_dir = "data"
os.makedirs(figures_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)

# Configuration dictionary
config = {
  # Data parameters
  "t_start":            0,
  "t_end":              90,

  # Model parameters
  "population":         1.0,  # Normalized population
  "delta":              1 / 5,  # Recovery rate
  "initial_beta":       0.5,  # Initial value for beta parameter

  # Network architecture
  "layers":             [1] + [50] * 4 + [1],
  "activation":         "tanh",
  "output_activation":  "square",

  # Training parameters
  "learning_rate":      1e-3,
  "batch_size":         100,
  "dataset_size":       6000,
  "max_epochs":         5000,

  # Early stopping
  "patience":           200,
  "min_delta":          1e-5,

  # Loss weights
  "weight_pde":         1.0,
  "weight_ic":          1.0,
  "weight_data":        1.0,

  # Optimizer settings
  "clip_grad_norm":     1.0,

  # Scheduler
  "scheduler_patience": 100,
  "scheduler_factor":   0.5,
  "min_lr":             1e-6,

  # Logging
  "log_interval":       100,
}


## Utility functions

In [None]:
class Square(nn.Module):
  """A module that squares its input element-wise."""

  @staticmethod
  def forward(x):
    return torch.square(x)


def get_activation(name):
  """Get activation function by name."""
  if name == "tanh":
    return nn.Tanh()
  elif name == "relu":
    return nn.ReLU()
  elif name == "sigmoid":
    return nn.Sigmoid()
  elif name == "square":
    return Square()
  elif name == "softplus":
    return nn.Softplus()
  elif name == "none":
    return nn.Identity()
  else:
    raise ValueError(f"Activation function {name} not recognized")


def create_fnn(layers_dimensions, activation, output_activation):
  """Create a feedforward neural network with specified architecture."""
  layers_modules = []
  for i in range(len(layers_dimensions) - 1):
    layers_modules.append(
      nn.Linear(layers_dimensions[i], layers_dimensions[i + 1])
    )
    if i < len(layers_dimensions) - 2:
      layers_modules.append(activation)

  if output_activation is not None:
    layers_modules.append(output_activation)

  net = nn.Sequential(*layers_modules)

  # Initialize weights and biases
  for m in net:
    if isinstance(m, nn.Linear):
      nn.init.xavier_normal_(m.weight)
      nn.init.zeros_(m.bias)

  return net


## Model definition

In [None]:
class SIRPINN(pl.LightningModule):
  def __init__(self, config):
    super().__init__()
    self.save_hyperparameters(config)
    self.config = config

    # Get activation functions
    activation = get_activation(config["activation"])
    output_activation = get_activation(config["output_activation"])

    # Create neural networks
    self.net_S = create_fnn(config["layers"], activation, output_activation)
    self.net_I = create_fnn(config["layers"], activation, output_activation)

    # Model parameters
    self.beta = nn.Parameter(
      torch.tensor(config["initial_beta"], dtype=torch.float32)
    )
    self.delta = config["delta"]
    self.N = config["population"]

    # Loss tracking
    self.l_pde_history = []
    self.l_ic_history = []
    self.l_data_history = []
    self.l_total_history = []
    self.beta_evolution = []
    self.param_changes = []
    self.re_s_history = []
    self.re_i_history = []

    # Store previous parameters for tracking changes
    self.old_params = None

  def forward(self, x):
    """Forward pass to compute S, I, R values at time points x."""
    S = self.net_S(x)
    I = self.net_I(x)
    R = torch.ones_like(S) * self.N - S - I
    return torch.cat([S, I, R], dim=1)

  def compute_pde_residuals(self, x):
    """Compute residuals of SIR ODE system."""
    x.requires_grad_(True)
    S = self.net_S(x)
    I = self.net_I(x)

    dS_dt = torch.autograd.grad(
      S, x, grad_outputs=torch.ones_like(S), create_graph=True
    )[0]
    dI_dt = torch.autograd.grad(
      I, x, grad_outputs=torch.ones_like(I), create_graph=True
    )[0]

    res_S = dS_dt + self.beta * S * I
    res_I = dI_dt - self.beta * S * I + self.delta * I
    return res_S, res_I

  def loss_pde(self, x):
    """Compute PDE residual loss."""
    res_S, res_I = self.compute_pde_residuals(x)
    return torch.mean(res_S ** 2) + torch.mean(res_I ** 2)

  def loss_ic(self, t0_tensor, ic_tensor):
    """Compute initial condition loss."""
    ic_pred = self(t0_tensor)
    return torch.mean((ic_pred - ic_tensor) ** 2)

  def loss_data(self, t_obs_tensor, i_obs_tensor):
    """Compute data fitting loss."""
    I = self(t_obs_tensor)[:, 1].reshape(-1, 1)
    return torch.mean((I - i_obs_tensor) ** 2)

  def compute_param_change(self):
    """Compute L2 norm of parameter changes between steps."""
    if self.old_params is None:
      self.old_params = [p.clone().detach() for p in self.parameters()]
      return 0.0

    param_diff_squared_sum = 0.0
    new_params = [p.clone().detach() for p in self.parameters()]

    for old_p, new_p in zip(self.old_params, new_params):
      param_diff_squared_sum += torch.sum((old_p - new_p) ** 2).item()

    self.old_params = new_params
    return np.sqrt(param_diff_squared_sum)

  def configure_optimizers(self):
    """Configure optimizer and learning rate scheduler."""
    optimizer = torch.optim.Adam(
      self.parameters(),
      lr=self.config["learning_rate"]
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
      optimizer,
      mode='min',
      factor=self.config["scheduler_factor"],
      patience=self.config["scheduler_patience"],
      min_lr=self.config["min_lr"]
    )

    return {
      "optimizer":    optimizer,
      "lr_scheduler": {
        "scheduler": scheduler,
        "monitor":   "val_loss",
        "interval":  "epoch",
        "frequency": 1
      }
    }

  def training_step(self, batch, batch_idx):
    """Single training step."""
    t_batch, is_obs, t0_tensor, ic_tensor, t_obs_indices = batch

    # Compute PDE loss
    L_pde = self.loss_pde(t_batch)

    # Compute initial condition loss
    L_ic = self.loss_ic(t0_tensor, ic_tensor)

    # Compute data loss if there are observation points in the batch
    L_data = torch.tensor(0.0, device=self.device)
    if is_obs.any():
      obs_points = t_batch[is_obs]
      outputs = self(obs_points)
      predicted_I = outputs[:, 1].reshape(-1, 1)

      # Get corresponding observation values
      obs_indices = t_obs_indices[is_obs]
      true_I = self.trainer.datamodule.i_obs[obs_indices].to(self.device)

      L_data = torch.mean((predicted_I - true_I) ** 2)

    # Compute weighted total loss
    loss = (self.config["weight_pde"] * L_pde +
            self.config["weight_ic"] * L_ic +
            self.config["weight_data"] * L_data)

    # Log losses
    self.log('train_loss_pde', L_pde, prog_bar=False)
    self.log('train_loss_ic', L_ic, prog_bar=False)
    self.log('train_loss_data', L_data, prog_bar=False)
    self.log('train_loss', loss, prog_bar=True)
    self.log('beta', self.beta.item(), prog_bar=True)

    return loss

  def validation_step(self, batch, batch_idx):
    """Validation step."""
    t_all, t0_tensor, ic_tensor, t_obs_tensor, i_obs_tensor = batch

    # Compute PDE loss on all validation points
    L_pde = self.loss_pde(t_all)

    # Compute initial condition loss
    L_ic = self.loss_ic(t0_tensor, ic_tensor)

    # Compute data loss on all observation points
    L_data = self.loss_data(t_obs_tensor, i_obs_tensor)

    # Compute total loss
    val_loss = L_pde + L_ic + L_data

    # Store losses for history
    self.l_pde_history.append(L_pde.item())
    self.l_ic_history.append(L_ic.item())
    self.l_data_history.append(L_data.item())
    self.l_total_history.append(val_loss.item())

    # Store beta and compute parameter changes
    self.beta_evolution.append(self.beta.item())
    param_change = self.compute_param_change()
    self.param_changes.append(param_change)

    # Compute relative errors if true data is available
    if hasattr(self.trainer.datamodule, 'true_data'):
      t_np = t_obs_tensor.cpu().detach().numpy().flatten()
      sir_pred = self(t_np).cpu().detach().numpy()
      s_pred, i_pred, r_pred = sir_pred[:, 0], sir_pred[:, 1], sir_pred[:, 2]

      s_true = self.trainer.datamodule.true_data['s_true']
      i_true = self.trainer.datamodule.true_data['i_true']

      re_s = np.linalg.norm(s_true - s_pred, 2) / np.linalg.norm(s_true, 2)
      re_i = np.linalg.norm(i_true - i_pred, 2) / np.linalg.norm(i_true, 2)

      self.re_s_history.append(re_s)
      self.re_i_history.append(re_i)

      self.log('val_re_s', re_s, prog_bar=False)
      self.log('val_re_i', re_i, prog_bar=False)

    # Log validation metrics
    self.log('val_loss_pde', L_pde, prog_bar=False)
    self.log('val_loss_ic', L_ic, prog_bar=False)
    self.log('val_loss_data', L_data, prog_bar=False)
    self.log('val_loss', val_loss, prog_bar=True)
    self.log('param_change', param_change, prog_bar=False)

    return val_loss

  def on_validation_epoch_end(self):
    """Called at the end of validation epoch."""
    if (self.current_epoch + 1) % self.config["log_interval"] == 0:
      print(
        f"[{self.current_epoch + 1:>{len(str(self.config['max_epochs']))}}/{self.config['max_epochs']}]: "
        f"β = {self.beta.item():.4f} "
        f"| Loss = {self.l_total_history[-1]:.2e} "
        f"| Param Δ = {self.param_changes[-1]:.2e} "
        f"| lr = {self.optimizers().param_groups[0]['lr']:.2e}"
      )

  def on_train_end(self):
    """Save training history when training completes."""
    timestamp = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")

    # Create DataFrames
    loss_df = pd.DataFrame(
      {
        'L_pde':   self.l_pde_history,
        'L_ic':    self.l_ic_history,
        'L_data':  self.l_data_history,
        'L_total': self.l_total_history
      }
    )

    evolution_df = pd.DataFrame(
      {
        'beta':   self.beta_evolution,
        'params': self.param_changes,
        're_s':   self.re_s_history,
        're_i':   self.re_i_history
      }
    )

    # Save to CSV
    loss_df.to_csv(
      f"{data_dir}/loss_history_{timestamp}.csv",
      float_format='%.6e'
    )
    evolution_df.to_csv(
      f"{data_dir}/evolutions_{timestamp}.csv",
      float_format='%.6e'
    )

    # Save config
    with open(f"{data_dir}/config_{timestamp}.json", "w") as f:
      json.dump(
        {k: v if not isinstance(v, (np.float32, np.float64)) else float(v)
          for k, v in self.config.items()},
        f,
        indent=2
      )
