# Double/Debiased Machine Learning for the Partially Linear Regression Model.

This is a simple implementation of Debiased Machine Learning for the Partially Linear Regression Model.

Reference: 

https://arxiv.org/abs/1608.00060


https://www.amazon.com/Business-Data-Science-Combining-Accelerate/dp/1260452778

The code is based on the book.

# DML algorithm

Here we perform estimation and inference of predictive coefficient $\alpha$ in the partially linear statistical model, 
$$
Y = D\alpha + g(X) + U, \quad E (U | D, X) = 0. 
$$
For $\tilde Y = Y- E(Y|X)$ and $\tilde D= D- E(D|X)$, we can write
$$
\tilde Y = \alpha \tilde D + U, \quad E (U |\tilde D) =0.
$$
Parameter $\alpha$ is then estimated using cross-fitting approach to obtain the residuals $\tilde D$ and $\tilde Y$.
The algorithm comsumes $Y, D, X$, and machine learning methods for learning the residuals $\tilde Y$ and $\tilde D$, where
the residuals are obtained by cross-validation (cross-fitting).

The statistical parameter $\alpha$ has a causal intertpreation of being the effect of $D$ on $Y$ in the causal DAG $$ D\to Y, \quad X\to (D,Y)$$ or the counterfactual outcome model with conditionally exogenous (conditionally random) assignment of treatment $D$ given $X$:
$$
Y(d) = d\alpha + g(X) + U(d),\quad  U(d) \text{ indep } D |X, \quad Y = Y(D), \quad U = U(D).
$$


In [1]:
install.packages("hdm")
install.packages("AER")
install.packages("randomForest")



Installing package into 'C:/Users/Anzony/Documents/R/win-library/4.0'
(as 'lib' is unspecified)



package 'hdm' successfully unpacked and MD5 sums checked

The downloaded binary packages are in
	C:\Users\Anzony\AppData\Local\Temp\RtmpwVuAEC\downloaded_packages


Installing package into 'C:/Users/Anzony/Documents/R/win-library/4.0'
(as 'lib' is unspecified)



package 'AER' successfully unpacked and MD5 sums checked

The downloaded binary packages are in
	C:\Users\Anzony\AppData\Local\Temp\RtmpwVuAEC\downloaded_packages


Installing package into 'C:/Users/Anzony/Documents/R/win-library/4.0'
(as 'lib' is unspecified)



package 'randomForest' successfully unpacked and MD5 sums checked

The downloaded binary packages are in
	C:\Users\Anzony\AppData\Local\Temp\RtmpwVuAEC\downloaded_packages


In [12]:
getAnywhere( vcovHC.default )

A single object matching 'vcovHC.default' was found
It was found in the following places
  package:sandwich
  registered S3 method for vcovHC from namespace sandwich
  namespace:sandwich
with value

function (x, type = c("HC3", "const", "HC", "HC0", "HC1", "HC2", 
    "HC4", "HC4m", "HC5"), omega = NULL, sandwich = TRUE, ...) 
{
    type <- match.arg(type)
    rval <- meatHC(x, type = type, omega = omega)
    if (sandwich) 
        rval <- sandwich(x, meat. = rval, ...)
    return(rval)
}
<bytecode: 0x000000000d0968e8>
<environment: namespace:sandwich>

In [2]:
DML2.for.PLM <- function(x, d, y, dreg, yreg, nfold=2) {
  nobs <- nrow(x) #number of observations
  foldid <- rep.int(1:nfold,times = ceiling(nobs/nfold))[sample.int(nobs)] #define folds indices
  I <- split(1:nobs, foldid)  #split observation indices into folds  
  ytil <- dtil <- rep(NA, nobs)
  cat("fold: ")
  for(b in 1:length(I)){
    dfit <- dreg(x[-I[[b]],], d[-I[[b]]]) #take a fold out
    yfit <- yreg(x[-I[[b]],], y[-I[[b]]]) # take a foldt out
    dhat <- predict(dfit, x[I[[b]],], type="response") #predict the left-out fold 
    yhat <- predict(yfit, x[I[[b]],], type="response") #predict the left-out fold 
    dtil[I[[b]]] <- (d[I[[b]]] - dhat) #record residual for the left-out fold
    ytil[I[[b]]] <- (y[I[[b]]] - yhat) #record residial for the left-out fold
    cat(b," ")
        }
  rfit <- lm(ytil ~ dtil)    #estimate the main parameter by regressing one residual on the other
  coef.est <- coef(rfit)[2]  #extract coefficient
  se <- sqrt(vcovHC(rfit)[2,2]) #record robust standard error
  cat(sprintf("\ncoef (se) = %g (%g)\n", coef.est , se))  #printing output
  return( list(coef.est =coef.est , se=se, dtil=dtil, ytil=ytil) ) #save output and residuals 
}


In [3]:
library(AER)  #applied econometrics library
library(randomForest)  #random Forest library
library(hdm) #high-dimensional econometrics library
library(glmnet) #glm net

data(GrowthData)  

"package 'AER' was built under R version 4.0.5"
Loading required package: car

Loading required package: carData

Loading required package: lmtest

Loading required package: zoo


Attaching package: 'zoo'


The following objects are masked from 'package:base':

    as.Date, as.Date.numeric


Loading required package: sandwich

Loading required package: survival

"package 'randomForest' was built under R version 4.0.5"
randomForest 4.6-14

Type rfNews() to see new features/changes/bug fixes.

"package 'hdm' was built under R version 4.0.5"
Loading required package: Matrix

Loaded glmnet 4.1-1



In [5]:
save(GrowthData, file = "../data/GrowthData.RData")

In [16]:
vcovHC(lm(y~d +x) , type = "HC3")

Unnamed: 0,(Intercept),d,xbmp1l,xfreeop,xfreetar,xh65,xhm65,xhf65,xp65,xpm65,...,xseccf65,xsyr65,xsyrm65,xsyrf65,xteapri65,xteasec65,xex1,xim1,xxr65,xtot1
(Intercept),2.46214926,0.0270541713,0.0157779680,0.1027024067,0.557035170,-0.1215138603,0.256464780,-0.092400907,0.1504513461,0.087924556,...,-0.1978348527,0.26989229,-0.11009350,-0.19524540,-1.754018e-04,-4.258588e-04,-0.202962635,1.657021e-01,2.259938e-05,0.0564987136
d,0.02705417,0.0044800846,0.0003153135,0.0005653200,0.019148431,-0.0241453264,0.011345831,0.011038354,0.0029726754,-0.003088819,...,0.0245318252,0.41945332,-0.19476160,-0.22928603,1.659792e-05,-4.446398e-06,-0.029849684,2.648345e-02,-7.033014e-07,-0.0046631461
xbmp1l,0.01577797,0.0003153135,0.0085879910,0.0111782580,-0.013774437,0.0007266886,0.007206927,-0.014290817,-0.0015134840,-0.002554885,...,-0.0005260561,-0.17102401,0.07229316,0.10096386,-6.354781e-05,-3.648880e-05,-0.003660238,-2.133201e-03,9.546143e-07,0.0034978142
xfreeop,0.10270241,0.0005653201,0.0111782580,0.1586291725,0.067099435,0.2347979290,-0.189898773,-0.060734685,-0.0015130036,0.024534263,...,0.0098294467,-2.33832928,1.15806300,1.21986139,-1.766267e-04,1.269197e-04,0.022305594,-8.628802e-02,1.337824e-06,0.0054039582
xfreetar,0.55703517,0.0191484312,-0.0137744368,0.0670994352,0.831327004,0.1511399959,-0.152070056,-0.028878568,0.0406819118,-0.036140920,...,0.0678272216,2.05286070,-0.95259312,-1.05666438,1.107170e-04,7.695293e-05,-0.156290228,1.371161e-01,1.999134e-05,0.0022514956
xh65,-0.12151386,-0.0241453264,0.0007266887,0.2347979296,0.151139996,2.0019286089,-0.974352440,-1.018528829,0.0269664951,-0.009667150,...,-0.1825324097,-6.14858199,3.02934239,3.13691932,8.316001e-07,5.296413e-04,0.249786583,-3.624705e-01,1.557314e-05,0.0176571631
xhm65,0.25646478,0.0113458312,0.0072069269,-0.1898987730,-0.152070056,-0.9743524390,0.777422356,0.249693960,0.0163759741,-0.027339512,...,0.0187626322,5.04080733,-2.49037734,-2.63039854,-9.839973e-05,-4.932112e-04,-0.109973917,1.780032e-01,-3.795330e-06,0.0125707011
xhf65,-0.09240091,0.0110383538,-0.0142908170,-0.0607346854,-0.028878568,-1.0185288290,0.249693960,0.804758811,-0.0337330032,0.046786871,...,0.0918863158,0.62483632,-0.28901434,-0.28947457,8.151681e-05,2.577723e-05,-0.123574636,1.676842e-01,-1.984981e-05,-0.0205128315
xp65,0.15045135,0.0029726754,-0.0015134840,-0.0015130036,0.040681912,0.0269664954,0.016375974,-0.033733003,0.1066512178,-0.041519671,...,0.0225785615,1.56315870,-0.74827828,-0.81593144,1.156964e-04,1.584556e-04,-0.022412969,2.100140e-02,4.149174e-06,-0.0138762881
xpm65,0.08792456,-0.0030888188,-0.0025548851,0.0245342626,-0.036140920,-0.0096671498,-0.027339512,0.046786870,-0.0415196713,0.087289895,...,-0.0513284224,-1.97172214,0.94738736,1.03976189,-8.928399e-05,-1.493732e-05,0.045773714,-4.531569e-02,-4.372326e-06,0.0022059369


In [13]:
vcovHC(lm(y~d +x))

Unnamed: 0,(Intercept),d,xbmp1l,xfreeop,xfreetar,xh65,xhm65,xhf65,xp65,xpm65,...,xseccf65,xsyr65,xsyrm65,xsyrf65,xteapri65,xteasec65,xex1,xim1,xxr65,xtot1
(Intercept),2.46214926,0.0270541713,0.0157779680,0.1027024067,0.557035170,-0.1215138603,0.256464780,-0.092400907,0.1504513461,0.087924556,...,-0.1978348527,0.26989229,-0.11009350,-0.19524540,-1.754018e-04,-4.258588e-04,-0.202962635,1.657021e-01,2.259938e-05,0.0564987136
d,0.02705417,0.0044800846,0.0003153135,0.0005653200,0.019148431,-0.0241453264,0.011345831,0.011038354,0.0029726754,-0.003088819,...,0.0245318252,0.41945332,-0.19476160,-0.22928603,1.659792e-05,-4.446398e-06,-0.029849684,2.648345e-02,-7.033014e-07,-0.0046631461
xbmp1l,0.01577797,0.0003153135,0.0085879910,0.0111782580,-0.013774437,0.0007266886,0.007206927,-0.014290817,-0.0015134840,-0.002554885,...,-0.0005260561,-0.17102401,0.07229316,0.10096386,-6.354781e-05,-3.648880e-05,-0.003660238,-2.133201e-03,9.546143e-07,0.0034978142
xfreeop,0.10270241,0.0005653201,0.0111782580,0.1586291725,0.067099435,0.2347979290,-0.189898773,-0.060734685,-0.0015130036,0.024534263,...,0.0098294467,-2.33832928,1.15806300,1.21986139,-1.766267e-04,1.269197e-04,0.022305594,-8.628802e-02,1.337824e-06,0.0054039582
xfreetar,0.55703517,0.0191484312,-0.0137744368,0.0670994352,0.831327004,0.1511399959,-0.152070056,-0.028878568,0.0406819118,-0.036140920,...,0.0678272216,2.05286070,-0.95259312,-1.05666438,1.107170e-04,7.695293e-05,-0.156290228,1.371161e-01,1.999134e-05,0.0022514956
xh65,-0.12151386,-0.0241453264,0.0007266887,0.2347979296,0.151139996,2.0019286089,-0.974352440,-1.018528829,0.0269664951,-0.009667150,...,-0.1825324097,-6.14858199,3.02934239,3.13691932,8.316001e-07,5.296413e-04,0.249786583,-3.624705e-01,1.557314e-05,0.0176571631
xhm65,0.25646478,0.0113458312,0.0072069269,-0.1898987730,-0.152070056,-0.9743524390,0.777422356,0.249693960,0.0163759741,-0.027339512,...,0.0187626322,5.04080733,-2.49037734,-2.63039854,-9.839973e-05,-4.932112e-04,-0.109973917,1.780032e-01,-3.795330e-06,0.0125707011
xhf65,-0.09240091,0.0110383538,-0.0142908170,-0.0607346854,-0.028878568,-1.0185288290,0.249693960,0.804758811,-0.0337330032,0.046786871,...,0.0918863158,0.62483632,-0.28901434,-0.28947457,8.151681e-05,2.577723e-05,-0.123574636,1.676842e-01,-1.984981e-05,-0.0205128315
xp65,0.15045135,0.0029726754,-0.0015134840,-0.0015130036,0.040681912,0.0269664954,0.016375974,-0.033733003,0.1066512178,-0.041519671,...,0.0225785615,1.56315870,-0.74827828,-0.81593144,1.156964e-04,1.584556e-04,-0.022412969,2.100140e-02,4.149174e-06,-0.0138762881
xpm65,0.08792456,-0.0030888188,-0.0025548851,0.0245342626,-0.036140920,-0.0096671498,-0.027339512,0.046786870,-0.0415196713,0.087289895,...,-0.0513284224,-1.97172214,0.94738736,1.03976189,-8.928399e-05,-1.493732e-05,0.045773714,-4.531569e-02,-4.372326e-06,0.0022059369


In [27]:
x1=x[1:50, ]

In [28]:
y1 = y[1:50]

In [32]:
help(glmnet)

In [29]:
model = glmnet(x1, y1 , lambda = 0)

In [None]:
predict( model , x , type = 'response')

In [33]:
library(AER)  #applied econometrics library
library(randomForest)  #random Forest library
library(hdm) #high-dimensional econometrics library
library(glmnet) #glm net

data(GrowthData)                     # Barro-Lee growth data
y= as.matrix(GrowthData[,1])         # outcome: growth rate
d= as.matrix(GrowthData[,3])         # treatment: initial wealth
x= as.matrix(GrowthData[,-c(1,2,3)]) # controls: country characteristics

cat(sprintf("\n length of y is %g \n", length(y) ))
cat(sprintf("\n num features x is %g \n", dim(x)[2] ))


#summary(y)
#summary(d)
#summary(x)

cat(sprintf("\n Naive OLS that uses all features w/o cross-fitting \n"))
lres=summary(lm(y~d +x))$coef[2,1:2]

cat(sprintf("\ncoef (se) = %g (%g)\n", lres[1] , lres[2]))



#DML with OLS


cat(sprintf("\n DML with OLS w/o feature selection \n"))

set.seed(1)
dreg <- function(x,d){ glmnet(x, d, lambda = 0) } #ML method= OLS using glmnet; using lm gives bugs
yreg <- function(x,y){ glmnet(x, y, lambda = 0) } #ML method = OLS

DML2.OLS = DML2.for.PLM(x, d, y, dreg, yreg, nfold=10)


#DML with Lasso:

cat(sprintf("\n DML with Lasso \n"))

set.seed(1)
dreg <- function(x,d){ rlasso(x,d, post=FALSE) } #ML method= lasso from hdm 
yreg <- function(x,y){ rlasso(x,y, post=FALSE) } #ML method = lasso from hdm
DML2.lasso = DML2.for.PLM(x, d, y, dreg, yreg, nfold=10)

cat(sprintf("\n DML with Random Forest \n"))

#DML with Random Forest:
dreg <- function(x,d){ randomForest(x, d) } #ML method=Forest 
yreg <- function(x,y){ randomForest(x, y) } #ML method=Forest
set.seed(1)
DML2.RF = DML2.for.PLM(x, d, y, dreg, yreg, nfold=10)


cat(sprintf("\n DML with Lasso/Random Forest \n"))

#DML MIX:
dreg <- function(x,d){ rlasso(x,d, post=FALSE) } #ML method=Forest 
yreg <- function(x,y){ randomForest(x, y) } #ML method=Forest
set.seed(1)
DML2.RF = DML2.for.PLM(x, d, y, dreg, yreg, nfold=10)



 length of y is 90 

 num features x is 60 

 Naive OLS that uses all features w/o cross-fitting 

coef (se) = -0.00937799 (0.0298877)

 DML with OLS w/o feature selection 
fold: 1  2  3  4  5  6  7  8  9  10  
coef (se) = 0.01013 (0.0167061)

 DML with Lasso 
fold: 1  2  3  4  5  6  7  8  9  10  
coef (se) = -0.0352021 (0.0161357)

 DML with Random Forest 
fold: 1  2  3  4  5  6  7  8  9  10  
coef (se) = -0.0365831 (0.0151539)

 DML with Lasso/Random Forest 
fold: 1  2  3  4  5  6  7  8  9  10  
coef (se) = -0.0375019 (0.0135088)


In [37]:
if (replace) nrow(x) else ceiling(.632*nrow(x))

ERROR: Error in if (replace) nrow(x) else ceiling(0.632 * nrow(x)): argument is not interpretable as logical


In [36]:
if (!is.null(y) && !is.factor(y)) 5 else 1

In [35]:
if (!is.null(y) && !is.factor(y))
             max(floor(ncol(x)/3), 1) else floor(sqrt(ncol(x)))

In [4]:
prRes.D<- c( mean((DML2.OLS$dtil)^2), mean((DML2.lasso$dtil)^2), mean((DML2.RF$dtil)^2));
prRes.Y<- c(mean((DML2.OLS$ytil)^2), mean((DML2.lasso$ytil)^2),mean((DML2.RF$ytil)^2));
prRes<- rbind(sqrt(prRes.D), sqrt(prRes.Y)); 
rownames(prRes)<- c("RMSE D", "RMSE Y");
colnames(prRes)<- c("OLS", "Lasso", "RF")
print(prRes,digit=2)


         OLS Lasso    RF
RMSE D 0.467 0.372 0.372
RMSE Y 0.054 0.052 0.046
