In [1]:
using CausalForest
using RCall
using StatsBase
using Statistics
using Random
using Distributions

In [2]:
Random.seed!(123);
n, m = 10^4, 10;
@rlibrary grf
R"""
set.seed(123)
grid = as.data.frame(matrix(c(1,5,1,1,1,0.2), ncol=2, byrow=T))
colnames(grid) = c("a1", "a2")
N <- nrow(grid)
res <- data.frame()
"""
@rget N
for i in 1:N
    @rput i
    R"""
    a1 <- grid$a1[i]
    a2 <- grid$a2[i]
    """
    @rget a1 a2
    errors_grf = zeros(60)
    errors_hterf = zeros(60)
    for j in 1:60
        u = Uniform(0,1);
        features = rand(u, (n, m));
        X = features;
        b = Bernoulli();
        T = convert(Vector{Int64},rand(b, n));
        Y = a1*sin.(features*[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]).*T  + a2*cos.(features*[0, 2, 3, 0, 0, 0, 0, 0, 0, 0]);
        Xtest = rand(u, (n, m));
        tau = a1*sin.(Xtest*[1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
        @rput X T Y Xtest tau
        R"""
        cf <- grf::causal_forest(X, Y, T, num.trees=500, tune.num.trees=500, sample.fraction=0.7, ci.group.size=1)
        tau.hat <- predict(cf, Xtest)$predictions
        mse = mean((tau.hat - tau)^2)
        """
        @rget mse
        cf = build_forest(false, true, Y, T, X, true, m, 500, 500)
        tauhat = apply_forest(cf, Xtest)
        errors_hterf[j] = msd(tau, tauhat)
        errors_grf[j] = mse
    end
    err_hterf = mean(errors_hterf)
    err_grf = mean(errors_grf)
    @rput err_grf err_hterf
    R"""
    df = data.frame(a1 = a1, a2 = a2, C.GRF = err_grf * 1000, HTERF = err_hterf * 1000)
    res = rbind(res, df)
    """
end
@rget res
print(res)

[1m3×4 DataFrame[0m
[1m Row [0m│[1m a1      [0m[1m a2      [0m[1m C_GRF     [0m[1m HTERF      [0m
[1m     [0m│[90m Float64 [0m[90m Float64 [0m[90m Float64   [0m[90m Float64    [0m
─────┼─────────────────────────────────────────
   1 │     1.0      5.0  0.276206   0.117095
   2 │     1.0      1.0  0.121759   0.0116932
   3 │     1.0      0.2  0.0789389  0.00405936

In [3]:
function get_all_nodes_in_tree!(
    tree ,
    depth = 3,
    result  = []
    )
    if hasproperty(tree, :featid) && depth > 0
        push!(result, tree.featid)
        get_all_nodes_in_tree!(tree.left, depth-1, result)
        get_all_nodes_in_tree!(tree.right, depth-1, result)
    end
    return result
end

function get_freq(forest, depth=3) 
    ensemble = forest.trees
    n_trees = length(ensemble)
    res = []
    for i = 1:n_trees
        append!(res, get_all_nodes_in_tree!(forest.trees[i].tree, depth))
    end
    return proportionmap(res)
end

get_freq (generic function with 2 methods)

In [4]:
Random.seed!(123);
R"""
resu <- data.frame()
"""
for i in 1:N
    @rput i
    R"""
    a1 <- grid$a1[i]
    a2 <- grid$a2[i]
    """
    @rget a1 a2
    hterf1 = zeros(60)
    hterf2 = zeros(60)
    hterf3 = zeros(60)
    hterf4 = zeros(60)
    grf1 = zeros(60)
    grf2 = zeros(60)
    grf3 = zeros(60)
    grf4 = zeros(60)
    for j in 1:60
        u = Uniform(0,1);
        features = rand(u, (n, m));
        X = features;
        b = Bernoulli();
        T = convert(Vector{Int64},rand(b, n));
        Y = a1*sin.(features*[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]).*T  + a2*cos.(features*[0, 2, 3, 0, 0, 0, 0, 0, 0, 0]);
        cf0 = build_forest(false, true, Y, T, X, true, m, 500, 500)
        hterf1[j] = get_freq(cf0, 3)[1]
        hterf2[j] = get_freq(cf0, 5)[1]
        hterf3[j] = get_freq(cf0, 10)[1]
        hterf4[j] = importance(cf0)[1]
        @rput X T Y
        R"""
        cf1 <- grf::causal_forest(X, Y, T, num.trees=500, tune.num.trees=500, sample.fraction=0.7, ci.group.size=1)
        freq = grf::split_frequencies(cf1,20)
        g_1 = sum(freq[1:3,1])/sum(freq[1:3,])
        g_2 = sum(freq[1:5,1])/sum(freq[1:5,])
        g_3 = sum(freq[1:10,1])/sum(freq[1:10,])
        g_4 = grf::variable_importance(cf1)[1]
        """
        @rget g_1 g_2 g_3 g_4
        grf1[j] = g_1
        grf2[j] = g_2
        grf3[j] = g_3
        grf4[j] = g_4
    end
    grf_1 = mean(grf1)
    grf_2 = mean(grf2)
    grf_3 = mean(grf3)
    grf_4 = mean(grf4)
    hterf_1 = mean(hterf1)
    hterf_2 = mean(hterf2)
    hterf_3 = mean(hterf3)
    hterf_4 = mean(hterf4)
    @rput grf_1 grf_2 grf_3 grf_4 hterf_1 hterf_2 hterf_3 hterf_4
    R"""
    df = data.frame(a1 = a1, a2 = a2, GRF_3 = grf_1, GRF_5 = grf_2, GRF_10 = grf_3, GRF_imp = grf_4, HTERF_3 = hterf_1, HTERF_5 = hterf_2, HTERF_10 = hterf_3, HTERF_imp = hterf_4)
    resu = rbind(resu, df)
    """
    end
@rget resu
print(resu)

[1m3×10 DataFrame[0m
[1m Row [0m│[1m a1      [0m[1m a2      [0m[1m GRF_3    [0m[1m GRF_5    [0m[1m GRF_10   [0m[1m GRF_imp  [0m[1m HTERF_3  [0m[1m HTERF_5  [0m[1m HTERF_10 [0m[1m HTERF_imp [0m
[1m     [0m│[90m Float64 [0m[90m Float64 [0m[90m Float64  [0m[90m Float64  [0m[90m Float64  [0m[90m Float64  [0m[90m Float64  [0m[90m Float64  [0m[90m Float64  [0m[90m Float64   [0m
─────┼───────────────────────────────────────────────────────────────────────────────────────────────────
   1 │     1.0      5.0  0.869723  0.377593  0.149885  0.852033  0.999886  0.497583  0.175225   0.984969
   2 │     1.0      1.0  0.87365   0.526142  0.17448   0.865501  1.0       0.994539  0.281786   1.0
   3 │     1.0      0.2  0.875354  0.627452  0.199608  0.87365   1.0       1.0       0.603331   1.0