# Fast Batch Multitask Net GP

In [1]:
import fastgp
import torch
import numpy as np

In [2]:
torch.set_default_dtype(torch.float64)

## True Function

In [3]:
def f(l,x):
    weights = 2**torch.arange(1,l+2)
    return torch.vstack([
        (torch.sin(weights*np.pi*x)/weights).sum(1),
        (torch.cos(weights*np.pi*x)/weights).sum(1),
    ])
num_tasks = 4
d = 1 # dimension
rng = torch.Generator().manual_seed(17)
x = torch.rand((2**7,d),generator=rng) # random testing locations
y = torch.cat([f(l,x)[:,None,:] for l in range(num_tasks)],1) # true values at random testing locations
z = torch.rand((2**8,d),generator=rng) # other random locations at which to evaluate covariance
print("x.shape = %s"%str(tuple(x.shape)))
print("y.shape = %s"%str(tuple(y.shape)))
print("z.shape = %s"%str(tuple(z.shape)))

x.shape = (128, 1)
y.shape = (2, 4, 128)
z.shape = (256, 1)


## Construct Fast GP

In [4]:
fgp = fastgp.FastGPDigitalNetB2(d,seed_for_seq=7,num_tasks=num_tasks)
x_next = fgp.get_x_next(n=2**torch.arange(5,5+num_tasks))
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
fgp.add_y_next(y_next)
assert len(x_next)==len(y_next)
for i in range(len(x_next)):
    print("i = %d"%i)
    print("\tx_next[%d].shape = %s"%(i,str(tuple(x_next[i].shape))))
    print("\ty_next[%d].shape = %s"%(i,str(tuple(y_next[i].shape))))

i = 0
	x_next[0].shape = (32, 1)
	y_next[0].shape = (2, 32)
i = 1
	x_next[1].shape = (64, 1)
	y_next[1].shape = (2, 64)
i = 2
	x_next[2].shape = (128, 1)
	y_next[2].shape = (2, 128)
i = 3
	x_next[3].shape = (256, 1)
	y_next[3].shape = (2, 256)


In [5]:
pmean = fgp.post_mean(x)
print("pmean.shape = %s"%str(tuple(pmean.shape)))
print("l2 relative error:\n%s"%str(torch.linalg.norm(y-pmean,dim=-1)/torch.linalg.norm(y,dim=-1)))

pmean.shape = (2, 4, 128)
l2 relative error:
tensor([[0.0781, 0.0383, 0.0265, 0.0227],
        [0.0797, 0.0374, 0.0279, 0.0283]])


In [6]:
data = fgp.fit()
list(data.keys())

     iter of 5.0e+03 | NMLL       | noise      | scale      | lengthscales         | task_kernel 
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            0.00e+00 | 5.02e+03   | 1.00e-16   | 1.00e+00   | [1.0e+00]            | [[2.0e+00 1.0e+00 1.0e+00 1.0e+00] [1.0e+00 2.0e+00 1.0e+00 1.0e+00] [1.0e+00 1.0e+00 2.0e+00 1.0e+00] [1.0e+00 1.0e+00 1.0e+00 2.0e+00]]
            5.00e+00 | -3.57e+04  | 1.00e-16   | 4.75e-01   | [4.8e-01]            | [[5.4e-01 6.5e-02 6.5e-02 6.5e-02] [6.5e-02 5.4e-01 6.5e-02 6.5e-02] [6.5e-02 6.5e-02 5.4e-01 6.5e-02] [6.5e-02 6.5e-02 6.5e-02 5.4e-01]]
            1.00e+01 | -6.87e+04  | 1.00e-16   | 7.46e-02   | [7.5e-02]            | [[9.5e-02 2.0e-02 2.0e-02 -5.2e-03] [2.0e-02 9.5e-02 2.0e-02 -5.2e-03] [2.0e-02 2.0e-02 9.5e-02 -5.2e-03] [-5.2e-03 -5.2e-03 -5.2e-03 7.6e-02]]
            1.50e+01 | -9.05e+04  | 1.00e-16   | 1.13e-01   | [1.1e-01]            | [[1.2e-01 7.5e-02 7.5e-02 2.2e-02] [7.5e-02 

['mll_hist', 'scale_hist', 'lengthscales_hist', 'task_kernel_hist']

In [7]:
pmean,pvar,q,ci_low,ci_high = fgp.post_ci(x,confidence=0.99)
print("pmean.shape = %s"%str(tuple(pmean.shape)))
print("pvar.shape = %s"%str(tuple(pvar.shape)))
print("q = %.2f"%q)
print("ci_low.shape = %s"%str(tuple(ci_low.shape)))
print("ci_high.shape = %s"%str(tuple(ci_high.shape)))
print("l2 relative error:\n%s"%str(torch.linalg.norm(y-pmean,dim=-1)/torch.linalg.norm(y,dim=-1)))
pcov = fgp.post_cov(x,x)
print("pcov.shape = %s"%str(tuple(pcov.shape)))
_range0,_rangen1 = torch.arange(pcov.size(0)),torch.arange(pcov.size(-1))
assert torch.allclose(pcov[_range0,_range0][:,_rangen1,_rangen1],pvar) and (pvar>=0).all()
pcov2 = fgp.post_cov(x,z)
print("pcov2.shape = %s"%str(tuple(pcov2.shape)))

pmean.shape = (2, 4, 128)
pvar.shape = (4, 128)
q = 2.58
ci_low.shape = (2, 4, 128)
ci_high.shape = (2, 4, 128)
l2 relative error:
tensor([[0.0656, 0.0390, 0.0252, 0.0241],
        [0.0715, 0.0370, 0.0306, 0.0315]])
pcov.shape = (4, 4, 128, 128)
pcov2.shape = (4, 4, 128, 256)


In [8]:
pcmean,pcvar,q,cci_low,cci_high = fgp.post_cubature_ci(confidence=0.99)
print("pcmean:\n%s"%str(pcmean))
print("\npcvar:\n%s"%str(pcvar))
print("\ncci_low:\n%s"%str(cci_low))
print("\ncci_high:\n%s"%str(cci_high))

pcmean:
tensor([[ 1.2705e-20,  9.0633e-20,  4.9551e-20,  4.6587e-20],
        [-1.8381e-19, -3.1001e-19, -3.7439e-19, -3.8625e-19]])

pcvar:
tensor([3.5095e-08, 2.8624e-08, 2.5938e-08, 2.7794e-08])

cci_low:
tensor([[-0.0005, -0.0004, -0.0004, -0.0004],
        [-0.0005, -0.0004, -0.0004, -0.0004]])

cci_high:
tensor([[0.0005, 0.0004, 0.0004, 0.0004],
        [0.0005, 0.0004, 0.0004, 0.0004]])


## Project and Increase Sample Size

In [9]:
n_new = fgp.n*2**torch.arange(num_tasks-1,-1,-1)
pcov_future = fgp.post_cov(x,z,n=n_new)
pvar_future = fgp.post_var(x,n=n_new)
pcvar_future = fgp.post_cubature_var(n=n_new)

In [10]:
x_next = fgp.get_x_next(n_new)
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
for _y in y_next:
    print(_y.shape)
fgp.add_y_next(y_next)
print("l2 relative error:\n%s"%str(torch.linalg.norm(y-fgp.post_mean(x),dim=-1)/torch.linalg.norm(y,dim=-1)))
assert torch.allclose(fgp.post_cov(x,z),pcov_future)
assert torch.allclose(fgp.post_var(x),pvar_future)
assert torch.allclose(fgp.post_cubature_var(),pcvar_future)

torch.Size([2, 224])
torch.Size([2, 192])
torch.Size([2, 128])
torch.Size([2, 0])
l2 relative error:
tensor([[0.0163, 0.0163, 0.0195, 0.0239],
        [0.0157, 0.0180, 0.0228, 0.0321]])


In [11]:
data = fgp.fit(verbose=False,store_mll_hist=False,store_scale_hist=False,store_lengthscales_hist=False,store_noise_hist=False)
print("l2 relative error:\n%s"%str(torch.linalg.norm(y-fgp.post_mean(x),dim=-1)/torch.linalg.norm(y,dim=-1)))

l2 relative error:
tensor([[0.0165, 0.0161, 0.0200, 0.0240],
        [0.0154, 0.0171, 0.0225, 0.0312]])


In [12]:
n_new = fgp.n*2**torch.arange(num_tasks)
pcov_new = fgp.post_cov(x,z,n=n_new)
pvar_new = fgp.post_var(x,n=n_new)
pcvar_new = fgp.post_cubature_var(n=n_new)
x_next = fgp.get_x_next(n_new)
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
fgp.add_y_next(y_next)
assert torch.allclose(fgp.post_cov(x,z),pcov_new)
assert torch.allclose(fgp.post_var(x),pvar_new)
assert torch.allclose(fgp.post_cubature_var(),pcvar_new)