In [21]:
using CausalForest
using StatsBase
using Plots
using Random
using Distributions
using RCall
using BenchmarkTools
using DelimitedFiles

In [10]:
Random.seed!(123);
n, m = 10^4, 10;
u = Uniform();
features = rand(u, (n, m));
X = features;
d = Normal();
eps = rand(d, n);
b = Bernoulli();
T = convert(Vector{Int64},rand(b, n));
Y = sin.(features*[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]).*T  .+ 0.5;

In [11]:
df = hcat(X,T,Y)

10000×12 Matrix{Float64}:
 0.9063     0.0273863  0.92398     0.583606   …  0.861508   1.0  1.28723
 0.443494   0.865398   0.438668    0.866793      0.579587   0.0  0.5
 0.745673   0.798071   0.935901    0.596536      0.999516   0.0  0.5
 0.512083   0.825448   0.879223    0.287236      0.806585   0.0  0.5
 0.253849   0.530423   0.584012    0.97704       0.96668    0.0  0.5
 0.334152   0.266141   0.98669     0.976664   …  0.949805   1.0  0.827968
 0.427328   0.875436   0.00258112  0.0863643     0.224987   1.0  0.91444
 0.867547   0.727313   0.0151025   0.499536      0.0765286  0.0  0.5
 0.0991336  0.120752   0.587184    0.317045      0.981301   0.0  0.5
 0.125287   0.688003   0.364868    0.580333      0.300126   0.0  0.5
 0.692209   0.0298644  0.484843    0.811503   …  0.149476   0.0  0.5
 0.136551   0.107227   0.0570038   0.105282      0.760498   1.0  0.636128
 0.0320967  0.645858   0.485281    0.20843       0.198433   0.0  0.5
 ⋮                                            ⋱            

In [23]:
header = ["X1" "X2" "X3" "X4" "X5" "X6" "X7" "X8" "X9" "X10" "T" "Y"]
writedlm("sinus_causal.csv", [header ; df], ",")

In [12]:
reval("set.seed(1)")
Xtest_1 = rcopy(R"matrix(runif(10000),nrow=1000)")
true_effect_1 = sin.(Xtest_1[:, 1]);

In [24]:
function get_all_nodes_in_tree!(
    tree ,
    result  = []
    )
    if hasproperty(tree, :featid)  
        push!(result, tree.featid)
        get_all_nodes_in_tree!(tree.left, result)
        get_all_nodes_in_tree!(tree.right, result)
    end
    return result
end

function get_freq(forest) 
    ensemble = forest.trees
    n_trees = length(ensemble)
    res = []
    for i = 1:n_trees
        append!(res, get_all_nodes_in_tree!(forest.trees[i].tree))
    end
    return proportionmap(res)
end

get_freq (generic function with 1 method)

# Critère exact : centré honnête

### Subsampling

In [13]:
errors_7 = zeros(100)

cf = @btime build_forest(true, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest(cf, Xtest_1)
errors_7[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_7[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_7))
print(var(errors_7))

  229.404 s (101443434 allocations: 12.32 GiB)
  7.885 s (246926015 allocations: 4.14 GiB)
0.27140091262435971.7674570181564734e-5

In [25]:
cf = build_forest(true, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.0976571
  4  => 0.0987599
  6  => 0.100649
  7  => 0.0994695
  2  => 0.0975652
  10 => 0.0989182
  9  => 0.0938893
  8  => 0.0981064
  3  => 0.0977235
  1  => 0.117262

### Bootstrap

In [14]:
errors_8 = zeros(100)

cf = @btime build_forest(true, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest(cf, Xtest_1)
errors_8[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_8[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_8))
print(var(errors_8))

  327.383 s (104412041 allocations: 12.71 GiB)
  10.623 s (249554580 allocations: 4.18 GiB)
0.26983421958306021.6089012971035695e-5

# Critère alternatif (différence non centrée) honnête

### Subsampling

In [17]:
errors_111 = zeros(100)

cf = @btime build_forest_1(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_111[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_111[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_111))
print(var(errors_111))

  6.820 s (3136175 allocations: 862.99 MiB)
  7.884 s (246561184 allocations: 4.14 GiB)
0.004844742098710911.0308282414730315e-6

In [26]:
cf = build_forest_1(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.0487784
  4  => 0.0502723
  6  => 0.0497427
  7  => 0.0496977
  2  => 0.0476692
  10 => 0.0470847
  9  => 0.0516463
  8  => 0.0489133
  3  => 0.0491431
  1  => 0.557052

### Bootstrap

In [19]:
errors_121 = zeros(100)

cf = @btime build_forest_1(false, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_121[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_121[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_121))
print(var(errors_121))

  29.951 s (6048326 allocations: 1.23 GiB)
  8.261 s (249960520 allocations: 4.19 GiB)
0.0097649919794376532.7775000043447086e-6

# Critère alternatif (différence non centrée) honnête optimisé

### Subsampling

In [15]:
errors_11 = zeros(100)

cf = @btime build_forest_1_opti(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_11[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_11[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_11))
print(var(errors_11))

  6.916 s (2188551 allocations: 794.13 MiB)
  11.585 s (273917525 allocations: 4.54 GiB)
0.0003813293803487854.5860544899894993e-10

In [27]:
cf = build_forest_1_opti(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.00706998
  4  => 0.00654335
  6  => 0.00654335
  7  => 0.00701731
  2  => 0.00662234
  10 => 0.00701731
  9  => 0.00602989
  8  => 0.00700415
  3  => 0.00683299
  1  => 0.939319

### Bootstrap

In [16]:
errors_12 = zeros(100)

cf = @btime build_forest_1_opti(false, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_12[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_12[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_12))
print(var(errors_12))

  11.218 s (4178057 allocations: 1.09 GiB)
  8.643 s (281324202 allocations: 4.65 GiB)
0.000244755791987972041.1306716385320434e-10

In [28]:
cf = build_forest_1_opti(false, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.0144808
  4  => 0.0132759
  6  => 0.0146973
  7  => 0.0136439
  2  => 0.0143293
  10 => 0.0139469
  9  => 0.0137232
  8  => 0.0143798
  3  => 0.0138531
  1  => 0.87367