<div style="font-size: 0.85em;">

# S-Learner for Conditional Average Treatment Effect (CATE) Estimation

This notebook demonstrates the use of S-Learner for estimating heterogeneous treatment effects. 

## What is S-Learner?

S-Learner (Single-Learner) is a meta-learner approach for estimating Conditional Average Treatment Effects (CATE). The key characteristics of S-Learner are:

1. **Single Model Approach**: Uses a single model for both treatment and control groups
2. **Treatment as Feature**: Includes treatment assignment as a feature in the model
3. **Simplicity**: Simpler implementation compared to other meta-learners
4. **Flexible Base Learners**: Can use any regression model that follows scikit-learn's API
5. **Interaction Modeling**: Implicitly models interactions between treatment and covariates

## How S-Learner Works

The S-Learner approach follows these steps:
1. Combine treatment indicator with features into a single feature matrix
2. Train a single model to predict outcomes using this combined feature matrix
3. Estimate CATE by predicting outcomes with treatment set to 1 and 0, then taking the difference

The main advantage of S-Learner is its simplicity and ability to handle small sample sizes. However, it may not perform as well as other meta-learners when treatment effects are highly heterogeneous or when propensity scores vary significantly across the population.
</div>


<div style="font-size: 0.85em;">

# Library Imports

We import the necessary libraries for this demonstration:

- **synthetic_data_for_cate**: Custom module for generating synthetic data with enhanced heterogeneity for treatment effects
- **SLearner**: Our implementation of the S-Learner meta-learner
- **sklearn models**: Various regression models to use as base learners
- **matplotlib/seaborn**: For visualization
- **numpy**: For numerical operations

We also attempt to import optional dependencies (LightGBM and XGBoost) which provide additional base learners.
</div>


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import importlib
from synthetic_data.synthetic_data_for_cate_class import synthetic_data_for_cate
# Force reload of the s_learner module to ensure we're using the latest version
import metalearners.s_learner
importlib.reload(metalearners.s_learner)
from metalearners.s_learner import SLearner
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import LinearRegression

# Set plot style
try:
    # For newer matplotlib versions (>=3.6)
    plt.style.use('seaborn-v0_8-whitegrid')
except ValueError:
    # For older matplotlib versions
    plt.style.use('seaborn-whitegrid')

sns.set_context("notebook", font_scale=1.2)

# Try to import optional dependencies
try:
    from lightgbm import LGBMRegressor
    LIGHTGBM_AVAILABLE = True
except ImportError:
    LIGHTGBM_AVAILABLE = False
    print("LightGBM not available. Some examples will be skipped.")

try:
    from xgboost import XGBRegressor
    XGBOOST_AVAILABLE = True
except ImportError:
    XGBOOST_AVAILABLE = False
    print("XGBoost not available. Some examples will be skipped.")


<div style="font-size: 0.85em;">

# Note on Dependencies

This notebook requires the following packages:
- numpy, matplotlib, seaborn: For data manipulation and visualization
- scikit-learn: For machine learning models
- lightgbm, xgboost (optional): For gradient boosting models

The notebook also uses local packages:
- synthetic_data.synthetic_data_for_cate: For generating synthetic data with enhanced heterogeneity
- metalearners.s_learner: For the SLearner implementation

Make sure you have installed all required dependencies before running this notebook.
</div>


<div style="font-size: 0.85em;">

# Synthetic Data Generation

We generate synthetic data with known heterogeneous treatment effects using the `synthetic_data_for_cate` class. This function creates:

- A feature matrix with 5 covariates (by default)
- A binary treatment indicator (1=treated, 0=control)
- An outcome variable that depends on both covariates and treatment

The data generation process includes:
- Non-linear confounding (treatment assignment depends on X1 and X2)
- Heterogeneous treatment effects (effects vary based on all covariates)
- Non-linear baseline effects (outcome depends non-linearly on covariates)
- Heteroskedastic noise (noise level varies with X1)

This synthetic data allows us to evaluate the performance of different CATE estimation methods, as we know the true treatment effects.
</div>


In [None]:
# Create an instance of synthetic_data_for_cate with model2
data_generator = synthetic_data_for_cate(model_type='model2')

# Generate synthetic data with heterogeneous treatment effects
features, treatment_vector, outcomes = data_generator.get_synthetic_data()

# Print basic information about the generated data
print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
print(f"Outcome mean: {outcomes.mean():.2f}")


<div style="font-size: 0.85em;">

# Model Initialization and Fitting

We initialize several S-Learner models with different base learners:

1. **Random Forest**: A non-parametric model that can capture complex non-linear relationships
2. **Linear Regression**: A simple parametric model that assumes linear relationships
3. **Gradient Boosting**: An ensemble method that builds trees sequentially to correct errors
4. **LightGBM** (if available): A gradient boosting framework that uses tree-based learning
5. **XGBoost** (if available): Another gradient boosting framework with different implementation

Each model has different strengths and weaknesses for CATE estimation. For this demonstration, we'll use the Random Forest-based S-Learner as our primary model.

The fitting process involves:
1. Combining treatment indicator with features
2. Training a single model on this combined feature matrix
3. Using the model to predict outcomes under different treatment conditions
</div>


In [None]:
# Initialize S-Learner with different base models
rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
linear_model = LinearRegression()
gb_model = GradientBoostingRegressor(n_estimators=100, random_state=42)

# Create S-Learner instances with different models
s_learner_rf = SLearner(rf_model)
s_learner_linear = SLearner(linear_model)
s_learner_gb = SLearner(gb_model)

# Create LightGBM model if available
if 'LIGHTGBM_AVAILABLE' in globals() and LIGHTGBM_AVAILABLE:
    lgbm_model = LGBMRegressor(random_state=42)
    s_learner_lgbm = SLearner(lgbm_model)
    # Uncomment to use LightGBM model
    # sl = s_learner_lgbm

# Create XGBoost model if available
if 'XGBOOST_AVAILABLE' in globals() and XGBOOST_AVAILABLE:
    xgb_model = XGBRegressor(random_state=42)
    s_learner_xgb = SLearner(xgb_model)
    # Uncomment to use XGBoost model
    # sl = s_learner_xgb

# Select Random Forest as our primary model
sl = s_learner_rf

# Fit the model
print("Fitting S-Learner with Random Forest...")
sl.fit(
    X=features,             # Feature matrix
    t=treatment_vector,     # Treatment assignments (0/1)
    y=outcomes              # Observed outcomes
)
print("Model fitting complete.")


<div style="font-size: 0.85em;">

# CATE Estimation

After fitting the model, we can estimate the Conditional Average Treatment Effect (CATE) for each individual in our dataset. The CATE represents how much the treatment is expected to affect the outcome for an individual with specific characteristics.

The CATE estimation process for S-Learner involves:
1. Predicting outcomes with treatment set to 1 for all individuals
2. Predicting outcomes with treatment set to 0 for all individuals
3. Taking the difference between these predictions to get the CATE

This gives us an estimate of the treatment effect for each individual, conditional on their covariates.
</div>


In [None]:
# Estimate treatment effects (CATE)
cate = sl.effect(X=features)

# Print basic statistics about the estimated CATE
print(f"CATE Statistics:")
print(f"  Mean: {cate.mean():.4f}")
print(f"  Std Dev: {cate.std():.4f}")
print(f"  Min: {cate.min():.4f}")
print(f"  Max: {cate.max():.4f}")


<div style="font-size: 0.85em;">

# CATE Distribution Visualization

Visualizing the distribution of estimated treatment effects helps us understand the heterogeneity in treatment effects across the population. This can reveal:

1. **Average Effect**: The center of the distribution shows the average treatment effect
2. **Effect Heterogeneity**: The spread of the distribution shows how much treatment effects vary
3. **Subgroups**: Multiple peaks might indicate distinct subgroups with different responses
4. **Negative/Positive Effects**: The proportion of individuals with negative vs. positive effects

A narrow distribution suggests homogeneous treatment effects, while a wide distribution suggests high heterogeneity. Skewness in the distribution might indicate that certain types of individuals benefit more or less from the treatment.
</div>


In [None]:
# Calculate true treatment effects using the data generator
true_cate = data_generator.get_true_cate(features)

# Create a DataFrame for easier plotting with seaborn
import pandas as pd
cate_df = pd.DataFrame({
    'Estimated CATE': cate,
    'True CATE': true_cate
})

# Create images directory if it doesn't exist
import os
os.makedirs('images', exist_ok=True)

# Plot the distribution of estimated CATE
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE
sns.histplot(data=cate_df['Estimated CATE'], kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean values
plt.axvline(x=cate.mean(), color='blue', linestyle='--', linewidth=2,
            label=f'Mean Estimated CATE: {cate.mean():.4f}')
plt.axvline(x=true_cate.mean(), color='orange', linestyle='--', linewidth=2,
            label=f'Mean True CATE: {true_cate.mean():.4f}')

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of Estimated CATE - S-Learner', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects for estimated CATE
est_pos_pct = (cate > 0).mean() * 100
est_neg_pct = (cate < 0).mean() * 100
plt.annotate(f'Estimated Positive Effects: {est_pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Estimated Negative Effects: {est_neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

# Show percentage of positive and negative effects for true CATE
true_pos_pct = (true_cate > 0).mean() * 100
true_neg_pct = (true_cate < 0).mean() * 100
plt.annotate(f'True Positive Effects: {true_pos_pct:.1f}%',
             xy=(0.68, 0.80), xycoords='axes fraction', fontsize=11)
plt.annotate(f'True Negative Effects: {true_neg_pct:.1f}%',
             xy=(0.68, 0.75), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/s_learner_cate_distribution.png', dpi=300, bbox_inches='tight')
print("CATE distribution plot saved to images/s_learner_cate_distribution.png")
plt.show()


<div style="font-size: 0.85em;">

# Overlaid CATE Distributions: Estimated vs. True

To better compare the estimated and true CATE distributions, we can overlay them on the same plot using different colors. This visualization allows us to:

1. **Directly Compare Shapes**: See how closely the estimated distribution matches the true distribution
2. **Identify Discrepancies**: Spot areas where the model over- or under-estimates treatment effects
3. **Assess Heterogeneity Capture**: Determine if the model captures the true heterogeneity in treatment effects
4. **Evaluate Peaks and Modes**: Compare the peaks and modes of both distributions

The plot below shows both distributions with different colors, allowing for a direct visual comparison of their shapes, centers, and spreads.
</div>


In [None]:
# Create a plot that overlays both the estimated and true CATE distributions
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE using different colors and line styles
sns.kdeplot(data=cate_df, x='Estimated CATE', color='blue', fill=False, linewidth=2.5, linestyle='-',
            label=f'Estimated CATE (Mean: {cate.mean():.4f})')
sns.kdeplot(data=cate_df, x='True CATE', color='red', fill=False, linewidth=2.5, linestyle='--',
            label=f'True CATE (Mean: {true_cate.mean():.4f})')

# Add vertical lines for mean values
plt.axvline(x=cate.mean(), color='blue', linestyle='-', linewidth=2)
plt.axvline(x=true_cate.mean(), color='red', linestyle='--', linewidth=2)

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1, label='No Effect')

# Add annotations
plt.title('Comparison of Estimated vs. True CATE Distributions - S-Learner', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

# Calculate and display statistics
est_std = cate.std()
true_std = true_cate.std()

# Calculate correlation and mean absolute error
correlation = np.corrcoef(true_cate, cate)[0, 1]
mae = np.mean(np.abs(true_cate - cate))

# Add statistics annotations
plt.annotate(f'Estimated CATE Std: {est_std:.4f}', xy=(0.05, 0.95), xycoords='axes fraction', fontsize=10)
plt.annotate(f'True CATE Std: {true_std:.4f}', xy=(0.05, 0.90), xycoords='axes fraction', fontsize=10)
plt.annotate(f'Mean Absolute Error: {mae:.4f}', xy=(0.05, 0.85), xycoords='axes fraction', fontsize=10)
plt.annotate(f'Correlation: {correlation:.4f}', xy=(0.05, 0.80), xycoords='axes fraction', fontsize=10)

plt.tight_layout()

# Save and display plot
plt.savefig('images/s_learner_cate_distributions_comparison.png', dpi=300, bbox_inches='tight')
print("CATE distributions comparison plot saved to images/s_learner_cate_distributions_comparison.png")
plt.show()


<div style="font-size: 0.85em;">

# CATE Accuracy Evaluation: Predicted vs. Actual Treatment Effects

To evaluate the accuracy of our S-Learner model, we can compare the predicted CATE values with the actual treatment effects. This comparison helps us:

1. **Assess Model Accuracy**: How well does the model recover the true treatment effects?
2. **Identify Patterns in Errors**: Are there systematic biases in the predictions?
3. **Compare Treatment Groups**: Do predictions differ in accuracy between treated and control groups?
4. **Detect Outliers**: Are there individuals for whom the model predictions are particularly inaccurate?

The scatter plot below shows:
- Predicted CATE values on the y-axis
- True treatment effects on the x-axis
- Different symbols for treated and control groups
- A diagonal line representing perfect prediction (y=x)

Points close to the diagonal line indicate accurate predictions, while deviations suggest estimation errors.
</div>


In [None]:
# Create a scatter plot of predicted vs. true treatment effects
plt.figure(figsize=(10, 8))

# Separate treated and control groups for different markers
treated_indices = treatment_vector == 1
control_indices = treatment_vector == 0

# Plot treated group with one marker
plt.scatter(true_cate[treated_indices], cate[treated_indices], 
            marker='^', color='red', alpha=0.6, label='Treated Group')

# Plot control group with another marker
plt.scatter(true_cate[control_indices], cate[control_indices], 
            marker='o', color='blue', alpha=0.6, label='Control Group')

# Add diagonal line (perfect prediction)
min_val = min(true_cate.min(), cate.min())
max_val = max(true_cate.max(), cate.max())
plt.plot([min_val, max_val], [min_val, max_val], 'k--', label='Perfect Prediction')

# Calculate and display correlation coefficient
plt.annotate(f'Correlation: {correlation:.4f}', 
             xy=(0.05, 0.95), xycoords='axes fraction', fontsize=12)

# Calculate and display mean absolute error
plt.annotate(f'Mean Absolute Error: {mae:.4f}', 
             xy=(0.05, 0.90), xycoords='axes fraction', fontsize=12)

# Add labels and title
plt.title('Predicted vs. Actual Treatment Effects - S-Learner', fontsize=14)
plt.xlabel('Actual Treatment Effect', fontsize=12)
plt.ylabel('Predicted Treatment Effect (CATE)', fontsize=12)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)

plt.tight_layout()

# Save the plot as a PNG file
plt.savefig('images/s_learner_cate_accuracy_evaluation.png', dpi=300, bbox_inches='tight')
print("CATE accuracy evaluation plot saved to images/s_learner_cate_accuracy_evaluation.png")

plt.show()

<div style="font-size: 0.85em;">

# Scatter Plot Analysis: Discrete True CATE vs. Continuous Estimated CATE

## Observation of Discrete vs. Continuous Values

The scatter plot reveals an important characteristic of our data and model:

- **True CATE Values (x-axis)**: Appear as discrete values, forming vertical clusters
- **Estimated CATE Values (y-axis)**: Appear as continuous values, spread across a range from approximately -6 to +10

This pattern occurs because:

1. **True CATE Generation**: In model2, true CATE values are generated using threshold functions:
   ```python
   true_cate = (
       4.0 * (features[:, 0] > 0.5) -  # +4.0 if X1 > 0.5
       3.0 * (features[:, 1] > 0.7) +  # -3.0 if X2 > 0.7
       5.0 * (features[:, 2] * features[:, 3] > 0.5) -  # +5.0 if X3*X4 > 0.5
       2.0 * (features[:, 4] < 0.3)    # -2.0 if X5 < 0.3
   )
   ```

   Each component is a binary condition (0 or 1) multiplied by a coefficient. With 4 binary conditions, there are 2⁴ = 16 possible combinations, creating a discrete set of possible CATE values.

2. **Estimated CATE Generation**: The S-Learner uses RandomForestRegressor, which produces continuous predictions by averaging the outputs of many decision trees, resulting in a continuous range of estimated values.

## Addressing the Key Questions

### 1. Model Accuracy Assessment

- **Correlation**: 0.9133 - Indicates a very strong positive correlation between true and estimated CATE values
- **Mean Absolute Error**: 0.8050 - Represents the average magnitude of errors
- **Visual Assessment**: Points cluster around the diagonal line but with considerable spread, suggesting moderate accuracy

The model captures the general direction of treatment effects (positive vs. negative) reasonably well, but struggles to precisely estimate the exact magnitude of effects. This is expected given the challenge of recovering discrete threshold-based effects using a continuous prediction model.

### 2. Patterns in Errors

Several systematic patterns are visible:

- **Regression to the Mean**: Extreme true CATE values tend to have less extreme predictions, a common phenomenon in predictive modeling
- **Horizontal Banding**: Predictions cluster around certain horizontal bands, suggesting the model is identifying some discrete effect levels but not capturing the full complexity
- **Overestimation of Negative Effects**: The model tends to predict less negative values for the most negative true CATE values
- **Underestimation of Positive Effects**: The highest true CATE values are often underestimated

These patterns suggest the model is smoothing out the sharp threshold effects present in the true data generation process.

### 3. Treatment Group Comparison

Comparing the red triangles (treated group) and blue circles (control group):

- **Distribution**: Both groups appear similarly distributed across the true CATE spectrum
- **Accuracy**: There's no obvious difference in prediction accuracy between treated and control groups
- **Density**: The treated group may have slightly more points in certain regions, reflecting the non-random treatment assignment in the data generation process

The similar prediction accuracy across groups suggests the S-Learner is not biased toward either group, which is a positive finding for fairness in treatment effect estimation.

### 4. Outlier Detection

Several types of outliers are visible:

- **Far from Diagonal**: Some points deviate significantly from the perfect prediction line
- **Extreme Predictions**: A few points have estimated CATE values below -5 or above +8, which are more extreme than most predictions
- **Misclassified Direction**: Some points with positive true CATE have negative estimated CATE and vice versa, representing qualitative errors in effect direction

These outliers likely occur when:
1. Multiple feature interactions create complex patterns that the model fails to capture
2. Rare feature combinations appear in the data that weren't well-represented in the training process
3. The inherent noise in the data generation process leads to challenging cases

## Conclusion

The S-Learner provides reasonable but imperfect estimates of the true CATE values. The discrete nature of the true effects presents a fundamental challenge for any continuous prediction model. While the model captures the general pattern of effects, it struggles with precise estimation of threshold-based effects. This highlights the importance of understanding the underlying data generation process when interpreting model predictions.

</div>
