# Refactored HistoryMatching workflow example

## 1. Set up

In [1]:
import torch

# imports from main
from autoemulate.history_matching_dashboard import HistoryMatchingDashboard

# imports from experimental
from autoemulate.experimental.emulators.gaussian_process.exact import (
    GaussianProcessExact,
)
from autoemulate.experimental.simulations.epidemic import Epidemic
from autoemulate.experimental.history_matching import HistoryMatching, HistoryMatchingWorkflow

### Simulate data & train a GP

Set up a `Simulator` and generate data.

In [2]:
simulator = Epidemic()
x = simulator.sample_inputs(10)
y = simulator.forward_batch(x)

Running simulations: 100%|██████████| 10/10 [00:00<00:00, 553.21it/s]

Successfully completed 10/10 simulations (100.0%)





The next step should be done with `AutoEmulate.compare()`.

In [3]:
gp = GaussianProcessExact(x, y)
gp.fit(x, y)

## 2. HistoryMatching

Firstly, one can instantiate `HistoryMatching` without a simulator or an emulator. It can be used to calculate implausability for a given set of predictions.

In [4]:
# generate predictions for new x inputs
x_new = simulator.sample_inputs(10)
output = gp.predict(torch.tensor(x_new, dtype=torch.float32))
pred_means, pred_vars = (
    output.mean.float().detach(),
    output.variance.float().detach(),
)

  output = gp.predict(torch.tensor(x_new, dtype=torch.float32))


In [5]:
# Define observed data with means and variances
observations = {"infection_rate": (0.3, 0.05)}

# Create history matcher
hm = HistoryMatching(
    observations=observations,
    threshold=3.0
)

In [6]:
implausability = hm.calculate_implausibility(pred_means, pred_vars)

Once implausability has been calculated, it can be used to identify indices of NROY parameters:

In [7]:
hm.get_nroy(implausability)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

Or to filter parameters at those NROY indices:

In [8]:

hm.get_nroy(implausability, x_new)

tensor([[0.1709, 0.0508],
        [0.4401, 0.0828],
        [0.3762, 0.0232],
        [0.1998, 0.1910],
        [0.3041, 0.1451],
        [0.2680, 0.0296],
        [0.4938, 0.1741],
        [0.2367, 0.0964],
        [0.1025, 0.1294],
        [0.3978, 0.1181]])

Optionally, `HistoryMatching` can be instantiated with an emulator to make extracting prediction means and variances easier.

In [9]:
hm_with_emul = HistoryMatching(
    observations=observations,
    threshold=3.0,
    emulator=gp
)

pred_means, pred_vars = hm_with_emul.emulator_predict(x_new)
hm_with_emul.calculate_implausibility(pred_means, pred_vars)

tensor([[0.0526],
        [0.0409],
        [0.0536],
        [0.0313],
        [0.0283],
        [0.0547],
        [0.0281],
        [0.0363],
        [0.0300],
        [0.0315]])

## 3. Iterative HistoryMatchingWorkflow

We also have a separate class that implements an iterative sample-predict-refit workflow:
- sample `n_test_samples` to test from the NROY space
- use emulator to filter out implausible samples and update the NROY space
- run `n_simulations` predictions for the sampled parameters using the simulator
- refit the emulator using the simulated data

The object maintains and updates the internal state each time `run()` is called so this can be done as many times as the user wants.

In [10]:
hmw = HistoryMatchingWorkflow(
    simulator=simulator,
    emulator=gp,
    observations=observations,
    threshold=3.0,
    train_x=x,
    train_y=y
)

test_parameters, impl_scores = hmw.run(n_simulations=20, n_test_samples=100)

Running simulations: 100%|██████████| 20/20 [00:00<00:00, 1140.05it/s]

Successfully completed 20/20 simulations (100.0%)





In [11]:
test_parameters.shape, impl_scores.shape

(torch.Size([100, 2]), torch.Size([100, 1]))

We can call `run()` as many times as we want, the class stores states from previous runs.

In [12]:
test_parameters, impl_scores = hmw.run(n_simulations=20, n_test_samples=100)

Running simulations: 100%|██████████| 20/20 [00:00<00:00, 1008.67it/s]

Successfully completed 20/20 simulations (100.0%)





## 4. Integration with dashboard

In [13]:
dashboard = HistoryMatchingDashboard(
    samples=test_parameters,
    impl_scores=impl_scores,
    param_names=simulator.param_names,  
    output_names=simulator.output_names, 
    )

In [14]:
dashboard.display()

HTML(value='<h2>History Matching Dashboard</h2>')

VBox(children=(HBox(children=(Dropdown(description='Plot Type:', options=('Parameter vs Implausibility', 'Pair…