In [4]:
# imports
import torch
import sys
from pathlib import Path

# change import path
path = Path.cwd().parents[1]
if str(path) not in sys.path:
    sys.path.append(str(path))

from src.causalKLGP import *
from src.kernels import *
from src.dgps import *
from src.CBO import *

# CBO prior kernel class
class CBOPriorKernel:
    def __init__(self,kernel_func):
        self.kernel_func = kernel_func

    def get_gram(self,X,Z):
        return self.kernel_func(X,Z)

seed = 1 
n = 100 
n_int = 100
two_datasets = True
niter = 1000
learn_rate = 0.1
calibrate = False
sample_split = False
marginal_loss = False
retrain_hypers = False

torch.manual_seed(seed)

""" causalklgp configs """
default_nu = 1.0
cal_nulist = 2**torch.linspace(-4,4,5)
quantiles = torch.linspace(0,1,101)[:,None]
reg = 1e-3
Kernel = GaussianKernel
force_PD = True

""" CBO configs """
int_samples = 10**5
n_iter = 20
xi = 0.0
update_hyperparameters = False
noise_init = -10.0
cbo_reg = 1e-1

In [5]:
""" Draw int data """
dostatin = torch.linspace(0,1,n_int)
age, bmi, aspirin, statin, cancer, psa = STATIN_PSA(int_samples, 
                                                    seed = seed, 
                                                    gamma = False, 
                                                    interventional_data = True, 
                                                    dostatin = dostatin)
psa,fvol, vol = PSA_VOL(psa = psa)  

""" Draw training data"""
age_, bmi_, aspirin_, statin_, cancer_, psa_ = STATIN_PSA(n, 
                                                          seed = seed, 
                                                          gamma = False, 
                                                          interventional_data = False, 
                                                          dostatin=[])
if two_datasets:
    age_2, bmi_2, aspirin_2, statin_2, cancer_2, psa_2 = STATIN_PSA(n, 
                                                      seed = seed, 
                                                      gamma = False, 
                                                      interventional_data = False, 
                                                      dostatin=[])
    psa_, fvol_, vol_ = PSA_VOL(psa = psa_2)
    A = torch.column_stack((age_, bmi_, aspirin_, statin_))
    V = psa_.reshape(len(psa_),1)
    Y = vol_
    
else:
    psa_, fvol_, vol_ = PSA_VOL(psa = psa_)
    A = torch.column_stack((age_, bmi_, aspirin_, statin_))
    V = [psa_.reshape(len(psa_),1), psa_.reshape(len(psa_),1)]
    Y = vol_

""" Initialise model """
model = causalKLGP(Kernel_A = GaussianKernel, 
               Kernel_V = GaussianKernel, 
               Kernel_Z = [],
               dim_A = A.size()[1], 
               dim_V = V.size()[1], 
               samples = 10**5,
               scale_V_init = Y.var()**0.5/2,
               noise_Y_init = torch.log(Y.var()/4)
              )

""" Train + Calibrate model """
if calibrate:
    Post_levels, Calibration_losses = model.frequentist_calibrate(Y, V, A, dostatin[:,None],
                                                                 nulist = cal_nulist,
                                                                 sample_split = sample_split,
                                                                 marginal_loss = marginal_loss,
                                                                 retrain_hypers = retrain_hypers,
                                                                 average_doA = True,
                                                                 intervention_indices = [3],
                                                                 force_PD = force_PD
                                                                )
    best_ind = torch.where(Calibration_losses == Calibration_losses.min())[0][0]
    nu_best = cal_nulist[best_ind]
else:
    nu_best = default_nu
    
if (not calibrate) or sample_split: 
    model.train(Y, A, V, niter, learn_rate, force_PD = force_PD)   

  self.noise_Y = torch.tensor(noise_Y_init, requires_grad = True).float()


iter 0 P(Y|V) loss:  tensor(456.5264)
iter 100 P(Y|V) loss:  tensor(343.1656)
iter 200 P(Y|V) loss:  tensor(342.1736)
iter 300 P(Y|V) loss:  tensor(341.7100)
iter 400 P(Y|V) loss:  tensor(341.3611)
iter 500 P(Y|V) loss:  tensor(341.0620)
iter 600 P(Y|V) loss:  tensor(340.8018)
iter 700 P(Y|V) loss:  tensor(340.5723)
iter 800 P(Y|V) loss:  tensor(340.3672)
iter 900 P(Y|V) loss:  tensor(340.1852)
iter 0 P(V|A) loss:  tensor(-2055.8352)
iter 100 P(V|A) loss:  tensor(-5967.9531)
iter 200 P(V|A) loss:  tensor(-6119.8135)
iter 300 P(V|A) loss:  tensor(-6191.1914)
iter 400 P(V|A) loss:  tensor(-6232.5098)
iter 500 P(V|A) loss:  tensor(-6259.1729)
iter 600 P(V|A) loss:  tensor(-6277.6660)
iter 700 P(V|A) loss:  tensor(-6291.1299)
iter 800 P(V|A) loss:  tensor(-6301.3789)
iter 900 P(V|A) loss:  tensor(-6309.4453)


In [6]:
""" Get posterior funcs and CBO prior kernel """
def mean(X):
    doA = X.reshape(len(X),1)
    return model.post_mean(Y,A,V,doA,
                           reg = 1e-3, 
                           average_doA = True, 
                           intervention_indices = [3]) 

def cov(X, Z, diag = False):
    doA = X.reshape(len(X),1)
    doA2 = Z.reshape(len(Z),1)
    return model.post_var(Y,A,V,doA,doA2,
                          reg = 1e-3,
                          nu = nu_best, 
                          average_doA = True, 
                          intervention_indices = [3], 
                          diag = diag)

cbo_kernel = CBOPriorKernel(cov)

""" Run CBO """
# Define a grid of intervention points and precompute E[Y|do(x)]
doX = dostatin[:,None]
EYdoX = fvol.reshape(n_int,int_samples).mean(1)[:,None]

# Random search for first intervention point
start = torch.randint(0,99,(1,))[0]
doXtrain,EYdoXtrain = doX[start].reshape(1,1),EYdoX[start].reshape(1,1)

# Run CBO iters
doXeval, EYdoXeval = causal_bayesian_optimization(X_train = doXtrain, 
                                                    y_train = EYdoXtrain, 
                                                    kernel = cbo_kernel, 
                                                    mean = mean,
                                                    X_test = doX, 
                                                    Y_test = EYdoX, 
                                                    n_iter = n_iter, 
                                                    update_hyperparameters = update_hyperparameters,
                                                    xi = xi, 
                                                    print_ = True, 
                                                    minimise = True,
                                                    noise_init = noise_init,
                                                    reg = cbo_reg)

Iteration 1: X = 1.0, Y = 10.476282119750977
Iteration 2: X = 1.0, Y = 10.476282119750977
Iteration 3: X = 1.0, Y = 10.476282119750977
Iteration 4: X = 1.0, Y = 10.476282119750977
Iteration 5: X = 1.0, Y = 10.476282119750977
Iteration 6: X = 1.0, Y = 10.476282119750977
Iteration 7: X = 1.0, Y = 10.476282119750977
Iteration 8: X = 1.0, Y = 10.476282119750977
Iteration 9: X = 1.0, Y = 10.476282119750977
Iteration 10: X = 1.0, Y = 10.476282119750977
Iteration 11: X = 1.0, Y = 10.476282119750977
Iteration 12: X = 1.0, Y = 10.476282119750977
Iteration 13: X = 1.0, Y = 10.476282119750977
Iteration 14: X = 1.0, Y = 10.476282119750977
Iteration 15: X = 1.0, Y = 10.476282119750977
Iteration 16: X = 1.0, Y = 10.476282119750977
Iteration 17: X = 1.0, Y = 10.476282119750977
Iteration 18: X = 1.0, Y = 10.476282119750977
Iteration 19: X = 1.0, Y = 10.476282119750977
Iteration 20: X = 1.0, Y = 10.476282119750977
