# 02 Tune CATE estimators

In this notebook, we will tune the hyperparemeters for our CATE methods.

### Contents:
1. Description of estimator library  
2. Setting up  
3. Actual tuning

## 1. Description of estimator library

We will consider the following estimators:

1. S-learner:  
A. RF  
B. XGB
2. T-learner:  
A. Lasso  
B. logistic  
C. RF  
D. XGB
3. X-learner:  
A. Outcome_learner: lasso, effect_learner: lasso  
B. Outcome_learner: logistic, effect_learner: lasso  
C. Outcome_learner: RF, effect_learner: lasso  
D. Outcome_learner: XGB, effect_learner: lasso
4. R-learner:  
A. Outcome_learner: lasso, effect_learner: lasso  
B. Outcome_learner: lasso, effect_learner: XGB  
C. Outcome_learner: RF, effect_learner: lasso  
E. Outcome_learner: RF, effect_learner: RF

R-learner base learner types were chosen independently at random from {lasso, RF, XGB}

We will tune the models for the 4 outcomes: GI, cardio, hypertension, severe GI, without perturbations.

## 2. Setting up

In [6]:
# Standard imports
import numpy as np
import pandas as pd
import sys
import copy
import random
import joblib

# Import sklearn methods
from xgboost import XGBRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Lasso, LogisticRegression
from sklearn.model_selection import StratifiedKFold

# Import own methods
from methods.data_processing import prepare_df, separate_vars
from methods.cate_estimator_wrappers import (SLearnerWrapper, TLearnerWrapper,
                                             XLearnerWrapper, RLearnerWrapper,
                                             CausalTreeWrapper, CausalForestWrapper)
from methods.cate_estimator_validation import make_estimator_library

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### 2.1. Defining some globals

In [2]:
BASEDIR = "/home/ubuntu/vioxx_vigor/"
DATA_PATH = BASEDIR + "data/01_transformed/20jul2020/"
cv = StratifiedKFold(n_splits = 4, shuffle = True, random_state = 405)
features = ["male", "white", "US", "PUB_PRIOR_HISTORY", 'elderly_65_adj',
            "HYPGRP", "DBTGRP", "CHLGRP", "ASPFDA", "ASCGRP", 'obese',
            "PSTRDS", "PNSAIDS", "PNAPRXN", "smoker", "drinker"]
treatment_var = "TREATED"

### 2.2. Defining parameter grids and base learners

In [3]:
lasso_grid = {"alpha" : np.logspace(-5,5,500) }
logistic_grid = {"penalty" : ["l1", "l2"], 
                 "C" : np.logspace(-5,5,500)}
rf_grid = {'min_samples_leaf': [10,50,100,200,300,400,500],
           'max_depth': [3,4,5,6,7,8],
           'bootstrap': [False, True],
           'n_estimators': [100,200,300,400,500]}
xgb_grid = {'max_depth': [5,6,7,8,9,10,11,12],
            'gamma': [0, 0.1, 0.2, 0.3, 0.4],
            'subsample': [0.7, 0.75, 0.8,1],
            'reg_lambda': [100,150,200,250, 300, 350, 400],
            'n_estimators': [200, 300, 400, 500, 600, 700, 800, 900, 1000],
            'min_child_weight': [4,5,6,7,8,9,10],
            'learning_rate': [0.1,0.125,0.15,0.175,0.2,0.225,0.25]}

base_learners = {"lasso" : Lasso(),
                 "logistic" : LogisticRegression(solver = "liblinear", 
                                                 max_iter = 500),
                 "rf" : RandomForestRegressor(),
                 "xgb" : XGBRegressor(objective = "reg:squarederror")}
param_grids = {"lasso" : lasso_grid,
               "logistic" : logistic_grid,
               "rf" : rf_grid,
               "xgb" : xgb_grid}

### 2.3. Select 4 base learner type combinations for R-learner.
##### (don't need to run this cell again)

In [None]:
r_learners_all = {}
for name_1 in ["lasso", "rf", "xgb"]:
    for name_2 in ["lasso", "rf", "xgb"]:
        RLearnerWrapper(X, t, y, cv,
                        outcome_learner = base_learners[name_1],
                        effect_learner = base_learners[name_2],
                        outcome_param_grid = param_grids[name_1],
                        effect_param_grid = param_grids[name_2])
random.seed(405)
r_learner_names = random.sample(list(r_methods_all.keys()), 4)
# r_learner_names = ['r_lassolasso', 'r_lassoxgb', 'r_rfrf', 'r_lassorf']

## 3. Actual tuning

In [None]:
results = {}
for response_var in ["GI", "cfd_cardio", "all_hypertension", "severe_GI"]:
    print("=== Getting results for " + response_var + " ===")
    DIR_PATH = DATA_PATH + response_var + "/"
    trainval_df = prepare_df(DIR_PATH + "trainval_data.csv", 
                             features, response_var, treatment_var)
    X, t, y = separate_vars(trainval_df, response_var, treatment_var)
    results[response_var] = make_estimator_library(X, t, y, cv, 
                                                   base_learners, param_grids, 
                                                   n_iter = 200)
    tuned_params = {}
    for estimator_name, estimator in results[response_var].items():
        tuned_params[estimator_name] = estimator.get_params()
    joblib.dump(tuned_params, "data_files/" + 
                f"{response_var}/{response_var}_tuned_params")

=== Getting results for GI ===
Tuning s_rf
Tuning s_xgb
Tuning t_lasso
Tuning t_logistic
Tuning t_rf
Tuning t_xgb
Tuning x_lasso
Tuning x_logistic
Tuning x_rf
Tuning x_xgb
Tuning r_lassolasso
Tuning r_rfrf
Tuning r_lassorf
Tuning r_lassoxgb
=== Getting results for cfd_cardio ===
Tuning s_rf
Tuning s_xgb
Tuning t_lasso
Tuning t_logistic
Tuning t_rf
Tuning t_xgb
Tuning x_lasso
Tuning x_logistic
Tuning x_rf
Tuning x_xgb
Tuning r_lassolasso
Tuning r_rfrf
Tuning r_lassorf
