In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import importlib
from synthetic_data.synthetic_data_for_cate_class import synthetic_data_for_cate
# Force reload of the t_learner module to ensure we're using the latest version
import metalearners.t_learner
importlib.reload(metalearners.t_learner)
from metalearners.t_learner import TLearner
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import LinearRegression

# Set plot style
try:
    # For newer matplotlib versions (>=3.6)
    plt.style.use('seaborn-v0_8-whitegrid')
except ValueError:
    # For older matplotlib versions
    plt.style.use('seaborn-whitegrid')

sns.set_context("notebook", font_scale=1.2)

# Try to import optional dependencies
try:
    from lightgbm import LGBMRegressor
    LIGHTGBM_AVAILABLE = True
except ImportError:
    LIGHTGBM_AVAILABLE = False
    print("LightGBM not available. Some examples will be skipped.")

try:
    from xgboost import XGBRegressor
    XGBOOST_AVAILABLE = True
except ImportError:
    XGBOOST_AVAILABLE = False
    print("XGBoost not available. Some examples will be skipped.")


<div style="font-size: 0.85em;">

# Note on Dependencies

This notebook requires the following packages:
- numpy, matplotlib, seaborn: For data manipulation and visualization
- scikit-learn: For machine learning models
- lightgbm, xgboost (optional): For gradient boosting models

The notebook also uses local packages:
- synthetic_data.synthetic_data_for_cate: For generating synthetic data with enhanced heterogeneity
- metalearners.t_learner: For the TLearner implementation

Make sure you have installed all required dependencies before running this notebook.
</div>


<div style="font-size: 0.85em;">

# Synthetic Data Generation

We generate synthetic data with known heterogeneous treatment effects using the `synthetic_data_for_cate` class. This function creates:

- A feature matrix with 5 covariates (by default)
- A binary treatment indicator (1=treated, 0=control)
- An outcome variable that depends on both covariates and treatment

The data generation process includes:
- Non-linear confounding (treatment assignment depends on X1 and X2)
- Heterogeneous treatment effects (effects vary based on all covariates)
- Non-linear baseline effects (outcome depends non-linearly on covariates)
- Heteroskedastic noise (noise level varies with X1)

This synthetic data allows us to evaluate the performance of different CATE estimation methods, as we know the true treatment effects.
</div>


In [None]:
# Create an instance of synthetic_data_for_cate with model2
data_generator = synthetic_data_for_cate(model_type='model2')

# Generate synthetic data with heterogeneous treatment effects
features, treatment_vector, outcomes = data_generator.get_synthetic_data()

# Print basic information about the generated data
print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
print(f"Outcome mean: {outcomes.mean():.2f}")


<div style="font-size: 0.85em;">

# Model Initialization and Fitting

We initialize several T-Learner models with different base learners:

1. **Random Forest**: A non-parametric model that can capture complex non-linear relationships
2. **Linear Regression**: A simple parametric model that assumes linear relationships
3. **Gradient Boosting**: An ensemble method that builds trees sequentially to correct errors
4. **LightGBM** (if available): A gradient boosting framework that uses tree-based learning
5. **XGBoost** (if available): Another gradient boosting framework with different implementation

Each model has different strengths and weaknesses for CATE estimation. For this demonstration, we'll use the Random Forest-based T-Learner as our primary model.

The fitting process involves:
1. Splitting the data into treatment and control groups
2. Training separate models for each group
3. Using these models to predict outcomes under different treatment conditions
</div>


In [None]:
# Initialize T-Learner with different base models
rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
linear_model = LinearRegression()
gb_model = GradientBoostingRegressor(n_estimators=100, random_state=42)

# Create T-Learner instances with different models
t_learner_rf = TLearner(rf_model)
t_learner_linear = TLearner(linear_model)
t_learner_gb = TLearner(gb_model)

# Create LightGBM model if available
if 'LIGHTGBM_AVAILABLE' in globals() and LIGHTGBM_AVAILABLE:
    lgbm_model = LGBMRegressor(random_state=42)
    t_learner_lgbm = TLearner(lgbm_model)
    # Uncomment to use LightGBM model
    # tl = t_learner_lgbm

# Create XGBoost model if available
if 'XGBOOST_AVAILABLE' in globals() and XGBOOST_AVAILABLE:
    xgb_model = XGBRegressor(random_state=42)
    t_learner_xgb = TLearner(xgb_model)
    # Uncomment to use XGBoost model
    # tl = t_learner_xgb

# Select Random Forest as our primary model
tl = t_learner_rf

# Fit the model
print("Fitting T-Learner with Random Forest...")
tl.fit(
    X=features,             # Feature matrix
    t=treatment_vector,     # Treatment assignments (0/1)
    y=outcomes              # Observed outcomes
)
print("Model fitting complete.")


<div style="font-size: 0.85em;">

# CATE Estimation

After fitting the model, we can estimate the Conditional Average Treatment Effect (CATE) for each individual in our dataset. The CATE represents how much the treatment is expected to affect the outcome for an individual with specific characteristics.

The CATE estimation process for T-Learner involves:
1. Using the treatment model to predict outcomes if treated
2. Using the control model to predict outcomes if not treated
3. Taking the difference between these predictions to get the CATE

This gives us an estimate of the treatment effect for each individual, conditional on their covariates.
</div>


In [None]:
# Estimate treatment effects (CATE)
cate = tl.effect(X=features)

# Print basic statistics about the estimated CATE
print(f"CATE Statistics:")
print(f"  Mean: {cate.mean():.4f}")
print(f"  Std Dev: {cate.std():.4f}")
print(f"  Min: {cate.min():.4f}")
print(f"  Max: {cate.max():.4f}")


<div style="font-size: 0.85em;">

# CATE Distribution Visualization

Visualizing the distribution of estimated treatment effects helps us understand the heterogeneity in treatment effects across the population. This can reveal:

1. **Average Effect**: The center of the distribution shows the average treatment effect
2. **Effect Heterogeneity**: The spread of the distribution shows how much treatment effects vary
3. **Subgroups**: Multiple peaks might indicate distinct subgroups with different responses
4. **Negative/Positive Effects**: The proportion of individuals with negative vs. positive effects

A narrow distribution suggests homogeneous treatment effects, while a wide distribution suggests high heterogeneity. Skewness in the distribution might indicate that certain types of individuals benefit more or less from the treatment.
</div>


In [None]:
# Calculate true treatment effects using the data generator
true_cate = data_generator.get_true_cate(features)

# Create a DataFrame for easier plotting with seaborn
import pandas as pd
cate_df = pd.DataFrame({
    'Estimated CATE': cate,
    'True CATE': true_cate
})

# Create images directory if it doesn't exist
import os
os.makedirs('images', exist_ok=True)

# Plot the distribution of estimated CATE
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE
sns.histplot(data=cate_df['Estimated CATE'], kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean values
plt.axvline(x=cate.mean(), color='blue', linestyle='--', linewidth=2,
            label=f'Mean Estimated CATE: {cate.mean():.4f}')
plt.axvline(x=true_cate.mean(), color='orange', linestyle='--', linewidth=2,
            label=f'Mean True CATE: {true_cate.mean():.4f}')

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of Estimated CATE - T-Learner', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects for estimated CATE
est_pos_pct = (cate > 0).mean() * 100
est_neg_pct = (cate < 0).mean() * 100
plt.annotate(f'Estimated Positive Effects: {est_pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Estimated Negative Effects: {est_neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

# Show percentage of positive and negative effects for true CATE
true_pos_pct = (true_cate > 0).mean() * 100
true_neg_pct = (true_cate < 0).mean() * 100
plt.annotate(f'True Positive Effects: {true_pos_pct:.1f}%',
             xy=(0.68, 0.80), xycoords='axes fraction', fontsize=11)
plt.annotate(f'True Negative Effects: {true_neg_pct:.1f}%',
             xy=(0.68, 0.75), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/t_learner_cate_distribution.png', dpi=300, bbox_inches='tight')
print("CATE distribution plot saved to images/t_learner_cate_distribution.png")
plt.show()


<div style="font-size: 0.85em;">

# Overlaid CATE Distributions: Estimated vs. True

To better compare the estimated and true CATE distributions, we can overlay them on the same plot using different colors. This visualization allows us to:

1. **Directly Compare Shapes**: See how closely the estimated distribution matches the true distribution
2. **Identify Discrepancies**: Spot areas where the model over- or under-estimates treatment effects
3. **Assess Heterogeneity Capture**: Determine if the model captures the true heterogeneity in treatment effects
4. **Evaluate Peaks and Modes**: Compare the peaks and modes of both distributions

The plot below shows both distributions with different colors, allowing for a direct visual comparison of their shapes, centers, and spreads.
</div>


In [None]:
# Create a plot that overlays both the estimated and true CATE distributions
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE using different colors and line styles
sns.kdeplot(data=cate_df, x='Estimated CATE', color='blue', fill=False, linewidth=2.5, linestyle='-',
            label=f'Estimated CATE (Mean: {cate.mean():.4f})')
sns.kdeplot(data=cate_df, x='True CATE', color='red', fill=False, linewidth=2.5, linestyle='--',
            label=f'True CATE (Mean: {true_cate.mean():.4f})')

# Add vertical lines for mean values
plt.axvline(x=cate.mean(), color='blue', linestyle='-', linewidth=2)
plt.axvline(x=true_cate.mean(), color='red', linestyle='--', linewidth=2)

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1, label='No Effect')

# Add annotations
plt.title('Comparison of Estimated vs. True CATE Distributions - T-Learner', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

# Calculate and display statistics
est_std = cate.std()
true_std = true_cate.std()

# Calculate correlation and mean absolute error
correlation = np.corrcoef(true_cate, cate)[0, 1]
mae = np.mean(np.abs(true_cate - cate))

# Add statistics annotations
plt.annotate(f'Estimated CATE Std: {est_std:.4f}', xy=(0.05, 0.95), xycoords='axes fraction', fontsize=10)
plt.annotate(f'True CATE Std: {true_std:.4f}', xy=(0.05, 0.90), xycoords='axes fraction', fontsize=10)
plt.annotate(f'Mean Absolute Error: {mae:.4f}', xy=(0.05, 0.85), xycoords='axes fraction', fontsize=10)
plt.annotate(f'Correlation: {correlation:.4f}', xy=(0.05, 0.80), xycoords='axes fraction', fontsize=10)

plt.tight_layout()

# Save and display plot
plt.savefig('images/t_learner_cate_distributions_comparison.png', dpi=300, bbox_inches='tight')
print("CATE distributions comparison plot saved to images/t_learner_cate_distributions_comparison.png")
plt.show()


<div style="font-size: 0.85em;">

# CATE Accuracy Evaluation: Predicted vs. Actual Treatment Effects

To evaluate the accuracy of our T-Learner model, we can compare the predicted CATE values with the actual treatment effects. This comparison helps us:

1. **Assess Model Accuracy**: How well does the model recover the true treatment effects?
2. **Identify Patterns in Errors**: Are there systematic biases in the predictions?
3. **Compare Treatment Groups**: Do predictions differ in accuracy between treated and control groups?
4. **Detect Outliers**: Are there individuals for whom the model predictions are particularly inaccurate?

The scatter plot below shows:
- Predicted CATE values on the y-axis
- True treatment effects on the x-axis
- Different symbols for treated and control groups
- A diagonal line representing perfect prediction (y=x)

Points close to the diagonal line indicate accurate predictions, while deviations suggest estimation errors.
</div>


In [None]:
# Create a scatter plot of predicted vs. true treatment effects
plt.figure(figsize=(10, 8))

# Separate treated and control groups for different markers
treated_indices = treatment_vector == 1
control_indices = treatment_vector == 0

# Plot treated group with one marker
plt.scatter(true_cate[treated_indices], cate[treated_indices], 
            marker='^', color='red', alpha=0.6, label='Treated Group')

# Plot control group with another marker
plt.scatter(true_cate[control_indices], cate[control_indices], 
            marker='o', color='blue', alpha=0.6, label='Control Group')

# Add diagonal line (perfect prediction)
min_val = min(true_cate.min(), cate.min())
max_val = max(true_cate.max(), cate.max())
plt.plot([min_val, max_val], [min_val, max_val], 'k--', label='Perfect Prediction')

# Calculate and display correlation coefficient
plt.annotate(f'Correlation: {correlation:.4f}', 
             xy=(0.05, 0.95), xycoords='axes fraction', fontsize=12)

# Calculate and display mean absolute error
plt.annotate(f'Mean Absolute Error: {mae:.4f}', 
             xy=(0.05, 0.90), xycoords='axes fraction', fontsize=12)

# Add labels and title
plt.title('Predicted vs. Actual Treatment Effects - T-Learner', fontsize=14)
plt.xlabel('Actual Treatment Effect', fontsize=12)
plt.ylabel('Predicted Treatment Effect (CATE)', fontsize=12)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)

plt.tight_layout()

# Save the plot as a PNG file
plt.savefig('images/t_learner_cate_accuracy_evaluation.png', dpi=300, bbox_inches='tight')
print("CATE accuracy evaluation plot saved to images/t_learner_cate_accuracy_evaluation.png")

plt.show()


<div style="font-size: 0.85em;">

# Comparing T-Learner with Other Meta-Learners

The T-Learner is one of several meta-learner approaches for estimating heterogeneous treatment effects. Each approach has its own strengths and weaknesses:

## T-Learner vs. S-Learner
- **T-Learner**: Uses separate models for treatment and control groups
- **S-Learner**: Uses a single model with treatment as a feature
- **When to choose T-Learner**: When treatment and control groups have very different response surfaces
- **When to choose S-Learner**: When sample sizes are small or when treatment effects are relatively homogeneous

## T-Learner vs. X-Learner
- **T-Learner**: Directly estimates outcomes for each group
- **X-Learner**: Uses a multi-stage approach with imputed treatment effects
- **When to choose T-Learner**: When treatment and control groups are balanced
- **When to choose X-Learner**: When treatment and control groups are imbalanced

## T-Learner vs. R-Learner
- **T-Learner**: Doesn't model the propensity score
- **R-Learner**: Uses residualization to separate confounding from treatment effects
- **When to choose T-Learner**: When the propensity model is difficult to estimate
- **When to choose R-Learner**: When strong confounding is present

In practice, it's often valuable to try multiple meta-learners and compare their results, as different approaches may perform better in different scenarios.
</div>

<div style="font-size: 0.85em;">

# Library Imports

We import the necessary libraries for this demonstration:

- **synthetic_data_for_cate2**: Custom module for generating synthetic data with enhanced heterogeneity for treatment effects
- **TLearner**: Our implementation of the T-Learner meta-learner
- **sklearn models**: Various regression models to use as base learners
- **matplotlib/seaborn**: For visualization
- **numpy**: For numerical operations

We also attempt to import optional dependencies (LightGBM and XGBoost) which provide additional base learners.
</div>


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from synthetic_data.synthetic_data_for_cate_class import synthetic_data_for_cate
from metalearners.t_learner import TLearner
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression

# Set plot style
try:
    # For newer matplotlib versions (>=3.6)
    plt.style.use('seaborn-v0_8-whitegrid')
except ValueError:
    # For older matplotlib versions
    plt.style.use('seaborn-whitegrid')

sns.set_context("notebook", font_scale=1.2)

# Try to import optional dependencies
try:
    from lightgbm import LGBMRegressor
    LIGHTGBM_AVAILABLE = True
except ImportError:
    LIGHTGBM_AVAILABLE = False
    print("LightGBM not available. Some examples will be skipped.")

try:
    from xgboost import XGBRegressor
    XGBOOST_AVAILABLE = True
except ImportError:
    XGBOOST_AVAILABLE = False
    print("XGBoost not available. Some examples will be skipped.")


<div style="font-size: 0.85em;">

# Note on Dependencies

This notebook requires the following packages:
- numpy, matplotlib, seaborn: For data manipulation and visualization
- scikit-learn: For machine learning models
- lightgbm, xgboost (optional): For gradient boosting models

The notebook also uses local packages:
- synthetic_data.synthetic_data_for_cate2: For generating synthetic data with enhanced heterogeneity
- metalearners.t_learner: For the TLearner implementation

Make sure you have installed all required dependencies before running this notebook.
</div>


<div style="font-size: 0.85em;">

# Synthetic Data Generation

We generate synthetic data with known heterogeneous treatment effects using the `generate_synthetic_data_for_cate2()` function. This function creates:

- A feature matrix with 5 covariates (by default)
- A binary treatment indicator (1=treated, 0=control)
- An outcome variable that depends on both covariates and treatment

The data generation process includes:
- Non-linear confounding (treatment assignment depends on X1 and X2)
- Heterogeneous treatment effects (effects vary based on all covariates)
- Non-linear baseline effects (outcome depends non-linearly on covariates)
- Heteroskedastic noise (noise level varies with X1)

This synthetic data allows us to evaluate the performance of different CATE estimation methods, as we know the true treatment effects.
</div>


In [None]:
# Create an instance of synthetic_data_for_cate with model2
data_generator = synthetic_data_for_cate(model_type='model2')

# Generate synthetic data with heterogeneous treatment effects
features, treatment_vector, outcomes = data_generator.get_synthetic_data()

# Print basic information about the generated data
print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
print(f"Outcome mean: {outcomes.mean():.2f}")


<div style="font-size: 0.85em;">

# Model Initialization and Fitting

We initialize several T-Learner models with different base learners:

1. **Random Forest**: A non-parametric model that can capture complex non-linear relationships
2. **Linear Regression**: A simple parametric model that assumes linear relationships
3. **LightGBM** (if available): A gradient boosting framework that uses tree-based learning
4. **XGBoost** (if available): Another gradient boosting framework with different implementation

Each model has different strengths and weaknesses for CATE estimation. For this demonstration, we'll use the Random Forest-based T-Learner as our primary model.

The fitting process involves:
1. Splitting the data into treatment and control groups
2. Training separate models on each group
</div>


<div style="font-size: 0.85em;">

# Base Learners for T-Learner: Principles, Pros, and Cons

The choice of base learner significantly impacts the performance of T-Learner for CATE estimation. Here we describe the basic ideas, advantages, and limitations of each base learner.

## RandomForestRegressor

### Basic Principles
- **Ensemble Method**: Combines multiple decision trees to improve prediction accuracy and control overfitting
- **Bootstrap Aggregating (Bagging)**: Each tree is trained on a random subset of the data with replacement
- **Feature Randomization**: At each split, only a random subset of features is considered
- **Averaging**: Final prediction is the average of predictions from all trees

### Pros for CATE Estimation
- **Captures Non-linear Relationships**: Can model complex, non-linear treatment effects without explicit specification
- **Handles Interactions**: Automatically captures interactions between treatment and covariates
- **Robust to Outliers**: Less sensitive to extreme values in the data
- **No Distributional Assumptions**: Does not assume normality or homoscedasticity
- **Feature Importance**: Provides insights into which covariates drive treatment effect heterogeneity

### Cons for CATE Estimation
- **Black Box**: Less interpretable than linear models
- **Treatment Variable Importance**: May not give sufficient importance to the treatment indicator
- **Computational Cost**: Training many trees can be computationally intensive
- **Hyperparameter Sensitivity**: Performance depends on proper tuning of hyperparameters
- **Extrapolation Limitations**: May not extrapolate well to regions with sparse data

## LGBMRegressor (LightGBM)

### Basic Principles
- **Gradient Boosting Framework**: Builds trees sequentially, with each tree correcting errors of previous trees
- **Leaf-wise Growth**: Grows trees by leaf-wise (best-first) rather than level-wise growth
- **Histogram-based Learning**: Buckets continuous features into discrete bins to speed up training
- **Gradient-based One-Side Sampling (GOSS)**: Focuses on instances with larger gradients during training

### Pros for CATE Estimation
- **High Performance**: Often achieves state-of-the-art prediction accuracy
- **Efficiency**: Faster training speed and lower memory usage than traditional gradient boosting
- **Handles Large Datasets**: Scales well to datasets with many observations
- **Regularization Options**: Built-in L1/L2 regularization to prevent overfitting
- **Categorical Feature Support**: Native handling of categorical variables

### Cons for CATE Estimation
- **Complex Tuning**: Requires careful hyperparameter tuning for optimal performance
- **Overfitting Risk**: Can overfit on small datasets without proper regularization
- **Less Robust to Noisy Data**: May be more sensitive to noise than Random Forests
- **Treatment Importance**: Like Random Forests, may not give sufficient importance to treatment indicator
- **Installation Challenges**: Requires additional system dependencies

## XGBRegressor (XGBoost)

### Basic Principles
- **Extreme Gradient Boosting**: Enhanced implementation of gradient boosting
- **Regularized Learning**: Includes L1 and L2 regularization terms in the objective function
- **Approximate Greedy Algorithm**: Uses a quantile sketch algorithm to find approximate best splits
- **Sparsity-Aware Split Finding**: Efficiently handles missing values and sparse data

### Pros for CATE Estimation
- **Prediction Accuracy**: Consistently strong performance across many prediction tasks
- **Regularization**: Built-in mechanisms to prevent overfitting
- **Handles Missing Values**: Native support for missing data
- **Parallel Processing**: Efficient computation using multi-threading
- **Cross-validation**: Built-in cross-validation capabilities

### Cons for CATE Estimation
- **Complexity**: Many hyperparameters to tune
- **Black Box Nature**: Limited interpretability of the model
- **Memory Usage**: Can be memory-intensive for large datasets
- **Treatment Variable Importance**: May not prioritize treatment indicator appropriately
- **Installation Requirements**: Depends on additional libraries

## LinearRegression

### Basic Principles
- **Linear Function**: Models outcome as a weighted sum of input features
- **Ordinary Least Squares (OLS)**: Minimizes the sum of squared differences between observed and predicted values
- **Closed-form Solution**: Parameters can be calculated directly without iterative optimization
- **Additive Effects**: Assumes features contribute independently to the outcome

### Pros for CATE Estimation
- **Interpretability**: Clear interpretation of coefficients, including treatment effect
- **Treatment Interaction Modeling**: Can explicitly model interactions between treatment and covariates
- **Computational Efficiency**: Fast to train and make predictions
- **Stability**: Results are stable and reproducible
- **Statistical Inference**: Provides p-values and confidence intervals for treatment effects

### Cons for CATE Estimation
- **Linearity Assumption**: Cannot capture non-linear treatment effects without manual feature engineering
- **Homogeneity Assumption**: Assumes constant error variance across all observations
- **Sensitivity to Outliers**: Outliers can significantly impact coefficient estimates
- **Limited Flexibility**: May underfit complex relationships in the data
- **Collinearity Issues**: Performance degrades with highly correlated features

## Choosing the Right Base Learner

The optimal base learner depends on your specific use case:

- **RandomForestRegressor**: Good default choice that balances flexibility and robustness
- **LGBMRegressor/XGBRegressor**: Best for large datasets where prediction accuracy is paramount
- **LinearRegression**: Ideal when interpretability is critical or when the treatment effect is known to be linear

For complex, heterogeneous treatment effects, tree-based methods (Random Forest, LightGBM, XGBoost) typically outperform linear models. However, linear models offer better interpretability and explicit modeling of treatment interactions.
</div>


In [None]:
# Initialize T-Learner with different base models
t_learner_rf = TLearner(RandomForestRegressor(n_estimators=100, random_state=42))
t_learner_linear = TLearner(LinearRegression())

# Create LightGBM model if available
if 'LIGHTGBM_AVAILABLE' in globals() and LIGHTGBM_AVAILABLE:
    t_learner_lgbm = TLearner(LGBMRegressor(random_state=42))
    # Uncomment to use LightGBM model
    # tl = t_learner_lgbm

# Create XGBoost model if available
if 'XGBOOST_AVAILABLE' in globals() and XGBOOST_AVAILABLE:
    t_learner_xgb = TLearner(XGBRegressor(random_state=42))
    # Uncomment to use XGBoost model
    # tl = t_learner_xgb

# Select Random Forest as our primary model
tl = t_learner_rf

# Generate synthetic data with heterogeneous treatment effects if not already defined
if 'features' not in globals():
    # Create an instance of synthetic_data_for_cate with model2
    data_generator = synthetic_data_for_cate(model_type='model2')

    # Generate synthetic data
    features, treatment_vector, outcomes = data_generator.get_synthetic_data()
    print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
    print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
    print(f"Outcome mean: {outcomes.mean():.2f}")
else:
    # If features already exist, create a data generator to calculate true CATE
    data_generator = synthetic_data_for_cate(model_type='model2')

# Fit the model
print("Fitting T-Learner with Random Forest...")
tl.fit(
    X=features,             # Feature matrix
    t=treatment_vector,     # Treatment assignments (0/1)
    y=outcomes              # Observed outcomes
)
print("Model fitting complete.")


<div style="font-size: 0.85em;">

# CATE Estimation

After fitting the model, we can estimate the Conditional Average Treatment Effect (CATE) for each individual in our dataset. The CATE represents how much the treatment is expected to affect the outcome for an individual with specific characteristics.

The CATE estimation process involves:
1. Creating two copies of each individual's features, one with treatment=1 and one with treatment=0
2. Predicting outcomes for both scenarios
3. Computing the difference between the predicted outcomes (treatment - control)

This gives us an estimate of the treatment effect for each individual, conditional on their covariates.
</div>


In [None]:
# Generate synthetic data with heterogeneous treatment effects if not already defined
if 'features' not in globals():
    # Create an instance of synthetic_data_for_cate with model2
    data_generator = synthetic_data_for_cate(model_type='model2')

    # Generate synthetic data
    features, treatment_vector, outcomes = data_generator.get_synthetic_data()
    print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
    print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
    print(f"Outcome mean: {outcomes.mean():.2f}")
else:
    # If features already exist, create a data generator to calculate true CATE
    data_generator = synthetic_data_for_cate(model_type='model2')

# Estimate treatment effects (CATE)
cate = tl.effect(X=features)

# Print basic statistics about the estimated CATE
print(f"CATE Statistics:")
print(f"  Mean: {cate.mean():.4f}")
print(f"  Std Dev: {cate.std():.4f}")
print(f"  Min: {cate.min():.4f}")
print(f"  Max: {cate.max():.4f}")


<div style="font-size: 0.85em;">

# CATE Distribution Visualization

Visualizing the distribution of estimated treatment effects helps us understand the heterogeneity in treatment effects across the population. This can reveal:

1. **Average Effect**: The center of the distribution shows the average treatment effect
2. **Effect Heterogeneity**: The spread of the distribution shows how much treatment effects vary
3. **Subgroups**: Multiple peaks might indicate distinct subgroups with different responses
4. **Negative/Positive Effects**: The proportion of individuals with negative vs. positive effects

A narrow distribution suggests homogeneous treatment effects, while a wide distribution suggests high heterogeneity. Skewness in the distribution might indicate that certain types of individuals benefit more or less from the treatment.
</div>


In [None]:
# Calculate true treatment effects using the data generator
# Generate synthetic data with heterogeneous treatment effects if not already defined
if 'features' not in globals():
    # Create an instance of synthetic_data_for_cate with model2
    data_generator = synthetic_data_for_cate(model_type='model2')

    # Generate synthetic data
    features, treatment_vector, outcomes = data_generator.get_synthetic_data()
    print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
    print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
    print(f"Outcome mean: {outcomes.mean():.2f}")
else:
    # If features already exist, create a data generator to calculate true CATE
    data_generator = synthetic_data_for_cate(model_type='model2')

# Calculate true treatment effects using the data generator
true_cate = data_generator.get_true_cate(features)

# Separate treated and control groups
treated_indices = treatment_vector == 1
control_indices = treatment_vector == 0

# Create DataFrames for each group
import pandas as pd

treated_df = pd.DataFrame({
    'Estimated CATE': cate[treated_indices],
    'True CATE': true_cate[treated_indices]
})

control_df = pd.DataFrame({
    'Estimated CATE': cate[control_indices],
    'True CATE': true_cate[control_indices]
})

# Create images directory if it doesn't exist
import os

os.makedirs('images', exist_ok=True)

# Plot for Treatment Group
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE
sns.histplot(data=treated_df['Estimated CATE'], kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean values
plt.axvline(x=treated_df['Estimated CATE'].mean(), color='blue', linestyle='--', linewidth=2,
            label=f'Mean Estimated CATE: {treated_df["Estimated CATE"].mean():.4f}')
plt.axvline(x=treated_df['True CATE'].mean(), color='orange', linestyle='--', linewidth=2,
            label=f'Mean True CATE: {treated_df["True CATE"].mean():.4f}')

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of True vs Estimated CATE - Treatment Group', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects for estimated CATE
est_pos_pct = (treated_df['Estimated CATE'] > 0).mean() * 100
est_neg_pct = (treated_df['Estimated CATE'] < 0).mean() * 100
plt.annotate(f'Estimated Positive Effects: {est_pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Estimated Negative Effects: {est_neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

# Show percentage of positive and negative effects for true CATE
true_pos_pct = (treated_df['True CATE'] > 0).mean() * 100
true_neg_pct = (treated_df['True CATE'] < 0).mean() * 100
plt.annotate(f'True Positive Effects: {true_pos_pct:.1f}%',
             xy=(0.68, 0.80), xycoords='axes fraction', fontsize=11)
plt.annotate(f'True Negative Effects: {true_neg_pct:.1f}%',
             xy=(0.68, 0.75), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/t_learner_cate_distribution_treatment.png', dpi=300, bbox_inches='tight')
print("Treatment group CATE distribution plot saved to images/t_learner_cate_distribution_treatment.png")
plt.show()

# Plot for Control Group
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE
sns.histplot(data=control_df['Estimated CATE'], kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean values
plt.axvline(x=control_df['Estimated CATE'].mean(), color='blue', linestyle='--', linewidth=2,
            label=f'Mean Estimated CATE: {control_df["Estimated CATE"].mean():.4f}')
plt.axvline(x=control_df['True CATE'].mean(), color='orange', linestyle='--', linewidth=2,
            label=f'Mean True CATE: {control_df["True CATE"].mean():.4f}')

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of True vs Estimated CATE - Control Group', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects for estimated CATE
est_pos_pct = (control_df['Estimated CATE'] > 0).mean() * 100
est_neg_pct = (control_df['Estimated CATE'] < 0).mean() * 100
plt.annotate(f'Estimated Positive Effects: {est_pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Estimated Negative Effects: {est_neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

# Show percentage of positive and negative effects for true CATE
true_pos_pct = (control_df['True CATE'] > 0).mean() * 100
true_neg_pct = (control_df['True CATE'] < 0).mean() * 100
plt.annotate(f'True Positive Effects: {true_pos_pct:.1f}%',
             xy=(0.68, 0.80), xycoords='axes fraction', fontsize=11)
plt.annotate(f'True Negative Effects: {true_neg_pct:.1f}%',
             xy=(0.68, 0.75), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/t_learner_cate_distribution_control.png', dpi=300, bbox_inches='tight')
print("Control group CATE distribution plot saved to images/t_learner_cate_distribution_control.png")
plt.show()

# Combined plot (original)
plt.figure(figsize=(12, 7))

# Create a DataFrame for easier plotting with seaborn
cate_df = pd.DataFrame({
    'Estimated CATE': cate,
    'True CATE': true_cate
})

# Plot both distributions with KDE
sns.histplot(data=cate_df['Estimated CATE'], kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean values
plt.axvline(x=cate.mean(), color='blue', linestyle='--', linewidth=2,
            label=f'Mean Estimated CATE: {cate.mean():.4f}')
plt.axvline(x=true_cate.mean(), color='orange', linestyle='--', linewidth=2,
            label=f'Mean True CATE: {true_cate.mean():.4f}')

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of True vs Estimated CATE - All Groups', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects for estimated CATE
est_pos_pct = (cate > 0).mean() * 100
est_neg_pct = (cate < 0).mean() * 100
plt.annotate(f'Estimated Positive Effects: {est_pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Estimated Negative Effects: {est_neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

# Show percentage of positive and negative effects for true CATE
true_pos_pct = (true_cate > 0).mean() * 100
true_neg_pct = (true_cate < 0).mean() * 100
plt.annotate(f'True Positive Effects: {true_pos_pct:.1f}%',
             xy=(0.68, 0.80), xycoords='axes fraction', fontsize=11)
plt.annotate(f'True Negative Effects: {true_neg_pct:.1f}%',
             xy=(0.68, 0.75), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/t_learner_cate_distribution.png', dpi=300, bbox_inches='tight')
print("Combined CATE distribution plot saved to images/t_learner_cate_distribution.png")
plt.show()


<div style="font-size: 0.85em;">

# Overlaid CATE Distributions: Estimated vs. True

To better compare the estimated and true CATE distributions, we can overlay them on the same plot using different colors. This visualization allows us to:

1. **Directly Compare Shapes**: See how closely the estimated distribution matches the true distribution
2. **Identify Discrepancies**: Spot areas where the model over- or under-estimates treatment effects
3. **Assess Heterogeneity Capture**: Determine if the model captures the true heterogeneity in treatment effects
4. **Evaluate Peaks and Modes**: Compare the peaks and modes of both distributions

The plot below shows both distributions with different colors, allowing for a direct visual comparison of their shapes, centers, and spreads.
</div>


In [None]:
# Create a plot that overlays both the estimated and true CATE distributions
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE using different colors and line styles
sns.kdeplot(data=cate_df, x='Estimated CATE', color='blue', fill=False, linewidth=2.5, linestyle='-',
            label=f'Estimated CATE (Mean: {cate.mean():.4f})')
sns.kdeplot(data=cate_df, x='True CATE', color='red', fill=False, linewidth=2.5, linestyle='--',
            label=f'True CATE (Mean: {true_cate.mean():.4f})')

# Add vertical lines for mean values
plt.axvline(x=cate.mean(), color='blue', linestyle='-', linewidth=2)
plt.axvline(x=true_cate.mean(), color='red', linestyle='--', linewidth=2)

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1, label='No Effect')

# Add annotations
plt.title('Comparison of Estimated vs. True CATE Distributions', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

# Calculate and display statistics
est_std = cate.std()
true_std = true_cate.std()

# Calculate correlation and mean absolute error
correlation = np.corrcoef(true_cate, cate)[0, 1]
mae = np.mean(np.abs(true_cate - cate))

# Safely calculate KL divergence approximation to avoid division by zero and log of negative values
# Only consider positive values and avoid division by zero
valid_indices = (cate > 0) & (true_cate > 0)
if np.any(valid_indices):
    kl_div = np.sum(cate[valid_indices] * np.log(cate[valid_indices] / true_cate[valid_indices]))
else:
    kl_div = np.nan

# Add statistics annotations
plt.annotate(f'Estimated CATE Std: {est_std:.4f}', xy=(0.05, 0.95), xycoords='axes fraction', fontsize=10)
plt.annotate(f'True CATE Std: {true_std:.4f}', xy=(0.05, 0.90), xycoords='axes fraction', fontsize=10)
plt.annotate(f'Mean Absolute Error: {mae:.4f}', xy=(0.05, 0.85), xycoords='axes fraction', fontsize=10)
plt.annotate(f'Correlation: {correlation:.4f}', xy=(0.05, 0.80), xycoords='axes fraction', fontsize=10)

plt.tight_layout()

# Save and display plot
plt.savefig('images/t_learner_cate_distributions_comparison.png', dpi=300, bbox_inches='tight')
print("CATE distributions comparison plot saved to images/t_learner_cate_distributions_comparison.png")
plt.show()


<div style="font-size: 0.85em;">

# CATE Accuracy Evaluation: Predicted vs. Actual Treatment Effects

To evaluate the accuracy of our T-Learner model, we can compare the predicted CATE values with the actual treatment effects. This comparison helps us:

1. **Assess Model Accuracy**: How well does the model recover the true treatment effects?
2. **Identify Patterns in Errors**: Are there systematic biases in the predictions?
3. **Compare Treatment Groups**: Do predictions differ in accuracy between treated and control groups?
4. **Detect Outliers**: Are there individuals for whom the model predictions are particularly inaccurate?

The scatter plot below shows:
- Predicted CATE values on the y-axis
- True treatment effects on the x-axis
- Different symbols for treated and control groups
- A diagonal line representing perfect prediction (y=x)

Points close to the diagonal line indicate accurate predictions, while deviations suggest estimation errors.
</div>


In [None]:
# We already calculated the true CATE earlier, so we'll use that

# Create a scatter plot of predicted vs. true treatment effects
plt.figure(figsize=(10, 8))

# Separate treated and control groups for different markers
treated_indices = treatment_vector == 1
control_indices = treatment_vector == 0

# Plot treated group with one marker
plt.scatter(true_cate[treated_indices], cate[treated_indices], 
            marker='^', color='red', alpha=0.6, label='Treated Group')

# Plot control group with another marker
plt.scatter(true_cate[control_indices], cate[control_indices], 
            marker='o', color='blue', alpha=0.6, label='Control Group')

# Add diagonal line (perfect prediction)
min_val = min(true_cate.min(), cate.min())
max_val = max(true_cate.max(), cate.max())
plt.plot([min_val, max_val], [min_val, max_val], 'k--', label='Perfect Prediction')

# Calculate and display correlation coefficient
# Note: correlation is already calculated in the previous cell
plt.annotate(f'Correlation: {correlation:.4f}', 
             xy=(0.05, 0.95), xycoords='axes fraction', fontsize=12)

# Calculate and display mean absolute error
# Note: mae is already calculated in the previous cell
plt.annotate(f'Mean Absolute Error: {mae:.4f}', 
             xy=(0.05, 0.90), xycoords='axes fraction', fontsize=12)

# Add labels and title
plt.title('Predicted vs. Actual Treatment Effects', fontsize=14)
plt.xlabel('Actual Treatment Effect', fontsize=12)
plt.ylabel('Predicted Treatment Effect (CATE)', fontsize=12)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)

plt.tight_layout()

# Save the plot as a PNG file
plt.savefig('images/t_learner_cate_accuracy_evaluation.png', dpi=300, bbox_inches='tight')
print("CATE accuracy evaluation plot saved to images/t_learner_cate_accuracy_evaluation.png")

plt.show()


<div style="font-size: 0.85em;">

# Comparison of T-Learner and S-Learner Performance

Visual inspection of the scatter plots and overlaid distributions in both the T-Learner and S-Learner notebooks reveals a striking similarity in their performance on this synthetic data for 'model2'. Despite their different theoretical approaches, both models achieve comparable results in terms of:

1. **Correlation with True CATE**: Both models show a strong positive correlation with the true CATE values
2. **Mean Absolute Error**: The average magnitude of errors is similar between the two approaches
3. **Distribution Shape**: The estimated CATE distributions have similar shapes and spread
4. **Prediction Patterns**: Both models show similar patterns in their scatter plots, with comparable clustering around the diagonal line

## Why T-Learner Isn't Superior Despite Its Theoretical Advantages

In theory, the T-Learner should have advantages over the S-Learner in certain scenarios:

1. **Separate Models**: By training separate models for treatment and control groups, T-Learner can capture different functional forms for each group
2. **No Treatment Variable Importance Issues**: Unlike S-Learner, T-Learner doesn't rely on the treatment indicator being given sufficient importance
3. **Flexibility**: T-Learner can adapt to different response surfaces in the treatment and control groups

However, for this specific synthetic data ('model2'), these advantages don't translate to superior performance because:

1. **Sufficient Data in Both Groups**: The synthetic data has enough samples in both treatment and control groups, allowing the S-Learner to learn the treatment effect properly
2. **Random Forest Base Learner**: The RandomForestRegressor used in both approaches is flexible enough to capture the complex relationships, even when the treatment is just another feature
3. **Threshold-Based Effects**: The true CATE in 'model2' is based on threshold functions that create distinct subgroups with different effects, which both models can capture equally well
4. **Similar Functional Forms**: The underlying functional forms for treatment and control outcomes may not be different enough to benefit from T-Learner's separate modeling approach

## Implications

This similarity in performance suggests that:

1. **Model Choice Flexibility**: For this type of data, either approach could be used with similar results
2. **Computational Efficiency**: S-Learner might be preferred for its simplicity and computational efficiency (training one model instead of two)
3. **Data-Dependent Performance**: The relative performance of these meta-learners is highly dependent on the specific data characteristics
4. **Need for Model Comparison**: It's valuable to compare multiple approaches rather than relying on theoretical advantages alone

This observation aligns with findings in causal inference literature that no single meta-learner consistently outperforms others across all datasets and scenarios.
</div>
