In [17]:
using CausalForest
using StatsBase
using Plots
using Random
using Distributions
using RCall
using BenchmarkTools
using DelimitedFiles

In [18]:
Random.seed!(123);
n, m = 10^4, 10;
u = Uniform(0,10);
features = rand(u, (n, m));
X = features;
d = Normal();
eps = rand(d, n);
b = Bernoulli();
T = convert(Vector{Int64},rand(b, n));
Y = sin.(features*[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]).*T  .+ 0.5;

In [19]:
df = hcat(X,T,Y)

10000×12 Matrix{Float64}:
 9.063     0.273863  9.2398     5.83606   …  8.61508   1.0   0.853941
 4.43494   8.65398   4.38668    8.66793      5.79587   0.0   0.5
 7.45673   7.98071   9.35901    5.96536      9.99516   0.0   0.5
 5.12083   8.25448   8.79223    2.87236      8.06585   0.0   0.5
 2.53849   5.30423   5.84012    9.7704       9.6668    0.0   0.5
 3.34152   2.66141   9.8669     9.76664   …  9.49805   1.0   0.301406
 4.27328   8.75436   0.0258112  0.863643     2.24987   1.0  -0.40513
 8.67547   7.27313   0.151025   4.99536      0.765286  0.0   0.5
 0.991336  1.20752   5.87184    3.17045      9.81301   0.0   0.5
 1.25287   6.88003   3.64868    5.80333      3.00126   0.0   0.5
 6.92209   0.298644  4.84843    8.11503   …  1.49476   0.0   0.5
 1.36551   1.07227   0.570038   1.05282      7.60498   1.0   1.479
 0.320967  6.45858   4.85281    2.0843       1.98433   0.0   0.5
 ⋮                                        ⋱            ⋮    
 1.41477   4.12499   4.66667    9.07923      2.01061

In [20]:
header = ["X1" "X2" "X3" "X4" "X5" "X6" "X7" "X8" "X9" "X10" "T" "Y"]
writedlm("sinus_causal_2.csv", [header ; df], ",")

In [5]:
reval("set.seed(1)")
Xtest_1 = rcopy(R"matrix(runif(10000),nrow=1000)")
true_effect_1 = sin.(Xtest_1[:, 1]);

In [6]:
function get_all_nodes_in_tree!(
    tree ,
    result  = []
    )
    if hasproperty(tree, :featid)  
        push!(result, tree.featid)
        get_all_nodes_in_tree!(tree.left, result)
        get_all_nodes_in_tree!(tree.right, result)
    end
    return result
end

function get_freq(forest) 
    ensemble = forest.trees
    n_trees = length(ensemble)
    res = []
    for i = 1:n_trees
        append!(res, get_all_nodes_in_tree!(forest.trees[i].tree))
    end
    return proportionmap(res)
end

get_freq (generic function with 1 method)

# Critère exact : centré honnête

### Subsampling

In [7]:
errors_7 = zeros(100)

cf = @btime build_forest(true, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest(cf, Xtest_1)
errors_7[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_7[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_7))
print(var(errors_7))

  204.151 s (101779117 allocations: 12.34 GiB)
  7.878 s (245341149 allocations: 4.12 GiB)
0.418525669659797763.2258673065726005e-5

In [8]:
cf = build_forest(true, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.100447
  4  => 0.102306
  6  => 0.0964178
  7  => 0.101045
  2  => 0.101663
  10 => 0.102705
  9  => 0.0994362
  8  => 0.100192
  3  => 0.101734
  1  => 0.0940532

### Bootstrap

In [9]:
errors_8 = zeros(100)

cf = @btime build_forest(true, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest(cf, Xtest_1)
errors_8[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_8[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_8))
print(var(errors_8))

  227.593 s (104756678 allocations: 12.74 GiB)
  7.867 s (247680383 allocations: 4.15 GiB)
0.41322214044403773.2099570307814725e-5

# Critère alternatif (différence non centrée) honnête

### Subsampling

In [10]:
errors_111 = zeros(100)

cf = @btime build_forest_1(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_111[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_111[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_111))
print(var(errors_111))

  8.443 s (3133652 allocations: 862.82 MiB)
  7.971 s (244947232 allocations: 4.11 GiB)
0.0224892425902219378.432681684806591e-7

In [11]:
cf = build_forest_1(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.0478191
  4  => 0.0479592
  6  => 0.0506356
  7  => 0.0485445
  2  => 0.0480843
  10 => 0.0465885
  9  => 0.0497101
  8  => 0.0479842
  3  => 0.0486045
  1  => 0.56407

### Bootstrap

In [12]:
errors_121 = zeros(100)

cf = @btime build_forest_1(false, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_121[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_121[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_121))
print(var(errors_121))

  33.972 s (6042450 allocations: 1.23 GiB)
  8.071 s (247592779 allocations: 4.15 GiB)
0.038715198256655283.758565423668049e-6

# Critère alternatif (différence non centrée) honnête optimisé

### Subsampling

In [13]:
errors_11 = zeros(100)

cf = @btime build_forest_1_opti(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_11[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_11[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_11))
print(var(errors_11))

  5.447 s (2200113 allocations: 795.56 MiB)
  8.522 s (274231343 allocations: 4.55 GiB)
0.0066428928347308061.0184612246468457e-7

In [14]:
cf = build_forest_1_opti(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.00795208
  4  => 0.00808202
  6  => 0.00893959
  7  => 0.00815998
  2  => 0.00904354
  10 => 0.00818596
  9  => 0.00817297
  8  => 0.00835488
  3  => 0.00786112
  1  => 0.925248

### Bootstrap

In [15]:
errors_12 = zeros(100)

cf = @btime build_forest_1_opti(false, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_12[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_12[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_12))
print(var(errors_12))

  13.070 s (4185300 allocations: 1.09 GiB)
  8.606 s (281574467 allocations: 4.66 GiB)
0.00360142797735669724.8440732233700656e-8

In [16]:
cf = build_forest_1_opti(false, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.0156709
  4  => 0.0157069
  6  => 0.0158221
  7  => 0.0156925
  2  => 0.0150446
  10 => 0.0159948
  9  => 0.0147783
  8  => 0.015966
  3  => 0.0154189
  1  => 0.859905