<a href="https://colab.research.google.com/github/Sabelz/Master_Thesis_Alexander/blob/main/utils/functions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Functions for the project

# Imports

In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Master_Thesis_Alexander
!git config --global user.email "alexander.sabelstrom.1040@student.uu.se"
!git config --global user.name "Sabelz"

import numpy as np
import matplotlib.pyplot as plt
import torch
!pip install gpytorch > \dev\null # Suppress prints
import gpytorch
from matplotlib import pyplot as plt
import math
import jax
import jax.numpy as jnp

Mounted at /content/drive
/content/drive/MyDrive/Master_Thesis_Alexander


# Training Function

In [2]:
def train(model, likelihood, x_train, y_train, training_iter=10):
    """
    Trains a Gaussian Process (GP) model using the ExactMarginalLogLikelihood loss function.

    Parameters:
    model (gpytorch.models.ExactGP): The GP model to be trained.
    likelihood (gpytorch.likelihoods._OneDimensionalLikelihood): The likelihood function to be used with the model.
    x_train (torch.Tensor): The training data.
    y_train (torch.Tensor): The labels for the training data.
    training_iter (int, optional): The number of training iterations. Default is 10.
    train_loader (torch.utils.data.DataLoader, optional):

    Returns:
    None. The function operates in-place on the `model` and `likelihood` objects.

    Note:
    The function moves the model and likelihood to GPU if available and uses the Adam optimizer for training.
    """
    if torch.cuda.is_available():
      model = model.cuda()
      likelihood = likelihood.cuda()
    model.train()
    likelihood.train()
    # Use the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(x_train)
        # Calc loss and backprop gradients
        loss = -mll(output, y_train)
        loss.backward()
        optimizer.step()



# Training Function for Inducing Points GP

In [None]:
def train_ELBO(model, likelihood, x_train, y_train, training_iter=10, train_loader=None):
    """
    Trains a Gaussian Process (GP) model using the VariationalELBO loss function.

    Parameters:
    model (gpytorch.models.ApproximateGP): The GP model to be trained.
    likelihood (gpytorch.likelihoods._OneDimensionalLikelihood): The likelihood function to be used with the model.
    x_train (torch.Tensor): The training data.
    y_train (torch.Tensor): The labels for the training data.
    training_iter (int, optional): The number of training iterations. Default is 10.
    train_loader (torch.utils.data.DataLoader, optional):
    The DataLoader that provides batches of the training data.
    If None, the entire training dataset is used in each iteration.

    Returns:
    None. The function operates in-place on the `model` and `likelihood` objects.

    Note:
    The function moves the model and likelihood to GPU if available and uses the Adam optimizer for training.
    """
    if torch.cuda.is_available():
      model = model.cuda()
      likelihood = likelihood.cuda()
    # Parameters and and input data should be of same dtype
    model = model.double()
    likelihood = likelihood.double()

    model.train()
    likelihood.train()
    # Initialize MLL
    n_points = y_train.numel() # Amount of training points
    # Yes, when training a variational Gaussian Process (GP) model like ApproximateGP,
    # you should use a variational marginal log likelihood (MLL) instead of the exact MLL.
    mll = gpytorch.mlls.VariationalELBO(likelihood, model, n_points) # Loss
    # Use the adam optimizer
    optimizer = torch.optim.Adam(list(model.parameters()) + list(likelihood.parameters()), lr=0.1)

    if(train_loader == None):
      for i in range(training_iter):
          # Zero gradients from previous iteration
          optimizer.zero_grad()
          # Output from model
          output = model(x_train)
          # Calc loss and backprop gradients
          loss = -mll(output, y_train)
          loss.backward()
          optimizer.step()
    else: # If train_loader defined, use it
      for i in range(training_iter):
        for x_batch, y_batch in train_loader:
          # Zero gradients from previous iteration
          optimizer.zero_grad()
          # Output from model
          output = model(x_batch)
          # Calc loss and backprop gradients
          loss = -mll(output, y_batch)
          loss.backward()
          optimizer.step()

# Predict Function

In [None]:
def predict(model, likelihood, test_x):
    """
    This function makes predictions using a given model and likelihood.

    The function sets the model and likelihood to evaluation mode,
    then computes the likelihood of the model's predictions on the test data.
    It uses PyTorch's `no_grad` context manager to avoid tracking gradients during the prediction,
    and GPyTorch's `fast_pred_var` setting for efficient computation.

    Parameters:
    model (gpytorch.models.GP): The Gaussian Process model to make predictions with.
    likelihood (gpytorch.likelihoods.Likelihood): The likelihood associated with the model.
    test_x (torch.Tensor): The test inputs to make predictions on.

    Returns:
    gpytorch.distributions.MultivariateNormal: The distribution of the model's predictions.
    """
    model.eval()
    likelihood.eval()
    # Make predictions by feeding model through likelihood
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        return likelihood(model(test_x))

# Plot Function GPyTorch

In [4]:
def plotGP(x_train, y_train, model, likelihood, title="GP Model"):
  """
  This function plots the Gaussian Process regression model along with the observed data.

  Parameters:
  x_train (torch.Tensor): The training inputs.
  y_train (torch.Tensor): The training targets.
  model (gpytorch.models.GP): The Gaussian Process regression model.
  likelihood (gpytorch.likelihoods.Likelihood): The likelihood function to use for the model.
  title (str, optional): The title of the plot. Defaults to "GP Model".

  Returns:
  None. The function creates a plot and does not return anything.

  """
  # Find min and max value of training set
  min_value, max_value = min(x_train), max(x_train)
  # Create points between min and max values
  x_plot = torch.linspace(min_value, max_value, 1000)
  # Evaluate on plot values
  prediction = likelihood(model(x_plot))
  mean = prediction.mean
  variance = prediction.variance
  with torch.no_grad(), gpytorch.settings.fast_pred_var():
    # Initalize plot
    plt.style.use('classic')
    _, ax = plt.subplots(1, 1, figsize=(4, 3))

    # Confidence region
    lower_bound = mean-(1.96*(np.sqrt(variance)))
    upper_bound = mean+(1.96*(np.sqrt(variance)))

    ax.plot(x_train.detach().numpy(), y_train.detach().numpy(), 'ko')
    # Plot predictive means
    ax.plot(x_plot.detach().numpy(), mean.detach().numpy(), 'purple')
    # Plot confidence bounds as lightly shaded region
    ax.fill_between(x_plot.detach().numpy(), lower_bound.detach().numpy(),
                    upper_bound.detach().numpy(), alpha=0.5, color="violet", zorder=-1)
    ax.set_title(title)
    ax.legend(['Observed Data', 'Mean', '95% Confidence'])
    plt.grid(False)
    ax.plot


# Plot Function for State Space Model

In [None]:
def plotGP_SS(x_train, y_train, ell=1, sigma=1, m0=0, v0=1, title="GP Model", n_test_points=1000):
  # Find min and max value of training set
  min_value, max_value = min(x_train).numpy(), max(x_train).numpy()
  # Create points between min and max values
  x_test = np.linspace(min_value, max_value, 1000)

  all_points = jnp.concatenate([x_train.numpy(), x_test])
  temporal_order = jnp.argsort(all_points)
  # State Space X's and Y's
  ss_xs = all_points[temporal_order]
  ss_ys = jnp.concatenate([y_train.numpy(), jnp.nan * jnp.ones((n_test_points, ))])[temporal_order]
  # Compute the equivalent SS model
  dts = jnp.diff(ss_xs, prepend=min_value.item())
  Fs = jnp.exp(-1 / ell * dts)
  Ws = sigma ** 2 * (1 - jnp.exp(-2 / ell * dts))
  mfs, vfs, mps, vps = kalmanFilter(ss_ys, Fs, Ws, m0 = m0, v0=v0)
  # Smoothed means and variances
  mss, vss = kalmanSmoothing(Fs, mfs, vfs, mps, vps)

  # Posterior distribution
  ss_posterior_mean = mss[jnp.isnan(ss_ys)]
  ss_posterior_var = vss[jnp.isnan(ss_ys)]

  plt.style.use('default')
  _, ax = plt.subplots(1, 1)

  ax.scatter(x_train.detach().numpy(), y_train.detach().numpy(), color='k', marker='o', label='Observed data')
  ax.plot(x_test, ss_posterior_mean, label="Mean", color = 'purple', alpha = 1)
  ax.fill_between(x_test,
                      ss_posterior_mean - 1.96 * jnp.sqrt(ss_posterior_var),
                      ss_posterior_mean + 1.96 * jnp.sqrt(ss_posterior_var),
                      alpha=0.5,
                     label="95% Confidence", color = "violet", zorder=-1)
  ax.set_title(title)
  ax.set_xlim([min_value, max_value])
  ax.legend()
  plt.grid(False)
  ax.plot

# Kalman Filter and Smoother

In [None]:
def kalmanFilter(ss_ys, Fs, Ws, m0=0, v0=1, observation_cov=1):
    """
    Implements the Kalman Filter algorithm for a given set of observations.

    The function consists of two nested functions: `update` and `scan_body`.
    The `update` function is responsible for updating the mean and variance
    based on the observation and the observation covariance. The `scan_body`
    function is used to scan through the observations and update the mean and
    variance accordingly.

    The function returns four arrays: `mfs`, `vfs`, `mps`, and `vps` which represent
    the filtered means, filtered variances, predicted means, and predicted variances
    respectively.

    Note: This function uses the `jax.lax.scan` function for efficient looping over
    the observations, and `jax.lax.cond` for conditionally updating the mean and
    variance based on whether the observation is NaN.

    Args:
        m0 (float, optional): Initial mean for the Kalman filter. Defaults to 0.
        v0 (float, optional): Initial variance for the Kalman filter. Defaults to 1.
        ss_ys (array): Observations in the state space model.
        Fs (array): Array of transition matrices.
        Ws (array): Process noise covariance in the state space model.
        observation_cov (float, optional): Observation covariance. Defaults to 1.

    Returns:
        mfs (array): Filtered means
        vfs (array): Filtered variances
        mps (array): Predicted means
        vps (array): Predicted variances
    """
    def update(y, mp, vp, observation_cov = 1):
        S = vp + observation_cov
        K = vp / S
        v = y - mp
        mf = mp + K * v
        vf = vp - K * K * S
        return mf, vf

    def scan_body(carry, elem):
        mf, vf = carry
        y, F, W = elem

        mp = F * mf
        vp = F * vf * F + W

        mf, vf = jax.lax.cond(jnp.isnan(y),
                              lambda _: (mp, vp),
                              lambda _: update(y, mp, vp, observation_cov),
                              None)

        return (mf, vf), (mf, vf, mp, vp)
    _, (mfs, vfs, mps, vps) = jax.lax.scan(scan_body, (m0, v0), (ss_ys, Fs, Ws))
    return mfs, vfs, mps, vps


def kalmanSmoothing(Fs, mfs, vfs, mps, vps):
    """
    Implements the Kalman Smoothing algorithm for a given set of filtered means and variances.

    The function consists of a nested function: `scan_body`. The `scan_body` function is used to
    scan through the filtered means and variances and update the smoothed means and variances accordingly.

    The function returns two arrays: `mss` and `vss` which represent the smoothed means and smoothed variances respectively.

    Note: This function uses the `jax.lax.scan` function for efficient looping over the filtered means and variances.

    Args:
        Fs (array): Array of transition matrices
        mfs (array): Filtered means
        vfs (array): Filtered variances
        mps (array): Predicted means
        vps (array): Predicted variances

    Returns:
        mss (array): Smoothed means
        vss (array): Smoothed variances
    """
    def scan_body(carry, elem):
        ms, vs = carry
        mf, vf, mp, vp, F = elem

        G = vf * F / vp
        ms = mf + G * (ms - mp)
        vs = vf + G * (vs - vp) * G
        return (ms, vs), (ms, vs)

    _, smoothing_results = jax.lax.scan(scan_body,
                                        (mfs[-1], vfs[-1]),
                                        (mfs[:-1], vfs[:-1], mps[1:], vps[1:], Fs[1:]),
                                        reverse=True)
    mss = jnp.concatenate([smoothing_results[0], mfs[-1, None]], axis=0)
    vss = jnp.concatenate([smoothing_results[1], vfs[-1, None]], axis=0)
    return (mss, vss)