In [None]:
from helpers.plotting_functions import plot_3d_gp, plot_3d_data
import matplotlib.pyplot as plt
import itertools
import torch

torch.set_default_dtype(torch.float64)
#torch.manual_seed(42)

In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn
import numpy as np

class TanhWarp(nn.Module):
    """
    g(y) = y + sum_i a_i * tanh(b_i * (y + c_i))
    with a_i, b_i > 0 to guarantee monotonicity: g'(y) = 1 + sum_i a_i*b_i*(1 - tanh(.)^2) > 0
    """
    def __init__(self, M=3, b_transform=None, a_transform=None):
        super().__init__()
        self.raw_a = nn.Parameter(torch.zeros(M))  # -> a = softplus(raw_a)
        self.raw_b = nn.Parameter(torch.zeros(M))  # -> b = softplus(raw_b)
        # The transformation functions for a, b
        # limit b to be in the interval (q, p) resp. for a in (q2, p2)
        #p, q = 10000.0, 0.0
        #self.b_transform = b_transform if b_transform is not None else lambda x: (torch.sigmoid(x) * (p - q)) + q 
        #p2, q2 = 20000.0, 0.0
        #self.a_transform = a_transform if a_transform is not None else lambda x: (torch.sigmoid(x) * (p2 - q2)) + q2
        self.a_transform = torch.nn.Softplus()
        self.b_transform = torch.nn.Softplus()

        self.c     = nn.Parameter(torch.zeros(M))

    def forward(self, y):
        a = self.a_transform(self.raw_a)  # (M,)
        b = self.b_transform(self.raw_b)  # (M,)
        z = y.unsqueeze(-1) + self.c   # (..., M)
        return y + torch.sum(a * torch.tanh(b * z), dim=-1)

    def log_abs_det_jacobian(self, y):
        a = self.a_transform(self.raw_a)
        b = self.b_transform(self.raw_b)
        z = y.unsqueeze(-1) + self.c
        t = torch.tanh(b * z)
        sech2 = 1 - t**2
        deriv = 1.0 + torch.sum(a * b * sech2, dim=-1)
        deriv = torch.clamp(deriv, min=1e-8)  # numerical safety
        return torch.log(deriv)

    @torch.no_grad()
    def inverse(self, u):
        u_scaled = u/self.a_transform(self.raw_a) 
        y = 1/self.b_transform(self.raw_b) * torch.atanh(u_scaled) - self.c
        if torch.any(torch.isnan(y)):
            # replace the nans with the values in u
            y[torch.isnan(y)] = u[torch.isnan(y)]
            for _ in range(40):
                gy  = self.forward(y)
                dgy = torch.exp(self.log_abs_det_jacobian(y))
                y   = y - (gy - u)/dgy
        return y


#    @torch.no_grad()
#    def inverse(self, u, y_init=None, iters=40):
#        # Newton solve for y: g(y)=u ; monotone g ensures unique root
#        y = u.clone() if y_init is None else y_init.clone()
#        for _ in range(iters):
#            gy  = self.forward(y)
#            dgy = torch.exp(self.log_abs_det_jacobian(y))
#            y   = y - (gy - u)/dgy
#        return y
#

In [None]:
import gpytorch
class Model(gpytorch.models.ExactGP):
    def __init__(self, X, u, likelihood, **kwargs):
        super().__init__(X, u, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        if "warp" in kwargs:
            self.warp = kwargs.get("warp", lambda x: None)
            self.y = kwargs.get("y", [])
            self.X = X
            self.old_params = self.parameters()
            self.apply_warp()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    def apply_warp(self):
        if self.old_params == self.parameters():
            return
        self.u = self.warp(self.y)
        self.set_train_data(self.X, self.u)
        self.old_params = self.parameters()



class MIModel(gpytorch.models.ExactGP):
    def __init__(self, X, u, likelihood, **kwargs):
        super().__init__(X, u, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        k0 = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(active_dims=0))
        k1 = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(active_dims=1))
        self.covar_module = k0 * k1
        if "warp" in kwargs:
            self.warp = kwargs.get("warp", lambda x: None)
            self.y = kwargs.get("y", [])
            self.X = X
            self.old_params = self.parameters()
            self.apply_warp()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


    def apply_warp(self):
        if self.old_params == self.parameters():
            return
        self.u = self.warp(self.y)
        self.set_train_data(self.X, self.u)
        self.old_params = self.parameters()

In [None]:
#in 1d:
## Daten von x=0..1
#f:=x->2+exp(exp(3-100*x))-1/5*exp(-1/2*(x-0.1)^2/(1/5)^2)+0.05*x;
#
#Noise:
#heteroscedatic!
#the noise is Gaussian in log-space with a standard deviation of 5%
#[x->exp(0.95*log(f(x))),f,x->exp(1.05*log(f(x)))]
#
## baue 2. Input "a" ein. a=0..1. Hat weniger einfluss
#g:=x->f(x-0.1*sin(5*a/2))*(1+0.1*arctan(5*a-2.5))+0.1*cos(10*a/3);


# Requires limits of a in [0, 4], b in [0, 1]
data_f = lambda x: 2 + torch.exp(torch.exp(3 - 100 * x)) - 1/5 * torch.exp(-1/2 * (x - 0.1)**2 / (1/5)**2) + 0.05 * x
# Requires limits of a in [0, 2], b in [0, 1]
data_noise = lambda x: torch.exp(torch.log(data_f(x))*(1 + 0.05 * torch.randn_like(x)))  # heteroscedastic noise
# Requires limits of a in [0, 1], b in [0, 1]
# Original
#data_g = lambda x, a: data_f(x - 0.1 * torch.sin(5 * a / 2)) * (1 + 0.1 * torch.arctan(5 * a - 2.5)) + 0.1 * torch.cos(10 * a / 3)
data_g = lambda x, a: data_f(x - 0.01 * torch.sin(5 * a / 2) ) * (1 + 0.1 * torch.arctan(5 * a - 2.5)) + 0.1 * torch.cos(10 * a / 3)
data_g2 = lambda x: data_g(x[:, 0], x[:, 1])


#lin = torch.stack([torch.linspace(0.000, 0.19, 5), torch.linspace(0.19, 1, 5)], dim=-1).flatten()
lin = torch.linspace(0.000, 0.1, 1)
lin2 = torch.linspace(0.0, 1.0, 10)

# TODO log(y) OHNE warping

xx, yy = torch.meshgrid(lin, lin2, indexing="xy")
coords_old = torch.stack((xx.reshape(-1), yy.reshape(-1)), dim=1)  # (121, 2)


soboleng = torch.quasirandom.SobolEngine(dimension=2)
#coords = soboleng.draw(100)
#coords = torch.cat([coords, coords_old])
#X = torch.tensor([[0.1, 0], [0.5, 0.5], [0.7, 1.0]])

X = torch.cat([torch.linspace(0, 0.001, 10), torch.linspace(0.01, 1, 10)])
#X = torch.linspace(0, 1.0, 100)

y = data_f(X) 
#y = data_g2(coords) 
y = y - y.mean()  # center the data
y = y / y.std()  # normalize the data

#X = torch.stack([X, a], dim=-1)  # shape (100, 2)

In [None]:
y

In [None]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
warp = TanhWarp(M=1)
u = warp(y)
model = Model(X, u, likelihood, warp=warp, y=y)
if "coords" in locals():
    X = coords
#model = MIModel(X, u, likelihood, warp=warp, y=y)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model.train(); likelihood.train()
optimizer.zero_grad()

In [None]:
if "coords" in locals() and coords.ndim > 1:
    z_vals = model.u.detach()
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(coords[:, 0], coords[:, 1], model.u.detach().numpy(),
                c=z_vals.numpy(), cmap='viridis', alpha=0.8)

In [None]:
if not "coords" in locals():
    plt.plot(model.X.numpy(), model.y.numpy(), label='Original Data', linestyle='dashed')
    plt.plot(model.X.numpy(), model.u.detach().numpy(), label='Warped Data', linestyle='dotted')
    plt.legend()

In [None]:
# Change GP data to be new u
#model.set_train_data(X, model.warp(y))

In [None]:
from helpers.training_functions import granso_optimization
from helpers.util_functions import log_normalized_prior, get_full_kernels_in_kernel_expression, randomize_model_hyperparameters
torch.autograd.set_detect_anomaly(True)

kernel_parameter_priors = {
    ("RBFKernel", "lengthscale"): {"mean": 0.0, "std": 10.0}, 
    ("MaternKernel", "lengthscale"): {"mean": 0.0, "std": 10.0},
    ("LinearKernel", "variance"): {"mean": 0.0, "std": 10.0},
    ("AffineKernel", "variance"): {"mean": 0.0, "std": 10.0},
    ("RQKernel", "lengthscale"): {"mean": 0.0, "std": 10.0},
    ("RQKernel", "alpha"): {"mean": 0.0, "std": 10.0},
    ("CosineKernel", "period_length"): {"mean": 0.0, "std": 10.0},
    ("PeriodicKernel", "lengthscale"): {"mean": 0.0, "std": 10.0},
    ("PeriodicKernel", "period_length"): {"mean": 0.0, "std": 10.0},
    ("ScaleKernel", "outputscale"): {"mean": 0.0, "std": 10.0},
    ("LODE_Kernel", "signal_variance_2_0"): {"mean": 0.0, "std": 10.0},  # full match
    ("LODE_Kernel", "lengthscale"): {"mean": 0.0, "std": 10.0},          # base fallback
}


parameter_priors = {
    "likelihood.raw_task_noises": {"mean": 0.0, "std": 10.0},
    "likelihood.raw_noise": {"mean": 0.0, "std": 10.0},
    "warp.raw_a": {"mean": -5.0, "std": 10.0},
    "warp.raw_b": {"mean": -5.0, "std": 10.0},
    "warp.c": {"mean": 0.0, "std": 10.0}
}


kernel_param_specs = {
    ("RBFKernel", "lengthscale"): {"bounds": (1e-1, 5.0)}, # add ', "type": "uniform"},' # to use uniform distribution
    ("MaternKernel", "lengthscale"): {"bounds": (1e-1, 1.0)},
    ("LinearKernel", "variance"): {"bounds": (1e-1, 1.0)},
    ("AffineKernel", "variance"): {"bounds": (1e-1, 1.0)},
    ("RQKernel", "lengthscale"): {"bounds": (1e-1, 1.0)},
    ("RQKernel", "alpha"): {"bounds": (1e-1, 1.0)},
    ("CosineKernel", "period_length"): {"bounds": (1e-1, 10.0), "type": "uniform"},
    ("PeriodicKernel", "lengthscale"): {"bounds": (1e-1, 5.0)},
    ("PeriodicKernel", "period_length"): {"bounds": (1e-1, 10.0), "type": "uniform"},
    ("ScaleKernel", "outputscale"): {"bounds": (1e-1, 10.0)},
    #("LODE_Kernel", "signal_variance_2_0"): {"bounds": (0.05, 0.5)},  # full match
    ("LODE_Kernel", "signal_variance"): {"bounds": (1e-1, 10)},  # base
    ("LODE_Kernel", "lengthscale"): {"bounds": (1e-1, 5.0)},           
}


param_specs = {
    "likelihood.raw_task_noises": {"bounds": (1e-1, 1e-0)},
    "likelihood.raw_noise": {"bounds": (1e-1, 1e-0)}
}


# Define the objective function
def objective_function(model):
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
    model.apply_warp()
    output = model.likelihood(model(model.X))
    try:
        # TODO PyGRANSO dying is a severe problem. as it literally exits the program instead of raising an error
        # negative scaled MLL
        loss = -mll(output, model.u) - model.warp.log_abs_det_jacobian(model.y).sum()
    except Exception as E:
        print("LOG ERROR: Severe PyGRANSO issue. Loss is inf+0")
        print(f"LOG ERROR: {E}")
        loss = torch.tensor(np.finfo(np.float32).max, requires_grad=True) + torch.tensor(-10.0)
    #print(f"LOG: {loss}")
    model.apply_warp()
    return [loss, None, None]

def objective_function_MAP(model):
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
    model.apply_warp()
    output = model.likelihood(model(model.X))
    try:
        # TODO PyGRANSO dying is a severe problem. as it literally exits the program instead of raising an error
        # negative scaled MLL
        loss = -mll(output, model.u) - model.warp.log_abs_det_jacobian(model.y).sum()
        # log_normalized_prior is in metrics.py 
        log_p = log_normalized_prior(model, param_specs=parameter_priors, kernel_param_specs=kernel_parameter_priors, prior=None)
        # negative scaled MAP
        loss -= log_p
    except Exception as E:
        print("LOG ERROR: Severe PyGRANSO issue. Loss is inf+0")
        print(f"LOG ERROR: {E}")
        loss = torch.tensor(np.finfo(np.float32).max, requires_grad=True) + torch.tensor(-10.0)
    #print(f"LOG: {loss}")
    model.apply_warp()
    return [loss, None, None]


#var_in = {"model" : model, "train_x" : X, "train_y" : y, "likelihood" : likelihood, "MAP" : False,
#          "parameter_priors" : parameter_priors, "kernel_parameter_priors" : kernel_parameter_priors, "model_parameter_prior" : None}
#
#comb_fn = lambda X_struct: objective_function(var_in)

neg_scaled_mll, model_MLL, model_likelihood_MLL, training_log_MLL = granso_optimization(model, likelihood, model.X, model.u, random_restarts=5, maxit=1000, MAP=False, double_precision=False, verbose=True, objective_function=objective_function_MAP)

#neg_scaled_mll -= model.warp.log_abs_det_jacobian(model.y).sum()

model_parameters = [p for p in model.parameters() if p.requires_grad]


jacobian_neg_unscaled_map = torch.autograd.grad(neg_scaled_mll, model_parameters, retain_graph=True, create_graph=True, allow_unused=True)
hessian_neg_unscaled_map_raw = []
# Calcuate -\nabla\nabla log(f(\theta)) (i.e. Hessian of negative log posterior)
for i in range(len(jacobian_neg_unscaled_map)):
    hessian_neg_unscaled_map_raw.append(torch.autograd.grad(jacobian_neg_unscaled_map[i], model_parameters, retain_graph=True, allow_unused=True))

In [None]:
hessian_neg_unscaled_map_raw

In [None]:
neg_scaled_mll

In [None]:
list(model.named_parameters())
#list(likelihood.named_parameters())

In [None]:

if "coords" in locals() and coords.ndim >1:
    Xtest = torch.linspace(0.0, 1, 1000)
    atest = torch.linspace(0.0, 1, 1000)
    Xtest = torch.stack([Xtest, atest], dim=-1)  # shape (1000, 2)
    model.eval(); likelihood.eval()
    with torch.no_grad():
        mvn = likelihood(model(Xtest))            # dist over u*
        mean_u = mvn.mean
        var_u  = mvn.variance

        # Point estimates in y-space
        median_y = model.warp.inverse(mean_u)           # ≈ predictive median

        # Monte Carlo moments
        S = 512
        eps = torch.randn(S, *mean_u.shape, device=mean_u.device)
        u_samps = mean_u.unsqueeze(0) + eps * var_u.sqrt().unsqueeze(0)
        y_samps = model.warp.inverse(u_samps)

        mean_y = y_samps.mean(dim=0)
        
        #plot_3d_gp(model, likelihood, data=None, x_min=0.0, x_max=1.0, y_min=0.0, y_max=1.0,
        #            resolution=50, return_figure=False, fig=None, ax=None, 
        #            display_figure=True, loss_val=None, loss_type=None, shadow=False,
        #            title_add = ""):
        plot_3d_gp(model, likelihood, data=torch.stack([X[:,0], X[:,1], u], -1), x_min=0.0, x_max=1.0, y_min=0.0, y_max=1.0)



        # matplotlib plot the samples
        #plt.figure(figsize=(10, 6))
        #plt.plot(Xtest.cpu().numpy(), median_y.cpu().numpy(), label='Predictive Median', color='blue')
        ##plt.fill_between(warp.inverse(mean_u -+ 2*torch.sqrt(var_u))))
        #plt.fill_between(Xtest.cpu().numpy(),
        #                warp.inverse(mean_u - 2 * torch.sqrt(var_u)).cpu().numpy(),
        #                warp.inverse(mean_u + 2 * torch.sqrt(var_u)).cpu().numpy(),
        #                color='blue', alpha=0.2, label='95% Credible Interval')
        ##plt.fill_between(Xtest.cpu().numpy(), 
        ##                 y_samps.quantile(0.05, dim=0).cpu().numpy(), 
        ##                 y_samps.quantile(0.95, dim=0).cpu().numpy(), 
        ##                 color='blue', alpha=0.2, label='95% Credible Interval')
        #plt.scatter(X.cpu().numpy(), y.cpu().numpy(), s=10, color='red', label='Data')
        #plt.title('GP with Tanh Warp')
        #plt.xlabel('X')
        #plt.ylabel('y')
        #plt.legend()
        #plt.show()
else:
    Xtest = torch.linspace(0.0, 2, 1000)
    model.eval(); likelihood.eval()
    with torch.no_grad():
        mvn = likelihood(model(Xtest))            # dist over u*
        mean_u = mvn.mean
        var_u  = mvn.variance

        # Point estimates in y-space
        median_y = model.warp.inverse(mean_u)           # ≈ predictive median

        # Monte Carlo moments
        S = 512
        eps = torch.randn(S, *mean_u.shape, device=mean_u.device)
        u_samps = mean_u.unsqueeze(0) + eps * var_u.sqrt().unsqueeze(0)
        y_samps = model.warp.inverse(u_samps)

        mean_y = y_samps.mean(dim=0)
        # matplotlib plot the samples
        plt.figure(figsize=(10, 6))
        plt.plot(Xtest.cpu().numpy(), median_y.cpu().numpy(), label='Predictive Median', color='blue')
        #plt.fill_between(warp.inverse(mean_u -+ 2*torch.sqrt(var_u))))
        plt.fill_between(Xtest.cpu().numpy(),
                        warp.inverse(mean_u - 2 * torch.sqrt(var_u)).cpu().numpy(),
                        warp.inverse(mean_u + 2 * torch.sqrt(var_u)).cpu().numpy(),
                        color='blue', alpha=0.2, label='95% Credible Interval')

        #plt.fill_between(Xtest.cpu().numpy(), 
        #                 y_samps.quantile(0.05, dim=0).cpu().numpy(), 
        #                 y_samps.quantile(0.95, dim=0).cpu().numpy(), 
        #                 color='blue', alpha=0.2, label='95% Credible Interval')
        plt.scatter(model.X.cpu().numpy(), model.y.cpu().numpy(), s=10, color='red', label='Data')
        plt.title('GP with Tanh Warp. Real space')
        plt.xlabel('X')
        plt.ylabel('y')
        plt.legend()
        plt.show()


In [None]:
model.y, model.u

In [None]:
with torch.no_grad():
    mvn = likelihood(model(Xtest))            # dist over u*
    mean_u = mvn.mean
    var_u  = mvn.variance

    plt.figure(figsize=(10, 6))
    plt.plot(Xtest.cpu().numpy(), mean_u.cpu().numpy(), label='Predictive Median', color='blue')
    #plt.fill_between(warp.inverse(mean_u -+ 2*torch.sqrt(var_u))))
    plt.fill_between(Xtest.cpu().numpy(),
                     mean_u - 2 * torch.sqrt(var_u),
                     mean_u + 2 * torch.sqrt(var_u),
                     color='blue', alpha=0.2, label='95% Credible Interval')

    plt.scatter(model.X.cpu().numpy(), model.u.cpu().numpy(), s=10, color='red', label='Data')
    plt.title('GP with Tanh Warp. Latent space')
    plt.xlabel('X')
    plt.ylabel('warp(y)')

In [None]:
# Plot the original and warped data with x label "y" and y label "warp(y)" 
plt.scatter(model.y, model.u.detach().numpy(), label='Warped Data')
#plt.plot(X, y, label='Original Data')
plt.xlabel("y")
plt.ylabel("warp(y)")
plt.legend()
#list(model.warp.parameters())



#plot_3d_data(X[:, 0], X[:, 1], model.warp(y).detach())
#plot_3d_data(X[:, 0], X[:, 1], y.detach())