# Question 2 - DML (Debiased Machine Learning)
## Julia Implementation

This notebook implements Debiased Machine Learning with and without cross-fitting using multiple ML methods.

## 0) Setup e imports

In [None]:
using Pkg

# Install required packages if not already installed
required_packages = ["DataFrames", "CSV", "Downloads", "Statistics",
                    "Random", "LinearAlgebra", "GLM", "GLMNet",
                    "DecisionTree", "Flux", "Distributions", "Printf"]

for pkg in required_packages
    if !haskey(Pkg.project().dependencies, pkg)
        Pkg.add(pkg)
    end
end

using DataFrames, CSV, Downloads
using Statistics, Random, LinearAlgebra
using GLM          # For OLS and Logistic Regression
using GLMNet       # For Lasso
using DecisionTree # For Random Forest
using Flux         # For Neural Networks
using Distributions # For Normal distribution
using Printf

Random.seed!(12345)

println("Packages loaded successfully!")

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
Precompiling CUDA_Driver_jll...
   1337.3 ms  ✓ CUDA_Driver_jll
  1 dependency successfully precompiled in 7 seconds. 25 already precompiled.
Precompiling CUDA_Runtime_jll...
   2545.0 ms  ✓ CUDA_Runtime_jll
  1 dependency successfully precompiled in 8 seconds. 27 already precompiled.
[32m[1m    Updating[22m[39m `~/.julia/environments/v1.11/Project.toml`
  [90m[f43a241f] [39m[92m+ Downloads v1.6.0[39m
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Manifest.toml`
[92m[1mPrecompiling[22m[39m project...
   1479.2 ms[32m  ✓ [39m[90mStructUtils → StructUtilsTablesExt[39m
   1939.7 ms[32m  ✓ [39m[90mComputePipeline[39m
   1970.8 ms[32m  ✓ [39m[90mOpenSSL_jll[39m
   4607.1 ms[32m  ✓ [39m[90mDispatchDoctor[39m
   1691.2 ms[32m  ✓ [39m[90mGraphite2_jll[39m
   2430.2 ms[32m  ✓ [39m[90mLibmount_jll[39m
   2399.7 ms

## 1) Load and Clean Data

In [None]:
# Download the data
url = "https://raw.githubusercontent.com/CausalAIBook/MetricsMLNotebooks/main/data/penn_jae.dat"
data_path = Downloads.download(url)
DT = CSV.read(data_path, DataFrame, delim=' ', ignorerepeated=true)

# Normalize column names to lowercase
rename!(DT, [Symbol(lowercase(string(n))) => n for n in names(DT)])

# Filter tg == 0 or tg == 4
DT = DT[in.(DT.tg, Ref([0, 4])), :]

# Create treatment variable
DT.T4 = Int.(DT.tg .== 4)

# Create outcome variable
DT.y = log.(DT.inuidur1)

# Create dep dummies
DT.dep = Int.(DT.dep)
DT.dep_0 = Int.(DT.dep .== 0)
DT.dep_1 = Int.(DT.dep .== 1)
DT.dep_2 = Int.(DT.dep .== 2)

# Handle age variables
if !hasproperty(DT, :agelt35) && hasproperty(DT, :age)
    DT.agelt35 = Int.(DT.age .< 35)
    DT.agegt54 = Int.(DT.age .> 54)
end

# Define X variables
x_vars = [:female, :black, :othrace,
          :dep_1, :dep_2,
          :q2, :q3, :q4, :q5, :q6,
          :recall, :agelt35, :agegt54,
          :durable, :nondurable, :lusd, :husd]

# Select columns and remove missing
use_cols = vcat([:y, :T4], x_vars)
DT = DT[:, use_cols]
DT = dropmissing(DT)

# Extract vectors and matrix
y = DT.y
d = DT.T4
X = Matrix{Float64}(DT[:, x_vars])
n, p = size(X)

println("Final sample: $n rows, $p predictors.")

## 2) Utility Functions

In [None]:
"""Calculate Root Mean Squared Error between vectors a and b."""
function rmse(a::Vector, b::Vector)
    return sqrt(mean((a .- b).^2))
end

"""Calculate theta and standard error for Partially Linear Model."""
function plm_theta_se(y_tilde::Vector, d_tilde::Vector)
    theta = sum(d_tilde .* y_tilde) / sum(d_tilde .* d_tilde)
    psi = (y_tilde .- d_tilde .* theta) .* d_tilde
    se = sqrt(mean(psi.^2) / (length(y_tilde) * mean(d_tilde.^2)^2))
    return (theta=theta, se=se)
end

## 3) Learners (OLS, Lasso, RF, NN)

In [None]:
# --- OLS and Logistic Regression ---
function fit_y_ols(X::Matrix, y::Vector)
    df = DataFrame(X, :auto)
    df.y = y
    return lm(@formula(y ~ .), df)
end

function pred_y_ols(fit, X::Matrix)
    df = DataFrame(X, :auto)
    return predict(fit, df)
end

function fit_d_logit(X::Matrix, d::Vector)
    df = DataFrame(X, :auto)
    df.d = d
    return glm(@formula(d ~ .), df, Binomial(), LogitLink())
end

function pred_d_logit(fit, X::Matrix)
    df = DataFrame(X, :auto)
    return predict(fit, df)
end

# --- Lasso ---
function fit_y_lasso(X::Matrix, y::Vector)
    return glmnetcv(X, y, alpha=1.0)
end

function pred_y_lasso(fit, X::Matrix)
    return vec(GLMNet.predict(fit, X))
end

function fit_d_lasso(X::Matrix, d::Vector)
    return glmnetcv(X, Float64.(d), Binomial(), alpha=1.0)
end

function pred_d_lasso(fit, X::Matrix)
    preds = GLMNet.predict(fit, X, outtype=:prob)
    return vec(preds)
end

# --- Random Forest ---
function fit_y_rf(X::Matrix, y::Vector)
    return build_forest(y, X,
                       n_trees=1000,
                       max_depth=-1,
                       min_samples_leaf=5,
                       rng=Random.MersenneTwister(1))
end

function pred_y_rf(fit, X::Matrix)
    return apply_forest(fit, X)
end

function fit_d_rf(X::Matrix, d::Vector)
    d_labels = string.(d)
    return build_forest(d_labels, X,
                       n_trees=1000,
                       max_depth=-1,
                       min_samples_leaf=5,
                       rng=Random.MersenneTwister(1))
end

function pred_d_rf(fit, X::Matrix)
    preds = apply_forest_proba(fit, X, ["0", "1"])
    return preds[:, 2]
end

# --- Neural Network ---
function fit_y_nn(X::Matrix, y::Vector; size=4, max_epochs=500, lr=1e-3)
    n_features = size(X, 2)

    X_mean = mean(X, dims=1)
    X_std = std(X, dims=1) .+ 1e-8
    X_norm = (X .- X_mean) ./ X_std

    y_mean = mean(y)
    y_std = std(y) + 1e-8
    y_norm = (y .- y_mean) ./ y_std

    model = Chain(
        Dense(n_features, size, relu),
        Dense(size, 1)
    )

    X_t = Float32.(X_norm')
    y_t = Float32.(reshape(y_norm, 1, :))

    opt = Flux.setup(Adam(lr), model)
    data = [(X_t, y_t)]

    for epoch in 1:max_epochs
        Flux.train!((m, x, y) -> Flux.mse(m(x), y), model, data, opt)
    end

    return (model=model, X_mean=X_mean, X_std=X_std, y_mean=y_mean, y_std=y_std)
end

function pred_y_nn(fit, X::Matrix)
    X_norm = (X .- fit.X_mean) ./ fit.X_std
    X_t = Float32.(X_norm')
    preds_norm = vec(fit.model(X_t))
    preds = preds_norm .* fit.y_std .+ fit.y_mean
    return preds
end

function fit_d_nn(X::Matrix, d::Vector; size=3, max_epochs=500, lr=1e-3)
    n_features = size(X, 2)

    X_mean = mean(X, dims=1)
    X_std = std(X, dims=1) .+ 1e-8
    X_norm = (X .- X_mean) ./ X_std

    model = Chain(
        Dense(n_features, size, relu),
        Dense(size, 1, sigmoid)
    )

    X_t = Float32.(X_norm')
    d_t = Float32.(reshape(d, 1, :))

    opt = Flux.setup(Adam(lr), model)
    data = [(X_t, d_t)]

    for epoch in 1:max_epochs
        Flux.train!((m, x, d) -> Flux.binarycrossentropy(m(x), d), model, data, opt)
    end

    return (model=model, X_mean=X_mean, X_std=X_std)
end

function pred_d_nn(fit, X::Matrix)
    X_norm = (X .- fit.X_mean) ./ fit.X_std
    X_t = Float32.(X_norm')
    preds = vec(fit.model(X_t))
    return preds
end

## 4) DML with Cross-Fitting

In [None]:
function dml_plm(y::Vector, d::Vector, X::Matrix, K::Int=2;
                 ml_y::NamedTuple,
                 ml_d::NamedTuple,
                 return_nuisance_rmse::Bool=true)

    n = length(y)
    folds = repeat(1:K, outer=ceil(Int, n/K))[1:n]
    shuffle!(folds)

    m_hat = fill(NaN, n)
    g_hat = fill(NaN, n)
    rmse_y_folds = Float64[]
    rmse_d_folds = Float64[]

    for k in 1:K
        I_tr = findall(folds .!= k)
        I_te = findall(folds .== k)

        fit_m = ml_y.fit(X[I_tr, :], y[I_tr])
        fit_g = ml_d.fit(X[I_tr, :], d[I_tr])

        m_hat[I_te] = ml_y.pred(fit_m, X[I_te, :])
        g_hat[I_te] = ml_d.pred(fit_g, X[I_te, :])

        if return_nuisance_rmse
            push!(rmse_y_folds, rmse(y[I_te], m_hat[I_te]))
            push!(rmse_d_folds, rmse(d[I_te], g_hat[I_te]))
        end
    end

    y_tilde = y .- m_hat
    d_tilde = d .- g_hat

    est = plm_theta_se(y_tilde, d_tilde)

    result = (theta=est.theta, se=est.se)

    if return_nuisance_rmse
        result = merge(result,
                      (rmse_y=mean(rmse_y_folds),
                       rmse_d=mean(rmse_d_folds)))
    end

    return result
end

## 5) DML WITHOUT Cross-Fitting

In [None]:
function dml_plm_no_cf(y::Vector, d::Vector, X::Matrix, K::Int=2;
                       ml_y::NamedTuple,
                       ml_d::NamedTuple,
                       return_nuisance_rmse::Bool=true)

    n = length(y)
    folds = repeat(1:K, outer=ceil(Int, n/K))[1:n]
    shuffle!(folds)

    m_hat = fill(NaN, n)
    g_hat = fill(NaN, n)
    rmse_y_folds = Float64[]
    rmse_d_folds = Float64[]

    for k in 1:K
        I_k = findall(folds .== k)

        fit_m = ml_y.fit(X[I_k, :], y[I_k])
        fit_g = ml_d.fit(X[I_k, :], d[I_k])

        m_hat[I_k] = ml_y.pred(fit_m, X[I_k, :])
        g_hat[I_k] = ml_d.pred(fit_g, X[I_k, :])

        if return_nuisance_rmse
            push!(rmse_y_folds, rmse(y[I_k], m_hat[I_k]))
            push!(rmse_d_folds, rmse(d[I_k], g_hat[I_k]))
        end
    end

    y_tilde = y .- m_hat
    d_tilde = d .- g_hat

    est = plm_theta_se(y_tilde, d_tilde)

    result = (theta=est.theta, se=est.se)

    if return_nuisance_rmse
        result = merge(result,
                      (rmse_y=mean(rmse_y_folds),
                       rmse_d=mean(rmse_d_folds)))
    end

    return result
end

## 6) Execute: CF and No-CF with 4 models

In [None]:
# Define learners
learners = Dict(
    "OLS+LOGIT" => (
        ml_y = (fit=fit_y_ols, pred=pred_y_ols),
        ml_d = (fit=fit_d_logit, pred=pred_d_logit)
    ),
    "LASSO" => (
        ml_y = (fit=fit_y_lasso, pred=pred_y_lasso),
        ml_d = (fit=fit_d_lasso, pred=pred_d_lasso)
    ),
    "RF" => (
        ml_y = (fit=fit_y_rf, pred=pred_y_rf),
        ml_d = (fit=fit_d_rf, pred=pred_d_rf)
    ),
    "NN" => (
        ml_y = (fit=fit_y_nn, pred=pred_y_nn),
        ml_d = (fit=fit_d_nn, pred=pred_d_nn)
    )
)

function run_block(fun::Function, y::Vector, d::Vector, X::Matrix,
                   K::Int, learners::Dict)
    results = DataFrame(
        Method = String[],
        theta = Float64[],
        se = Float64[],
        pval = Float64[],
        rmse_y = Float64[],
        rmse_d = Float64[]
    )

    for (name, ml) in learners
        println("Running $name...")
        Random.seed!(42)

        est = fun(y, d, X, K; ml_y=ml.ml_y, ml_d=ml.ml_d)

        pval = 2 * cdf(Normal(0, 1), -abs(est.theta / est.se))

        push!(results, (
            Method = name,
            theta = est.theta,
            se = est.se,
            pval = pval,
            rmse_y = est.rmse_y,
            rmse_d = est.rmse_d
        ))
    end

    return results
end

println("="^70)
println("Running DML with Cross-Fitting...")
println("="^70)
K = 2
tab_cf = run_block(dml_plm, y, d, X, K, learners)
tab_cf.CrossFitting .= "Yes"

println("\n" * "="^70)
println("Running DML WITHOUT Cross-Fitting...")
println("="^70)
tab_nocf = run_block(dml_plm_no_cf, y, d, X, K, learners)
tab_nocf.CrossFitting .= "No"

results_all = vcat(tab_cf, tab_nocf)
sort!(results_all, [:CrossFitting, :Method])

println("\n" * "="^70)
println("ALL RESULTS")
println("="^70)
println(results_all)

## 7) OLS with Controls as Benchmark

In [None]:
# Create DataFrame for regression
df_full = DataFrame(X, x_vars)
df_full.y = y
df_full.d = d

# Fit OLS with all controls
ols_full = lm(@formula(y ~ d + female + black + othrace + dep_1 + dep_2 +
                       q2 + q3 + q4 + q5 + q6 + recall + agelt35 + agegt54 +
                       durable + nondurable + lusd + husd), df_full)

# Extract coefficient, SE, and p-value for treatment
ols_coef_table = coeftable(ols_full)
d_idx = findfirst(==(:d), coefnames(ols_full))

theta_ols_controls = coef(ols_full)[d_idx]
se_ols_controls = stderror(ols_full)[d_idx]
pval_ols_controls = ols_coef_table.cols[4][d_idx]

# Calculate RMSE
y_pred_ols = predict(ols_full, df_full)
rmse_y_ols = rmse(y, y_pred_ols)

# Fit logistic for d
logit_d = glm(@formula(d ~ female + black + othrace + dep_1 + dep_2 +
                      q2 + q3 + q4 + q5 + q6 + recall + agelt35 + agegt54 +
                      durable + nondurable + lusd + husd),
              df_full, Binomial(), LogitLink())
d_pred_logit = predict(logit_d, df_full)
rmse_d_ols = rmse(Float64.(d), d_pred_logit)

# Add to results
ols_row = DataFrame(
    Method = "OLS with controls",
    theta = theta_ols_controls,
    se = se_ols_controls,
    pval = pval_ols_controls,
    rmse_y = rmse_y_ols,
    rmse_d = rmse_d_ols,
    CrossFitting = "N/A"
)

results_all = vcat(results_all, ols_row)
sort!(results_all, [:CrossFitting, :Method])

println("OLS with controls added to results")

## 8) Model Selection (CF) and Final Estimation

In [None]:
# Select best model from cross-fitting based on smallest SE
tab_cf_sorted = sort(tab_cf, :se)
best_cf = tab_cf_sorted[1, :]

println("\nBest model (smallest SE with cross-fitting):")
println(best_cf)

function run_final(method_name::String, learners::Dict, y::Vector, d::Vector, X::Matrix, K::Int)
    ml = learners[method_name]
    out = dml_plm(y, d, X, K; ml_y=ml.ml_y, ml_d=ml.ml_d)

    println("\nFinal DML (CF) with $method_name")
    @printf("theta=%.4f, se=%.4f, pval=%.4g\n",
            out.theta, out.se, 2*cdf(Normal(0,1), -abs(out.theta/out.se)))

    return out
end

# Uncomment to run final estimation:
# final_fit = run_final(best_cf.Method, learners, y, d, X, K)

## 9) Print Readable Tables

In [None]:
function print_table(tab::DataFrame, title::String)
    println("\n $title ")
    println("-"^70)

    tab_display = select(tab,
        :CrossFitting, :Method,
        :theta => (x -> round.(x, digits=4)) => :theta,
        :se => (x -> round.(x, digits=4)) => :se,
        :pval => (x -> round.(x, sigdigits=3)) => :pval,
        :rmse_y => (x -> round.(x, digits=4)) => :rmse_y,
        :rmse_d => (x -> round.(x, digits=4)) => :rmse_d
    )

    println(tab_display)
    println("-"^70)
end

print_table(filter(row -> row.CrossFitting == "Yes", results_all),
            "Table A. DML con cross-fitting")
print_table(filter(row -> row.CrossFitting == "No", results_all),
            "Table B. DML sin cross-fitting")
print_table(results_all,
            "Appendix. Todos los modelos (incluye OLS con controles)")

## 10) Answers

### PLM and DML
We estimate the partially linear model:
$$y = \theta d + g_0(X) + \varepsilon, \quad d = m_0(X) + \nu$$

DML uses cross-fitting to build out-of-sample residuals:
$$\tilde{y} = y - \hat{g}(X), \; \tilde{d} = d - \hat{m}(X)$$

and estimates:
$$\hat{\theta} = \frac{\sum_i \tilde{d}_i \tilde{y}_i}{\sum_i \tilde{d}_i^2}$$

with IF-based standard errors.

### Cross-fitting vs no cross-fitting
- RMSE for predicting $y$ and $d$ is usually **smaller** without cross-fitting due to in-sample optimism.
- Lower RMSE there does **not** mean better causal inference; it reflects **overfitting** of nuisances.
- Sin cross-fitting, el sesgo de regularización se filtra al estimando y genera **sesgo** y **inferencias no conservadoras**.

### Selected model
Choose the CF method with the smallest SE in Table A and report its $\hat{\theta}$ as the final effect.