In [None]:
from reina.metalearner import SparkSLearner
from reina.metalearner import SparkTLearner
from reina.metalearner import SparkXLearner
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.classification import RandomForestClassification

In [None]:
# Note: This notebook assumes that the test_data.csv is already stored in hdfs
# Read generated toy data from Hadoop HDFS
df = (spark.read
          .format("csv")
          .option('header', 'true')
          .load("/test_data.csv"))
df = df.withColumn("var1", df.var1.cast("float"))
df = df.withColumn("var2", df.var2.cast("float"))
df = df.withColumn("var3", df.var3.cast("float"))
df = df.withColumn("var4", df.var4.cast("float"))
df = df.withColumn("var5", df.var5.cast("float"))
df = df.withColumn("treatment", df.treatment.cast("float"))
df = df.withColumn("outcome", df.outcome.cast("float"))
df = df.drop("_c0")
print(df.schema)

#### S-learner

In [None]:
# Set up necessary parameters
treatments = ['treatment']
outcome = 'outcome'
estimator = RandomForestRegressor()

# Fit S-learner
spark_slearner = SparkSLearner()
spark_slearner.fit(data=df, treatments=treatments, outcome=outcome, estimator=estimator)

# Get heterogeneous treatment effects (cate for individual samples and ate for averaged treatment effect)
cate, ate = spark_slearner.effects()
print(cate)
print(ate)

#### T-learner

In [None]:
# Set up necessary parameters
treatments = ['treatment']
outcome = 'outcome'
estimator_1 = RandomForestRegressor()
estimator_0 = RandomForestRegressor()

# Fit T-learner
spark_tlearner = SparkTLearner()
spark_tlearner.fit(data=df, treatments=treatments, outcome=outcome,
                   estimator_0=estimator_0, estimator_1=estimator_1)

# Get heterogeneous treatment effects (cate for individual samples and ate for averaged treatment effect)
cate, ate = spark_tlearner.effects()
print(cate)
print(ate)

#### X-learner

In [None]:
# Set up necessary parameters
treatments = ['treatment']
outcome = 'outcome'
estimator_11 = RandomForestRegressor()
estimator_10 = RandomForestRegressor()
estimator_21 = RandomForestRegressor()
estimator_20 = RandomForestRegressor()
propensity_estimator = RandomForestClassifier()

# Fit X-learner
spark_xlearner = SparkXLearner()
spark_xlearner.fit(data=df, treatments=treatments, outcome=outcome, 
                       estimator_10=estimator_10, estimator_11=estimator_11, 
                       estimator_20=estimator_20, estimator_21=estimator_21,
                       propensity_estimator=propensity_estimator)

# Get heterogeneous treatment effects (cate for individual samples and ate for averaged treatment effect)
cate, ate = spark_xlearner.effects()
print(cate)
print(ate)