In [None]:
import os
import sys
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from core.kernel import ARDKernel
import torch.optim as optim
import core.GP_CommonCalculation as GP

sys.path.append('..')  # Add parent folder to sys.path
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # Fixing strange error if run in MacOS
JITTER = 1e-6
EPS = 1e-10
PI = 3.1415

In [None]:
# Generate synthetic data
torch.manual_seed(0)
xtr = torch.linspace(10, 20, 100,dtype=torch.float64).unsqueeze(1)
# Exponential growth with multiplicative noise
ytr = torch.exp(0.3 * xtr) * (1 + 0.01 * torch.randn(xtr.size()))

# Test data
xte=torch.linspace(10, 20, 100,dtype=torch.float64).unsqueeze(1)
yte = torch.exp(0.3 * xte)
# Plot the original data
plt.figure(figsize=(12, 6))

plt.plot(xtr.numpy(), ytr.numpy(), 'b*', label='Original Data')
plt.plot(xte.numpy(),yte.numpy(),'r')
plt.xlabel('X')
plt.ylabel('y')
plt.title('Original Data')
plt.legend()


plt.show()


In [None]:
warp=GP.Warp(method='log')
y_log=warp.transform(ytr)

In [None]:
kernel= ARDKernel(1)

In [None]:
#initiate_log_beta
log_beta = nn.Parameter(torch.ones(1) * -4) # this is a large noise. we optimize to shrink it to a proper value.

In [None]:
def negative_log_likelihood(xtr, ytr, kernel,log_beta):
    Sigma = kernel(xtr, xtr) + log_beta.exp().pow(-1) * torch.eye(
        xtr.size(0)) + JITTER * torch.eye(xtr.size(0))
    return -GP.Gaussian_log_likelihood(ytr, Sigma) #you may add a loss term for the change of variable, but it is a constant term, which does not affect the optimization.

In [None]:
def forward(xtr, ytr, xte, kernel,log_beta):
    n_test = xte.size(0)

    Sigma = kernel(xtr, xtr) + log_beta.exp().pow(-1) * torch.eye(
        xtr.size(0)) + JITTER * torch.eye(xtr.size(0))

    K_s = kernel(xtr, xte)
    K_ss= kernel(xte, xte)

    mean, var = GP.conditional_Gaussian(ytr, Sigma, K_s, K_ss)

    var_diag = var.sum(dim=0).view(-1, 1)
    var_diag = var_diag + log_beta.exp().pow(-1)
    return mean, var_diag

In [None]:
def train_adam(xtr, y_log, kernel, log_beta, niteration=10, lr=0.1):
    # Adam optimizer
    optimizer = optim.Adam([
        {'params': kernel.parameters()},
        {'params': [log_beta]}
    ], lr=lr)

    for i in range(niteration):
        optimizer.zero_grad()
        loss = negative_log_likelihood(xtr, y_log, kernel, log_beta)
        loss.backward()
        optimizer.step()

        # Print kernel parameters
        #for name, param in kernel.named_parameters():
            #if param.requires_grad:
                #print(f'{name}: {param.data}')

        #print('log_beta:', log_beta.data)
        if (i+1)%10==0:
            print('iter', i+1, 'nll:{:.5f}'.format(loss.item()))

In [None]:
train_adam(xtr, y_log, kernel, log_beta, niteration=200,lr=0.1)

In [None]:
with torch.no_grad():
    ypred_log, yvar_log = forward(xtr, y_log, xte, kernel,log_beta)

In [None]:
# Back-transform
ypred,yvar=warp.back_transform(ypred_log, yvar_log)

In [None]:
#plt.plot(xte.numpy(), (ypred+0.5*yvar).exp(), 'r', label='Predicted Mean')
plt.errorbar(xte.numpy().reshape(100), ypred.detach().numpy().reshape(100),
             yerr=yvar.sqrt().squeeze().detach().numpy(), fmt='r-.', alpha=0.2)
plt.plot(xtr.detach().numpy(), ytr.detach().numpy(), 'b+')
plt.figure(figsize=(12, 6))
plt.show()

In [None]:
%%capture captured_output
#reset parameter and train the model without log transform
kernel=ARDKernel(1)
log_beta = nn.Parameter(torch.ones(1) * -4) 
train_adam(xtr, ytr, kernel, log_beta, niteration=1000,lr=0.1)

In [None]:
 #By running this cell, you will see the output of the training process. This output will demonstrate that without performing a log transform, the training process need more iteration to converge.
captured_output.show()

In [29]:
with torch.no_grad():
    ypred2, yvar2 = forward(xtr, ytr, xte, kernel,log_beta)
mse_log = ((yte - ypred)**2).mean()
print('mse_log_transform:',mse_log)
mse_standard = ((yte - ypred2)**2).mean()
print('mse_standard:',mse_standard)

mse_log_transform: tensor(0.0471, dtype=torch.float64)
mse_standard: tensor(0.0789, dtype=torch.float64)
