In [5]:
using CausalForest
using StatsBase
using Plots
using Random
using Distributions
using RCall
using BenchmarkTools
using DelimitedFiles

In [6]:
Random.seed!(123);
n, m = 10^4, 10;
u = Uniform(0,10);
features = rand(u, (n, m));
X = features;
d = Normal();
eps = rand(d, n);
b = Bernoulli();
T = convert(Vector{Int64},rand(b, n));
Y = sin.(features*[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]).*T  + features*[0, 0, 2, 0, 0, 0, 3, 0, 0, 0];

In [7]:
df = hcat(X,T,Y)

10000×12 Matrix{Float64}:
 9.063     0.273863  9.2398     …  8.02953    8.61508   1.0  32.7318
 4.43494   8.65398   4.38668       4.26558    5.79587   0.0  10.2415
 7.45673   7.98071   9.35901       7.38559    9.99516   0.0  27.5137
 5.12083   8.25448   8.79223       0.835565   8.06585   0.0  24.796
 2.53849   5.30423   5.84012       0.570453   9.6668    0.0  22.6625
 3.34152   2.66141   9.8669     …  0.92876    9.49805   1.0  25.1241
 4.27328   8.75436   0.0258112     5.12688    2.24987   1.0   8.09666
 8.67547   7.27313   0.151025      0.0873517  0.765286  0.0  25.0373
 0.991336  1.20752   5.87184       8.38971    9.81301   0.0  20.8175
 1.25287   6.88003   3.64868       4.16181    3.00126   0.0  26.2578
 6.92209   0.298644  4.84843    …  8.07655    1.49476   0.0  36.4211
 1.36551   1.07227   0.570038      2.53062    7.60498   1.0  21.9686
 0.320967  6.45858   4.85281       9.74512    1.98433   0.0  32.6943
 ⋮                              ⋱                       ⋮    
 1.41477   4.12

In [8]:
header = ["X1" "X2" "X3" "X4" "X5" "X6" "X7" "X8" "X9" "X10" "T" "Y"]
writedlm("sinus_causal_3.csv", [header ; df], ",")

In [9]:
reval("set.seed(1)")
Xtest_1 = rcopy(R"matrix(runif(10000),nrow=1000)")
true_effect_1 = sin.(Xtest_1[:, 1]);

In [10]:
function get_all_nodes_in_tree!(
    tree ,
    result  = []
    )
    if hasproperty(tree, :featid)  
        push!(result, tree.featid)
        get_all_nodes_in_tree!(tree.left, result)
        get_all_nodes_in_tree!(tree.right, result)
    end
    return result
end

function get_freq(forest) 
    ensemble = forest.trees
    n_trees = length(ensemble)
    res = []
    for i = 1:n_trees
        append!(res, get_all_nodes_in_tree!(forest.trees[i].tree))
    end
    return proportionmap(res)
end

get_freq (generic function with 1 method)

# Critère alternatif (différence non centrée) honnête optimisé

### Subsampling

In [11]:
errors_11 = zeros(100)

cf = @btime build_forest_1_opti(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_11[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_11[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_11))
print(var(errors_11))

  7.920 s (2244564 allocations: 799.98 MiB)
  7.755 s (269208485 allocations: 4.47 GiB)
0.79318145606504280.00015793930591946576

In [12]:
cf = build_forest_1_opti(false, false, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.115038
  4  => 0.112562
  6  => 0.117142
  7  => 0.0158946
  2  => 0.114642
  10 => 0.111077
  9  => 0.115075
  8  => 0.115694
  3  => 0.0661905
  1  => 0.116684

### Bootstrap

In [13]:
errors_12 = zeros(100)

cf = @btime build_forest_1_opti(false, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)
pred = @btime apply_forest_1(cf, Xtest_1)
errors_12[1] = rmsd(float(true_effect_1), pred)
for i in 2:100
    @rput i
    reval("set.seed(i)")
    Xtest = rcopy(R"matrix(runif(10000),nrow=1000)")
    pred = apply_forest_1(cf, Xtest)
    true_effect = sin.(Xtest[:, 1])
    errors_12[i] = rmsd(float(true_effect), pred)
end
print(mean(errors_12))
print(var(errors_12))

  20.544 s (4277445 allocations: 1.10 GiB)
  7.787 s (274731398 allocations: 4.56 GiB)
1.1535404443524460.00020294799849312836

In [14]:
cf = build_forest_1_opti(false, true, true, Y, T, X, true, 10, 500, 0.5, 0.5, -1, 5, 10, 2000)

get_freq(cf)

Dict{Any, Float64} with 10 entries:
  5  => 0.114546
  4  => 0.11263
  6  => 0.11441
  7  => 0.0227747
  2  => 0.115654
  10 => 0.112569
  9  => 0.111482
  8  => 0.11261
  3  => 0.0703623
  1  => 0.112963