In [None]:
import torch
import numpy as np
from scipy.integrate import quad
from norm_constrain import Norm2ConstrainedContainer_rational, Norm2ConstrainedContainer_SE,Norm2ConstrainedContainer_ConvexCombination
from gpytorch.distributions import MultivariateNormal
from matplotlib import pyplot as plt
import matplotlib.cm as cm
%matplotlib inline
from cycler import cycler

torch.set_default_dtype(torch.float64)
torch.random.manual_seed(12345)


def C1_integrand(x, nk):

    res = nk.C1(x)
    res = res.detach().numpy() * x**2
    return res

def C2_integrand(x, y, nk):
    res = nk.base_kernel_eval(x,y)*x**2
    res = res.detach().numpy()
    return res

def quadr(y, dx):
    res = y[...,0] * x[0]**2 * dx * 0.5 + (y[...,1:]*x[1:]**2).sum(axis=-1) * dx 
    return res

def C0_Consistency(nk):
    print("C0 Consistency check:")
    quad_res = quad(C1_integrand, 0.0, np.inf, args=(nk,))
    C0_eval = nk.C0().detach().numpy()
    print("C0 from quadrature = %12.5E ; Error = %10.3E" % quad_res)
    print("C0 evaluated by kernel = %12.5E\n" % C0_eval)

def C1_Consistency(nk, test_vals):
    print("C1 Consistency check:")
    for y in test_vals:
        quad_res = quad(C2_integrand, 0.0, np.inf, args=(y, nk))
        C1_eval = nk.C1(y)
        print("At x1 = %12.5E, C1 from quadrature = %12.5E ; Error = %10.3E" % (y,quad_res[0], quad_res[1],))
        print("C1 evaluated by kernel = %12.5E\n" % C1_eval)

def mean_normalization_check(mean, dx):
    print("Mean normalization check:")
    ig = quadr(mean,dx)
    print ("Mean normalization = %12.5E\n" % (ig,))

def sample_curve_normalization_check(curves, dx):
    print("Sample curve normalization check:")
    igs = quadr(curves, dx)
    for ig in igs:
        print ("Curve normalization = %12.5E\n" % (ig,))

def plot_mean_and_samples(mean, curves=None, plot_filename=None):
    fig1 = plt.figure()
    fig1.set_figwidth(16.0)
    fig1.set_figheight(8.0)
    ax = fig1.add_subplot(1,2,1)

    custom_cycler = (cycler(color=['c', 'm', 'y', 'k']) +
                    cycler(ls=['--', ':', '-.',':']))
    ax.set_prop_cycle(custom_cycler)

    xx = x.numpy()

    mn = mean.detach().numpy()
    ax.plot(xx, mn, "r-")

    if curves is not None:
        sc = curves.detach().numpy()
        for s in sc:
            ax.plot(xx, s)

    ax.tick_params(labelsize=20)
    ax.set_xlabel('X', size=20)
    ax.set_ylabel('Y', size=20)
    ax.set_xlim(xmin=0.0)
    ax.set_ylim(ymin=0.0)

    ax = fig1.add_subplot(1,2,2)
    ax.plot(xx, mn*xx**2, "r-")

    if curves is not None:
        sc = curves.detach().numpy()
        for s in sc:
            ax.plot(xx, s*xx**2)


    ax.tick_params(labelsize=20)
    ax.set_xlabel('X', size=20)
    ax.set_ylabel(r'Y$\times$X$^2$', size=20)
    ax.set_xlim(xmin=0.0)
    ax.set_ylim(ymin=0.0)

    if plot_filename is not None:
        plt.savefig(plot_filename, format="png")

In [None]:
nval = 2.0
nmod1 = Norm2ConstrainedContainer_SE(norm_val=nval)
nmod1.sigma = 1.0 ; nmod1.gamma=1.0 ; nmod1.A = 4.0

nk = nmod1.covar_module ; nm = nmod1.mean_module


C0_Consistency(nk)

xt = torch.tensor([0.0,0.2,0.4,0.6,0.8,1.0])
C1_Consistency(nk, xt)

llim = 0.0
ulim = 8.0
nsamp = 200
dx = (ulim-llim) / (nsamp - 1)

x = torch.linspace(llim,ulim, nsamp)
mean = nm(x)
cov = nk(x)
small = 1.0E-10
cov = cov.to_dense() + small * torch.eye(nsamp)
mvn = MultivariateNormal(mean, cov)
ncurves = 5
curves = mvn.rsample(sample_shape=torch.Size((ncurves,)))

mean_normalization_check(mean, dx)
sample_curve_normalization_check(curves, dx)

plot_mean_and_samples(mean, curves=curves, plot_filename="norm_constrain_SE.png")

In [None]:
nmod2 = Norm2ConstrainedContainer_rational(norm_val=nval)
nmod2.alpha = 1.0 ; nmod2.p = 10.0 ; nmod2.A = 1000.0
nk = nmod2.covar_module ; nm = nmod2.mean_module

C0_Consistency(nk)

C1_Consistency(nk, xt)

llim = 0.0
ulim = 10.0
nsamp = 200
dx = (ulim-llim) / (nsamp - 1)

x = torch.linspace(llim,ulim, nsamp)
mean = nm(x)
cov = nk(x)
small = 1.0E-10
cov = cov.to_dense() + small * torch.eye(nsamp)
mvn = MultivariateNormal(mean, cov)
ncurves = 5
curves = mvn.rsample(sample_shape=torch.Size((ncurves,)))

mean_normalization_check(mean, dx)
sample_curve_normalization_check(curves, dx)


plot_mean_and_samples(mean, curves, "norm_constrain_rational.png")
#plot_mean_and_samples(mean, plot_filename="norm_constrain_rational.png")

In [None]:
nmod3 = Norm2ConstrainedContainer_ConvexCombination(norm_val=nval,kernels=(nmod1,nmod2))

# for nm,par in nmod3.named_parameters():
#     print(f"Parameter Name: {nm}\nParameter: {par}\n\n")
nk = nmod3.covar_module ; nm = nmod3.mean_module


C0_Consistency(nk)

C1_Consistency(nk, xt)


llim = 0.0
ulim = 10.0
nsamp = 200
dx = (ulim-llim) / (nsamp - 1)

x = torch.linspace(llim,ulim, nsamp)
mean = nm(x)
cov = nk(x)
small = 1.0E-10
cov = cov.to_dense() + small * torch.eye(nsamp)
mvn = MultivariateNormal(mean, cov)
ncurves = 5
curves = mvn.rsample(sample_shape=torch.Size((ncurves,)))

mean_normalization_check(mean, dx)
sample_curve_normalization_check(curves, dx)


plot_mean_and_samples(mean, curves, "norm_constrain_convex.png")