In [1]:
using CausalForest
using RCall
using StatsBase
using Statistics
using Random
using Distributions

In [2]:
function get_all_nodes_in_tree!(
    tree ,
    depth = 3,
    result  = []
    )
    if hasproperty(tree, :featid) && depth > 0
        push!(result, tree.featid)
        get_all_nodes_in_tree!(tree.left, depth-1, result)
        get_all_nodes_in_tree!(tree.right, depth-1, result)
    end
    return result
end

function get_freq(forest, depth=3) 
    ensemble = forest.trees
    n_trees = length(ensemble)
    res = []
    for i = 1:n_trees
        append!(res, get_all_nodes_in_tree!(forest.trees[i].tree, depth))
    end
    return proportionmap(res)
end

get_freq (generic function with 2 methods)

In [3]:
Random.seed!(123);
n, m = 10^4, 10;
@rlibrary grf
R"""
set.seed(123)
res <- data.frame()
"""
errors_grf = zeros(60)
errors_hterf = zeros(60)
errors_OLS = zeros(60)
hterf1 = zeros(60)
hterf2 = zeros(60)
hterf3 = zeros(60)
hterf4 = zeros(60)
grf1 = zeros(60)
grf2 = zeros(60)
grf3 = zeros(60)
grf4 = zeros(60)
ols1 = zeros(60)
ols2 = zeros(60)
ols3 = zeros(60)
ols4 = zeros(60)
for j in 1:60
    u = Uniform(0,1);
    features = rand(u, (n, m));
    X = features;
    b = Bernoulli();
    T = convert(Vector{Int64},rand(b, n));
    Y = sin.(features*[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]).*T  + features*[0, 2, 3, 0, 0, 0, 0, 0, 0, 0];
    Xtest = rand(u, (n, m));
    tau = sin.(Xtest*[1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    @rput X T Y Xtest tau
    R"""
    cf <- grf::causal_forest(X, Y, T, num.trees=500, tune.num.trees=500, sample.fraction=0.7, ci.group.size=1)
    tau.hat <- predict(cf, Xtest)$predictions
    mse = sqrt(mean((tau.hat - tau)^2))
    freq = grf::split_frequencies(cf,20)
    g_1 = sum(freq[1:3,1])/sum(freq[1:3,])
    g_2 = sum(freq[1:5,1])/sum(freq[1:5,])
    g_3 = sum(freq[1:10,1])/sum(freq[1:10,])
    g_4 = grf::variable_importance(cf)[1]
    """
    @rget mse g_1 g_2 g_3 g_4
    cf = build_forest(false, true, Y, T, X, true, m, 500, 500)
    tauhat = apply_forest(cf, Xtest)
    cf1 = build_forest_ols(false, true, Y, T, X, true, m, 500, 500)
    tauhat1 = apply_forest_ols(cf1, Xtest)
    hterf1[j] = get_freq(cf, 3)[1]
    hterf2[j] = get_freq(cf, 5)[1]
    hterf3[j] = get_freq(cf, 10)[1]
    hterf4[j] = importance(cf)[1]
    errors_hterf[j] = rmsd(tau, tauhat)
    grf1[j] = g_1
    grf2[j] = g_2
    grf3[j] = g_3
    grf4[j] = g_4
    errors_grf[j] = mse
    ols1[j] = get_freq(cf1, 3)[1]
    ols2[j] = get_freq(cf1, 5)[1]
    ols3[j] = get_freq(cf1, 10)[1]
    ols4[j] = importance(cf1)[1]
    errors_OLS[j] = rmsd(tau, tauhat1)
end
err_hterf = mean(errors_hterf)
err_grf = mean(errors_grf)
err_ols = mean(errors_OLS)
grf_1 = mean(grf1)
grf_2 = mean(grf2)
grf_3 = mean(grf3)
grf_4 = mean(grf4)
hterf_1 = mean(hterf1)
hterf_2 = mean(hterf2)
hterf_3 = mean(hterf3)
hterf_4 = mean(hterf4)
ols_1 = mean(ols1)
ols_2 = mean(ols2)
ols_3 = mean(ols3)
ols_4 = mean(ols4)
@rput err_grf err_hterf err_ols grf_1 grf_2 grf_3 grf_4 hterf_1 hterf_2 hterf_3 hterf_4 ols_1 ols_2 ols_3 ols_4
R"""
dfgrf = data.frame(method = "GRF", RMSE = err_grf, dep3 = grf_1, dep5 = grf_2, dep10 = grf_3, imp = grf_4)
dfhterf = data.frame(method = "HTERF", RMSE = err_hterf, dep3 = hterf_1, dep5 = hterf_2, dep10 = hterf_3, imp = hterf_4)
dfols = data.frame(method = "HTERF-OLS", RMSE = err_ols, dep3 = ols_1, dep5 = ols_2, dep10 = ols_3, imp = ols_4)
res = rbind(res, dfgrf, dfhterf, dfols)

"""
@rget res
print(res)

[1m3×6 DataFrame[0m
[1m Row [0m│[1m method    [0m[1m RMSE       [0m[1m dep3     [0m[1m dep5     [0m[1m dep10    [0m[1m imp     [0m
[1m     [0m│[90m String    [0m[90m Float64    [0m[90m Float64  [0m[90m Float64  [0m[90m Float64  [0m[90m Float64 [0m
─────┼──────────────────────────────────────────────────────────────
   1 │ GRF        0.0115552   0.875165  0.51386   0.171388  0.86674
   2 │ HTERF      0.00434751  1.0       0.953929  0.238555  1.0
   3 │ HTERF-OLS  0.00148948  1.0       1.0       0.944206  1.0