In [1]:
import numpy as np
import warnings
from matplotlib import pyplot as plt
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
# Simulation 1: The unbalanced case with a simple CATE
N = [300, 1000, 3000, 6000, 10000]
num_experiments = 10

# # Simulation setup
e = lambda x: 0.1
d = 20
beta = np.random.uniform(low=-5, high=5, size=d)
mu0 = lambda x: np.dot(x, beta) + 5 * (x[0] > 0.5)
mu1 = lambda x: mu0(x) + 8 * (x[1] > 0.1)

In [3]:
from Simulation.Perform_experiments_econML import iterate_experiments_econ

# LGBM Regressor as model
model = "LGBM"
x_mse_total, x_mse_total_econ = iterate_experiments_econ(N, num_experiments, e, d, mu0, mu1, model)
x_mse_lgbm = np.mean(x_mse_total, axis=0)
x_mse_lgbm_econ = np.mean(x_mse_total_econ, axis=0)

In [4]:
# Plotting the average MSE for different num of samples
plt.plot(N, x_mse_lgbm, marker='o', label='X-learner')
plt.plot(N, x_mse_lgbm_econ, marker='o', label='X-learner EconML')
plt.xlabel('Number of samples')
plt.ylabel('MSE')
plt.title('Simulation 1: LGBM Regressor, comparing custom XLearner with EconML XLearner')
plt.legend()
plt.show()

In [5]:
print("LGBM:")
print("X-learner: ")
print(x_mse_lgbm)
print("X-learner econML:")
print(x_mse_lgbm_econ)

In [6]:
from Simulation.Perform_experiments_econML import iterate_experiments_econ

# GradientBoostingRegressor as model
model = "GradientBoostingRegressor"
x_mse_total, x_mse_total_econ = iterate_experiments_econ(N, num_experiments, e, d, mu0, mu1, model)
x_mse_gbr = np.mean(x_mse_total, axis=0)
x_mse_gbr_econ = np.mean(x_mse_total_econ, axis=0)

In [7]:
# Plotting the average MSE for different num of samples
plt.plot(N, x_mse_gbr, marker='o', label='X-learner')
plt.plot(N, x_mse_gbr_econ, marker='o', label='X-learner EconML')
plt.xlabel('Number of samples')
plt.ylabel('MSE')
plt.title('Simulation 1: Gradient Boosting Regressor, comparing custom XLearner with EconML XLearner')
plt.legend()
plt.show()

In [9]:
print("Gradient Boosting Regressor:")
print("X-learner: ")
print(x_mse_gbr)
print("X-learner econML:")
print(x_mse_gbr_econ)