In [1]:
using CausalForest
using RCall
using StatsBase
using Statistics
using Random
using Distributions

In [2]:
function get_all_nodes_in_tree!(
    tree ,
    depth = 3,
    result  = []
    )
    if hasproperty(tree, :featid) && depth > 0
        push!(result, tree.featid)
        get_all_nodes_in_tree!(tree.left, depth-1, result)
        get_all_nodes_in_tree!(tree.right, depth-1, result)
    end
    return result
end

function get_freq(forest, depth=3) 
    ensemble = forest.trees
    n_trees = length(ensemble)
    res = []
    for i = 1:n_trees
        append!(res, get_all_nodes_in_tree!(forest.trees[i].tree, depth))
    end
    return proportionmap(res)
end

get_freq (generic function with 2 methods)

In [19]:
Random.seed!(123);
n, m = 10^4, 3;
@rlibrary grf
R"""
set.seed(123)
res <- data.frame()
"""
errors_grf = zeros(60)
errors_hterf = zeros(60)
hterf1 = zeros(60)
hterf2 = zeros(60)
hterf3 = zeros(60)
grf1 = zeros(60)
grf2 = zeros(60)
grf3 = zeros(60)
for j in 1:60
    u = Uniform(-pi,pi);
    features = rand(u, (n, m));
    X = features;
    b = Bernoulli();
    T = convert(Vector{Int64},rand(b, n));
    Y = sin.(features[:,1])+7*(sin.(features[:,2])).^2+0.3*(features[:,3].^4).*sin.(features[:,1]).*T
    Xtest = rand(u, (n, m));
    tau = 0.3*(Xtest[:,3].^4).*sin.(Xtest[:,1])
    @rput X T Y Xtest tau
    R"""
    cf <- grf::causal_forest(X, Y, T, num.trees=500, tune.num.trees=500, sample.fraction=0.7, ci.group.size=1)
    tau.hat <- predict(cf, Xtest)$predictions
    mse = sqrt(mean((tau.hat - tau)^2))
    g_1 = grf::variable_importance(cf)[1]
    g_2 = grf::variable_importance(cf)[2]
    g_3 = grf::variable_importance(cf)[3]
    """
    @rget mse g_1 g_2 g_3
    cf = build_forest(false, true, Y, T, X, true, m, 500, 500)
    tauhat = apply_forest(cf, Xtest)
    hterf1[j] = importance(cf)[1]
    hterf2[j] = importance(cf)[2]
    hterf3[j] = importance(cf)[3]
    errors_hterf[j] = rmsd(tau, tauhat)
    grf1[j] = g_1
    grf2[j] = g_2
    grf3[j] = g_3
    errors_grf[j] = mse
end
err_hterf = mean(errors_hterf)
err_grf = mean(errors_grf)
grf_1 = mean(grf1)
grf_2 = mean(grf2)
grf_3 = mean(grf3)
hterf_1 = mean(hterf1)
hterf_2 = mean(hterf2)
hterf_3 = mean(hterf3)
@rput err_grf err_hterf grf_1 grf_2 grf_3 hterf_1 hterf_2 hterf_3
R"""
dfgrf = data.frame(method = "GRF", RMSE = err_grf, imp1 = grf_1, imp2 = grf_2, imp3 = grf_3)
dfhterf = data.frame(method = "HTERF", RMSE = err_hterf, imp1 = hterf_1, imp2 = hterf_2, imp3 = hterf_3)
res = rbind(res, dfgrf, dfhterf)

"""
@rget res
print(res)

[1m2×5 DataFrame[0m
[1m Row [0m│[1m method [0m[1m RMSE     [0m[1m imp1     [0m[1m imp2        [0m[1m imp3     [0m
[1m     [0m│[90m String [0m[90m Float64  [0m[90m Float64  [0m[90m Float64     [0m[90m Float64  [0m
─────┼───────────────────────────────────────────────────
   1 │ GRF     0.981815  0.655163  0.0754224    0.269415
   2 │ HTERF   0.765765  0.762526  0.000351466  0.237123