In [1]:
import numpy as np
from scipy.stats import norm, cauchy
from scipy.optimize import minimize

class CATE_VariationalBayes:
    def __init__(self, variational_family='linear', prior='normal', variational_params=None, prior_params=None):
        self.variational_family = variational_family
        self.prior = prior
        self.variational_params = variational_params
        self.prior_params = prior_params

    def _variational_objective(self, params, X, T, Y):
        if self.variational_family == 'linear':
            alpha, beta, gamma = params
            Y_pred = X * (T * gamma + beta) + alpha
        else:
            raise ValueError("Invalid variational family. Currently supported families: 'linear'.")

        # Variational distributions
        q_alpha = norm(loc=alpha, scale=self.variational_params.get('alpha_std', 1.0))
        q_beta = norm(loc=beta, scale=self.variational_params.get('beta_std', 1.0))
        q_gamma = norm(loc=gamma, scale=self.variational_params.get('gamma_std', 1.0))

        # Model likelihood
        Y_pred = X * (T * q_gamma.rvs() + q_beta.rvs()) + q_alpha.rvs()

        # Variational objective to minimize (negative log-likelihood)
        return -np.sum(norm.logpdf(Y, loc=Y_pred))

    def fit(self, X, T, Y):
        # Initialize variational parameters
        if self.variational_family == 'linear':
            init_params = [0.0, 0.0, 0.0]
        else:
            raise ValueError("Invalid variational family. Currently supported families: 'linear'.")

        # Optimize variational parameters
        result = minimize(self._variational_objective, init_params, args=(X, T, Y), method='L-BFGS-B')
        optimized_params = result.x

        # Set optimized parameters based on variational family
        if self.variational_family == 'linear':
            self.alpha, self.beta, self.gamma = optimized_params

        return self

    def predict(self, X, T):
        if self.variational_family == 'linear':
            Y_pred = X * (T * self.gamma + self.beta) + self.alpha
        else:
            raise ValueError("Invalid variational family. Currently supported families: 'linear'.")

        return Y_pred

# Generate synthetic data
np.random.seed(42)
n_samples = 1000
X_host = np.random.normal(2, 1, n_samples)
T_host = np.random.rand(n_samples)
Y_host = 1 + 2 * X_host + 3 * X_host * T_host + np.random.normal(0, 1, n_samples)

X_cand = np.random.normal(2, 1, n_samples)
T_cand = np.random.rand(n_samples)
Y_cand = 1 + 2 * X_cand + 3 * X_cand * T_cand + np.random.normal(0, 1, n_samples)

# Apply sigmoid to treatment assignments
T_host_binary = 1 / (1 + np.exp(-T_host))
T_cand_binary = 1 / (1 + np.exp(-T_cand))

# Define variational family parameters
linear_variational_params = {'alpha_std': 1.0, 'beta_std': 1.0, 'gamma_std': 1.0}

# Define prior parameters
normal_prior_params = {'alpha_std': 1.0, 'beta_std': 1.0, 'gamma_std': 1.0}

# Initialize and fit the Variational Bayes model
cate_variational_bayes = CATE_VariationalBayes(variational_family='linear', prior='normal',
                                              variational_params=linear_variational_params,
                                              prior_params=normal_prior_params)
cate_variational_bayes.fit(X_cand, T_cand_binary, Y_cand)

# Predict using the Variational Bayes model
Y_cand_pred = cate_variational_bayes.predict(X_cand, T_cand_binary)
Y_host_pred = cate_variational_bayes.predict(X_host, T_host_binary)

# Print parameters
print("Variational Bayes Parameters:")
print("Alpha:", cate_variational_bayes.alpha)
print("Beta:", cate_variational_bayes.beta)
print("Gamma:", cate_variational_bayes.gamma)

# Evaluate the model on the host data
mse_host = np.mean((Y_host - Y_host_pred)**2)
print("Mean Squared Error on Host Data:", mse_host)

# Evaluate the model on the candidate data
mse_cand = np.mean((Y_cand - Y_cand_pred)**2)
print("Mean Squared Error on Candidate Data:", mse_cand)


Variational Bayes Parameters:
Alpha: -4.714382067306908e-07
Beta: -7.212243949115454e-08
Gamma: 5.112424952775461e-07
Mean Squared Error on Host Data: 83.20436392321801
Mean Squared Error on Candidate Data: 80.74271065945769
