#### Import reina and other necessary libraries. Initialize a spark session.

In [17]:
from reina.metalearners import SLearner
from reina.metalearners import TLearner
from reina.metalearners import XLearner
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.classification import RandomForestClassifier
from pyspark.sql import SparkSession

# Initialize spark session
spark = SparkSession \
            .builder \
            .appName('Meta-Learner-Spark') \
            .getOrCreate()

#### Read toy data. Replace .load() with the test_data.csv location -- this location could be a local one (no cluster) or it could be on a distributed storage system (e.g., HDFS)

*Note: Code below assumes data generated by our script (for specifics, please refer to our toy data generation in the README). You could also modify the code accordingly to use your own data.*

In [None]:
df = spark.read \
          .format("csv") \
          .option('header', 'true') \
          .load("test_data.csv")  # replace with the location of test_data.csv

# Case variables to appropriate types
df = df.withColumn("var1", df.var1.cast("float"))
df = df.withColumn("var2", df.var2.cast("float"))
df = df.withColumn("var3", df.var3.cast("float"))
df = df.withColumn("var4", df.var4.cast("float"))
df = df.withColumn("var5", df.var5.cast("float"))
df = df.withColumn("treatment", df.treatment.cast("float"))
df = df.withColumn("outcome", df.outcome.cast("float"))

# Drop garbage column
df = df.drop("_c0")

# Print out dataframe schema
print(df.schema)

## S-leaner

In [None]:
# Set up necessary parameters
treatments = ['treatment']
outcome = 'outcome'

# Arbitrary estimator. Can replace with other ML algo.
estimator = RandomForestRegressor()

# Fit S-learner
spark_slearner = SLearner()
spark_slearner.fit(data=df, treatments=treatments, outcome=outcome, estimator=estimator)

# Get heterogeneous treatment effects (cate for individual samples and ate for averaged treatment effect)
cate, ate = spark_slearner.effects()
print(cate)
print(ate)

## T-leaner

In [None]:
# Set up necessary parameters
treatments = ['treatment']
outcome = 'outcome'

# Arbitrary estimators. Can replace with other ML algo.
estimator_1 = RandomForestRegressor()
estimator_0 = RandomForestRegressor()

# Fit T-learner
spark_tlearner = TLearner()
spark_tlearner.fit(data=df, treatments=treatments, outcome=outcome,
                   estimator_0=estimator_0, estimator_1=estimator_1)

# Get heterogeneous treatment effects (cate for individual samples and ate for averaged treatment effect)
cate, ate = spark_tlearner.effects()
print(cate)
print(ate)

## X-leaner

In [None]:
# Set up necessary parameters
treatments = ['treatment']
outcome = 'outcome'

# Arbitrary estimators. Can replace with other ML algo.
estimator_11 = RandomForestRegressor()
estimator_10 = RandomForestRegressor()
estimator_21 = RandomForestRegressor()
estimator_20 = RandomForestRegressor()
propensity_estimator = RandomForestClassifier()

# Fit X-learner
spark_xlearner = XLearner()
spark_xlearner.fit(data=df, treatments=treatments, outcome=outcome, 
                       estimator_10=estimator_10, estimator_11=estimator_11, 
                       estimator_20=estimator_20, estimator_21=estimator_21,
                       propensity_estimator=propensity_estimator)

# Get heterogeneous treatment effects (cate for individual samples and ate for averaged treatment effect)
cate, ate = spark_xlearner.effects()
print(cate)
print(ate)