# 1. Loading and processing the data

1. 

In [17]:
library(hdm)
library(xtable)
library(glmnet)
library(ggplot2)
library(tidyr)
library(caret)

In [18]:
file<- "https://raw.githubusercontent.com/d2cml-ai/CausalAI-Course/main/data/wage2015_subsample_inference.csv"
data <- read.csv(file)
colnames(data)

In [19]:
data <- subset(data, select = -c(rownames))
dim(data)[2]

In [20]:
features <- data[,!names(data) %in% c('wage','lwage')]
dependent_var <- data['lwage']

In [21]:
features_flexible <- model.matrix(~ 0 + sex + (exp1 + exp2 + exp3 + exp4 + hsg + scl + clg + ad + so + we + ne + 
                                         factor(occ2) + factor(ind2))^2, data = features)

In [22]:
dependent_var<- as.matrix(dependent_var)
features_flexible<- as.matrix(features_flexible)

2.

# 2. Creating Lasso Cross-Validation Procedure 

Log_grid function

In [23]:
log_grid <- function(lower, upper, log_step) {
  log_grid <- seq(lower, upper, length.out = 1 / log_step)
  return(exp(log_grid))
}

K-fold procedure

In [24]:
k_folds <- function(data, k = 5) {
  
  module <- nrow(data) %% k
  floor <- nrow(data) %/% k 
  
  if (module == 0) {
    trues <- matrix(1, nrow = floor, ncol = 1)
    split_matrix <- kronecker(diag(k), trues)
  } else {
    trues_g1 <- matrix(1, nrow = floor + 1, ncol = 1)
    split_matrix_g1 <- kronecker(diag(module), trues_g1)
    
    trues_g2 <- matrix(1, nrow = floor, ncol = 1)
    split_matrix_g2 <- kronecker(diag(k - module), trues_g2)
    
    split_matrix <- rbind(
      cbind(split_matrix_g1, matrix(0, nrow = nrow(split_matrix_g1), ncol = ncol(split_matrix_g2))),
      cbind(matrix(0, nrow = nrow(split_matrix_g2), ncol = ncol(split_matrix_g1)), split_matrix_g2)
    )
  }
  
  sm_bool <- split_matrix == 1
  splits <- lapply(1:k, function(x) sm_bool[, x])
  
  return(splits)
}

Optimal_lambda procedure 

In [25]:
optimal_lambda <- function(Y, X, lambda_bounds, k = 5, niter = 100) {
  
  library(glmnet)
  
  Y <- drop(Y)
  
  if (is.vector(X)) {
    X <- matrix(X, ncol = 1)
  }
  
  folds <- k_folds(X, k)
  all_lambdas <- exp(seq(lambda_bounds[1], lambda_bounds[2], length.out = niter))
  all_mse <- numeric(niter)
  
  for (l in all_lambdas) {
    split_pes <- numeric(k)
    
    for (i in seq_len(k)) {
      X_train <- X[!folds[[i]], ]
      X_test <- X[folds[[i]], ]
      y_train <- Y[!folds[[i]]]
      y_test <- Y[folds[[i]]]
      
      model <- glmnet(X_train, y_train, alpha = 1, lambda = l,standardize = FALSE)
      predict <- predict(model, X_test, s = l)
      
      pe <- sum((y_test - predict)^2)
      split_pes[i] <- pe
    }
    
    all_mse[which(all_lambdas == l)] <- mean(split_pes)
  }
  
  selected <- which.min(all_mse)
  optimal_lambda <- all_lambdas[selected]
  optimal_model <- glmnet(X, Y, alpha = 1, lambda = optimal_lambda)
  optimal_coef <- coef(optimal_model, s = optimal_lambda)
  
  output <- list(
    optimal_lambda = optimal_lambda,
    optimal_coef = optimal_coef,
    all_lambdas = all_lambdas,
    all_mse = all_mse
  )
  
  return(output)
}

Prediction model 

In [26]:
predict_model <- function(optimal_model, X) {
  
  intercept <- matrix(1, nrow = nrow(X), ncol = 1)
  Z <- cbind(intercept, X)
  
  return(Z %*% optimal_model$optimal_coef)
}

## Applying the lasso cross-validation procedure 

In [28]:
library(caTools)

split <- sample.split(dependent_var, SplitRatio = 0.75)
features_flexible_train <- subset(features_flexible, split == TRUE)
features_flexible_test <- subset(features_flexible, split == FALSE)
dependent_train <- dependent_var[split]
dependent_test <- dependent_var[!split]

Ols fitting

In [29]:
ols <- lm(dependent_train ~ ., data = data.frame(dependent_train = dependent_train, features_flexible_train))

Finding optimal lambda

In [30]:
model_lasso <- optimal_lambda(dependent_train, features_flexible_train, c(-7, 7))

In [31]:
model_lasso$optimal_lambda

In [32]:
model_lasso$optimal_coef

981 x 1 sparse Matrix of class "dgCMatrix"
                                         s1
(Intercept)                    2.680768e+00
sex                           -5.731770e-02
exp1                           8.760830e-03
exp2                           .           
exp3                           .           
exp4                           .           
hsg                            .           
scl                            .           
clg                            3.188792e-01
ad                             3.367445e-01
so                             .           
we                             2.956107e-02
ne                             .           
factor(occ2)1                  2.794073e-01
factor(occ2)2                  2.112206e-01
factor(occ2)3                  9.837338e-02
factor(occ2)4                  1.040519e-01
factor(occ2)5                  .           
factor(occ2)6                 -8.393052e-02
factor(occ2)7                  .           
factor(occ2)8                 -6.