
# Gaussian Process Emulator Demonstration

This notebook demonstrates handling **Gaussian Process (GP)** emulators used in the work *Gaussian Process emulation for exploring complex infectious disease models*, which is currently available as a [preprint](https://www.medrxiv.org/content/10.1101/2024.11.28.24318136v2).

It walks through the following key steps:
1. Set up: Imports, Data Paths, and Parameter Space
2. Loading the Gaussian Process emulator
3. Evaluating GP performance
4. Sensitivity Analysis with the GP
5. Predictions with the GP
6. Sampling additional points based on GP predictions

For more information on GPs, please refer to:
- [GPyTorch Tutorials](https://gpytorch.ai)
- [Rasmussen & Williams (2006): *Gaussian Processes for Machine Learning*](http://www.gaussianprocess.org/gpml/)


## 1. Set up: Imports, Data Paths, and Parameter Space

In this section, we import the necessary libraries and define the key configuration elements used throughout the Gaussian Process emulation workflow.

* **`SIR_gp`**: This module contains the main Gaussian Process (GP) class implementation used for training and prediction.
* **`itertools.product`** and **`IPython.display.HTML`**: Utilities for data handling and HTML-based display of results within the notebook.
* **`emukit`**: Provides tools for **design of experiments** — here, we use its `ParameterSpace` and `LatinDesign` classes to perform **Latin Hypercube Sampling (LHS)**, ensuring that training data cover the input domain evenly.
* **`warnings.filterwarnings("ignore")`**: Suppresses non-critical warnings for cleaner notebook output.

Next, we define paths to:

* The **training data** (`PATH_TRAIN`)
* The **test data** (`PATH_TEST`)
* A **pretrained GP model snapshot** (`PATH_MODEL`)

Please check the **README** in the `../data/GP/data` directory for dataset details. 

The variable **`MODEL_TYPE`** indicates which epidemiological outcome the GP emulator models — for example:

* `"maxIncidence"`: maximum incidence
* `"establishment"`: outbreak probability
* `"duration"`: epidemic duration (log-10 transformed)

Finally, we define the **parameter space** using `emukit`’s `ParameterSpace`.
Each `ContinuousParameter` specifies a model parameter and its range.

The table below maps the variable names used in the code to their corresponding names and explanations in the manuscript.


| **Code Variable**  | **Manuscript Name**      | **Description**                                                                                                                    | **Range**    |
| ------------------ | ------------------------ | ---------------------------------------------------------------------------------------------------------------------------------- | ------------ |
| `alphaRest`        | **Average infectivity**  | Average infection probability across a year, removing the effect of seasonality.                                                   | [0, 0.03]    |
| `alphaAmp`         | **Seasonality strength** | Scaling factor (0–1) controlling the magnitude of seasonal variation in infection probability.                                     | [0, 1]       |
| `alphaShift`       | **First case timing**    | Timing of the first case relative to the seasonal peak in infection probability.                                                   | [0, 1]       |
| `infTicksCount`    | **Infectious period**    | Average number of days an individual remains infectious; actual durations vary probabilistically around this value.                | [4, 6]       |
| `avgVisitsCount`   | **Average mobility**     | Average number of visits a person makes to locations per day (in addition to their home).                                          | [1, 5]       |
| `pVisits`          | **Mobility skewness**    | Success probability in the negative binomial distribution determining daily visit counts — lower values yield greater variability. | [0.05, 0.95] |
| `propSocialVisits` | **Social structure**     | Probability that a visit occurs within an individual’s family cluster.                                                             | [0, 1]       |
| `locPerSGCount`    | **Family cluster size**  | Average number of locations per family cluster; actual sizes are probabilistically rounded around this value.                      | [1, 20]      |


In [1]:
from SIR_gp import * #class implementation of the GP
from itertools import product
from IPython.display import HTML
from emukit.core import ParameterSpace, ContinuousParameter #emukit for LHS
from emukit.core.initial_designs.latin_design import LatinDesign #emukit for LHS
import warnings
warnings.filterwarnings("ignore")

PATH_TRAIN = "../data/GP/data/sim-training-maxIncidence-round15.txt" #training data 
PATH_TEST = "../data/GP/data/DD-AML-test-LHS-10000-condSim-logDuration.txt" #test data 
PATH_MODEL = "../data/GP/model/maxIncidence-round15-snap3.pth" #trained GP model snapshot 
MODEL_TYPE = "maxIncidence" #defines GP type: imax = maxIncidence, outbreak probability = establishment, duration = duration 

PARAM_RANGES = ParameterSpace([
    ContinuousParameter("alphaRest",0 , 0.03),
    ContinuousParameter("alphaAmp", 0, 1),
    ContinuousParameter("alphaShift", 0, 1),
    ContinuousParameter("infTicksCount", 4, 6),
    ContinuousParameter("avgVisitsCount", 1, 5),
    ContinuousParameter("pVisits", 0.05, 0.95),
    ContinuousParameter("propSocialVisits", 0, 1),
    ContinuousParameter("locPerSGCount", 1, 20),
])


## 2. Loading the Gaussian Process emulator

In this step, we initialize and load a pre-trained Gaussian Process (GP) emulator based on the `SIR_GP` class defined in `SIR_gp.py`:


In [2]:
myGP = SIR_GP(training_data=PATH_TRAIN, model_type=MODEL_TYPE)
myGP.load(filename=PATH_MODEL)


Model loaded. Loss: -1.7712738513946533



* The first line creates an instance of the `SIR_GP` class.

  * `training_data=PATH_TRAIN` points to the file containing the training dataset used to fit the GP.
  * `model_type=MODEL_TYPE` specifies which epidemiological outcome the emulator predicts (e.g., *maximum incidence*, *outbreak probability*, or *epidemic duration*).

* The second line loads a pre-trained model snapshot (`.pth` file) from disk.
  This includes the learned kernel hyperparameters obtained during training.

Internally, `SIR_GP` automatically detects whether a GPU is available and places the model on the appropriate device.

This ensures the emulator runs efficiently on GPU-equipped systems but remains fully functional on CPU-only machines.

Once loaded, `myGP` is ready to make predictions for new parameter combinations sampled from the defined input space.


### 2.1. Training the GP model

The code below demonstrates how to train the GP model for **one iteration** .  

The `train` method updates the GP model using the training data we provided when creating `myGP` and returns the loss after training.

`num_iterations=1` means the model will go through the training data only **once**. Normally, you would use many iterations (hundreds or thousands) to get a well-trained model.

**Note:** Training for more iterations can take a long time, especially if a large dataset.

In [3]:
myGP.train(num_iterations=1)

-1.7759909629821777

## 3. Evaluating Gaussian Process performance

* `get_rmse()` evaluates how well the GP model predicts unseen data.
* `test_data` is a CSV file containing inputs and true outputs for the test set.
* RMSE (Root Mean Square Error) measures the average difference between predicted and actual values:
  * Smaller RMSE → better predictions.
* This is a common way to check the **accuracy** of a surrogate model.


In [4]:
myGP.get_rmse(test_data=PATH_TEST)

0.042115769745832025

## 4. Sensitivity analysis with the GP

### 4.1. Sensitivity analysis across whole input domain 
`myGP.param_ranges` shows the **range of values** each input parameter can take. These define the input domain for the model.
   Example: `alphaRest` ranges from 0 to 0.03, `alphaAmp` from 0 to 1, etc.


In [5]:
myGP.param_ranges # View the input parameter ranges (model domain)


{'alphaRest': (0, 0.03),
 'alphaAmp': (0, 1),
 'alphaShift': (0, 1),
 'infTicksCount': (4, 6),
 'avgVisitsCount': (1, 5),
 'pVisits': (0.05, 0.95),
 'propSocialVisits': (0, 1),
 'locPerSGCount': (1, 20)}

`sensitivity_analysis()` performs a **Sobol sensitivity analysis**, which tells us how sensitive the model output is to each input parameter.

- `pow2sampleSize=10` means the base number of Sobol samples is:  

$$
n = 2^{10} = 1024
$$

- The total number of points evaluated is calculated using the formula:  

$$
N_\text{total} = n \times (2d + 2)
$$

where $d$ is the number of input parameters.

- In our case, $d = 8$, so:

$$
N_\text{total} = 1024 \times (2 \cdot 8 + 2) = 1024 \times 18 = 18432
$$


In [6]:
# Perform a Sobol sensitivity analysis
my_sa = myGP.sensitivity_analysis(
    pow2sampleSize=10,          # n = 2^10 
    param_ranges=myGP.param_ranges,  # Use the defined input ranges
    verbose=True                # Print progress and results
)

Points to be evaluated: 18432
Fished predictions. Starting sensitivity analysis.
                        ST   ST_conf
alphaRest         0.573244  0.048126
alphaAmp          0.037860  0.005755
alphaShift        0.129500  0.018578
infTicksCount     0.017824  0.001582
avgVisitsCount    0.296127  0.030926
pVisits           0.001423  0.000143
propSocialVisits  0.045998  0.008509
locPerSGCount     0.003858  0.001101
                        S1   S1_conf
alphaRest         0.547259  0.062607
alphaAmp          0.009152  0.014850
alphaShift        0.080566  0.025421
infTicksCount     0.015681  0.013017
avgVisitsCount    0.272837  0.050451
pVisits           0.000877  0.003000
propSocialVisits  0.025461  0.015755
locPerSGCount     0.001176  0.004889
                                          S2   S2_conf
(alphaRest, alphaAmp)              -0.021582  0.093106
(alphaRest, alphaShift)            -0.015278  0.096897
(alphaRest, infTicksCount)         -0.018158  0.090487
(alphaRest, avgVisitsCount)      


The function `sensitivity_analysis()` returns **three DataFrames**:

1. **Total Effects (ST, ST_conf)**

   * `ST` = total sensitivity index for each parameter, including interactions.
   * `ST_conf` = confidence interval for `ST`.

2. **First-Order Effects (S1, S1_conf)**

   * `S1` = contribution to output variance from the parameter **alone**.
   * `S1_conf` = confidence interval for `S1`.

3. **Second-Order Effects (S2, S2_conf)**

   * `S2` = contribution to variance from **interactions between pairs of parameters**.
   * `S2_conf` = confidence interval for `S2`.

**Key takeaways:**

* High `ST`: parameter strongly influences the output.
* Low `S1` but high `ST`: parameter important mostly through interactions.
* `S2` highlights significant pairwise interactions.


### 4.2. Sensitivity analysis over a conditional parameter subdomain 

Sometimes, we are interested in how the model behaves in constrained scenarios, rather than across the full range of input values.  
For example, we might want to study the effects of varying some parameters while keeping others fixed (here: `alphaRest`, `avgVisitsCount`) at typical or average values. 


In [7]:
# Define a subset of the parameter ranges
param_ranges = {
    'alphaRest': (0.015),      # fix average infectivity to 0.015
    'alphaAmp': (0, 1),
    'alphaShift': (0, 1),
    'infTicksCount': (4, 6),
    'avgVisitsCount': (2),     # fix average mobility to 2
    'pVisits': (0.05, 0.95),
    'propSocialVisits': (0, 1),
    'locPerSGCount': (1, 20)
}

# Perform sensitivity analysis over this subdomain
sa = myGP.sensitivity_analysis(
    pow2sampleSize=10, 
    param_ranges=param_ranges, 
    verbose=True
)

Points to be evaluated: 18432
Fished predictions. Starting sensitivity analysis.
                            ST       ST_conf
alphaRest         4.162834e-11  5.725125e-12
alphaAmp          1.565111e-01  2.165385e-02
alphaShift        6.831600e-01  7.663662e-02
infTicksCount     8.675130e-02  8.977722e-03
avgVisitsCount    1.758958e-11  2.273427e-12
pVisits           6.316238e-03  6.522523e-04
propSocialVisits  2.067304e-01  4.021282e-02
locPerSGCount     1.730275e-02  3.721598e-03
                            S1       S1_conf
alphaRest        -3.551905e-08  5.796783e-07
alphaAmp         -5.781783e-03  3.837352e-02
alphaShift        5.089210e-01  7.524991e-02
infTicksCount     8.592577e-02  3.221632e-02
avgVisitsCount   -2.148522e-07  3.660348e-07
pVisits           7.365714e-03  6.379119e-03
propSocialVisits  1.880176e-01  5.359591e-02
locPerSGCount     5.115693e-03  1.051988e-02
                                              S2       S2_conf
(alphaRest, alphaAmp)              -1.731307e-

## 5. Predictions with the GP

### 5.1. Predicting model outputs for Latin Hypercube Sampled points


1. **Latin Hypercube Sampling (LHS)**

   * `LatinDesign(PARAM_RANGES).get_samples(10000)` generates 10,000 candidate points that evenly explore the input space.

2. **Reformatting for the GP model**

   * GP models require PyTorch tensors, so we convert the numpy array using `torch.from_numpy(...).float().contiguous()`.

3. **Predictions with the GP**

   * `predict_ys(parsed_data=res)` evaluates the GP at each candidate point.
   * Returns three tensors:

     * `predictions`: mean predicted outputs.
     * `lower` and `upper`: confidence intervals for each prediction.


In [8]:
candidates = LatinDesign(PARAM_RANGES).get_samples(10000) #obtain LHS 
res = torch.from_numpy(candidates).float().contiguous() #reformat data
predictions, lower, upper = myGP.predict_ys(parsed_data = res) #predicting model outputs for LHS points

print(f"Mean predicted value: {predictions.mean().item():.3f}")


Mean predicted value: 0.530



### 5.2. Predicting model outputs for test data

1. `pd.read_csv(PATH_TEST, sep='\t')` loads the test dataset containing observed outputs from the individual-based model.
2. `myGP.predict(test_data=PATH_TEST)` predicts outputs for the same inputs.
3. We extract the observed values for the current model type (`maxIncidence`, `epidemicSize`, etc.).
4. A quick comparison DataFrame shows observed vs predicted values for the first few test points.
5. Summary statistics (mean values) give a quick check of prediction accuracy before formal metrics like RMSE.

In [9]:
# Load the test dataset (observed values)
test = pd.read_csv(PATH_TEST, sep='\t')
test.info()  # Check the structure and number of rows/columns

# Predict outputs using the GP model
predicted_mean, lower, upper = myGP.predict(test_data=PATH_TEST)

# Extract the observed values for the same model type
y_id = getColumnIndex(myGP.model_type)
observed = test.iloc[:, y_id].values

# Compare observed vs predicted
comparison = pd.DataFrame({
    'Observed': observed,
    'Predicted': predicted_mean.numpy()
})
print(comparison.head())  # show first few rows

print("\nPrediction summary:")
print(f"Mean observed value: {observed.mean():.3f}")
print(f"Mean predicted value: {predicted_mean.mean().item():.3f}")

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 16 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   simRound          10000 non-null  float64
 1   simID             10000 non-null  float64
 2   alphaRest         10000 non-null  float64
 3   alphaAmp          10000 non-null  float64
 4   alphaShift        10000 non-null  float64
 5   infTicksCounts    10000 non-null  float64
 6   avgVisitsCounts   10000 non-null  float64
 7   pVisits           10000 non-null  float64
 8   propSocialVisits  10000 non-null  float64
 9   locPerSGCount     10000 non-null  float64
 10  maxIncidence      10000 non-null  float64
 11  epidemicSize      10000 non-null  float64
 12  duration          10000 non-null  float64
 13  sd_maxIncidence   10000 non-null  float64
 14  sd_epidemicSize   10000 non-null  float64
 15  sd_duration       10000 non-null  float64
dtypes: float64(16)
memory usage: 1.2 MB
   Ob

## 6. Sampling additional points based on GP predictions


1. **Latin Hypercube Sampling (LHS)**

   * `LatinDesign(PARAM_RANGES).get_samples(100000)` generates 100,000 candidate points evenly distributed across the input space.

2. **GP-based sampling of points**

   * `myGP.samplePoints(...)` selects `N=100` points from the candidates.
   * `p=0.5` means that 50% of the selected points are weighted by predicted outputs (i.e., policy 2), and the remaining 50% (1-p) are sampled based on prediction uncertainty (i.e., policy 1).
   * This approach focuses sampling on regions where the GP is most uncertain, which can be useful for active learning

3. **Result**

   * `candidates[id,]` shows the 100 points selected for further evaluation or simulation.


In [10]:
# Generate a large set of candidate points using Latin Hypercube Sampling (LHS)
candidates = LatinDesign(PARAM_RANGES).get_samples(100000)  # 100,000 samples

# Use the GP to sample 100 points from the candidate set
# 'p' controls the proportion of points weighted by predicted output uncertainty
id = myGP.samplePoints(candidates=candidates, N=100, p=0.5)

selected_points = candidates[id,]
print(selected_points[:10,]) #display first few selected points

[[2.7065550e-02 1.0152500e-01 3.0255500e-01 4.3884300e+00 1.0260200e+00
  9.1119650e-01 7.4299500e-01 2.0513650e+00]
 [1.4457450e-02 2.8037500e-01 7.2828500e-01 4.7442700e+00 4.1023800e+00
  4.3992050e-01 7.7506500e-01 6.5680450e+00]
 [2.2298850e-02 9.8111500e-01 8.3368500e-01 5.3509100e+00 3.1892200e+00
  3.1313750e-01 9.9070500e-01 1.3425715e+01]
 [1.1933250e-02 7.8837500e-01 5.4135000e-02 5.0783700e+00 4.8741000e+00
  9.4028450e-01 7.8372500e-01 6.4266850e+00]
 [1.7198850e-02 4.1352500e-01 7.4141500e-01 4.3611500e+00 4.8309400e+00
  2.8495850e-01 9.2865500e-01 1.5833950e+00]
 [2.2937550e-02 9.9353500e-01 8.4858500e-01 5.0419100e+00 3.0187000e+00
  4.1078750e-01 5.2321500e-01 1.9828905e+01]
 [2.4753150e-02 5.3258500e-01 9.2792500e-01 4.1244500e+00 2.9831800e+00
  1.6923650e-01 1.9439500e-01 7.4545850e+00]
 [9.5626500e-03 9.5660500e-01 3.5745500e-01 5.0351700e+00 2.5010200e+00
  7.4754500e-02 2.0913500e-01 7.1274050e+00]
 [4.3395000e-04 2.8523500e-01 4.3990500e-01 4.0386300e+00 4.3889