<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Load-Data" data-toc-modified-id="Load-Data-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Load Data</a></span><ul class="toc-item"><li><span><a href="#Exploration-Data" data-toc-modified-id="Exploration-Data-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Exploration Data</a></span></li><li><span><a href="#Modeling-Data" data-toc-modified-id="Modeling-Data-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Modeling Data</a></span></li></ul></li><li><span><a href="#Data-Exploration" data-toc-modified-id="Data-Exploration-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Data Exploration</a></span></li><li><span><a href="#Standard-Models" data-toc-modified-id="Standard-Models-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Standard Models</a></span><ul class="toc-item"><li><span><a href="#Two-Model" data-toc-modified-id="Two-Model-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Two Model</a></span></li><li><span><a href="#Interaction-Term" data-toc-modified-id="Interaction-Term-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Interaction Term</a></span></li><li><span><a href="#Class-Transformations" data-toc-modified-id="Class-Transformations-3.3"><span class="toc-item-num">3.3&nbsp;&nbsp;</span>Class Transformations</a></span><ul class="toc-item"><li><span><a href="#Binary-Transformation" data-toc-modified-id="Binary-Transformation-3.3.1"><span class="toc-item-num">3.3.1&nbsp;&nbsp;</span>Binary Transformation</a></span></li><li><span><a href="#Quaternary-Transformation" data-toc-modified-id="Quaternary-Transformation-3.3.2"><span class="toc-item-num">3.3.2&nbsp;&nbsp;</span>Quaternary Transformation</a></span></li></ul></li><li><span><a href="#Evaluation" data-toc-modified-id="Evaluation-3.4"><span class="toc-item-num">3.4&nbsp;&nbsp;</span>Evaluation</a></span></li></ul></li><li><span><a href="#Generalized-Random-Forest" data-toc-modified-id="Generalized-Random-Forest-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Generalized Random Forest</a></span><ul class="toc-item"><li><span><a href="#Modeling" data-toc-modified-id="Modeling-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>Modeling</a></span></li><li><span><a href="#Evaluation" data-toc-modified-id="Evaluation-4.2"><span class="toc-item-num">4.2&nbsp;&nbsp;</span>Evaluation</a></span></li></ul></li></ul></div>

In [None]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from causeinfer.data import mayo_pbc
from causeinfer.utilities import plot_unit_distributions
from causeinfer.standard_algorithms import TwoModel
from causeinfer.standard_algorithms import InteractionTerm
from causeinfer.standard_algorithms import BinaryClassTransformation
from causeinfer.standard_algorithms import QuaternaryClassTransformation

pd.set_option("display.max_rows", 16)
pd.set_option('display.max_columns', None)
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:99% !important; }</style>"))

In [None]:
os.getcwd()

In [None]:
head_shape = helpers_py.head_shape

# Load Data

In [None]:
mayo_pbc.download_mayo_pbc()

## Exploration Data

In [None]:
# The full raw dataset is loaded
data_raw = mayo_pbc.load_mayo_pbc(data_path="/datasets/mayo_pbc.text",
                                  load_raw_data=True,
                                  normalize=False)

df_full = pd.DataFrame(data_raw["dataset_full"], 
                       columns=data_raw["dataset_full_names"])

display(df_full.head())
df_full.shape

## Modeling Data

In [None]:
# The formatted dataset is loaded
data_mayo_pbc = mayo_pbc.load_mayo_pbc(data_path="/datasets/mayo_pbc.text",
                                       load_raw_data=False, 
                                       normalize=True)

# Covariates, treatments and responses are loaded separately
X = data_mayo_pbc["features"]
y = data_mayo_pbc["response"]
w = data_mayo_pbc["treatment"]

In [None]:
X_train, X_test, \
y_train, y_test, \
w_train, w_test = train_test_split(X, y, w, percent_train=0.6, random_state=42)

# Data Exploration

In [None]:
sns.set(style="whitegrid")
sns.set(rc={'figure.figsize':(15, 5)})

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True, figsize=(20,5))
plot_unit_distributions(df=df_full, 
                        variable='variable', treatment = None,
                        plot_x_lab='x', plot_y_lab='y', plot_title='title', 
                        fontsize=20, bins=20, axis=ax1),

plot_unit_distributions(df=df_full, 
                        variable='variable', treatment = 'status',
                        plot_x_lab='x', plot_y_lab='y', plot_title='title', 
                        fontsize=20, bins=20, axis=ax2)

# Standard Models

## Two Model

In [None]:
two_model = TwoModel()
two_model.fit(X_train, y_train, w_train)

In [None]:
two_model_effect = two_model.predict(X_test)

## Interaction Term

In [None]:
interaction_term = InteractionTerm()
interaction_term.fit(X_train, y_train, w_train)

In [None]:
interaction_term_effect = interaction_term.predict(X_test)

## Class Transformations

### Binary Transformation

In [None]:
b_transformation = BinaryClassTransformation()
b_transformation.fit(X_train, y_train, w_train)

In [None]:
b_transformation_effect = b_transformation.predict(X_test)

### Quaternary Transformation

In [None]:
q_transformation = QuaternaryClassTransformation()
q_transformation.fit(X_train, y_train, w_train)

In [None]:
q_transformation_effect = q_transformation.predict(X_test)

## Evaluation

# Generalized Random Forest

## Modeling

## Evaluation