# Technical Details

As with all packages, there are numerous technical details that are abstracted away from the user. Now in order to ensure a clean interface, this abstraction is entirely necessary. However, it can sometimes be confusing when navigating a package's source code to pin down what's going on when there's so many _under the hood_ operations taking place. In this notebook I'll aim to shed some light on all of tricks that we do in GPJax in order to help elucidate the code to anyone wishing to extend GPJax for their own uses.

## Parameter Transformations

### Motivations

Many parameters in a Gaussian process are what we call a _constrained parameter_. By this, we mean that the parameters value is only defined on a subset of $\mathbb{R}$. One example of this is the lengthscale parameter in any of the stationary kernels. It would not make sense to have a negative lengthscale, and as such the parameter's value is constrained to exist only on the positive real line. 

Whilst mathematically correct, constrained parameters can become a pain when optimising as many optimisers are designed to operate on an unconstrained space. Further, it can often be computationally inefficient to restrict the search space of an optimiser. For these reasons, we instead transform the constrained parameter to exist in an unconstrained space. Optimisation is then done on this unconstrained parameter before we transform it back when we need to evaluate its value. 

Only bijective transformations are valid as we cannot afford to lose our original parameter value when transforming. As such, we have to be careful about which transformations we use. Some common choices include the log-exponential bijection and the softplus transform. We, by default, opt for the softplus transformation in GPJax as it less prone to overflowing in comparison to log-exp transformations.

### Implementation

When it comes to implementations, we attach the transformation directly to the `Parameter` class. It is an optional argument that one can specify when instantiating their parameter. To see this, simply consider the following example

In [16]:
from gpjax.parameters import Parameter
from gpjax.transforms import Softplus
import jax.numpy as jnp

x = Parameter(jnp.array(1.0), transform=Softplus())

Now we know that the softplus transformation operation on an input $x \in \mathbb{R}_{>0}$ can be written as 
$$\alpha(x) = \log(\exp(x)-1)$$
where $\alpha(x) \in \mathbb{R}$. In this instance, it can be seen that $\alpha(1)=0.54$. Now this unconstrained value is stored within the parameter's `value` property

In [17]:
print(x.value)

0.541324854612918


whilst the original constrained value can be computed by accesing the parameter's `untransform` property

In [20]:
print(x.untransform)

1.0


### Custom transformation

Should you wish to define your own custom transformation, then this can very easily be done by simply extending the `Transform` class within `gpjax.transforms` and defining a forward transformation and a backward transformation.

In [23]:
class Transform:
    def __init__(self, name="Transformation"):
        self.name = name

    @staticmethod
    def forward(x: jnp.ndarray) -> jnp.ndarray:
        raise NotImplementedError

    @staticmethod
    def backward(x: jnp.ndarray) -> jnp.ndarray:
        raise NotImplementedError

The `forward` method is the transformation that maps from a constrained space to an unconstrained space, whilst the `backward` method is the transformation that reverses this. A nice example of this can be seen for the earlier used softplus transformation

In [24]:
from jax.nn import softplus

class Softplus(Transform):
    def __init__(self):
        super().__init__(name='Softplus')

    @staticmethod
    def forward(x: jnp.ndarray) -> jnp.ndarray:
        return jnp.log(jnp.exp(x) - 1.)

    @staticmethod
    def backward(x: jnp.ndarray) -> jnp.ndarray:
        return softplus(x)