# SEIR Model, LHS Sampling, and Surrogate Training with LSTM & GRU

In this notebook we simulate an age‐structured SEIR model for infectious disease dynamics, explore the parameter space using Latin Hypercube Sampling (LHS), and train surrogate models using Recurrent Neural Networks (RNNs) – specifically, Long Short-Term Memory (LSTM) and Gated Recurrent Unit (GRU) architectures.

The notebook is divided into several sections:

1. **Infectious Disease Modelling Theory:** An introduction to compartmental models such as SEIR.
2. **Surrogate Models:** Why we build surrogates and how they enable rapid approximation of expensive simulations.
3. **RNNs for Time Series Prediction:** An overview of RNNs with a focus on LSTM and GRU cells, including the relevant mathematics.
4. **Simulation, Parameter Sampling and Surrogate Training:** Code sections for the SEIR model, LHS sampling, and training of surrogate models.

Let’s get started!

## 1. Infectious Disease (ID) Modelling Theory

Infectious disease modelling is a key tool for understanding and predicting the spread of diseases in populations. One common approach is to use **compartmental models**. In these models, the population is divided into distinct groups (or compartments) based on disease status. 

### The SEIR Model

The SEIR model is one of the most widely used compartmental models. It divides the population into:

- **S (Susceptible):** Individuals who can contract the disease.
- **E (Exposed):** Individuals who have been infected but are not yet infectious.
- **I (Infectious):** Individuals who can transmit the disease to others.
- **R (Recovered):** Individuals who have recovered and may have immunity.

An additional compartment, **C (Cumulative incidence)**, is sometimes added to keep track of the total number of new infections. 

The dynamics of these compartments are usually described by a system of differential equations. In our model, we further **stratify by age** (17 age groups) to capture heterogeneous mixing patterns.

## 2. Surrogate Models and RNNs

### Why Surrogate Models?

High-fidelity simulations (such as our SEIR model) can be computationally expensive, especially when exploring large parameter spaces. **Surrogate models** are fast, approximate models that are trained to emulate the output of these expensive simulations. They allow rapid predictions, uncertainty quantification, and real-time decision support.

### Recurrent Neural Networks (RNNs)

RNNs are designed for sequential data (e.g., time series) by maintaining a hidden state that carries information from previous time steps. Two popular RNN architectures are **LSTM (Long Short-Term Memory)** and **GRU (Gated Recurrent Unit)**. 

#### LSTM Equations

The LSTM cell updates are given by:

$$\begin{aligned}
i_t &= \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) \\
f_t &= \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f) \\
o_t &= \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o) \\
\tilde{c}_t &= \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c) \\
c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \\
h_t &= o_t \odot \tanh(c_t) 
\end{aligned}$$

Here, $i_t$, $f_t$, and $o_t$ are the input, forget, and output gates respectively, and $c_t$ is the cell state.

#### GRU Equations

The GRU cell updates are somewhat simpler:

$$\begin{aligned}
z_t &= \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z) \\
r_t &= \sigma(W_{xr} x_t + W_{hr} h_{t-1} + b_r) \\
\tilde{h}_t &= \tanh(W_{xh} x_t + r_t \odot (W_{hh} h_{t-1}) + b_h) \\
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
\end{aligned}$$

These mechanisms allow the network to capture long-term dependencies and selectively forget or update information.

### What Does the Surrogate Do?

In our setting, the surrogate model (using either LSTM or GRU) takes as input a sequence that includes:

- A normalized time feature.
- A set of simulation parameters (e.g. latent period, infectious period, etc.) repeated at each time step.

It then predicts the time series output (e.g., total incidence over time) that would be produced by the full SEIR simulation. By training on many simulation runs, the surrogate learns to approximate the mapping from parameters to outcomes very quickly.

In [None]:
# -------------------------------------------------------- #
# 1. Load in packages
# -------------------------------------------------------- #

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from numpy.linalg import eigvals
import pandas as pd

# Use inline plotting for Jupyter
%matplotlib inline

In [None]:
# -------------------------------------------------------- #
# 2. SEIR Model Definition
# -------------------------------------------------------- #

class SEIR:
    """
    A class for simulating a deterministic age‐structured SEIR model with 17 age groups.
    
    Compartments per age group:
      - S: Susceptible
      - E: Exposed (infected but not yet infectious)
      - I: Infectious
      - R: Recovered
      - C: Cumulative incidence (number of new infections entering I)
    """

    def __init__(
        self,
        contact_matrix,
        latent_period,
        infectious_period,
        immunity_period, 
        N,
        I0,
        seed=None,
    ):
        """
        Initialise the model with user‐provided parameters.
        """
        self.n_age = 17
        self.contact_matrix = np.array(contact_matrix)

        self.latent_period = latent_period
        self.infectious_period = infectious_period
        self.immunity_period = immunity_period
        self.sigma = 1.0 / latent_period      # rate of leaving the exposed compartment
        self.gamma = 1.0 / infectious_period  # recovery rate
        self.rho = 1.0 / immunity_period      # rate of loss of immunity

        N = np.array(N)
        if N.shape[0] != self.n_age:
            raise ValueError(f"N must be of length {self.n_age}.")

        # Allocate I0 infections among age groups 20-60 (indices 4 to 11)
        eligible_indices = np.arange(4, 12)
        I_init = np.zeros(self.n_age, dtype=int)
        rng = np.random.default_rng(seed)
        if I0 > 0:
            allocation = rng.multinomial(I0, np.full(len(eligible_indices), 1/len(eligible_indices)))
            I_init[eligible_indices] = allocation

        S_init = N - I_init
        if np.any(S_init < 0):
            raise ValueError("Allocated initial infections exceed population in one or more groups.")

        E_init = np.zeros(self.n_age, dtype=int)
        R_init = np.zeros(self.n_age, dtype=int)

        self.initial_conditions = {
            'S': S_init.astype(float),
            'E': E_init.astype(float),
            'I': I_init.astype(float),
            'R': R_init.astype(float)
        }

        self.pop_age = S_init + E_init + I_init + R_init

        self.initial_C = np.zeros(self.n_age)
        self.y0 = np.concatenate([
            self.initial_conditions['S'],
            self.initial_conditions['E'],
            self.initial_conditions['I'],
            self.initial_conditions['R'],
            self.initial_C
        ])

        eigenvalues = eigvals(self.contact_matrix)
        self.denom = np.max(np.abs(eigenvalues))

        self.results = None

    def get_current_Rt(self, t):
        if t < self.tt_Rt[1]:
            return self.Rt[0]
        elif t < self.tt_Rt[2]:
            return self.Rt[1]
        else:
            return self.Rt[2]

    def beta_t(self, t):
        current_Rt = self.get_current_Rt(t)
        return current_Rt / (self.infectious_period * self.denom)

    def deriv(self, t, y):
        S = y[0:self.n_age]
        E = y[self.n_age:2*self.n_age]
        I = y[2*self.n_age:3*self.n_age]
        R = y[3*self.n_age:4*self.n_age]
        C = y[4*self.n_age:5*self.n_age]

        beta = self.beta_t(t)
        lambda_vec = beta * np.dot(self.contact_matrix, I / self.pop_age)

        dS = -lambda_vec * S + self.rho * R
        dE = lambda_vec * S - self.sigma * E
        dI = self.sigma * E - self.gamma * I
        dR = self.gamma * I - self.rho * R
        dC = self.sigma * E

        return np.concatenate([dS, dE, dI, dR, dC])

    def run(self, t_end, Rt, tt_Rt, dt=0.1):
        self.Rt = np.array(Rt)
        self.tt_Rt = np.array(tt_Rt)
        if len(self.Rt) != 3 or len(self.tt_Rt) != 3:
            raise ValueError("Rt and tt_Rt must be of length 3.")

        t_eval = np.arange(0, t_end, dt)
        sol = solve_ivp(self.deriv, [0, t_end], self.y0, t_eval=t_eval, vectorized=False)
        if not sol.success:
            raise RuntimeError("Integration failed.")
        self.results = sol
        return sol

    def get_output(self):
        if self.results is None:
            raise RuntimeError("No simulation results available. Run the simulation first.")

        t_all = self.results.t
        int_days = np.floor(t_all).astype(int)
        unique_days = np.unique(int_days)
        indices = []
        for day in unique_days:
            idx_day = np.where(int_days == day)[0]
            indices.append(idx_day[-1])
        indices = np.array(indices)

        S_int = self.results.y[0:self.n_age, :][:, indices].T
        E_int = self.results.y[self.n_age:2*self.n_age, :][:, indices].T
        I_int = self.results.y[2*self.n_age:3*self.n_age, :][:, indices].T
        R_int = self.results.y[3*self.n_age:4*self.n_age, :][:, indices].T
        C_int = self.results.y[4*self.n_age:5*self.n_age, :][:, indices].T

        incidence_age = np.vstack([np.zeros((1, self.n_age)), np.diff(C_int, axis=0)])
        total_incidence = incidence_age.sum(axis=1)

        output = {
            'time': unique_days,
            'S': S_int,
            'E': E_int,
            'I': I_int,
            'R': R_int,
            'C': C_int,
            'incidence_age': incidence_age,
            'total_incidence': total_incidence
        }
        return output

    def plot_output(self):
        output = self.get_output()
        t = output['time']
        S_total = output['S'].sum(axis=1)
        E_total = output['E'].sum(axis=1)
        I_total = output['I'].sum(axis=1)
        R_total = output['R'].sum(axis=1)
        total_incidence = output['total_incidence']

        fig, axs = plt.subplots(2, 1, figsize=(10, 8))

        axs[0].plot(t, S_total, label='Susceptible')
        axs[0].plot(t, E_total, label='Exposed')
        axs[0].plot(t, I_total, label='Infectious')
        axs[0].plot(t, R_total, label='Recovered')
        axs[0].set_xlabel('Time (days)')
        axs[0].set_ylabel('Number of individuals')
        axs[0].set_title('SEIR Model Compartments (Aggregated by Day)')
        axs[0].legend()

        axs[1].plot(t, total_incidence, label='Incidence', color='red')
        axs[1].set_xlabel('Time (days)')
        axs[1].set_ylabel('New infections per day')
        axs[1].set_title('Daily Incidence of Infections')
        axs[1].legend()

        plt.tight_layout()
        plt.show()

## 3. Demo: Running the SEIR Model

Below we load in the contact matrix and age distribution data (make sure the files `data/seir_contact_matrix.csv` and `data/sa_ages.csv` are available) and run the SEIR simulation. We then plot the aggregated compartments and compare the simulation output with external data.

In [None]:
# -------------------------------------------------------- #
# 3. Demo SEIR
# -------------------------------------------------------- #

n_age = 17

# Read in contact matrix and age distribution
contact_matrix = pd.read_csv("data/seir_contact_matrix.csv").values
N = pd.read_csv("data/sa_ages.csv").n.values

latent_period = 5.0
infectious_period = 7.0
immunity_period = 365

I0 = 20

Rt = [3, 0.9, 1.8]
tt_Rt = [0, 90, 180]

model = SEIR(contact_matrix, latent_period, infectious_period, immunity_period, N, I0, seed=42)

model.run(t_end=730, Rt=Rt, tt_Rt=tt_Rt, dt=0.1)

model.plot_output()

dat = pd.read_csv(
    "https://raw.githubusercontent.com/mrc-ide/global-lmic-reports/refs/heads/main/ZAF/2022-06-20/projections.csv"
)
dat["date"] = pd.to_datetime(dat["date"])
dat["t"] = (dat["date"] - dat["date"].min()).dt.days
dat["t"] = dat["t"].astype(int)

model.plot_output()
plt.plot(
    dat.loc[(dat["compartment"] == "infections") 
            & (dat["scenario"] == "Maintain Status Quo") 
            & (dat["t"] >= 110) 
            & (dat["t"] <= 730 + 110), "t"] - 110,
    dat.loc[(dat["compartment"] == "infections") 
            & (dat["scenario"] == "Maintain Status Quo") 
            & (dat["t"] >= 110) 
            & (dat["t"] <= 730 + 110), "y_median"],
    label='External Data'
)
plt.legend()
plt.show()

## 4. Latin Hypercube Sampling (LHS) for Parameter Exploration

Latin Hypercube Sampling (LHS) is a statistical method that efficiently explores a high-dimensional parameter space. In our model, we sample parameters such as the latent period, infectious period, immunity duration, reproduction numbers, and switching times. 

This approach helps us generate diverse simulation runs so that the surrogate model can be trained on a wide range of scenarios.

In [None]:
# -------------------------------------------------------- #
# 4. LHS
# -------------------------------------------------------- #

from scipy.stats import qmc

def sample_parameter_sets(
    n_samples, latent_range, infectious_range, immunity_range, Rt_range, tt_Rt_range, seed=None
):
    """
    Sample parameter sets using Latin Hypercube Sampling.
    """
    n_params = 9
    sampler = qmc.LatinHypercube(d=n_params, seed=seed)
    sample_unit = sampler.random(n=n_samples)

    latent_samples = latent_range[0] + sample_unit[:, 0] * (latent_range[1] - latent_range[0])
    infectious_samples = infectious_range[0] + sample_unit[:, 1] * (infectious_range[1] - infectious_range[0])
    immunity_samples = immunity_range[0] + sample_unit[:, 2] * (immunity_range[1] - immunity_range[0])

    Rt_samples = []
    for i in range(3):
        col = 3 + i
        Rt_samples.append(Rt_range[0] + sample_unit[:, col] * (Rt_range[1] - Rt_range[0]))
    Rt_samples = np.column_stack(Rt_samples)

    tt_Rt_samples = []
    for i in range(3):
        col = 6 + i
        tt_Rt_samples.append(tt_Rt_range[0] + sample_unit[:, col] * (tt_Rt_range[1] - tt_Rt_range[0]))
    tt_Rt_samples = np.column_stack(tt_Rt_samples)

    parameter_sets = []
    for i in range(n_samples):
        params = {
            "latent_period": latent_samples[i],
            "infectious_period": infectious_samples[i],
            "immunity_period": immunity_samples[i],
            "Rt": list(Rt_samples[i, :]),
            "tt_Rt": sorted(list(tt_Rt_samples[i, :]))
        }
        parameter_sets.append(params)

    return parameter_sets

def sample_and_run_models(
    n_samples,
    contact_matrix,
    N,
    I0,
    t_end,
    latent_range,
    infectious_range,
    immunity_range,
    Rt_range,
    tt_Rt_range,
    dt=0.1,
    seed=None,
):
    """
    Sample parameter sets using LHS and run the SEIR model for each set.
    """
    parameter_sets = sample_parameter_sets(
        n_samples, latent_range, infectious_range, immunity_range, Rt_range, tt_Rt_range, seed=seed
    )

    results = []
    for params in parameter_sets:
        model = SEIR(
            contact_matrix,
            params["latent_period"],
            params["infectious_period"],
            params["immunity_period"],
            N,
            I0,
            seed=seed,
        )
        model.run(t_end=t_end, Rt=params["Rt"], tt_Rt=params["tt_Rt"], dt=dt)
        output = model.get_output()
        results.append((params, output))

    return results

def plot_sampled_results(sampled_results):
    """
    Plot the simulation outputs for multiple sampled runs.
    """
    plt.figure(figsize=(12, 6))

    for i, (params, output) in enumerate(sampled_results):
        t = output["time"]
        total_incidence = output["total_incidence"]
        label_str = (
            f"latent={params['latent_period']:.2f}, "
            f"inf={params['infectious_period']:.2f}, "
            f"imm={params['immunity_period']:.0f}, "
            f"Rt=[{params['Rt'][0]:.2f}, {params['Rt'][1]:.2f}, {params['Rt'][2]:.2f}], "
            f"tt_Rt=[{params['tt_Rt'][0]:.0f}, {params['tt_Rt'][1]:.0f}, {params['tt_Rt'][2]:.0f}]"
        )
        plt.plot(t, total_incidence, label=label_str)

    plt.xlabel("Time (days)")
    plt.ylabel("Daily Total Incidence")
    plt.title("Daily Total Incidence for Sampled Parameter Sets")
    plt.legend(fontsize="small", loc="upper right")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# LHS Demo

latent_range = (2.0, 4.0)
infectious_range = (4.0, 7.0)
immunity_range = (365*1.5, 365*1.5)
Rt_range = (0.5, 4.0)
tt_Rt_range = (50, 200)

n_samples = 5

sampled_results = sample_and_run_models(
    n_samples,
    contact_matrix,
    N,
    I0,
    t_end=730,
    latent_range=latent_range,
    infectious_range=infectious_range,
    immunity_range=immunity_range,
    Rt_range=Rt_range,
    tt_Rt_range=tt_Rt_range,
    dt=0.1,
    seed=42,
)

for i, (params, output) in enumerate(sampled_results):
    print(f"Sample {i+1}:")
    print("Parameters:")
    print(params)
    print("Total incidence on final day:", output["total_incidence"][-1])
    print("-" * 50)

plot_sampled_results(sampled_results)

## 5. Surrogate Training with LSTM and GRU

In the final section we build surrogate models to emulate the output of the SEIR simulation. 

### Dataset and Scaling

We first construct a custom PyTorch dataset that for each simulation run produces an input sequence `X` of shape `(T, 1+9)`:

- The **first column** is a normalized time feature (values between 0 and 1).
- The **remaining columns** are the simulation parameters (latent period, infectious period, immunity period, Rt values, and switching times) repeated over all time steps. 

The target `Y` is the time series of total incidence (per day). Both inputs and targets are scaled using `StandardScaler`.

### LSTM and GRU Models

We define two architectures:

- **LSTM Model:** Uses LSTM cells as described earlier.
- **GRU Model:** Uses GRU cells.

The surrogate learns to predict the output sequence from the input sequence. During training, the model minimizes the mean-squared error (MSE) between its predictions and the true simulation output.

Below is the code for the dataset definition, model definitions, training loop, and prediction plotting. 

The training process uses mixed precision (via `torch.cuda.amp`) for efficiency and includes learning rate scheduling.

In [None]:
# -------------------------------------------------------- #
# 6. DLS Surrogate Training
# -------------------------------------------------------- #

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import time
import random
import math
from sklearn.preprocessing import StandardScaler

##############################################################################
# 0. Dataset Definition with Data Scaling
##############################################################################

class SEIRTimeSeriesDataset(Dataset):
    """
    Expects a list of tuples (params, output). For each sample, creates an input sequence X of shape (T, 10) where:
      - Column 0: normalized time feature (values between 0 and 1)
      - Columns 1-10: 9-dimensional parameter vector repeated at each time step
    The target Y is the total incidence time series (reshaped to (T, 1)).
    """

    def __init__(self, results_list, x_scaler=None, y_scaler=None):
        self.samples = []
        self.seq_len = None
        self.x_scaler = x_scaler
        self.y_scaler = y_scaler
        for params, output in results_list:
            p = np.array([
                params["latent_period"],
                params["infectious_period"],
                params["immunity_period"]
            ] + params["Rt"] + params["tt_Rt"], dtype=np.float32)
            y = np.array(output["total_incidence"], dtype=np.float32)
            if self.seq_len is None:
                self.seq_len = len(y)
            else:
                assert len(y) == self.seq_len, "All sequences must have the same length."
            self.samples.append((p, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        p, y = self.samples[idx]
        T = self.seq_len
        time_feature = np.linspace(0, 1, T).reshape(T, 1).astype(np.float32)
        p_repeated = np.tile(p, (T, 1))
        X = np.concatenate([time_feature, p_repeated], axis=1)
        if self.x_scaler is not None:
            scaled_params = self.x_scaler.transform(p.reshape(1, -1))
            scaled_params = np.tile(scaled_params, (T, 1))
            X[:, 1:] = scaled_params
        if self.y_scaler is not None:
            y = self.y_scaler.transform(y.reshape(-1, 1)).reshape(-1)
        return torch.tensor(X, dtype=torch.float32), torch.tensor(y.reshape(-1, 1), dtype=torch.float32)

##############################################################################
# 1. Model Definitions: LSTM and GRU
##############################################################################

class SEIRLSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout_prob=0.1):
        super(SEIRLSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers=num_layers,
            dropout=dropout_prob if num_layers > 1 else 0.0,
            batch_first=True,
        )
        self.ln = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.ln(out)
        out = self.dropout(out)
        out = self.fc(out)
        return out

class SEIRGRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout_prob=0.1):
        super(SEIRGRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(
            input_size,
            hidden_size,
            num_layers=num_layers,
            dropout=dropout_prob if num_layers > 1 else 0.0,
            batch_first=True,
        )
        self.ln = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.gru(x)
        out = self.ln(out)
        out = self.dropout(out)
        out = self.fc(out)
        return out

##############################################################################
# 2. Training Function
##############################################################################

def train_model(model, train_loader, val_loader, epochs, optimizer, scheduler, device):
    criterion = nn.MSELoss()
    scaler = GradScaler()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        epoch_start = time.time()

        for X, Y in train_loader:
            X = X.to(device)
            Y = Y.to(device)

            optimizer.zero_grad()
            with autocast():
                pred = model(X)
                loss = criterion(pred, Y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch}: Train Loss = {avg_train_loss:.4f} | Duration: {time.time()-epoch_start:.2f}s")

        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for X_val, Y_val in val_loader:
                X_val = X_val.to(device)
                Y_val = Y_val.to(device)
                with autocast():
                    pred_val = model(X_val)
                    loss_val = criterion(pred_val, Y_val)
                total_val_loss += loss_val.item()
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"          Val Loss   = {avg_val_loss:.4f}")

        if scheduler is not None:
            scheduler.step(avg_train_loss)

##############################################################################
# 3. Prediction Plotting Function (Side by Side for LSTM and GRU)
##############################################################################

def plot_predictions(model_lstm, model_gru, dataset, device, y_scaler, n_samples=5):
    model_lstm.eval()
    model_gru.eval()
    indices = np.random.choice(len(dataset), n_samples, replace=False)

    n_cols = 5
    n_rows = math.ceil(n_samples / n_cols)

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
    if n_rows * n_cols == 1:
        axs = [axs]
    else:
        axs = axs.flatten()

    for ax in axs[n_samples:]:
        ax.axis("off")

    for i, idx in enumerate(indices):
        X, Y = dataset[idx]
        X = X.unsqueeze(0).to(device)
        with torch.no_grad():
            pred_lstm = model_lstm(X)
            pred_gru = model_gru(X)

        pred_lstm = pred_lstm.squeeze(0).cpu().numpy()
        pred_gru = pred_gru.squeeze(0).cpu().numpy()
        Y = Y.squeeze(1).cpu().numpy()

        pred_lstm_orig = y_scaler.inverse_transform(pred_lstm)
        pred_gru_orig = y_scaler.inverse_transform(pred_gru)
        Y_orig = y_scaler.inverse_transform(Y.reshape(-1, 1)).reshape(-1)

        ax = axs[i]
        ax.plot(Y_orig, label="Ground Truth", color="black")
        ax.plot(pred_lstm_orig, label="LSTM", linestyle="--", color="blue")
        ax.plot(pred_gru_orig, label="GRU", linestyle=":", color="red")
        ax.set_title(f"Sample {idx}")
        ax.set_xlabel("Time Step")
        ax.set_ylabel("Total Incidence")
        ax.legend()

    plt.tight_layout()
    plt.show()

##############################################################################
# 5. Main Script for Surrogate Training
##############################################################################

train_params = np.array([
    np.array([
        p["latent_period"], p["infectious_period"], p["immunity_period"]
    ] + p["Rt"] + p["tt_Rt"])
    for (p, o) in train_results
], dtype=np.float32)

x_scaler = StandardScaler().fit(train_params)

train_targets = np.array([
    o["total_incidence"] for (p, o) in train_results
], dtype=np.float32)
y_scaler = StandardScaler().fit(train_targets.reshape(-1, 1))

train_dataset = SEIRTimeSeriesDataset(train_results, x_scaler=x_scaler, y_scaler=y_scaler)
val_dataset = SEIRTimeSeriesDataset(test_results, x_scaler=x_scaler, y_scaler=y_scaler)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_size = 10  
hidden_size = 128
output_size = 1
num_layers = 2
dropout_prob = 0.05

# Create both LSTM and GRU models
model_lstm = SEIRLSTMModel(input_size, hidden_size, output_size, num_layers, dropout_prob).to(device)
model_gru = SEIRGRUModel(input_size, hidden_size, output_size, num_layers, dropout_prob).to(device)

optimizer_lstm = optim.Adam(model_lstm.parameters(), lr=1e-3)
scheduler_lstm = optim.lr_scheduler.StepLR(optimizer_lstm, step_size=10, gamma=0.5)

optimizer_gru = optim.Adam(model_gru.parameters(), lr=1e-3)
scheduler_gru = optim.lr_scheduler.StepLR(optimizer_gru, step_size=10, gamma=0.5)

epochs = 300
print("Training LSTM model...")
train_model(model_lstm, train_loader, val_loader, epochs, optimizer_lstm, scheduler_lstm, device)
print("Training GRU model...")
train_model(model_gru, train_loader, val_loader, epochs, optimizer_gru, scheduler_gru, device)

plot_predictions(model_lstm, model_gru, val_dataset, device, y_scaler, n_samples=5)

# Evaluate on additional test samples
for i in range(10):
    n_samples = 5
    test_results = sample_and_run_models(n_samples, contact_matrix, N, I0, t_end,
                                          latent_range, infectious_range, immunity_range,
                                          Rt_range, tt_Rt_range, dt=0.1)
    val_dataset = SEIRTimeSeriesDataset(test_results, x_scaler=x_scaler, y_scaler=y_scaler)
    batch_size = 32
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    plot_predictions(model_lstm, model_gru, val_dataset, device, y_scaler, n_samples=n_samples)


## Conclusion

In this notebook we:

- Introduced the theory behind infectious disease (ID) modelling with the SEIR compartmental framework.
- Discussed the benefits of surrogate models to approximate expensive simulations.
- Reviewed the basics of RNNs and provided detailed equations for LSTM and GRU cells.
- Demonstrated parameter exploration using Latin Hypercube Sampling (LHS).
- Trained and compared two surrogate models (one based on LSTM and one on GRU) to predict the SEIR simulation output.

This integrated approach can be extended to other epidemiological models and serves as a foundation for rapid prediction and uncertainty quantification in infectious disease dynamics.

Feel free to modify and expand upon this notebook for your research or teaching needs!