# Refactored HistoryMatching workflow example

## 1. Set up

In [1]:
import torch

# imports from main
from autoemulate.simulations.epidemic import simulate_epidemic
from autoemulate.history_matching_dashboard import HistoryMatchingDashboard

# imports from experimental
from autoemulate.experimental.types import TensorLike
from autoemulate.experimental.emulators.gaussian_process.exact import (
    GaussianProcessExact,
)
from autoemulate.experimental.simulations.base import Simulator
from autoemulate.experimental.history_matching import HistoryMatching

### Simulate

Set up a Simulator and generate data.

In [2]:
class EpidemicSimulator(Simulator):
    """
    Simulator of infectious disease spread (SIR).
    """

    def __init__(
        self, 
        param_ranges={"beta": (0.1, 0.5), "gamma": (0.01, 0.2)}, 
        output_names = ["distance"]
        ):
        super().__init__(param_ranges, output_names)

    def _forward(self, x: TensorLike) -> TensorLike:
        """
        Parameters
        ----------
        x : TensorLike
            input parameter values to simulate [beta, gamma]:
            - `beta`: the transimission rate per day
            - `gamma`: the recovery rate per day

        Returns
        -------
        infection : np.ndarray
            Peak infection rate.
        """
        y = simulate_epidemic(x.numpy()[0])
        return torch.tensor([y]).view(-1, 1)
    
simulator = EpidemicSimulator()
x = simulator.sample_inputs(10)
y = simulator.forward_batch(x)

Running simulations: 100%|██████████| 10/10 [00:00<00:00, 988.34it/s]

Successfully completed 10/10 simulations (100.0%)





### Train a GP

(this should be done with AutoEmulate obviously)

In [3]:
gp_pytorch = GaussianProcessExact(
        x,
        y,
    )
gp_pytorch.fit(x, y)

### Generate predictions

In [4]:
x = simulator.sample_inputs(5)
output = gp_pytorch.predict(torch.tensor(x, dtype=torch.float32))
pred_means, pred_vars = (
    output.mean.float().detach(),
    output.variance.float().detach(),
)

  output = gp_pytorch.predict(torch.tensor(x, dtype=torch.float32))


## 2. History Matching

Firstly, one can instantiate HistoryMatching without a simulator or an emulator. It can be used to calculate implausability for a given set of predictions.

In [5]:
# Define observed data with means and variances
observations = {"beta": (0.25, 0.05), "gamma": (0.1, 0.01)}

# Create history matcher
hm = HistoryMatching(
    observations=observations,
    threshold=3.0
)

implausability = hm.calculate_implausibility(pred_means, pred_vars)


Once implausability has been calculated, it can be used to identify NROY parameters.

In [8]:

hm.get_nroy(implausability, x)

tensor([[0.4235, 0.0585],
        [0.2786, 0.1565],
        [0.1068, 0.0143],
        [0.2131, 0.1845],
        [0.3751, 0.0925]])

## 3. Iterative HistoryMatching

We can execute an iterative sample-predict-evaluate procedure with `hm.run()`. In each wave:
- sample parameter values to test from the NROY space
    - at the start, NROY is the entire parameter space
    - use emulator to filter out implausible samples
- make predictions for the sampled parameters using the simulator
- refit the emulator using the simulated data

In [9]:
emulator = hm.run(
    n_waves=2,
    n_samples_per_wave=20,
    simulator=simulator,
    emulator=gp_pytorch,
)

Running simulations: 100%|██████████| 20/20 [00:00<00:00, 766.81it/s]
History Matching:  50%|█████     | 1/2 [00:00<00:00,  7.90wave/s, samples=20, failed=0, NROY=17, min_impl=0.02, max_impl=5.28]

Successfully completed 20/20 simulations (100.0%)


Running simulations: 100%|██████████| 15/15 [00:00<00:00, 840.47it/s]
History Matching: 100%|██████████| 2/2 [00:00<00:00,  9.62wave/s, samples=15, failed=0, NROY=15, min_impl=0.03, max_impl=2.08]


Successfully completed 15/15 simulations (100.0%)


In [10]:
hm.tested_params.shape, hm.impl_scores.shape

(torch.Size([35, 2]), torch.Size([35, 2]))

## 4. Integration with dashboard

In [11]:
dashboard = HistoryMatchingDashboard(
    samples=hm.tested_params,
    impl_scores=hm.impl_scores,
    param_names=simulator.param_names,  
    output_names=simulator.output_names, 
    )

In [12]:
dashboard.display()

HTML(value='<h2>History Matching Dashboard</h2>')

VBox(children=(HBox(children=(Dropdown(description='Plot Type:', options=('Parameter vs Implausibility', 'Pair…