In [1]:
using CausalForest
using RCall
using StatsBase
using Statistics

In [27]:
@rlibrary grf
R"""
set.seed(123)
n <- c(800, 1600)
p <- c(10, 20)
dgp <- c("aw2", "aw1", "aw3")
grid <- expand.grid(n = n, p = p, dgp = dgp, stringsAsFactors = FALSE)
N <- nrow(grid)
res <- data.frame()
"""
@rget N
for i in 1:N
    @rput i
    R"""
    n <- grid$n[i]
    p <- grid$p[i]
    dgp <- grid$dgp[i]
    """
    errors_grf = zeros(60)
    errors_hterf = zeros(60)
    for j in 1:60
        R"""
        data <- grf::generate_causal_data(n = n, p = p, dgp = dgp, sigma.tau = 1)
        data.test <- grf::generate_causal_data(n = 1000, p = p, dgp = dgp, sigma.tau = 1)
        X = data$X
        Y = data$Y
        W = data$W
        Xtest = data.test$X
        tau = data.test$tau
        cf <- grf::causal_forest(data$X, data$Y, data$W, num.trees=500, tune.num.trees=500, sample.fraction=0.7, ci.group.size=1)
        tau.hat <- predict(cf, data.test$X)$predictions
        mse = mean((tau.hat - data.test$tau)^2)
        """
        @rget Y X W Xtest tau p mse
        cf = build_forest(false, true, Y, W, X, true, p, 500, 500)
        tauhat = apply_forest(cf, Xtest)
        errors_hterf[j] = msd(tau, tauhat)
        errors_grf[j] = mse
    end
    err_hterf = mean(errors_hterf)
    err_grf = mean(errors_grf)
    @rput err_grf err_hterf
    R"""
    df = data.frame(dgp = dgp, p = p, n = n, C.GRF = err_grf * 10, HTERF = err_hterf * 10)
    res = rbind(res, df)
    """
end
@rget res
print(res)

[1m12×5 DataFrame[0m
[1m Row [0m│[1m dgp    [0m[1m p       [0m[1m n       [0m[1m C_GRF     [0m[1m HTERF     [0m
[1m     [0m│[90m String [0m[90m Float64 [0m[90m Float64 [0m[90m Float64   [0m[90m Float64   [0m
─────┼────────────────────────────────────────────────
   1 │ aw2        10.0    800.0  1.00637    0.843369
   2 │ aw2        10.0   1600.0  0.584201   0.504359
   3 │ aw2        20.0    800.0  1.06687    0.91687
   4 │ aw2        20.0   1600.0  0.649022   0.549425
   5 │ aw1        10.0    800.0  0.136139   0.146508
   6 │ aw1        10.0   1600.0  0.0906345  0.0866225
   7 │ aw1        20.0    800.0  0.103666   0.109072
   8 │ aw1        20.0   1600.0  0.0753912  0.075559
   9 │ aw3        10.0    800.0  1.15571    1.12422
  10 │ aw3        10.0   1600.0  0.692547   0.629847
  11 │ aw3        20.0    800.0  1.28517    1.22857
  12 │ aw3        20.0   1600.0  0.740011   0.630446

In [7]:
@rlibrary grf
R"""
set.seed(123)
n <- c(800, 1600)
p <- c(10, 20)
dgp <- c("aw2", "aw1", "aw3")
grid <- expand.grid(n = n, p = p, dgp = dgp, stringsAsFactors = FALSE)
N <- nrow(grid)
resu <- data.frame()
"""
@rget N
for i in 1:N
    @rput i
    R"""
    n <- grid$n[i]
    p <- grid$p[i]
    dgp <- grid$dgp[i]
    """
    g1 = zeros(60)
    g2 = zeros(60)
    g3 = zeros(60)
    h1 = zeros(60)
    h2 = zeros(60)
    h3 = zeros(60)
    for j in 1:60
        R"""
        data <- grf::generate_causal_data(n = n, p = p, dgp = dgp, sigma.tau = 1)
        X = data$X
        Y = data$Y
        W = data$W
        cf0 <- grf::causal_forest(data$X, data$Y, data$W, num.trees=500, tune.num.trees=500, sample.fraction=0.7, ci.group.size=1)
        imp0 = grf::variable_importance(cf0)
        """
        @rget Y X W p imp0
        g1[j] = imp0[1]
        g2[j] = imp0[2]
        g3[j] = imp0[3]
        cf1 = build_forest(false, true, Y, W, X, true, p, 500, 500)
        imp1 = importance(cf1)
        h1[j] = imp1[1]
        h2[j] = imp1[2]
        h3[j] = imp1[3]
    end
    grf_1 = mean(g1)
    grf_2 = mean(g2)
    grf_3 = mean(g3)
    hterf_1 = mean(h1)
    hterf_2 = mean(h2)
    hterf_3 = mean(h3)
    @rput grf_1 grf_2 grf_3 hterf_1 hterf_2 hterf_3
    R"""
    df = data.frame(dgp = dgp, p = p, n = n, GRF_1 = grf_1, GRF_2 = grf_2, GRF_3 = grf_3, HTERF_1 = hterf_1, HTERF_2 = hterf_2, HTERF_3 = hterf_3)
    resu = rbind(resu, df)
    """
end
@rget resu
print(resu)

[1m12×9 DataFrame[0m
[1m Row [0m│[1m dgp    [0m[1m p       [0m[1m n       [0m[1m GRF_1     [0m[1m GRF_2     [0m[1m GRF_3      [0m[1m HTERF_1   [0m[1m HTERF_2   [0m[1m HTERF_3    [0m
[1m     [0m│[90m String [0m[90m Float64 [0m[90m Float64 [0m[90m Float64   [0m[90m Float64   [0m[90m Float64    [0m[90m Float64   [0m[90m Float64   [0m[90m Float64    [0m
─────┼──────────────────────────────────────────────────────────────────────────────────────────────
   1 │ aw2        10.0    800.0  0.416431   0.41021    0.022094    0.446331   0.416176   0.0171308
   2 │ aw2        10.0   1600.0  0.413139   0.431188   0.0198204   0.440149   0.447187   0.0147454
   3 │ aw2        20.0    800.0  0.397831   0.413228   0.0108249   0.410206   0.426739   0.0090882
   4 │ aw2        20.0   1600.0  0.415919   0.419826   0.00907746  0.434067   0.433229   0.00742244
   5 │ aw1        10.0    800.0  0.0866326  0.0971825  0.100964    0.153515   0.0939358  0.0930449
   6 │ aw