In [1]:
# imports
import torch
import sys
from pathlib import Path

# change import path
path = Path.cwd().parents[1]
if str(path) not in sys.path:
    sys.path.append(str(path))

from src.BayesIMP import *
from src.kernels import *
from src.dgps import *
from src.CBO import *

# CBO prior kernel class
class CBOPriorKernel:
    def __init__(model,kernel_func):
        model.kernel_func = kernel_func

    def get_gram(model,X,Z):
        return model.kernel_func(X,Z)

seed = 19
n = 100 
n_int = 100
two_datasets = True
niter = 1000
learn_rate = 0.1
optimise_mu = False
exact = True
mc_samples = 100


torch.manual_seed(seed)

""" Bayesimp configs """
default_nu = 1.0
reg = 1e-1
Kernel = GaussianKernel
quantiles = torch.linspace(0,1,101)[:,None]

""" CBO configs """
int_samples = 10**5
n_iter = 20
xi = 0.0
update_hyperparameters = False
noise_init = -10.0
cbo_reg = 1e-1




In [2]:
""" Draw int data """
dostatin = torch.linspace(0,1,n_int)
age, bmi, aspirin, statin, cancer, psa = STATIN_PSA(int_samples, 
                                                    seed = seed, 
                                                    gamma = False, 
                                                    interventional_data = True, 
                                                    dostatin = dostatin)
psa,fvol, vol = PSA_VOL(psa = psa)  

""" Draw training data"""
age_, bmi_, aspirin_, statin_, cancer_, psa_ = STATIN_PSA(n, 
                                                          seed = seed, 
                                                          gamma = False, 
                                                          interventional_data = False, 
                                                          dostatin=[])
if two_datasets:
    age_2, bmi_2, aspirin_2, statin_2, cancer_2, psa_2 = STATIN_PSA(n, 
                                                      seed = seed+1, 
                                                      gamma = False, 
                                                      interventional_data = False, 
                                                      dostatin=[])
    psa_2, fvol_2, vol_2 = PSA_VOL(psa = psa_2)
    A = torch.column_stack((age_, bmi_, aspirin_, statin_))
    V = [psa_.reshape(len(psa_),1),psa_2.reshape(len(psa_),1)]
    Y = vol_2
    
else:
    psa_, fvol_, vol_ = PSA_VOL(psa = psa_)
    A = torch.column_stack((age_, bmi_, aspirin_, statin_))
    V = [psa_.reshape(len(psa_),1), psa_.reshape(len(psa_),1)]
    Y = vol_

""" Initialise model """
model = BayesIMP(Kernel_A = Kernel, 
               Kernel_V = Kernel, 
               Kernel_Z = [],
               dim_A = A.size()[1], 
               dim_V = V[1].size()[1], 
               samples = 10**5,
               exact = exact,
               scale_V_init = Y.var()**0.5/2,
               noise_Y_init = torch.log(Y.var()/4)
              )

""" Train model """
model.train(Y,A,V,niter,learn_rate, optimise_measure = optimise_mu, mc_samples = mc_samples, reg = reg)

  scale = torch.tensor(scale_V_init, requires_grad = True))
  self.noise_Y = torch.tensor(noise_Y_init, requires_grad = True).float()
  self.kernel_V.dist.scale = torch.tensor(measure_init*V[1].var()**0.5).requires_grad_(optimise_measure)


iter 0 P(Y|V) loss:  tensor(889.4224)
iter 100 P(Y|V) loss:  tensor(165.9147)
iter 200 P(Y|V) loss:  tensor(164.5506)
iter 300 P(Y|V) loss:  tensor(164.1805)
iter 400 P(Y|V) loss:  tensor(164.0024)
iter 500 P(Y|V) loss:  tensor(163.8721)
iter 600 P(Y|V) loss:  tensor(163.7448)
iter 700 P(Y|V) loss:  tensor(163.6001)
iter 800 P(Y|V) loss:  tensor(163.4211)
iter 900 P(Y|V) loss:  tensor(163.1900)
iter 0 P(V|A) loss:  tensor(-21002.3613)
iter 100 P(V|A) loss:  tensor(-38754.6602)
iter 200 P(V|A) loss:  tensor(-39047.5703)
iter 300 P(V|A) loss:  tensor(-39221.3125)
iter 400 P(V|A) loss:  tensor(-39305.3789)
iter 500 P(V|A) loss:  tensor(-39352.6328)
iter 600 P(V|A) loss:  tensor(-39475.2383)
iter 700 P(V|A) loss:  tensor(-39532.)
iter 800 P(V|A) loss:  tensor(-39572.1914)
iter 900 P(V|A) loss:  tensor(-39603.2695)


In [3]:
""" Get posterior funcs and CBO prior kernel """
def mean(X):
    doA = X.reshape(len(X),1)
    return model.post_mean(Y,A,V,doA,
                           reg = reg, 
                           average_doA = True, 
                           intervention_indices = [3]) 

def cov(X, Z, diag = False):
    doA = X.reshape(len(X),1)
    doA2 = Z.reshape(len(Z),1)
    return model.post_var(Y,A,V,doA,doA2,
                          reg = reg,
                          average_doA = True, 
                          intervention_indices = [3], 
                          diag = diag)

cbo_kernel = CBOPriorKernel(cov)

# Define a grid of intervention points and precompute E[Y|do(x)]
doX = dostatin[:,None]
EYdoX = fvol.reshape(n_int,int_samples).mean(1)[:,None]

# Random search for first intervention point
torch.manual_seed(seed)
start = torch.randint(0,99,(1,))[0]
doXtrain, EYdoXtrain = doX[start].reshape(1,1), EYdoX[start].reshape(1,1)

In [4]:
""" Run CBO """
# Run CBO iters
doXeval, EYdoXeval = causal_bayesian_optimization(X_train = doXtrain, 
                                                    y_train = EYdoXtrain, 
                                                    kernel = cbo_kernel, 
                                                    mean = mean,
                                                    X_test = doX, 
                                                    Y_test = EYdoX, 
                                                    n_iter = n_iter, 
                                                    update_hyperparameters = update_hyperparameters,
                                                    xi = xi, 
                                                    print_ = False, 
                                                    minimise = True,
                                                    noise_init = noise_init,
                                                    reg = cbo_reg)

In [5]:
X_train = doXtrain
y_train = EYdoXtrain 
kernel = cbo_kernel
mean = mean
X_test = doX
Y_test = EYdoX
n_iter = n_iter
update_hyperparameters = update_hyperparameters
xi = xi
print_ = False
minimise = True
noise_init = noise_init
reg = 1e-1

In [6]:
if X_train is None or y_train is None:
    X_train = torch.empty((0, X_test.shape[1]))  # Initialize empty tensor
    y_train = torch.empty((0, 1))  # Initialize empty tensor

# Initialize Gaussian Process model with the initial data and kernel
gp = GaussianProcess(X_train=X_train, y_train=y_train, kernel=kernel, noise_init = noise_init, mean = mean, nugget = reg)

# Initialize the maximum observed value
y_best = torch.max(y_train) if len(y_train) > 0 else 0
x_best = 0.5
i = 0
while x_best < 1.0 and i < n_iter:
    # Get the GP predictions for the test grid
    mu_s, cov_s = gp(X_test)
    sigma_s = torch.sqrt(torch.diag(cov_s).abs())

    # Calculate the Expected Improvement
    ei = expected_improvement(mu_s[:,0], sigma_s, y_best, xi = xi, minimise = minimise)

    # Find the next best point
    next_index = torch.argmax(ei)
    next_x = X_test[next_index]
    next_y = Y_test[next_index]

    # Update the training data with the new point
    X_train = torch.cat((X_train, next_x.unsqueeze(0)), dim=0)
    y_train = torch.cat((y_train, next_y.unsqueeze(0)), dim=0)

    # Update GP model with new data
    gp.X_train = X_train
    gp.y_train = y_train

    # Perform hyperparameter optimization if required
    if update_hyperparameters and (i + 1) % update_interval == 0:
        gp.optimize_hyperparameters(num_steps=hyperparam_steps, lr=lr, print_ = print_)

    # Update the best observed value
    if minimise:
        y_best = torch.min(y_train)
        x_best = X_train[torch.argmin(y_train)]
    else:
        y_best = torch.max(y_train)
        x_best = X_train[torch.argmax(y_train)]
        
    if print_:
        print(f"Iteration {i+1}: X = {X_train.min()}, Y = {y_train.min()}")

    i += 1
if i < n_iter:
    X_train = torch.cat((X_train,x_best*torch.ones((n_iter - i,1))), dim = 0)
    y_train = torch.cat((y_train,y_best*torch.ones((n_iter - i,1))), dim = 0)

In [7]:
doA = X_test
doA2 = []
average_doA = True
intervention_indices = [3]
diag = False
reg = 1e-2



if not model.exact:
    model.kernel_V.samples = samples

# Getting second set of doA if computing covariance
if doA2 ==[]:
    doA2 = doA
    
# Dimensions
n = len(Y)
n0, N, M, D = len(A), len(doA), len(doA2), A.shape[1]
Y = Y.reshape(n, 1)

# Expand doA with replacement using the selected columns from A
if average_doA:
    average_indices = [j for j in range(D) if j not in intervention_indices]
    expanded_doA = model.expand_doA(doA, A, intervention_indices)  # Shape (n0*N, D)
    expanded_doA2 = model.expand_doA(doA2, A, intervention_indices)  # Shape (n0*M, D)
else:
    assert doA.shape[1] == A.shape[1]
    assert doA2.shape[1] == A.shape[1]
    expanded_doA = doA  # No expansion if not averaging, expanded_doA: (N, D)
    expanded_doA2 = doA2  # No expansion if not averaging, expanded_doA: (M, D)

In [8]:
Vall = torch.row_stack((V[0], V[1]))
n1, n0 = len(Y), len(A)

# Getting kernel matrices
R_v1v0, R_v0v0, R_v1v1, R_vv1, R_vv0, R_vv = (
    model.kernel_V.get_gram(V[1], V[0]),
    model.kernel_V.get_gram(V[0], V[0]),
    model.kernel_V.get_gram(V[1], V[1]),
    model.kernel_V.get_gram(Vall, V[1]),
    model.kernel_V.get_gram(Vall, V[0]),
    model.kernel_V.get_gram(Vall, Vall)
)
K_v0v0, K_v1v1, K_vv1, Kvv0, K_vv = (
    model.kernel_V.get_gram_base(V[0], V[0]),
    model.kernel_V.get_gram_base(V[1], V[1]),
    model.kernel_V.get_gram_base(Vall, V[1]),
    model.kernel_V.get_gram_base(Vall, V[0]),
    model.kernel_V.get_gram_base(Vall, Vall)
)
K_aa, k_atest,k_atest2 = (
    model.kernel_A.get_gram(A, A),
    model.kernel_A.get_gram(expanded_doA, A),
    model.kernel_A.get_gram(expanded_doA2, A)
)
R_v1 = R_v1v1 + (model.noise_Y.exp() + reg) * torch.eye(n1)
K_v1 = K_v1v1 + (model.noise_Y.exp() + reg) * torch.eye(n1)
K_a = K_aa + (model.noise_feat.exp() + reg) * torch.eye(n0)

In [9]:
# Averaging out selected indices
if average_doA:
    k_atest = k_atest.reshape(n0,N,n0).mean(0) # (N,n0)
    k_atest2 = k_atest2.reshape(n0,M,n0).mean(0) # (M,n0)    

# Computing matrix vector products
Theta1 = torch.linalg.solve(K_vv+torch.eye(n1+n0)*reg, R_vv0) @ torch.linalg.solve(R_v0v0+torch.eye(n0)*reg, K_v0v0)  # (n1+n0, n1+n0)
Theta4 = torch.linalg.solve(K_vv+torch.eye(n1+n0)*reg, R_vv1) @ torch.linalg.solve(K_v1, Y)  # (n1+n0, 1)
Theta2a = Theta4.T @ R_vv @ Theta4  # (1, 1)
Theta2b = Theta4.T @ R_vv0 @ torch.linalg.solve(R_v0v0+torch.eye(n0)*reg, R_vv0.T) @ Theta4  # (1, 1)
Theta3a = torch.trace(torch.linalg.solve(K_vv+torch.eye(n1+n0)*reg, R_vv) @ torch.linalg.solve(K_vv+torch.eye(n1+n0)*reg, R_vv - R_vv1 @ torch.linalg.solve(R_v1, R_vv1.T)))  # scalar
Theta3b = torch.trace(torch.linalg.solve(K_vv+torch.eye(n1+n0)*reg, R_vv0) 
                      @ torch.linalg.solve(R_v0v0+torch.eye(n0)*reg, R_vv0.T) 
                      @ torch.linalg.solve(K_vv+torch.eye(n1+n0)*reg, R_vv - R_vv1 @ torch.linalg.solve(R_v1, R_vv1.T)))  # scalar 
E_a = torch.linalg.solve(K_a, k_atest.T)  # (n0, N)
E_a2 = torch.linalg.solve(K_a, k_atest2.T)  # (n0, M)
G_aa = E_a.T @ k_atest2.T

# Get gram matrix on doA,doA2
if average_doA: 
    
    # If averaging, define separate kernels for averaging and intervention indices
    kernel_Aavg = deepcopy(model.kernel_A)
    kernel_Aavg.lengthscale = model.kernel_A.lengthscale[average_indices]
    kernel_Aavg.scale = 1.0

    kernel_doA = deepcopy(model.kernel_A)
    kernel_doA.lengthscale = model.kernel_A.lengthscale[intervention_indices]
    
    # Construct average and interventional gram matrices
    Aavg = A[:,average_indices].reshape(n0,len(average_indices))
    K_Aavg = kernel_Aavg.get_gram(Aavg,Aavg).mean()
    K_doA = kernel_doA.get_gram(doA,doA2)

    K_atestatest = K_Aavg*K_doA
else:
    # otherwise, just construct N x M gram matrix on doA 
    K_atestatest = model.kernel_A.get_gram(doA,doA2)            
F_aa = K_atestatest  # (N,M)

# Final computations
V1 = E_a.T @ Theta1.T @ (R_vv - R_vv1 @ torch.linalg.solve(R_v1, R_vv1.T)) @ Theta1 @ E_a2  # (N,M)
V2 = Theta2a * F_aa - Theta2b * G_aa  # (N,M)
V3 = Theta3a * F_aa - Theta3b * G_aa  # (N,M)

if not diag:
    posterior_variance = V1 + V2 + V3 # (N,M)
else:
    posterior_variance = (V1 + V2 + V3).diag().reshape(N,1)  # (N,1)