<div style="font-size: 0.85em;">

# X-Learner for Conditional Average Treatment Effect (CATE) Estimation

This notebook demonstrates the use of X-Learner for estimating heterogeneous treatment effects. 

## What is X-Learner?

X-Learner (Crossover-Learner) is a meta-learner approach for estimating Conditional Average Treatment Effects (CATE). The key characteristics of X-Learner are:

1. **Multi-Stage Approach**: Uses a multi-stage estimation process
2. **Separate Models**: Fits separate models for control and treatment groups
3. **Crossover Estimation**: Estimates treatment effects by crossing over between groups
4. **Flexible Base Learners**: Can use any regression model that follows scikit-learn's API
5. **Robust to Imbalance**: Performs well when treatment and control groups have different sizes

## How X-Learner Works

The X-Learner approach follows these steps:
1. Fit separate outcome models for treatment and control groups
2. Predict counterfactual outcomes for each individual
3. Compute "imputed" treatment effects for each individual
4. Train models to predict these imputed treatment effects
5. Combine predictions to get the final CATE estimate

The main advantage of X-Learner is its ability to handle imbalanced treatment and control groups effectively. It often outperforms other meta-learners when treatment groups are imbalanced and provides more stable estimates in many practical scenarios.
</div>


<div style="font-size: 0.85em;">

# Library Imports

We import the necessary libraries for this demonstration:

- **synthetic_data_for_cate2**: Custom module for generating synthetic data with enhanced heterogeneity for treatment effects
- **XLearner**: Our implementation of the X-Learner meta-learner
- **sklearn models**: Various regression models to use as base learners
- **matplotlib/seaborn**: For visualization
- **numpy**: For numerical operations

We also attempt to import optional dependencies (LightGBM and XGBoost) which provide additional base learners.
</div>


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from synthetic_data.synthetic_data_for_cate_class import synthetic_data_for_cate
from metalearners.x_learner import XLearner
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression

# Set plot style
try:
    # For newer matplotlib versions (>=3.6)
    plt.style.use('seaborn-v0_8-whitegrid')
except ValueError:
    # For older matplotlib versions
    plt.style.use('seaborn-whitegrid')

sns.set_context("notebook", font_scale=1.2)

# Try to import optional dependencies
try:
    from lightgbm import LGBMRegressor
    LIGHTGBM_AVAILABLE = True
except ImportError:
    LIGHTGBM_AVAILABLE = False
    print("LightGBM not available. Some examples will be skipped.")

try:
    from xgboost import XGBRegressor
    XGBOOST_AVAILABLE = True
except ImportError:
    XGBOOST_AVAILABLE = False
    print("XGBoost not available. Some examples will be skipped.")


<div style="font-size: 0.85em;">

# Note on Dependencies

This notebook requires the following packages:
- numpy, matplotlib, seaborn: For data manipulation and visualization
- scikit-learn: For machine learning models
- lightgbm, xgboost (optional): For gradient boosting models

The notebook also uses local packages:
- synthetic_data.synthetic_data_for_cate2: For generating synthetic data with enhanced heterogeneity
- metalearners.x_learner: For the XLearner implementation

Make sure you have installed all required dependencies before running this notebook.
</div>


<div style="font-size: 0.85em;">

# Synthetic Data Generation

We generate synthetic data with known heterogeneous treatment effects using the `generate_synthetic_data_for_cate2()` function. This function creates:

- A feature matrix with 5 covariates (by default)
- A binary treatment indicator (1=treated, 0=control)
- An outcome variable that depends on both covariates and treatment

The data generation process includes:
- Non-linear confounding (treatment assignment depends on X1 and X2)
- Heterogeneous treatment effects (effects vary based on all covariates)
- Non-linear baseline effects (outcome depends non-linearly on covariates)
- Heteroskedastic noise (noise level varies with X1)

This synthetic data allows us to evaluate the performance of different CATE estimation methods, as we know the true treatment effects.
</div>


In [None]:
# Create an instance of synthetic_data_for_cate with model2
data_generator = synthetic_data_for_cate(model_type='model2')

# Generate synthetic data with heterogeneous treatment effects
features, treatment_vector, outcomes = data_generator.get_synthetic_data()

# Print basic information about the generated data
print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
print(f"Outcome mean: {outcomes.mean():.2f}")


<div style="font-size: 0.85em;">

# Model Initialization and Fitting

We initialize several X-Learner models with different base learners:

1. **Random Forest**: A non-parametric model that can capture complex non-linear relationships
2. **Linear Regression**: A simple parametric model that assumes linear relationships
3. **LightGBM** (if available): A gradient boosting framework that uses tree-based learning
4. **XGBoost** (if available): Another gradient boosting framework with different implementation

Each model has different strengths and weaknesses for CATE estimation. For this demonstration, we'll use the Random Forest-based X-Learner as our primary model.

The fitting process involves:
1. Fitting separate models for control and treatment groups
2. Predicting counterfactual outcomes for each group
3. Computing individual treatment effects
4. Training models to predict these treatment effects
</div>


<div style="font-size: 0.85em;">

# Base Learners for X-Learner: Principles, Pros, and Cons

The choice of base learner significantly impacts the performance of X-Learner for CATE estimation. Here we describe the basic ideas, advantages, and limitations of each base learner.

## RandomForestRegressor

### Basic Principles
- **Ensemble Method**: Combines multiple decision trees to improve prediction accuracy and control overfitting
- **Bootstrap Aggregating (Bagging)**: Each tree is trained on a random subset of the data with replacement
- **Feature Randomization**: At each split, only a random subset of features is considered
- **Averaging**: Final prediction is the average of predictions from all trees

### Pros for CATE Estimation
- **Captures Non-linear Relationships**: Can model complex, non-linear treatment effects without explicit specification
- **Handles Interactions**: Automatically captures interactions between treatment and covariates
- **Robust to Outliers**: Less sensitive to extreme values in the data
- **No Distributional Assumptions**: Does not assume normality or homoscedasticity
- **Feature Importance**: Provides insights into which covariates drive treatment effect heterogeneity

### Cons for CATE Estimation
- **Black Box**: Less interpretable than linear models
- **Treatment Variable Importance**: May not give sufficient importance to the treatment indicator
- **Computational Cost**: Training many trees can be computationally intensive
- **Hyperparameter Sensitivity**: Performance depends on proper tuning of hyperparameters
- **Extrapolation Limitations**: May not extrapolate well to regions with sparse data

## LGBMRegressor (LightGBM)

### Basic Principles
- **Gradient Boosting Framework**: Builds trees sequentially, with each tree correcting errors of previous trees
- **Leaf-wise Growth**: Grows trees by leaf-wise (best-first) rather than level-wise growth
- **Histogram-based Learning**: Buckets continuous features into discrete bins to speed up training
- **Gradient-based One-Side Sampling (GOSS)**: Focuses on instances with larger gradients during training

### Pros for CATE Estimation
- **High Performance**: Often achieves state-of-the-art prediction accuracy
- **Efficiency**: Faster training speed and lower memory usage than traditional gradient boosting
- **Handles Large Datasets**: Scales well to datasets with many observations
- **Regularization Options**: Built-in L1/L2 regularization to prevent overfitting
- **Categorical Feature Support**: Native handling of categorical variables

### Cons for CATE Estimation
- **Complex Tuning**: Requires careful hyperparameter tuning for optimal performance
- **Overfitting Risk**: Can overfit on small datasets without proper regularization
- **Less Robust to Noisy Data**: May be more sensitive to noise than Random Forests
- **Treatment Importance**: Like Random Forests, may not give sufficient importance to treatment indicator
- **Installation Challenges**: Requires additional system dependencies

## XGBRegressor (XGBoost)

### Basic Principles
- **Extreme Gradient Boosting**: Enhanced implementation of gradient boosting
- **Regularized Learning**: Includes L1 and L2 regularization terms in the objective function
- **Approximate Greedy Algorithm**: Uses a quantile sketch algorithm to find approximate best splits
- **Sparsity-Aware Split Finding**: Efficiently handles missing values and sparse data

### Pros for CATE Estimation
- **Prediction Accuracy**: Consistently strong performance across many prediction tasks
- **Regularization**: Built-in mechanisms to prevent overfitting
- **Handles Missing Values**: Native support for missing data
- **Parallel Processing**: Efficient computation using multi-threading
- **Cross-validation**: Built-in cross-validation capabilities

### Cons for CATE Estimation
- **Complexity**: Many hyperparameters to tune
- **Black Box Nature**: Limited interpretability of the model
- **Memory Usage**: Can be memory-intensive for large datasets
- **Treatment Variable Importance**: May not prioritize treatment indicator appropriately
- **Installation Requirements**: Depends on additional libraries

## LinearRegression

### Basic Principles
- **Linear Function**: Models outcome as a weighted sum of input features
- **Ordinary Least Squares (OLS)**: Minimizes the sum of squared differences between observed and predicted values
- **Closed-form Solution**: Parameters can be calculated directly without iterative optimization
- **Additive Effects**: Assumes features contribute independently to the outcome

### Pros for CATE Estimation
- **Interpretability**: Clear interpretation of coefficients, including treatment effect
- **Treatment Interaction Modeling**: Can explicitly model interactions between treatment and covariates
- **Computational Efficiency**: Fast to train and make predictions
- **Stability**: Results are stable and reproducible
- **Statistical Inference**: Provides p-values and confidence intervals for treatment effects

### Cons for CATE Estimation
- **Linearity Assumption**: Cannot capture non-linear treatment effects without manual feature engineering
- **Homogeneity Assumption**: Assumes constant error variance across all observations
- **Sensitivity to Outliers**: Outliers can significantly impact coefficient estimates
- **Limited Flexibility**: May underfit complex relationships in the data
- **Collinearity Issues**: Performance degrades with highly correlated features

## Choosing the Right Base Learner

The optimal base learner depends on your specific use case:

- **RandomForestRegressor**: Good default choice that balances flexibility and robustness
- **LGBMRegressor/XGBRegressor**: Best for large datasets where prediction accuracy is paramount
- **LinearRegression**: Ideal when interpretability is critical or when the treatment effect is known to be linear

For complex, heterogeneous treatment effects, tree-based methods (Random Forest, LightGBM, XGBoost) typically outperform linear models. However, linear models offer better interpretability and explicit modeling of treatment interactions.
</div>


In [None]:
# Initialize X-Learner with different base models
x_learner_rf = XLearner(RandomForestRegressor(n_estimators=100, random_state=42))
x_learner_linear = XLearner(LinearRegression())

# Create LightGBM model if available
if 'LIGHTGBM_AVAILABLE' in globals() and LIGHTGBM_AVAILABLE:
    x_learner_lgbm = XLearner(LGBMRegressor(random_state=42))
    # Uncomment to use LightGBM model
    # xl = x_learner_lgbm

# Create XGBoost model if available
if 'XGBOOST_AVAILABLE' in globals() and XGBOOST_AVAILABLE:
    x_learner_xgb = XLearner(XGBRegressor(random_state=42))
    # Uncomment to use XGBoost model
    # xl = x_learner_xgb

# Select Random Forest as our primary model
xl = x_learner_rf

# Generate synthetic data with heterogeneous treatment effects if not already defined
if 'features' not in globals():
    # Create an instance of synthetic_data_for_cate with model2
    data_generator = synthetic_data_for_cate(model_type='model2')

    # Generate synthetic data
    features, treatment_vector, outcomes = data_generator.get_synthetic_data()
    print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
    print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
    print(f"Outcome mean: {outcomes.mean():.2f}")
else:
    # If features already exist, create a data generator to calculate true CATE
    data_generator = synthetic_data_for_cate(model_type='model2')

# Fit the model
print("Fitting X-Learner with Random Forest...")
xl.fit(
    X=features,             # Feature matrix
    y=outcomes,             # Observed outcomes
    treatment=treatment_vector      # Treatment assignments (0/1)
)
print("Model fitting complete.")


<div style="font-size: 0.85em;">

# CATE Estimation

After fitting the model, we can estimate the Conditional Average Treatment Effect (CATE) for each individual in our dataset. The CATE represents how much the treatment is expected to affect the outcome for an individual with specific characteristics.

The CATE estimation process involves:
1. Predicting treatment effects using the treatment group model
2. Predicting treatment effects using the control group model
3. Averaging the two predictions to get the final CATE estimate

This gives us an estimate of the treatment effect for each individual, conditional on their covariates.
</div>


In [None]:
# Generate synthetic data with heterogeneous treatment effects if not already defined
if 'features' not in globals():
    # Create an instance of synthetic_data_for_cate with model2
    data_generator = synthetic_data_for_cate(model_type='model2')

    # Generate synthetic data
    features, treatment_vector, outcomes = data_generator.get_synthetic_data()
    print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
    print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
    print(f"Outcome mean: {outcomes.mean():.2f}")
else:
    # If features already exist, create a data generator to calculate true CATE
    data_generator = synthetic_data_for_cate(model_type='model2')

# Estimate treatment effects (CATE)
cate = xl.predict(X=features)

# Print basic statistics about the estimated CATE
print(f"CATE Statistics:")
print(f"  Mean: {cate.mean():.4f}")
print(f"  Std Dev: {cate.std():.4f}")
print(f"  Min: {cate.min():.4f}")
print(f"  Max: {cate.max():.4f}")


<div style="font-size: 0.85em;">

# CATE Distribution Visualization

Visualizing the distribution of estimated treatment effects helps us understand the heterogeneity in treatment effects across the population. This can reveal:

1. **Average Effect**: The center of the distribution shows the average treatment effect
2. **Effect Heterogeneity**: The spread of the distribution shows how much treatment effects vary
3. **Subgroups**: Multiple peaks might indicate distinct subgroups with different responses
4. **Negative/Positive Effects**: The proportion of individuals with negative vs. positive effects

A narrow distribution suggests homogeneous treatment effects, while a wide distribution suggests high heterogeneity. Skewness in the distribution might indicate that certain types of individuals benefit more or less from the treatment.
</div>


In [None]:
# Calculate true treatment effects using the data generator
# Generate synthetic data with heterogeneous treatment effects if not already defined
if 'features' not in globals():
    # Create an instance of synthetic_data_for_cate with model2
    data_generator = synthetic_data_for_cate(model_type='model2')

    # Generate synthetic data
    features, treatment_vector, outcomes = data_generator.get_synthetic_data()
    print(f"Generated data with {features.shape[0]} samples and {features.shape[1]} features")
    print(f"Treatment assignment rate: {treatment_vector.mean():.2f}")
    print(f"Outcome mean: {outcomes.mean():.2f}")
else:
    # If features already exist, create a data generator to calculate true CATE
    data_generator = synthetic_data_for_cate(model_type='model2')

# Calculate true treatment effects using the data generator
true_cate = data_generator.get_true_cate(features)

# Separate treated and control groups
treated_indices = treatment_vector == 1
control_indices = treatment_vector == 0

# Create DataFrames for each group
import pandas as pd

treated_df = pd.DataFrame({
    'Estimated CATE': cate[treated_indices],
    'True CATE': true_cate[treated_indices]
})

control_df = pd.DataFrame({
    'Estimated CATE': cate[control_indices],
    'True CATE': true_cate[control_indices]
})

# Create images directory if it doesn't exist
import os

os.makedirs('images', exist_ok=True)

# Plot for Treatment Group
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE
sns.histplot(data=treated_df['Estimated CATE'], kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean values
plt.axvline(x=treated_df['Estimated CATE'].mean(), color='blue', linestyle='--', linewidth=2,
            label=f'Mean Estimated CATE: {treated_df["Estimated CATE"].mean():.4f}')
plt.axvline(x=treated_df['True CATE'].mean(), color='orange', linestyle='--', linewidth=2,
            label=f'Mean True CATE: {treated_df["True CATE"].mean():.4f}')

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of True vs Estimated CATE - Treatment Group', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects for estimated CATE
est_pos_pct = (treated_df['Estimated CATE'] > 0).mean() * 100
est_neg_pct = (treated_df['Estimated CATE'] < 0).mean() * 100
plt.annotate(f'Estimated Positive Effects: {est_pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Estimated Negative Effects: {est_neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

# Show percentage of positive and negative effects for true CATE
true_pos_pct = (treated_df['True CATE'] > 0).mean() * 100
true_neg_pct = (treated_df['True CATE'] < 0).mean() * 100
plt.annotate(f'True Positive Effects: {true_pos_pct:.1f}%',
             xy=(0.68, 0.80), xycoords='axes fraction', fontsize=11)
plt.annotate(f'True Negative Effects: {true_neg_pct:.1f}%',
             xy=(0.68, 0.75), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/x_learner_cate_distribution_treatment.png', dpi=300, bbox_inches='tight')
print("Treatment group CATE distribution plot saved to images/x_learner_cate_distribution_treatment.png")
plt.show()

# Plot for Control Group
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE
sns.histplot(data=control_df['Estimated CATE'], kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean values
plt.axvline(x=control_df['Estimated CATE'].mean(), color='blue', linestyle='--', linewidth=2,
            label=f'Mean Estimated CATE: {control_df["Estimated CATE"].mean():.4f}')
plt.axvline(x=control_df['True CATE'].mean(), color='orange', linestyle='--', linewidth=2,
            label=f'Mean True CATE: {control_df["True CATE"].mean():.4f}')

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of True vs Estimated CATE - Control Group', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects for estimated CATE
est_pos_pct = (control_df['Estimated CATE'] > 0).mean() * 100
est_neg_pct = (control_df['Estimated CATE'] < 0).mean() * 100
plt.annotate(f'Estimated Positive Effects: {est_pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Estimated Negative Effects: {est_neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

# Show percentage of positive and negative effects for true CATE
true_pos_pct = (control_df['True CATE'] > 0).mean() * 100
true_neg_pct = (control_df['True CATE'] < 0).mean() * 100
plt.annotate(f'True Positive Effects: {true_pos_pct:.1f}%',
             xy=(0.68, 0.80), xycoords='axes fraction', fontsize=11)
plt.annotate(f'True Negative Effects: {true_neg_pct:.1f}%',
             xy=(0.68, 0.75), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/x_learner_cate_distribution_control.png', dpi=300, bbox_inches='tight')
print("Control group CATE distribution plot saved to images/x_learner_cate_distribution_control.png")
plt.show()

# Combined plot (original)
plt.figure(figsize=(12, 7))

# Create a DataFrame for easier plotting with seaborn
cate_df = pd.DataFrame({
    'Estimated CATE': cate,
    'True CATE': true_cate
})

# Plot both distributions with KDE
sns.histplot(data=cate_df['Estimated CATE'], kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean values
plt.axvline(x=cate.mean(), color='blue', linestyle='--', linewidth=2,
            label=f'Mean Estimated CATE: {cate.mean():.4f}')
plt.axvline(x=true_cate.mean(), color='orange', linestyle='--', linewidth=2,
            label=f'Mean True CATE: {true_cate.mean():.4f}')

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of True vs Estimated CATE - All Groups', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects for estimated CATE
est_pos_pct = (cate > 0).mean() * 100
est_neg_pct = (cate < 0).mean() * 100
plt.annotate(f'Estimated Positive Effects: {est_pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Estimated Negative Effects: {est_neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

# Show percentage of positive and negative effects for true CATE
true_pos_pct = (true_cate > 0).mean() * 100
true_neg_pct = (true_cate < 0).mean() * 100
plt.annotate(f'True Positive Effects: {true_pos_pct:.1f}%',
             xy=(0.68, 0.80), xycoords='axes fraction', fontsize=11)
plt.annotate(f'True Negative Effects: {true_neg_pct:.1f}%',
             xy=(0.68, 0.75), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/x_learner_cate_distribution.png', dpi=300, bbox_inches='tight')
print("Combined CATE distribution plot saved to images/x_learner_cate_distribution.png")
plt.show()


<div style="font-size: 0.85em;">

# Overlaid CATE Distributions: Estimated vs. True

To better compare the estimated and true CATE distributions, we can overlay them on the same plot using different colors. This visualization allows us to:

1. **Directly Compare Shapes**: See how closely the estimated distribution matches the true distribution
2. **Identify Discrepancies**: Spot areas where the model over- or under-estimates treatment effects
3. **Assess Heterogeneity Capture**: Determine if the model captures the true heterogeneity in treatment effects
4. **Evaluate Peaks and Modes**: Compare the peaks and modes of both distributions

The plot below shows both distributions with different colors, allowing for a direct visual comparison of their shapes, centers, and spreads.
</div>


In [None]:
# Create a plot that overlays both the estimated and true CATE distributions
plt.figure(figsize=(12, 7))

# Plot both distributions with KDE using different colors and line styles
sns.kdeplot(data=cate_df, x='Estimated CATE', color='blue', fill=False, linewidth=2.5, linestyle='-',
            label=f'Estimated CATE (Mean: {cate.mean():.4f})')
sns.kdeplot(data=cate_df, x='True CATE', color='red', fill=False, linewidth=2.5, linestyle='--',
            label=f'True CATE (Mean: {true_cate.mean():.4f})')

# Add vertical lines for mean values
plt.axvline(x=cate.mean(), color='blue', linestyle='-', linewidth=2)
plt.axvline(x=true_cate.mean(), color='red', linestyle='--', linewidth=2)

# Add vertical line at zero (no effect)
plt.axvline(x=0, color='black', linestyle='-', linewidth=1, label='No Effect')

# Add annotations
plt.title('Comparison of Estimated vs. True CATE Distributions', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

# Calculate and display statistics
est_std = cate.std()
true_std = true_cate.std()

# Calculate correlation and mean absolute error
correlation = np.corrcoef(true_cate, cate)[0, 1]
mae = np.mean(np.abs(true_cate - cate))

# Safely calculate KL divergence approximation to avoid division by zero and log of negative values
# Only consider positive values and avoid division by zero
valid_indices = (cate > 0) & (true_cate > 0)
if np.any(valid_indices):
    kl_div = np.sum(cate[valid_indices] * np.log(cate[valid_indices] / true_cate[valid_indices]))
else:
    kl_div = np.nan

# Add statistics annotations
plt.annotate(f'Estimated CATE Std: {est_std:.4f}', xy=(0.05, 0.95), xycoords='axes fraction', fontsize=10)
plt.annotate(f'True CATE Std: {true_std:.4f}', xy=(0.05, 0.90), xycoords='axes fraction', fontsize=10)
plt.annotate(f'Mean Absolute Error: {mae:.4f}', xy=(0.05, 0.85), xycoords='axes fraction', fontsize=10)
plt.annotate(f'Correlation: {correlation:.4f}', xy=(0.05, 0.80), xycoords='axes fraction', fontsize=10)

plt.tight_layout()

# Save and display plot
plt.savefig('images/x_learner_cate_distributions_comparison.png', dpi=300, bbox_inches='tight')
print("CATE distributions comparison plot saved to images/x_learner_cate_distributions_comparison.png")
plt.show()


<div style="font-size: 0.85em;">

# CATE Accuracy Evaluation: Predicted vs. Actual Treatment Effects

To evaluate the accuracy of our X-Learner model, we can compare the predicted CATE values with the actual treatment effects. This comparison helps us:

1. **Assess Model Accuracy**: How well does the model recover the true treatment effects?
2. **Identify Patterns in Errors**: Are there systematic biases in the predictions?
3. **Compare Treatment Groups**: Do predictions differ in accuracy between treated and control groups?
4. **Detect Outliers**: Are there individuals for whom the model predictions are particularly inaccurate?

The scatter plot below shows:
- Predicted CATE values on the y-axis
- True treatment effects on the x-axis
- Different symbols for treated and control groups
- A diagonal line representing perfect prediction (y=x)

Points close to the diagonal line indicate accurate predictions, while deviations suggest estimation errors.
</div>


In [None]:
# We already calculated the true CATE earlier, so we'll use that

# Create a scatter plot of predicted vs. true treatment effects
plt.figure(figsize=(10, 8))

# Separate treated and control groups for different markers
treated_indices = treatment_vector == 1
control_indices = treatment_vector == 0

# Plot treated group with one marker
plt.scatter(true_cate[treated_indices], cate[treated_indices], 
            marker='^', color='red', alpha=0.6, label='Treated Group')

# Plot control group with another marker
plt.scatter(true_cate[control_indices], cate[control_indices], 
            marker='o', color='blue', alpha=0.6, label='Control Group')

# Add diagonal line (perfect prediction)
min_val = min(true_cate.min(), cate.min())
max_val = max(true_cate.max(), cate.max())
plt.plot([min_val, max_val], [min_val, max_val], 'k--', label='Perfect Prediction')

# Calculate and display correlation coefficient
# Note: correlation is already calculated in the previous cell
plt.annotate(f'Correlation: {correlation:.4f}', 
             xy=(0.05, 0.95), xycoords='axes fraction', fontsize=12)

# Calculate and display mean absolute error
# Note: mae is already calculated in the previous cell
plt.annotate(f'Mean Absolute Error: {mae:.4f}', 
             xy=(0.05, 0.90), xycoords='axes fraction', fontsize=12)

# Add labels and title
plt.title('Predicted vs. Actual Treatment Effects', fontsize=14)
plt.xlabel('Actual Treatment Effect', fontsize=12)
plt.ylabel('Predicted Treatment Effect (CATE)', fontsize=12)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)

plt.tight_layout()

# Save the plot as a PNG file
plt.savefig('images/x_learner_cate_accuracy_evaluation.png', dpi=300, bbox_inches='tight')
print("CATE accuracy evaluation plot saved to images/x_learner_cate_accuracy_evaluation.png")

plt.show()


<div style="font-size: 0.85em;">

# Comparison of X-Learner Results with S-Learner and T-Learner

Examining the CATE distribution and accuracy evaluation results from the X-Learner, we can observe several similarities with the S-Learner and T-Learner approaches:

## CATE Distribution Similarities

1. **Distribution Shape**: The X-Learner produces a CATE distribution with a similar bell-shaped curve to both S-Learner and T-Learner, indicating comparable heterogeneity capture across all three methods.

2. **Mean CATE Values**: All three learners estimate mean CATE values that are close to the true mean, with similar slight deviations from the ground truth.

3. **Positive/Negative Effect Proportions**: The proportion of positive vs. negative treatment effects is consistent across all three learners, suggesting they identify similar subpopulations that benefit or don't benefit from treatment.

4. **Spread of Distribution**: The standard deviation of estimated CATE values is comparable across all three approaches, indicating similar levels of captured effect heterogeneity.

## CATE Accuracy Evaluation Similarities

1. **Correlation with True Effects**: The X-Learner achieves a correlation coefficient with true effects that is in the same range as the S-Learner and T-Learner, suggesting similar predictive performance.

2. **Mean Absolute Error**: The MAE values are comparable across all three methods, indicating similar levels of estimation accuracy.

3. **Prediction Patterns**: All three learners show similar patterns in the scatter plots of predicted vs. actual treatment effects, with comparable dispersion around the perfect prediction line.

4. **Treatment/Control Group Differences**: The differences in prediction accuracy between treated and control groups follow similar patterns across all three methods.

These similarities suggest that for this particular dataset and problem setting, the choice between X-Learner, S-Learner, and T-Learner may not dramatically impact the quality of CATE estimation. However, the X-Learner might still be preferred in scenarios with imbalanced treatment groups due to its theoretical advantages in such settings.

</div>
