In [4]:
# imports
import torch
import sys
from pathlib import Path

# change import path
path = Path.cwd().parents[1]
if str(path) not in sys.path:
    sys.path.append(str(path))

from src.BayesIMP import *
from src.kernels import *
from src.dgps import *
from src.CBO import *

# CBO prior kernel class
class CBOPriorKernel:
    def __init__(self,kernel_func):
        self.kernel_func = kernel_func

    def get_gram(self,X,Z):
        return self.kernel_func(X,Z)

seed = 40
n = 100 
n_int = 100
two_datasets = True
niter = 1000
learn_rate = 0.1
optimise_mu = True
exact = True
mc_samples = 100


torch.manual_seed(seed)

""" Bayesimp configs """
default_nu = 1.0
reg = 1e-1
Kernel = GaussianKernel
quantiles = torch.linspace(0,1,101)[:,None]

""" CBO configs """
int_samples = 10**5
n_iter = 20
xi = 0.0
update_hyperparameters = False
noise_init = -10.0
cbo_reg = 1e-1


In [5]:
""" Draw int data """
dostatin = torch.linspace(0,1,n_int)
age, bmi, aspirin, statin, cancer, psa = STATIN_PSA(int_samples, 
                                                    seed = seed, 
                                                    gamma = False, 
                                                    interventional_data = True, 
                                                    dostatin = dostatin)
psa,fvol, vol = PSA_VOL(psa = psa)  

""" Draw training data"""
age_, bmi_, aspirin_, statin_, cancer_, psa_ = STATIN_PSA(n, 
                                                          seed = seed, 
                                                          gamma = False, 
                                                          interventional_data = False, 
                                                          dostatin=[])
if two_datasets:
    age_2, bmi_2, aspirin_2, statin_2, cancer_2, psa_2 = STATIN_PSA(n, 
                                                      seed = seed, 
                                                      gamma = False, 
                                                      interventional_data = False, 
                                                      dostatin=[])
    psa_, fvol_, vol_ = PSA_VOL(psa = psa_2)
    A = torch.column_stack((age_, bmi_, aspirin_, statin_))
    V = psa_.reshape(len(psa_),1)
    Y = vol_
    
else:
    psa_, fvol_, vol_ = PSA_VOL(psa = psa_)
    A = torch.column_stack((age_, bmi_, aspirin_, statin_))
    V = [psa_.reshape(len(psa_),1), psa_.reshape(len(psa_),1)]
    Y = vol_

""" Initialise model """
model = BayesIMP(Kernel_A = Kernel, 
               Kernel_V = Kernel, 
               Kernel_Z = [],
               dim_A = A.size()[1], 
               dim_V = V.size()[1], 
               samples = 10**5,
               exact = exact)

""" Train model """
model.train(Y,A,V,niter,learn_rate, optimise_measure = optimise_mu, mc_samples = mc_samples)

  self.kernel_V.dist.scale = torch.tensor(measure_init*V[1].var()**0.5).requires_grad_(optimise_measure)


iter 0 P(Y|V) loss:  tensor(13862.9346)
iter 100 P(Y|V) loss:  tensor(365.2025)
iter 200 P(Y|V) loss:  tensor(285.1972)
iter 300 P(Y|V) loss:  tensor(267.1989)
iter 400 P(Y|V) loss:  tensor(262.0068)
iter 500 P(Y|V) loss:  tensor(260.3353)
iter 600 P(Y|V) loss:  tensor(259.7330)
iter 700 P(Y|V) loss:  tensor(259.4659)
iter 800 P(Y|V) loss:  tensor(259.3099)
iter 900 P(Y|V) loss:  tensor(259.1981)
iter 0 P(V|A) loss:  tensor(-45183.6289)
iter 100 P(V|A) loss:  tensor(-54388.0938)
iter 200 P(V|A) loss:  tensor(-54576.1406)
iter 300 P(V|A) loss:  tensor(-54728.2148)
iter 400 P(V|A) loss:  tensor(-54859.8867)
iter 500 P(V|A) loss:  tensor(-54938.1836)
iter 600 P(V|A) loss:  tensor(-54995.1680)
iter 700 P(V|A) loss:  tensor(-55042.7500)
iter 800 P(V|A) loss:  tensor(-55083.9219)
iter 900 P(V|A) loss:  tensor(-55119.3750)


In [6]:
 """ Get posterior funcs and CBO prior kernel """
def mean(X):
    doA = X.reshape(len(X),1)
    return model.post_mean(Y,A,V,doA,
                           reg = reg, 
                           average_doA = True, 
                           intervention_indices = [3]) 

def cov(X, Z, diag = False):
    doA = X.reshape(len(X),1)
    doA2 = Z.reshape(len(Z),1)
    return model.post_var(Y,A,V,doA,doA2,
                          reg = reg,
                          average_doA = True, 
                          intervention_indices = [3], 
                          diag = diag)

cbo_kernel = CBOPriorKernel(cov)

""" Run CBO """
# Define a grid of intervention points and precompute E[Y|do(x)]
doX = dostatin[:,None]
EYdoX = fvol.reshape(n_int,int_samples).mean(1)[:,None]

# Random search for first intervention point
torch.manual_seed(seed)
start = torch.randint(0,99,(1,))[0]
doXtrain, EYdoXtrain = doX[start].reshape(1,1), EYdoX[start].reshape(1,1)

# Run CBO iters
doXeval, EYdoXeval = causal_bayesian_optimization(X_train = doXtrain, 
                                                    y_train = EYdoXtrain, 
                                                    kernel = cbo_kernel, 
                                                    mean = mean,
                                                    X_test = doX, 
                                                    Y_test = EYdoX, 
                                                    n_iter = n_iter, 
                                                    update_hyperparameters = update_hyperparameters,
                                                    xi = xi, 
                                                    print_ = True, 
                                                    minimise = True,
                                                    noise_init = noise_init,
                                                    reg = cbo_reg)

Iteration 1: X = 1.0, Y = 4.873571395874023
Iteration 2: X = 1.0, Y = 4.873571395874023
Iteration 3: X = 1.0, Y = 4.873571395874023
Iteration 4: X = 1.0, Y = 4.873571395874023
Iteration 5: X = 1.0, Y = 4.873571395874023
Iteration 6: X = 1.0, Y = 4.873571395874023
Iteration 7: X = 1.0, Y = 4.873571395874023
Iteration 8: X = 1.0, Y = 4.873571395874023
Iteration 9: X = 1.0, Y = 4.873571395874023
Iteration 10: X = 1.0, Y = 4.873571395874023
Iteration 11: X = 1.0, Y = 4.873571395874023
Iteration 12: X = 1.0, Y = 4.873571395874023
Iteration 13: X = 1.0, Y = 4.873571395874023
Iteration 14: X = 1.0, Y = 4.873571395874023
Iteration 15: X = 1.0, Y = 4.873571395874023
Iteration 16: X = 1.0, Y = 4.873571395874023
Iteration 17: X = 1.0, Y = 4.873571395874023
Iteration 18: X = 1.0, Y = 4.873571395874023
Iteration 19: X = 1.0, Y = 4.873571395874023
Iteration 20: X = 1.0, Y = 4.873571395874023
