In [1]:
from __future__ import annotations
import Analysis_new as analysis
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import pandas as pd

In [2]:
p_list = [20, 30, 40, 50]
n_list= [500, 1000, 2000, 3000, 4000]

In [3]:
p_dict_mse= {}
p_dict_coverage = {}
estimator_list = ['OLS', 'T-Learner', 'HRF', 'CF DML', 'GRF']

for p in tqdm(p_list):
    
    estimator_dict_mse = {}
    estimator_dict_coverage = {}

    for estimator in estimator_list:
            print(f'Current p is {p} and current estimator is {estimator}')
            key_est = f'{estimator}'

            if estimator=='OLS' or estimator == 'T-Learner':
                mse_simulation = analysis.mse_ci_analysis(p=p, mean_correlation=0.1, n_list=n_list, estimator=estimator, function='quadratic')

            else:
                mse_simulation, coverage_dict = analysis.mse_ci_analysis(p=p, mean_correlation=0.1, n_list=n_list, estimator=estimator)
                estimator_dict_coverage[key_est] = coverage_dict
                print(estimator_dict_coverage)

            estimator_dict_mse[key_est] = mse_simulation

    key_p = f'{p}'
    p_dict_mse[key_p] = estimator_dict_mse
    p_dict_coverage[key_p] = estimator_dict_coverage

     



  0%|          | 0/4 [00:00<?, ?it/s]

Current p is 20 and current estimator is OLS
500
1000
2000
3000
4000
Current p is 20 and current estimator is T-Learner
500
1000
2000
3000
4000
Current p is 20 and current estimator is HRF
500
1000
2000
3000
4000
{'HRF': {'500': 0.188, '1000': 0.252, '2000': 0.321, '3000': 0.244, '4000': 0.2525}}
Current p is 20 and current estimator is CF DML
500
1000
2000
3000
4000
{'HRF': {'500': 0.188, '1000': 0.252, '2000': 0.321, '3000': 0.244, '4000': 0.2525}, 'CF DML': {'500': 0.752, '1000': 0.746, '2000': 0.721, '3000': 0.78, '4000': 0.8415}}
Current p is 20 and current estimator is GRF
500
1000


  0%|          | 0/4 [10:50<?, ?it/s]


KeyboardInterrupt: 

In [None]:
rows = []
for p_key, est_dict in p_dict_coverage.items():
    for est_key, n_dict in est_dict.items():
        for n_key, coverage_rate in n_dict.items():
            rows.append({'p': p_key, 'est': est_key, 'n': n_key, 'coverage_rates': coverage_rate})

# Creating the DataFrame
coverage_df = pd.DataFrame(rows, columns=['p', 'est', 'n', 'coverage_rates'])
            
coverage_df

In [None]:
mse_20_features_linear = p_dict_mse['20']
mse_30_features_linear = p_dict_mse['30']
mse_40_features_linear = p_dict_mse['40']
mse_50_features_linear = p_dict_mse['50']

In [None]:
def plot_mse_analysis_test(mse_list: list, p_list: list):
      
    plt.style.use('seaborn-v0_8')

    mse_list = enumerate(mse_list)
    p_list = enumerate(p_list)

    fig = plt.figure(figsize=(10,10))
    gs = fig.add_gridspec(nrows=2, ncols=2, hspace=0.2)
    axs = gs.subplots(sharex=True, sharey=True)
    
    for j in range(0,2): 
        for i in range(0, 2):

            colors = enumerate(['red', 'green', 'blue', 'yellow'])
            mse_dict = next(mse_list)[1]
            
            for est, mse_df in mse_dict.items():
                color = next(colors)
                
                axs[j, i].scatter(mse_df['n'], mse_df['MSE Test'], s=20, marker='o', color = color[1], label=est)

            p = next(p_list)[1]
            axs[j, i].set_title(f'p $=$ {p}')


        if j == 0 and i==1:
            axs[j, i].legend()
        else:
            continue
    # Create a single legend for the entire figure

    #fig.suptitle('Influence of Individual Feature Values on the CATE Function', y=0.96)
    fig.text(0.5, 0.04, 'No. of Observations', ha='center')
    fig.text(0.04, 0.5, 'MSE Test', va='center', rotation='vertical')

    # Hide x labels and tick labels for top plots and y ticks for right plots.
    #for ax in axs.flat:
     #   ax.label_outer()

In [None]:
mse_list = [mse_20_features_linear, mse_30_features_linear, mse_40_features_linear, mse_50_features_linear]
plot_mse_analysis_test(mse_list, p_list)

In [None]:
with open('mse_20_features_low.pkl', 'wb') as pickle_file:
    pickle.dump(mse_20_features_linear, pickle_file)

with open('mse_30_features_low.pkl', 'wb') as pickle_file:
    pickle.dump(mse_30_features_linear, pickle_file)

with open('mse_40_features_low.pkl', 'wb') as pickle_file:
    pickle.dump(mse_40_features_linear, pickle_file)

with open('mse_50_features_low.pkl', 'wb') as pickle_file:
    pickle.dump(mse_50_features_linear, pickle_file)