Skip to content

JackDunnNZ/PrescriptionUtils.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PrescriptionUtils

This package provides a set of utilities to help with evaluating prescriptive problems.

Data

The package assumes your data is in the following format:

  • X: A Matrix or DataFrame where each row contains the covariates for each observation.
  • T: A Vector{Int} giving the treatment applied for each observation. We assume the treatments are labelled as integers from 1 to the number of treatments.
  • y: A Vector{Float64} giving the outcome for each obbservation. We adopt the convention that smaller outcomes are better.

Counterfactual estimation

For the test set, the counterfactuals can be imputed with the following function:

cf = getcounterfactuals(X, y, T, impute_method)

X, y, and T are the test data, and impute_method specifies the regression method to use for counterfactual estimation (one of :knn, :random_forest, or :lasso).

cf is a Matrix{Float64} containing the estimated outcome for each observation under each treatment.

Evaluation

Baseline evaluation

Evaluates the outcomes of the current treatments on the data X, y and T using the counterfactuals cf:

baseline_outcomes = evaluatebaseline(cf, X, y, T)

baseline_outcomes is a Vector{Float64} containing the predicted outcome for each observation.

Oracle evaluation

Evaluates the outcomes of the a clairvoyant oracle using the counterfactuals cf:

oracle_outcomes, oracle_prescriptions = evaluateoracle(cf, allowed_prescriptions)

oracle_outcomes is a Vector{Float64} containing the predicted outcome for each observation. oracle_prescriptions is a Vector{Int} containing the prescribed treatment for each observation.

allowed_prescriptions is an optional argument that allows you to specify for each observation the set of allowed treatments to make sure any prescription rules are respected by the oracle. If not specified, we assume all treatments are available for all observations.

Prescription evaluation

Evaluates the outcomes of prescriptions using the counterfactuals cf:

predicted_outcomes = evaluateprescriptions(cf, prescriptions)

predicted_outcomes is a Vector{Float64} containing the predicted outcome for each observation.

Regress-and-compare methods

The package also has simple utility functions for applying regress-and-compare methods to prescription problems of the form described.

Training

Train the regress-and-compare method (one of :knn, :random_forest, or :lasso) on training data train_X, train_y and train_T and predict the outcomes for each treatment on testing data test_X:

outcomes = getoutcomes(train_X, train_y, train_T, test_X, method)

outcomes is a Matrix{Float64} containing the predicted outcome from the regress-and-compare for each treatment and observation pair.

Prescribing

Make prescriptions from the predicted regress-and-compare outcomes subject to the allowed_prescriptions (see 'Oracle evaluation'):

prescriptions = makeprescriptions(outcomes, allowed_prescriptions)

prescriptions is a Vector{Int} containing the prescribed treatment for each observation.

Complete example

Assume we have training and testing data: train_X, train_y, train_T, test_X, test_y and test_T

We can estimate the counterfactual outcomes on the test set with kNN using:

cf = getcounterfactuals(test_X, test_y, test_T, :knn)

We can get the baseline and oracle outcomes to get lower and upper bounds on performance:

baseline_outcomes = evaluatebaseline(cf, test_X, test_y, test_T)
oracle_outcomes, oracle_prescriptions = evaluateoracle(cf)

Now we can compare the various regress-and-compare methods. First kNN:

knn = getoutcomes(train_X, train_y, train_T, test_X, :knn)
knn_prescriptions = makeprescriptions(knn)
knn_outcomes = evaluateprescriptions(cf, knn_prescriptions)

Similar for random forests

rf = getoutcomes(train_X, train_y, train_T, test_X, :randomforest)
rf_prescriptions = makeprescriptions(rf)
rf_outcomes = evaluateprescriptions(cf, rf_prescriptions)

And for lasso regression

lasso = getoutcomes(train_X, train_y, train_T, test_X, :lasso)
lasso_prescriptions = makeprescriptions(lasso)
lasso_outcomes = evaluateprescriptions(cf, lasso_prescriptions)

This gives a vector of estimated outcomes on the test set for the baseline, oracle, and each of the three regress-and-compare methods. We can now use whatever metrics we want to compare these approaches.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages