# Train Gaussian process regression model

In this notebook we will train the Gaussian process (GP) regression model that we will later use for the sensitivity analysis.

We will go through the following steps:

* Load the dataset.
* Prepare the training and validation data.
* Train a GP regression model.
* Check the model predictions.
* Save the trained model parameters to a file.


## Dependencies

First we import the dependencies.

If you are in Colab, you need to install the [pyro](https://pyro.ai/) package by uncommenting and running the line `!pip3 install pyro-ppl` below before proceeding.

In [None]:
# install dependencies
# !pip3 install pyro-ppl

In [None]:
# imports
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import pyro
import pyro.contrib.gp as gp

pyro.set_rng_seed(0)
print(f"torch version: {torch.__version__}")
print(f"pyro version: {pyro.__version__}")

## Load dataset

We can load the dataset directly from the GitHub URL.
Alternatively, the dataset can be loaded from a local file.

In [None]:
# load dataset
dataset_path = "https://raw.githubusercontent.com/BIG-MAP/sensitivity_analysis_tutorial/main/data/p2d_sei_10k.csv"
# dataset_path = "data/p2d_sei_10k.csv"  # local
df = pd.read_csv(dataset_path, index_col=0)

# store the names of the features and the name of the target variable
features = df.columns[:15].tolist()  # use input parameters as features
target = "SEI_thickness(m)"  # primary target
# target = "Capacity loss (%)"  # secondary target

## Prepare training and validation data

In preparation for training the GP regression model we do a few data transformations:

* The target variable is log transformed and normalised to zero mean and unit variance.
* The input features are normalised to zero mean and unit variance to make the kernel parameters easier to learn and to put the inputs on the same scale and thus make results for each input directly comparable. 

Finally, the data is split into a training and a validation set. 

In [None]:
# helper functions

def create_data_split_index(n_data, n_train, n_valid=None, shuffle=False):
    """Create data split index."""
    n_valid = n_data - n_train if n_valid is None else n_valid        
    index = torch.randperm(n_data) if shuffle else torch.arange(n_data)
    split = {
        "train": index[:n_train],
        "valid": index[n_train:n_train + n_valid],
        "rest":  index[n_train + n_valid:],
    }
    return split

def create_normaliser(x, y):
    """Create data normalisation function"""
    x_mean, x_std = x.mean(axis=0), x.std(axis=0)
    y_mean, y_std = y.mean(axis=0), y.std(axis=0)
    def normaliser(x, y):
        return (x - x_mean) / x_std, (y - y_mean) / y_std
    normaliser_params = {"x_mean": x_mean, "x_std": x_std, "y_mean": y_mean, "y_std": y_std}
    return normaliser, normaliser_params

In [None]:
# settings
shuffle = False
n_data = len(df)
n_train = 5000
n_valid = 5000

assert n_train + n_valid <= n_data

# create data tensors
x_data_orig = torch.tensor(df[features].values, dtype=torch.float)
y_data_orig = torch.tensor(df[target].values, dtype=torch.float)

# log transform y
y_data_orig = torch.log(y_data_orig)

# create data split index
split = create_data_split_index(n_data, n_train, n_valid)

# create normalisation function from training split
normaliser, normaliser_params = create_normaliser(x_data_orig[split["train"]], y_data_orig[split["train"]])

# normalise data
x_data, y_data = normaliser(x_data_orig, y_data_orig)

# create data splits 
x_train, y_train = x_data[split["train"]], y_data[split["train"]]
x_valid, y_valid = x_data[split["valid"]], y_data[split["valid"]]

assert len(x_train) == len(y_train) == n_train
assert len(x_valid) == len(y_valid) == n_valid

n_bins = 50
plt.figure(figsize=(8,3))
plt.subplot(121)
plt.hist(y_train.numpy(), bins=n_bins)
plt.xlabel("y_train")
plt.subplot(122)
plt.hist(y_valid.numpy(), bins=n_bins)
plt.xlabel("y_valid")
plt.show()

## Train sparse GP regression model

Now we train the GP regression model that we will later use in the sensitivity analysis.
Specifically, we use the [SparseGPRegression](https://docs.pyro.ai/en/stable/contrib.gp.html#module-pyro.contrib.gp.models.sgpr) model from the [pyro](https://pyro.ai/) package because we have found it can handle rather large datasets while still being quite fast to train, and it is easy to use with automatic differentiation as we will see later.
Please refer to the [pyro documentation](https://docs.pyro.ai/en/stable/contrib.gp.html#module-pyro.contrib.gp.models.sgpr) for details about the model.

If at some point you want to apply this method on a small dataset, perhaps you do not need a sparse mode and you can use the simpler [GPRegression](https://docs.pyro.ai/en/stable/contrib.gp.html#module-pyro.contrib.gp.models.gpr) model instead.

The model training might take a minute to run. 

In [None]:
# helper functions

def mnll(loc, scale, targets):
    """Compute mean negative log likelihood."""
    log2pi = np.log(2 * np.pi)
    loglik = -0.5 * (torch.log(scale) + log2pi + (targets - loc)**2 / scale)
    return torch.mean(-loglik)

def rmse(y_true, y_pred):
    """Compute root mean squared error."""
    return torch.sqrt(torch.mean((y_true - y_pred)**2))

def mae(y_true, y_pred):
    """Compute mean absolute error."""
    return torch.mean(torch.abs(y_true - y_pred))

def r2(y_true, y_pred):
    """Compute coefficient of determination."""
    ssr = torch.sum((y_true - y_pred)**2)
    sst = torch.sum((y_true - torch.mean(y_true))**2)
    return 1 - (ssr / sst)

@torch.no_grad()
def evaluate(model, x, y):
    """Evaluate model."""
    mean, var = model(x, full_cov=False, noiseless=False)
    errors = dict()
    errors["mnll"] = mnll(mean, var, y).detach().item()
    errors["rmse"] = rmse(y, mean).detach().item()
    errors["mae"] = mae(y, mean).detach().item()
    errors["r2"] = r2(y, mean).detach().item()
    return errors

In [None]:
# train model

def train(
    x_train,
    y_train,
    x_valid,
    y_valid,
    n_inducing_points=100,
    n_steps=1000,
    eval_freq=100,
    jitter=1.0e-5
):
    pyro.clear_param_store()
    n_features = x_train.shape[1]

    # select the first n training points as the inducing inputs
    x_inducing = x_train[:n_inducing_points].clone()
    
    # initialise the kernel and model
    kernel = gp.kernels.RBF(input_dim=n_features, variance=torch.tensor(5.), lengthscale=torch.tensor(n_features * [10.]))
    model = gp.models.SparseGPRegression(x_train, y_train, kernel, Xu=x_inducing, jitter=jitter)

    # setup optimiser and loss function 
    optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

    errors = defaultdict(list)
    for step in range(n_steps):
        # train
        optimiser.zero_grad()
        loss = loss_fn(model.model, model.guide)
        loss.backward()
        optimiser.step()
        # evaluate
        if step == 0 or (step + 1) % eval_freq == 0:
            with torch.no_grad():
                errors["train_step"].append(step + 1)
                errors["train_loss"].append(loss.item() / len(x_train))
                for k,v in evaluate(model, x_train, y_train).items():
                    errors["train_" + k].append(v)
                for k,v in evaluate(model, x_valid, y_valid).items():
                    errors["valid_" + k].append(v)
            print(f"[{step + 1:5d}] train loss: {errors['train_loss'][-1]:7.4f} train mnll: {errors['train_mnll'][-1]:7.4f} valid mnll: {errors['valid_mnll'][-1]:7.4f}")        
    return model, errors
  

model, errors = train(x_train, y_train, x_valid, y_valid, n_steps=800, jitter=1.0e-4)

In [None]:
# plot training curve
plt.figure()
plt.plot(errors["train_step"], errors["train_mnll"], label="train mnll")
plt.plot(errors["train_step"], errors["valid_mnll"], label="valid mnll")
plt.xlabel("training step"); plt.ylabel("error")
plt.legend()
plt.grid()
plt.show()

We should see the training and validation errors go down with the number of training steps.
Go ahead and plot some of the other errors stored in the `errors` dictionary if you like.

## Check model predictions

Before we do any further analyses, we want to verify that the model fits the training data and makes good predictions on the held-out validation data. 

In [None]:
def evaluate_predictions(y_true, y_pred, lim=(-3,3), figsize=(5,5)):
    _r2 = r2(y_true, y_pred)  # coefficient of determination
    _mae = mae(y_true, y_pred)  # mean absolute error
    print(f"r2: {_r2:.4f}, mae: {_mae:.4f}\n")
    # plot y_true against y_pred
    plt.figure(figsize=figsize)
    plt.plot(lim, lim, color="k", linestyle="--", linewidth=1)
    plt.plot(y_true, y_pred, ".", alpha=0.1)
    plt.xlabel("y_true"); plt.ylabel("y_pred")
    plt.xlim(lim); plt.ylim(lim)
    plt.grid()
    plt.show()

In [None]:
# evaluate on training data
y_pred, y_var = model(x_train, full_cov=False, noiseless=False)
evaluate_predictions(y_train.detach(), y_pred.detach())

In [None]:
# evaluate on validation data
y_pred, y_var = model(x_valid, full_cov=False, noiseless=False)
evaluate_predictions(y_valid.detach(), y_pred.detach())

We should see that the model achieves a r2 value close to 1, indicating the model is able to explain most of the variation in the data, and that the predictions generally correlate with the true target values on both the training and validation data splits.

## Save trained model

Finally, we save the trained model parameters so we can use the model for analysis later.
We additionally save some data parameters that will be useful later.

IMPORTANT: If you are running this notebook in Colab, you should make sure to download the saved file as we will need it later in the tutorial series. 
You can find it in the Files section to the left (the small folder icon) after running the code below.

In [None]:
# store data normalisation parameters
pyro.param("data.x_mean", normaliser_params["x_mean"])
pyro.param("data.x_std", normaliser_params["x_std"])
pyro.param("data.y_mean", normaliser_params["y_mean"])
pyro.param("data.y_std", normaliser_params["y_std"])

# store data range parameters
pyro.param("data.x_min", x_data.min(dim=0)[0])
pyro.param("data.x_max", x_data.max(dim=0)[0])
pyro.param("data.y_min", y_data.min())
pyro.param("data.y_max", y_data.max())

# store training and validation data
pyro.param("data.x_train", x_train)
pyro.param("data.y_train", y_train)
pyro.param("data.x_valid", x_valid)
pyro.param("data.y_valid", y_valid)

# save model parameters in a file
print(pyro.get_param_store().keys())
if target == "SEI_thickness(m)":
    pyro.get_param_store().save("sgpr_params_sei.p")
if target == "Capacity loss (%)":
    pyro.get_param_store().save("sgpr_params_cap.p")
  
# !!! remember to download the saved file !!!