In [1]:
library(grf)
library(tidyverse)
library(ggplot2)

Registered S3 methods overwritten by 'ggplot2':
  method         from 
  [.quosures     rlang
  c.quosures     rlang
  print.quosures rlang
Registered S3 method overwritten by 'rvest':
  method            from
  read_xml.response xml2
── Attaching packages ─────────────────────────────────────── tidyverse 1.2.1 ──
✔ ggplot2 3.1.1       ✔ purrr   0.3.2  
✔ tibble  2.1.1       ✔ dplyr   0.8.0.1
✔ tidyr   0.8.3       ✔ stringr 1.4.0  
✔ readr   1.3.1       ✔ forcats 0.4.0  
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()


### Causal Forests

In [2]:
fold = 'Fold3'

In [3]:
feature_column_names = read_csv(paste0('../build/simulation/', fold, '/examination_fc_names.csv'), col_types='ic')

“Missing column names filled in: 'X1' [1]”

In [5]:
feature_column_names

X1,feature_column_names
0,13
1,38
2,53
3,63
4,64
5,88
6,107
7,126
8,129
9,133


In [6]:
feature_col = feature_column_names$feature_column_names

In [7]:
propensity_test_features = read_csv(paste0('../build/simulation/', fold, '/examination_test.csv'))

Parsed with column specification:
cols(
  qid = col_double(),
  y = col_double(),
  `13` = col_double(),
  `38` = col_double(),
  `53` = col_double(),
  `63` = col_double(),
  `64` = col_double(),
  `88` = col_double(),
  `107` = col_double(),
  `126` = col_double(),
  `129` = col_double(),
  `133` = col_double()
)


In [8]:
ltr_test_features = read_csv(paste0('../build/simulation/', fold, '/examination_features.csv'))

Parsed with column specification:
cols(
  partition = col_character(),
  `13` = col_double(),
  `38` = col_double(),
  `53` = col_double(),
  `63` = col_double(),
  `64` = col_double(),
  `88` = col_double(),
  `107` = col_double(),
  `126` = col_double(),
  `129` = col_double(),
  `133` = col_double(),
  qd_id = col_double()
)


In [9]:
head(propensity_test_features)

qid,y,13,38,53,63,64,88,107,126,129,133
1,1,-1.0517464,-0.8043028,-0.5379134,-0.7348817,-0.8378642,-0.7383881,-1.2176843,-0.8195302,-0.1129179,-0.01501939
1,1,-0.5011712,1.3082802,1.7877163,1.6193548,2.3323488,0.7436748,-1.2176843,-0.3575459,-0.6929039,-0.01501939
1,1,0.5999794,1.3082802,0.8574616,0.6776574,0.430221,0.7436748,-1.2176843,1.5743887,-0.6706098,-0.01501939
1,0,0.875267,-0.8043028,-0.5379134,-0.7348817,-0.7842749,-0.7383881,-1.2176843,0.6084214,-0.6963567,-0.01501939
1,0,-1.0517464,-0.8043028,-0.5379134,-0.7348817,0.7760532,-0.7383881,-0.5841292,-0.5255402,-0.6954615,-0.01501939
1,0,-1.0517464,-0.8043028,-0.5379134,-0.7348817,0.5640294,-0.7383881,-1.2176843,-0.5255402,-0.6954615,-0.01501939


In [10]:
head(ltr_test_features)

partition,13,38,53,63,64,88,107,126,129,133,qd_id
train,0.875267,0.2519887,-0.5379134,-0.09281209,0.446275,0.0417442,-0.37168315,0.9864086,-0.1828691,-0.0150193933,0
train,0.5999794,0.2519887,-0.5379134,-0.02861219,0.6196983,0.0417442,0.02206703,0.3984285,-0.6789221,-0.0150193933,1
train,1.9764176,-0.8043028,-0.5379134,-0.73488173,0.5179472,-0.7383881,-0.33572059,1.3223972,0.2067863,-0.0150193933,2
train,-0.5011712,0.2519887,-0.5379134,0.44222946,0.6925879,0.0417442,-0.17473643,-0.6935345,2.0668253,0.0004752088,3
train,-1.0517464,0.2519887,-0.5379134,1.03079212,0.7472423,0.0417442,-0.17473643,-1.239516,2.0808496,-0.0150193933,4
train,0.5999794,1.3082802,-0.5379134,0.67765735,-0.3119385,0.8218767,-0.12738014,0.2304342,-0.696783,-0.0150193933,5


Compute nunique qids

In [11]:
sim_exp_train_vali_rankings = read_csv(paste0('../build/simulation/', fold, '/sim_exp_train_vali_rankings.csv'))
nqids = sim_exp_train_vali_rankings %>% .$qid %>% n_distinct()

Parsed with column specification:
cols(
  .default = col_double(),
  partition = col_character()
)
See spec(...) for full column specifications.


In [12]:
nqids

In [13]:
avg_clicks = c(5, 10, 25, 50)
nqueries = nqids %/% c(100, 10, 2, 1)

In [24]:
# cf_model$tuning.output$params$

In [14]:
for (avg_click in avg_clicks) {
    for (nq in nqueries) {
        dpath = paste0('../build/simulation/', fold, '/sim_exp_swap_causal_forests_train_clicks_', avg_click, '_', nq, '.csv')
        message('read data from ', dpath)
        data = read_csv(dpath)
        models = vector("list", 9)
        propensity_test_results = vector("list", 9)
        ltr_test_results = vector("list", 9)
        for (i in 2:10) {
            message('sess ', avg_click, ' nsample ', nq, ' train ', i, ' model')
            # slice data
            train_pairs = data %>% filter(treatment_group == i)
            # train outcome model
            Y_model = regression_forest(train_pairs[feature_col], train_pairs$click, tune.parameters = "all")
            Y_pred = predict(Y_model)$predictions
            # train treatment model
            W_model = regression_forest(train_pairs[feature_col], train_pairs$treatment, tune.parameters = "all")
            W_pred = predict(W_model)$predictions
            # train causal model
            cf_model = causal_forest(train_pairs[feature_col], train_pairs$click, train_pairs$treatment, ci.group.size=1,
                                     Y.hat=Y_pred, W.hat=W_pred, tune.parameters = "all")
            # predict tau on test data
            propensity_test_tau_pred = predict(cf_model, propensity_test_features[feature_col])$predictions
            propensity_test_result = propensity_test_features %>% mutate(tau_pred=propensity_test_tau_pred)
            # predict tau on ltr test data
            ltr_test_tau_pred = predict(cf_model, ltr_test_features[feature_col])$predictions
            ltr_test_result = ltr_test_features %>% mutate(tau_pred=ltr_test_tau_pred)
            
            # append results
            idx = i - 1
            models[[idx]] = cf_model
            propensity_test_results[[idx]] = propensity_test_result
            ltr_test_results[[idx]] = ltr_test_result
        }
        propensity_test_results_binded = bind_rows(propensity_test_results, .id='treatment_rank')
        saveto = paste0('../build/simulation/', fold, '/sim_exp_swap_causal_forests_propensity_test_results_', avg_click, '_', nq, '.csv')
        message('save propensity results to ', saveto)
        propensity_test_results_binded %>% write_csv(saveto)
        # ltr results
        ltr_test_results_binded = bind_rows(ltr_test_results, .id='treatment_rank')
        saveto = paste0('../build/simulation/', fold, '/sim_exp_swap_causal_forests_ltr_test_results_', avg_click, '_', nq, '.csv')
        message('save ltr results to ', saveto)
        ltr_test_results_binded %>% write_csv(saveto)
    }
}

read data from ../build/simulation/Fold3/sim_exp_swap_causal_forests_train_clicks_5_159.csv
Parsed with column specification:
cols(
  partition = col_character(),
  qd_id = col_double(),
  swapped_rank = col_double(),
  click = col_double(),
  treatment = col_double(),
  true_click_probability = col_double(),
  true_propensity = col_double(),
  treatment_group = col_double(),
  `13` = col_double(),
  `38` = col_double(),
  `53` = col_double(),
  `63` = col_double(),
  `64` = col_double(),
  `88` = col_double(),
  `107` = col_double(),
  `126` = col_double(),
  `129` = col_double(),
  `133` = col_double()
)
sess 5 nsample 159 train 2 model
sess 5 nsample 159 train 3 model
sess 5 nsample 159 train 4 model
sess 5 nsample 159 train 5 model
sess 5 nsample 159 train 6 model
sess 5 nsample 159 train 7 model
sess 5 nsample 159 train 8 model
sess 5 nsample 159 train 9 model
sess 5 nsample 159 train 10 model
save propensity results to ../build/simulation/Fold3/sim_exp_swap_causal_forests_propens