In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn
import numpy as np

class TanhWarp(nn.Module):
    """
    g(y) = y + sum_i a_i * tanh(b_i * (y + c_i))
    with a_i, b_i > 0 to guarantee monotonicity: g'(y) = 1 + sum_i a_i*b_i*(1 - tanh(.)^2) > 0
    """
    def __init__(self, M=3, b_transform=None, a_transform=None):
        super().__init__()
        self.raw_a = nn.Parameter(torch.zeros(M))  # -> a = softplus(raw_a)
        self.raw_b = nn.Parameter(torch.zeros(M))  # -> b = softplus(raw_b)
        # The transformation functions for a, b
        # limit b to be in the interval (q, p) resp. for a in (q2, p2)
        p, q = 1.0, 0.0
        self.b_transform = b_transform if b_transform is not None else lambda x: (torch.sigmoid(x) * (p - q)) + q 
        p2, q2 = 5.0, 0.0
        self.a_transform = a_transform if a_transform is not None else lambda x: (torch.sigmoid(x) * (p2 - q2)) + q2

        self.c     = nn.Parameter(torch.zeros(M))
        self.softplus = nn.Softplus()

    def forward(self, y):
        a = self.a_transform(self.raw_a)  # (M,)
        b = self.b_transform(self.raw_b)  # (M,)
        z = y.unsqueeze(-1) + self.c   # (..., M)
        return y + torch.sum(a * torch.tanh(b * z), dim=-1)

    def log_abs_det_jacobian(self, y):
        a = self.a_transform(self.raw_a)
        b = self.b_transform(self.raw_b)
        z = y.unsqueeze(-1) + self.c
        t = torch.tanh(b * z)
        sech2 = 1 - t**2
        deriv = 1.0 + torch.sum(a * b * sech2, dim=-1)
        deriv = torch.clamp(deriv, min=1e-8)  # numerical safety
        return torch.log(deriv)

    @torch.no_grad()
    def inverse(self, u, y_init=None, iters=40):
        # Newton solve for y: g(y)=u ; monotone g ensures unique root
        y = u.clone() if y_init is None else y_init.clone()
        for _ in range(iters):
            gy  = self.forward(y)
            dgy = torch.exp(self.log_abs_det_jacobian(y))
            y   = y - (gy - u)/dgy
        return y


In [None]:
import gpytorch
class Model(gpytorch.models.ExactGP):
    def __init__(self, X, u, likelihood, **kwargs):
        super().__init__(X, u, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        if "warp" in kwargs:
            self.warp = kwargs.get("warp", [])

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [None]:
#in 1d:
## Daten von x=0..1
#f:=x->2+exp(exp(3-100*x))-1/5*exp(-1/2*(x-0.1)^2/(1/5)^2)+0.05*x;
#
#Noise:
#heteroscedatic!
#the noise is Gaussian in log-space with a standard deviation of 5%
#[x->exp(0.95*log(f(x))),f,x->exp(1.05*log(f(x)))]
#
## baue 2. Input "a" ein. a=0..1. Hat weniger einfluss
#g:=x->f(x-0.1*sin(5*a/2))*(1+0.1*arctan(5*a-2.5))+0.1*cos(10*a/3);


data_f = lambda x: 2 + torch.exp(torch.exp(3 - 100 * x)) - 1/5 * torch.exp(-1/2 * (x - 0.1)**2 / (1/5)**2) + 0.05 * x
data_noise = lambda x: torch.exp(0.95 * torch.log(data_f(x)))
data_noise_high = lambda x: torch.exp(1.05 * torch.log(data_f(x)))

data_g = lambda x, a: data_f(x - 0.1 * torch.sin(5 * a / 2)) * (1 + 0.1 * torch.arctan(5 * a - 2.5)) + 0.1 * torch.cos(10 * a / 3)


X = torch.linspace(0.1, 1, 1000)


# TODO log(y) OHNE warping



y = data_g(X, torch.tensor(0.5)) #torch.sin(2 * torch.pi * X) + 0.1 * torch.randn_like(X)  # example data
y = y - y.mean()  # center the data
y = y / y.std()  # normalize the data

warp = TanhWarp(M=1)
u = warp(y)  # transformed targets
plt.plot(X.numpy(), y.numpy(), label='Original Data')
plt.plot(X.numpy(), u.detach().numpy(), label='Warped Data')


In [None]:

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = Model(X, u, likelihood, warp=warp)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model.train(); likelihood.train()
optimizer.zero_grad()

In [None]:
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for _ in range(200):
    optimizer.zero_grad()
    u = model.warp(y)  # y -> u
    output = model(X)  # MultivariateNormal over f(x) (latent function in u-space)
    base_term = mll(output, u)

    jac_term = model.warp.log_abs_det_jacobian(y).sum()  # + Σ log g'(y_n)
    loss = -(base_term + jac_term)
    if _ % 50 == 0:
        print(loss.item())

    loss.backward()
    optimizer.step()

In [None]:
list(model.warp.named_parameters())

In [None]:
list(model.named_parameters())
#list(likelihood.named_parameters())

In [None]:

Xtest = torch.linspace(0.1, 1, 1000)
model.eval(); likelihood.eval()
with torch.no_grad():
    mvn = likelihood(model(Xtest))            # dist over u*
    mean_u = mvn.mean
    var_u  = mvn.variance

    # Point estimates in y-space
    median_y = warp.inverse(mean_u)           # ≈ predictive median

    # Monte Carlo moments
    S = 512
    eps = torch.randn(S, *mean_u.shape, device=mean_u.device)
    u_samps = mean_u.unsqueeze(0) + eps * var_u.sqrt().unsqueeze(0)
    y_samps = warp.inverse(u_samps)

    mean_y = y_samps.mean(dim=0)
    # matplotlib plot the samples
    plt.figure(figsize=(10, 6))
    plt.plot(Xtest.cpu().numpy(), median_y.cpu().numpy(), label='Predictive Median', color='blue')
    #plt.fill_between(warp.inverse(mean_u -+ 2*torch.sqrt(var_u))))
    plt.fill_between(Xtest.cpu().numpy(),
                    warp.inverse(mean_u - 2 * torch.sqrt(var_u)).cpu().numpy(),
                    warp.inverse(mean_u + 2 * torch.sqrt(var_u)).cpu().numpy(),
                    color='blue', alpha=0.2, label='95% Credible Interval')
    #plt.fill_between(Xtest.cpu().numpy(), 
    #                 y_samps.quantile(0.05, dim=0).cpu().numpy(), 
    #                 y_samps.quantile(0.95, dim=0).cpu().numpy(), 
    #                 color='blue', alpha=0.2, label='95% Credible Interval')
    plt.scatter(X.cpu().numpy(), y.cpu().numpy(), s=10, color='red', label='Data')
    plt.title('GP with Tanh Warp')
    plt.xlabel('X')
    plt.ylabel('y')
    plt.legend()
    plt.show()


In [None]:

Xtest = torch.linspace(0.1, 1, 1000)
model.eval(); likelihood.eval()
for a_val in list(range(0, 50, 5)):
    model.warp.raw_a.data.fill_(a_val)
    print(model.warp.raw_a.data)
    with torch.no_grad():
        mvn = likelihood(model(Xtest))            # dist over u*
        mean_u = mvn.mean
        var_u  = mvn.variance

        # Point estimates in y-space
        median_y = warp.inverse(mean_u)           # ≈ predictive median

        # Monte Carlo moments
        S = 512
        eps = torch.randn(S, *mean_u.shape, device=mean_u.device)
        u_samps = mean_u.unsqueeze(0) + eps * var_u.sqrt().unsqueeze(0)
        y_samps = warp.inverse(u_samps)

        mean_y = y_samps.mean(dim=0)
        # matplotlib plot the samples
        plt.figure(figsize=(10, 6))
        plt.plot(Xtest.cpu().numpy(), median_y.cpu().numpy(), label='Predictive Median', color='blue')
        #plt.fill_between(warp.inverse(mean_u -+ 2*torch.sqrt(var_u))))
        plt.fill_between(Xtest.cpu().numpy(),
                        warp.inverse(mean_u - 2 * torch.sqrt(var_u)).cpu().numpy(),
                        warp.inverse(mean_u + 2 * torch.sqrt(var_u)).cpu().numpy(),
                        color='blue', alpha=0.2, label='95% Credible Interval')
        #plt.fill_between(Xtest.cpu().numpy(), 
        #                 y_samps.quantile(0.05, dim=0).cpu().numpy(), 
        #                 y_samps.quantile(0.95, dim=0).cpu().numpy(), 
        #                 color='blue', alpha=0.2, label='95% Credible Interval')
        plt.scatter(X.cpu().numpy(), y.cpu().numpy(), s=10, color='red', label='Data')
        plt.title('GP with Tanh Warp')
        plt.xlabel('X')
        plt.ylabel('y')
        plt.legend()
        plt.show()


In [None]:
with torch.no_grad():
    mvn = likelihood(model(Xtest))            # dist over u*
    mean_u = mvn.mean
    var_u  = mvn.variance

    plt.figure(figsize=(10, 6))
    plt.plot(Xtest.cpu().numpy(), mean_u.cpu().numpy(), label='Predictive Median', color='blue')
    #plt.fill_between(warp.inverse(mean_u -+ 2*torch.sqrt(var_u))))
    plt.fill_between(Xtest.cpu().numpy(),
                     mean_u - 2 * torch.sqrt(var_u),
                     mean_u + 2 * torch.sqrt(var_u),
                     color='blue', alpha=0.2, label='95% Credible Interval')

In [None]:
plt.plot(X, model.warp(y).detach().numpy())
plt.plot(X, y)
list(model.warp.parameters())

In [None]:
x = torch.linspace(-20, 20, 1000)
b = 2.1
plt.plot(x, x + 80*torch.tanh(b*x))