In [1]:
from __future__ import annotations
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from Methods import EstimationMethods
from Gen_data import SimulationStudy
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
def get_split(simulation: pd.DataFrame) -> tuple[pd.DataFrame]:
    
    train_df, test_df = train_test_split(simulation, test_size=0.5, random_state=42)

    # Extract features and target variables for training
    X_train = train_df[[col for col in simulation.columns if col.startswith('X')]]
    T_train = train_df['T']
    Y_train = train_df['y']

    # Extract features and true CATE for testing
    X_test = test_df[[col for col in simulation.columns if col.startswith('X')]]
    T_test = test_df['T']
    y_test = test_df['y']

    true_cate_train = train_df[['CATE', 'T']]
    true_cate_test = test_df[['CATE', 'T']]

    return train_df, test_df, X_train, Y_train, T_train, X_test, T_test, y_test, true_cate_train, true_cate_test

In [3]:
def mse_analysis(ols: bool = False, t_learner: bool = False, cf_dml: bool = False, x_learner: bool = False, grf: bool = False) -> pd.DataFrame:
    dict = {}
    
    
    for p in tqdm([20, 30, 40, 50]):
        for n in [500, 1000, 2000, 3000, 4000]:
            
            li_train = []
            li_test = []

            for i in range(4):

                sim: SimulationStudy = SimulationStudy(p=p, mean_correlation=0.1, cor_variance=0.2, n=n, no_feat_cate=2, non_linear='linear')
                simulation = sim.create_dataset()
                train_df, test_df, X_train, Y_train, T_train, X_test, T_test, Y_test, train_cate_test, true_cate_test = get_split(simulation)
                estimators: EstimationMethods = EstimationMethods(X_train, T_train, Y_train, X_test, T_test, Y_test, train_cate_test, true_cate_test)

                if ols is True:
                    estimated_cate_train, estimated_cate_test, MSE_train, MSE_test = estimators.ols_estimator()
                elif t_learner is True:
                    estimated_cate_train, estimated_cate_test, MSE_test, MSE_train = estimators.TLearner_estimator()
                elif cf_dml is True:
                    estimated_cate_train, estimated_cate_test, MSE_test, MSE_train = estimators.CF_DML()
                elif x_learner is True:
                    estimated_cate_train, estimated_cate_test, MSE_test, MSE_train = estimators.XLearner_estimator()
                elif grf is True:
                    estimated_cate_train, estimated_cate_test, MSE_test, MSE_train = estimators.GRF_estimator()
                else:
                    print('Choose either ols, t_learner, cf_dml, x_learner or grf')
                
                li_train.append(MSE_train)
               # print(li_train)
                li_test.append(MSE_test)
                #print(li_test)

            mse_train_mean = np.mean(np.array(li_train))
            mse_test_mean = np.mean(np.array(li_test))

            

            key = f'{p}_{n}'
            dict[key] = mse_test_mean
            #print(dict)

            mse_df = pd.DataFrame()
            mse_df['n'] = dict.keys()
            mse_df['MSE Test'] = dict.values()


    return mse_df


In [4]:
mse_ols = mse_analysis(ols=True)

100%|██████████| 4/4 [00:24<00:00,  6.10s/it]


In [5]:
mse_ols

Unnamed: 0,n,MSE Test
0,20_500,1.005184
1,20_1000,0.608188
2,20_2000,0.448902
3,20_3000,0.353552
4,20_4000,0.344468
5,30_500,1.329103
6,30_1000,0.770422
7,30_2000,0.588961
8,30_3000,0.419197
9,30_4000,0.386664


In [6]:
mse_x = mse_analysis(x_learner=True)

  0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
mse_ols = mse_analysis(ols=True)
#mse_t = mse_analysis(t_learner=True)
#mse_cf_dml = mse_analysis(cf_dml=True)
#mse_x = mse_analysis(x_learner=True)
#mse_grf = mse_analysis(grf=True)


In [None]:
mse_x

In [None]:

# Use the n values to set the ticks
# Create the scatter plot
plt.style.use('seaborn-v0_8-ticks')
plt.scatter(mse_ols['n'], mse_ols['MSE Test'], alpha=0.5, color='red', label='OLS')
plt.scatter(mse_t['n'], mse_t['MSE Test'], alpha=0.5, color='green', label='T-Learner')
plt.scatter(mse_cf_dml['n'], mse_cf_dml['MSE Test'], alpha=0.5, color='blue', label='Causal Forest DML')
plt.scatter(mse_x['n'], mse_x['MSE Test'], alpha=0.5, color='purple', label='X-Learner')
plt.scatter(mse_grf['n'], mse_grf['MSE Test'], alpha=0.5, color='yellow', label='Generalized Random Forest')


# Set the x-ticks and y-ticks
#plt.xticks([])
# Set labels and legend
plt.xlabel('Observations')
plt.ylabel('MSE Test')
#plt.title('p = 20')
plt.legend()

# Show the plot
plt.show()

In [None]:
fig, axs = plt.subplots(2, 2)
axs[0, 0].plot(x, y)
axs[0, 0].set_title('Axis [0, 0]')
axs[0, 1].plot(x, y, 'tab:orange')
axs[0, 1].set_title('Axis [0, 1]')
axs[1, 0].plot(x, -y, 'tab:green')
axs[1, 0].set_title('Axis [1, 0]')
axs[1, 1].plot(x, -y, 'tab:red')
axs[1, 1].set_title('Axis [1, 1]')

for ax in axs.flat:
    ax.set(xlabel='x-label', ylabel='y-label')

# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axs.flat:
    ax.label_outer()

In [None]:
import seaborn as sns

# Apply the default theme
sns.set_theme(style="ticks")

sns.scatterplot(data=mse_ols, x="n", y="MSE Test")