In [109]:
import gpytorch
import torch
from itertools import product

In [None]:
torch.linspace(0, 1, 10)

In [111]:
class ExactMIGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, kernel_text="RBF", weights=None):
        super(ExactMIGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.AdditiveStructureKernel(gpytorch.kernels.RBFKernel(active_dims=0) + gpytorch.kernels.MaternKernel(active_dims=1), num_dims=2)

    def forward(self, x):
        mean_x = self.mean_module(x)#sum([self.mean_module(xp) for xp in x])
        covar_x = self.covar_module(x)#sum([covar(xp) for xp, covar in zip(x, self.covar_module)])
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [None]:
X_SIZE = 10 

# data MUST be pairs of x_i values i.e. (x1, x2), (x1, x3), (x2, x3), etc.
train_x = torch.tensor(list(product(torch.linspace(0, 5, X_SIZE), torch.linspace(0, 5, X_SIZE))))
train_y = (torch.sin(train_x[:,0]) + torch.cos(train_x[:,1]))#.view(-1, 1)
print(train_x)
print(train_y)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
data_model = ExactMIGPModel(train_x, train_y, likelihood)

In [None]:
gpytorch.means.ZeroMean()(train_x)
data_model(train_x)


In [None]:
plot_2d_gp(data_model, likelihood, x_min=0.0, x_max=1.0, y_min=0.0, y_max=1.0, resolution=50)

In [None]:
zer = gpytorch.means.ZeroMean()
mean_x = sum([zer(xp) for xp in train_x.T])
mean_x.shape
covar_module = [gpytorch.kernels.RBFKernel(), gpytorch.kernels.MaternKernel()]
covar_x = sum([covar(xp) for covar, xp in zip(covar_module, train_x.T)])
covar_x.evaluate()
#covar_x = sum([covar(xp) for xp, covar in zip(x, self.covar_module)])

In [None]:
with torch.no_grad(), gpytorch.settings.prior_mode(True):
    f_preds = data_model(train_x)
all_observations_y = f_preds.sample_n(5)

In [117]:
import matplotlib.pyplot as plt

def plot_2d_gp(model, likelihood, x_min=0.0, x_max=1.0, y_min=0.0, y_max=1.0, resolution=50):
    xx, yy = torch.meshgrid(
        torch.linspace(x_min, x_max, resolution), 
        torch.linspace(y_min, y_max, resolution)
    )
    test_x = torch.stack([xx.reshape(-1), yy.reshape(-1)], dim=-1)

    model.eval()
    likelihood.eval()
    with torch.no_grad():
        preds = likelihood(model(test_x))
        mean = preds.mean.reshape(resolution, resolution)
        lower, upper = preds.confidence_region()
        lower = lower.reshape(resolution, resolution)
        upper = upper.reshape(resolution, resolution)

    plt.figure()
    mean_contour = plt.contourf(xx.numpy(), yy.numpy(), mean.numpy(), levels=50, cmap='viridis')
    plt.colorbar(mean_contour, label='Mean')
    plt.contour(xx.numpy(), yy.numpy(), lower.numpy(), levels=10, linestyles='dotted', colors='white', alpha=0.7)
    plt.contour(xx.numpy(), yy.numpy(), upper.numpy(), levels=10, linestyles='dashed', colors='white', alpha=0.7)
    plt.title('2D GP Mean and Variance')
    plt.xlabel('X1')
    plt.ylabel('X2')
    plt.show()

In [141]:
def plot_3d_gp(model, likelihood, x_min=0.0, x_max=1.0, y_min=0.0, y_max=1.0, resolution=50):

    model.eval()
    likelihood.eval()

    x_vals = torch.linspace(x_min, x_max, resolution)
    y_vals = torch.linspace(y_min, y_max, resolution)
    xx, yy = torch.meshgrid(x_vals, y_vals)
    test_x = torch.stack([xx.reshape(-1), yy.reshape(-1)], dim=-1)

    with torch.no_grad():
        preds = likelihood(model(test_x))
        mean = preds.mean.reshape(resolution, resolution)
        lower, upper = preds.confidence_region()
        lower = lower.reshape(resolution, resolution)
        upper = upper.reshape(resolution, resolution)

    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    # Plot mean surface
    ax.plot_surface(xx.numpy(), yy.numpy(), mean.numpy(), cmap='viridis', alpha=0.8)

    # Plot lower and upper surfaces
    ax.plot_surface(xx.numpy(), yy.numpy(), lower.numpy(), color='gray', alpha=0.2)
    ax.plot_surface(xx.numpy(), yy.numpy(), upper.numpy(), color='gray', alpha=0.2)

    ax.set_title('2D GP in 3D')
    ax.set_xlabel('X1')
    ax.set_ylabel('X2')
    ax.set_zlabel('Mean and Variance Range')

    plt.show()


In [None]:
plot_3d_gp(data_model, likelihood, x_min=-10.0, x_max=10.0, y_min=-10.0, y_max=10.0, resolution=50)