# Set up

In [1]:
import torch
import gpytorch
import pandas as pd
import numpy as np
import tqdm as tqdm
from linear_operator import settings

import pyro
import math
import pickle
import time
from joblib import Parallel, delayed

from sklearn.preprocessing import StandardScaler

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import arviz as az
import seaborn as sns

import os

from torch.utils.data import TensorDataset, DataLoader
import itertools

In [2]:
import GP_functions.Loss_function as Loss_function
import GP_functions.bound as bound
import GP_functions.Estimation as Estimation
import GP_functions.Training as Training
import GP_functions.Prediction as Prediction
import GP_functions.GP_models as GP_models
import GP_functions.Tools as Tools
import GP_functions.FeatureE as FeatureE

# Data

In [3]:
X_train = pd.read_csv('Data/X_train.csv', header=None, delimiter=',').values
X_test = pd.read_csv('Data/X_test.csv', header=None, delimiter=',').values

Y_train_21 = pd.read_csv('Data/Y_train_std_21.csv', header=None, delimiter=',').values
Y_test_21 = pd.read_csv('Data/Y_test_std_21.csv', header=None, delimiter=',').values

Y_train_std = pd.read_csv('Data/Y_train_std.csv', header=None, delimiter=',').values
Y_test_std = pd.read_csv('Data/Y_test_std.csv', header=None, delimiter=',').values


train_x = torch.tensor(X_train, dtype=torch.float32)
test_x = torch.tensor(X_test, dtype=torch.float32)

train_y_21 = torch.tensor(Y_train_21, dtype=torch.float32)
test_y_21 = torch.tensor(Y_test_21, dtype=torch.float32)

# train_y = torch.tensor(Y_train_std, dtype=torch.float32)
# test_y = torch.tensor(Y_test_std, dtype=torch.float32)


# torch.set_default_dtype(torch.float32)

# Model

In [7]:
class DGPHiddenLayer(gpytorch.models.deep_gps.DeepGPLayer):
    def __init__(
        self,
        input_dims,
        output_dims,
        num_inducing = 512,
        covar_type = "RBF",
        linear_mean = False,
        train_x_for_init = None
    ):
        self.input_dims = input_dims
        self.output_dims = output_dims
        batch_shape = torch.Size([output_dims])

        if train_x_for_init is not None:
            idx = torch.randperm(train_x_for_init.size(0))[:num_inducing]
            inducing_points = train_x_for_init[idx].clone()
            inducing_points = inducing_points.unsqueeze(0).expand(
                output_dims, -1, -1
            )  # B x M x D
        else:
            inducing_points = (
                torch.rand(output_dims, num_inducing, input_dims) * 4.9 + 0.1
            )

        variational_dist = gpytorch.variational.CholeskyVariationalDistribution(
            num_inducing_points=num_inducing,
            batch_shape=batch_shape,
        )
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self,
            inducing_points,
            variational_dist,
            learn_inducing_locations=True,
        )

        super().__init__(variational_strategy, input_dims, output_dims)
        
        self.mean_module = gpytorch.means.ZeroMean() if linear_mean else gpytorch.means.LinearMean(input_dims)
        
        if covar_type == 'Matern5/2':
            base_kernel = gpytorch.kernels.MaternKernel(nu=2.5,
                                                        batch_shape=batch_shape,
                                                        ard_num_dims=input_dims)
        elif covar_type == 'RBF':
            base_kernel = gpytorch.kernels.RBFKernel(batch_shape=batch_shape,
                                                     ard_num_dims=input_dims)
        elif covar_type == 'Matern3/2':
            base_kernel = gpytorch.kernels.MaternKernel(nu=1.5,
                                                        batch_shape=batch_shape,
                                                        ard_num_dims=input_dims)
        elif covar_type == 'RQ':
            base_kernel = gpytorch.kernels.RQKernel(batch_shape=batch_shape,
                                                    ard_num_dims=input_dims)
        elif covar_type == 'PiecewisePolynomial':
            base_kernel = gpytorch.kernels.PiecewisePolynomialKernel(q=2,
                                                                     batch_shape=batch_shape,
                                                                     ard_num_dims=input_dims)
        else:
            raise ValueError("RBF, Matern5/2, Matern3/2, RQ, PiecewisePolynomial")
        
        self.covar_module = gpytorch.kernels.ScaleKernel(base_kernel,
                                                         batch_shape=batch_shape, 
                                                         ard_num_dims=None)
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    


class DeepGP2(gpytorch.models.deep_gps.DeepGP):
    def __init__(
        self,
        train_x,
        train_y,
        hidden_dim = 4,
        inducing_num = 512,
        covar_types = ["RBF", "RBF"],
    ):
        num_tasks = train_y.size(-1)

        layer1 = DGPHiddenLayer(
            input_dims=train_x.size(-1),
            output_dims=hidden_dim,
            num_inducing=inducing_num,
            covar_type=covar_types[0],
            train_x_for_init=train_x,
        )
        layer2 = DGPHiddenLayer(
            input_dims=hidden_dim,
            output_dims=num_tasks,
            num_inducing=inducing_num,
            covar_type=covar_types[1],
            linear_mean=True,
            train_x_for_init=train_x,
        )

        super().__init__()
        self.layers = torch.nn.ModuleList([layer1, layer2])
        self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_tasks)

    def forward(self, x):
        x = self.layers[0](x)
        return self.layers[1](x)
    
    def predict(self, test_x):
        # with gpytorch.settings.fast_pred_var():
        preds = self.likelihood(self(test_x)).to_data_independent_dist()

        return preds.mean.mean(0).squeeze(), preds.variance.mean(0).squeeze()

In [5]:
def train_dgp_minibatch(
    train_x,
    train_y,
    hidden_dim = 4,
    inducing_num = 512,
    num_iterations = 3000,
    patience = 100,
    batch_size = 256,
    eval_every = 200,
    eval_batch_size = 1024,
    lr = 0.05,
    device = "cuda"
):
    train_x, train_y = train_x.to(device), train_y.to(device)

    model = DeepGP2(
        train_x, train_y, hidden_dim, inducing_num
    ).to(device)

    model.train()


    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    mll = gpytorch.mlls.DeepApproximateMLL(
        gpytorch.mlls.VariationalELBO(
            likelihood=model.likelihood, model=model, num_data=train_y.size(0)
        )
    )


    best_loss = float("inf")
    best_state = model.state_dict()
    no_improve = 0

    loader = itertools.cycle(
        DataLoader(TensorDataset(train_x, train_y), batch_size, shuffle=True)
    )

    # --- jitter ---
    jitter_ctx = gpytorch.settings.variational_cholesky_jitter(1e-3)

    with tqdm.tqdm(total=num_iterations, desc="Training DGP") as pbar, jitter_ctx:
        for step in range(num_iterations):
            x_batch, y_batch = next(loader)
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            output = model(x_batch)
            loss = -mll(output, y_batch)
            loss.backward()
            optimizer.step()

            if (step + 1) % eval_every == 0:
                model.eval()
                with torch.no_grad(), gpytorch.settings.fast_pred_var():
                    total_loss = 0.0
                    for i in range(0, train_x.size(0), eval_batch_size):
                        xb, yb = (
                            train_x[i : i + eval_batch_size],
                            train_y[i : i + eval_batch_size],
                        )
                        out = model(xb)
                        total_loss += -mll(out, yb).item() * yb.size(0)
                full_loss = total_loss / train_x.size(0)
                pbar.set_postfix(loss=f"{full_loss:.4f}")
                model.train()

                if full_loss < best_loss - 1e-4:
                    best_loss, best_state, no_improve = full_loss, model.state_dict(), 0
                else:
                    no_improve += 1
                    if no_improve >= patience:
                        print("Early stopping")
                        break
            pbar.update(1)

    model.load_state_dict(best_state)
    model.eval()
    return model

In [35]:
dgp_model= train_dgp_minibatch(train_x, train_y_21,
                                hidden_dim = 10,
                                inducing_num = 100,
                                num_iterations = 5000,
                                patience = 100,
                                batch_size = 256,
                                eval_every = 100,
                                eval_batch_size = 1024,
                                lr = 0.05,
                                device = "cuda")

Training DGP: 100%|██████████| 5000/5000 [04:00<00:00, 20.75it/s, loss=-24.7777]


In [36]:
test_x = test_x.to('cuda')

In [37]:
dgp_model.eval()
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    mean, var = dgp_model.predict(test_x[0,:].unsqueeze(0))

In [38]:
dgp_model.predict(test_x[0,:].unsqueeze(0))

(tensor([-0.9150,  4.2057, -0.1588,  1.6351,  1.0829,  1.3062, -0.2847, -0.0042,
         -0.3159,  0.0779, -0.0245,  0.1177, -0.0522, -0.2008,  0.1500, -0.1301,
         -0.0589, -0.0170,  0.0771,  0.0000,  0.0121], device='cuda:0',
        grad_fn=<SqueezeBackward0>),
 tensor([0.0038, 0.0225, 0.0246, 0.0209, 0.0151, 0.0133, 0.0033, 0.0035, 0.0034,
         0.0031, 0.0028, 0.0030, 0.0030, 0.0031, 0.0026, 0.0025, 0.0024, 0.0025,
         0.0024, 0.0097, 0.0024], device='cuda:0', grad_fn=<SqueezeBackward0>))

In [39]:
test_y_21[0]

tensor([-0.9226,  4.2240, -0.2954,  1.5937,  1.0363,  1.3768, -0.2562,  0.0340,
        -0.3592,  0.1005,  0.0101,  0.1750, -0.0758, -0.2362,  0.1707, -0.1471,
        -0.0780, -0.0288,  0.0638,  0.0508,  0.0267])