<div style="font-size: 0.85em;">

# Synthetic Data Generation for Conditional Average Treatment Effect (CATE) Estimation

This notebook demonstrates the generation and analysis of synthetic data specifically designed for Conditional Average Treatment Effect (CATE) estimation. CATE represents how treatment effects vary across different subgroups or individuals based on their characteristics.

The synthetic data generated in this notebook exhibits several key properties:

1. **Heterogeneous Treatment Effects**: Effects vary based on covariates
2. **Non-linear Interactions**: Complex relationships between treatment and covariates
3. **Variable Effect Sizes**: Different subgroups experience different magnitudes of effects
4. **Complex Confounding**: Treatment assignment depends non-linearly on covariates

This notebook serves as a practical example of how to generate, visualize, and analyze synthetic data for causal inference research, particularly for evaluating CATE estimation methods.
</div>


<div style="font-size: 0.85em;">

# Library Imports

- **numpy**: For numerical operations and array manipulation
- **pandas**: For data manipulation and analysis using DataFrames
- **matplotlib.pyplot**: For creating visualizations
- **seaborn**: For statistical data visualization
- **synthetic_data_for_cate**: Custom class for generating synthetic data with enhanced heterogeneity for treatment effects
</div>


In [None]:
import seaborn
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from synthetic_data.synthetic_data_for_cate_class import synthetic_data_for_cate

<div style="font-size: 0.85em;">

# Synthetic Data Generation

In this section, we generate synthetic data using the `synthetic_data_for_cate` class with `model_type='model2'`. This creates:

- A feature matrix `X` with 5 covariates (by default)
- A binary treatment indicator `treatment` (1=treated, 0=control)
- An outcome variable `y` that depends on both covariates and treatment

The data generation process includes:
- Non-linear confounding (treatment assignment depends on X1 and X2)
- Heterogeneous treatment effects (effects vary based on all covariates)
- Non-linear baseline effects (outcome depends non-linearly on covariates)
- Heteroskedastic noise (noise level varies with X1)

After generating the data, we organize it into a pandas DataFrame for easier analysis.
</div>


In [None]:

# Create an instance of synthetic_data_for_cate with model2
data_generator = synthetic_data_for_cate(model_type='model2')

# Generate synthetic data with heterogeneous treatment effects
X, treatment, y = data_generator.get_synthetic_data()

# Create DataFrame with all variables
data = pd.DataFrame({
    'treatment': treatment,
    'y': y,
    'X1': X[:, 0],
    'X2': X[:, 1],
    'X3': X[:, 2],
    'X4': X[:, 3],
    'X5': X[:, 4]
})

# The rest of your analysis code remains the same...

<div style="font-size: 0.85em;">

# Data Preview

Here we display the first few rows of our synthetic dataset to examine its structure. The DataFrame contains:

- `treatment`: Binary indicator (1=treated, 0=control)
- `y`: Outcome variable
- `X1` to `X5`: Covariates that influence both treatment assignment and outcome

This preview helps us understand the data structure before proceeding with further analysis.
</div>


In [None]:
print("Data Preview:")
print(data.head())


<div style="font-size: 0.85em;">

# Required Package Installation

Before proceeding with visualization, we need to install several packages:

- **networkx**: For creating and manipulating complex networks/graphs
- **graphviz**: For graph visualization (Python binding for Graphviz)
- **seaborn**: For statistical data visualization

Note that Graphviz requires both the Python package and the system-level software.
</div>


In [None]:
# Note: The 'graphviz' package is the Python binding for Graphviz
# You may need to install the system-level Graphviz software separately:
# - On Ubuntu/Debian: sudo apt-get install graphviz
# - On macOS: brew install graphviz
# - On Windows: download and install from https://graphviz.org/download/
!pip install networkx graphviz seaborn

In [None]:
import networkx as nx
import graphviz


<div style="font-size: 0.85em;">

# Causal Graph Visualization

Here we create a directed graph to visualize the causal relationships in our synthetic data:

- **Nodes**:
  - `X`: Covariates (X1-X5) - represented as a box
  - `T`: Treatment variable - represented as a circle
  - `Y`: Outcome variable - represented as a circle

- **Edges**:
  - `X → T`: Covariates influence treatment assignment (confounding)
  - `X → Y`: Covariates directly affect the outcome
  - `T → Y`: Treatment affects the outcome

This graph illustrates the causal structure that was used to generate our synthetic data. Understanding this structure is crucial for properly estimating conditional average treatment effects.
</div>


In [None]:
# Create and visualize causal graph
G = graphviz.Digraph()
G.attr(rankdir='LR')

# Add nodes
G.node('X', 'Covariates\n(X1-X5)', shape='box')
G.node('T', 'Treatment', shape='circle')
G.node('Y', 'Outcome', shape='circle')

# Add edges
G.edge('X', 'T')
G.edge('X', 'Y')
G.edge('T', 'Y')

# Save the graph as a PNG file
G.format = 'png'
G.render('images/causal_graph', cleanup=True)
print("Causal graph saved to images/causal_graph.png")

# Display the graph
G

<div style="font-size: 0.85em;">

# Correlation Analysis

In this final section, we analyze the correlations between all variables in our dataset:

- We calculate the Pearson correlation coefficients between all pairs of variables
- We visualize these correlations using a heatmap where:
  - Red indicates positive correlation
  - Blue indicates negative correlation
  - The intensity of color represents correlation strength
  - Numerical values show the exact correlation coefficients

This analysis helps us understand:
- How strongly covariates are related to treatment assignment
- How strongly covariates are related to the outcome
- How treatment is related to the outcome (before controlling for confounding)
- Potential multicollinearity among covariates

Note that correlation does not imply causation. The causal graph above shows the true causal relationships, while this correlation analysis shows only the statistical associations.
</div>


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
# Analyze correlations between variables
plt.figure(figsize=(10, 6))
correlation_matrix = data[['treatment', 'y', 'X1', 'X2', 'X3', 'X4', 'X5']].corr()
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0)
plt.title('Correlation Heatmap')
plt.tight_layout()

# Save the heatmap as a PNG file
plt.savefig('images/correlation_heatmap.png', dpi=300, bbox_inches='tight')
print("Correlation heatmap saved to images/correlation_heatmap.png")

# Display the plot
plt.show()


<div style="font-size: 0.85em;">

# Correlation Heatmap Analysis

Based on the correlation heatmap, we can make several important observations:

1. **Covariates and Treatment Assignment**:
   - X1 and X2 show moderate positive correlations with treatment (approximately 0.4-0.5), indicating they strongly influence treatment assignment
   - X3, X4, and X5 have weak correlations with treatment, suggesting minimal influence on treatment assignment
   - This pattern confirms the presence of confounding, as some covariates affect both treatment assignment and outcome

2. **Covariates and Outcome**:
   - X1 shows a moderate positive correlation with the outcome
   - X2 shows a weak negative correlation with the outcome
   - X3 and X4 show weak to moderate positive correlations with the outcome
   - These relationships indicate that covariates directly influence the outcome, independent of treatment

3. **Treatment and Outcome Relationship**:
   - The treatment shows a positive correlation with the outcome (before controlling for confounding)
   - This raw correlation is confounded by X1 and X2, which influence both treatment and outcome
   - The observed correlation does not represent the true causal effect, highlighting the importance of proper causal inference methods

4. **Multicollinearity Among Covariates**:
   - Most covariates show weak correlations with each other, indicating limited multicollinearity
   - This is beneficial for modeling as it reduces estimation problems associated with highly correlated predictors
   - However, even with low correlation, there may still be non-linear relationships between covariates that aren't captured by the Pearson correlation coefficient

These observations highlight the complexity of the causal structure in this dataset and the importance of methods that can properly account for confounding when estimating treatment effects.
</div>


<div style="font-size: 0.85em;">

# True CATE Computation

In this section, we compute the true Conditional Average Treatment Effect (CATE) for our synthetic data. The true CATE represents the actual causal effect of the treatment on the outcome for each individual, based on their covariates.

For our synthetic data, the true CATE is calculated using the `get_true_cate` method of the `synthetic_data_for_cate` class. For model2, the true CATE is defined by the formula:

$$\text{CATE}(X) = 4.0 \times (X_1 > 0.5) - 3.0 \times (X_2 > 0.7) + 5.0 \times (X_3 \times X_4 > 0.5) - 2.0 \times (X_5 < 0.3)$$

This formula creates heterogeneous treatment effects with the following components:
- Positive effect of 4.0 if X1 > 0.5
- Negative effect of 3.0 if X2 > 0.7
- Positive effect of 5.0 if X3*X4 > 0.5
- Negative effect of 2.0 if X5 < 0.3

By computing these true CATE values, we can:
1. Understand the actual treatment effect heterogeneity in our data
2. Use them as ground truth when evaluating CATE estimation methods
3. Analyze the distribution of treatment effects across the population
</div>


In [None]:
# Compute true CATE values using the data generator
true_cate = data_generator.get_true_cate(X)

# Add true CATE to the DataFrame
data['true_cate'] = true_cate

# Print basic statistics about the true CATE
print(f"True CATE Statistics:")
print(f"  Mean: {true_cate.mean():.4f}")
print(f"  Std Dev: {true_cate.std():.4f}")
print(f"  Min: {true_cate.min():.4f}")
print(f"  Max: {true_cate.max():.4f}")

# Count the number of positive and negative treatment effects
positive_effects = (true_cate > 0).sum()
negative_effects = (true_cate < 0).sum()
neutral_effects = (true_cate == 0).sum()

print(f"\nTreatment Effect Distribution:")
print(f"  Positive Effects: {positive_effects} ({positive_effects/len(true_cate)*100:.1f}%)")
print(f"  Negative Effects: {negative_effects} ({negative_effects/len(true_cate)*100:.1f}%)")
print(f"  Neutral Effects: {neutral_effects} ({neutral_effects/len(true_cate)*100:.1f}%)")


<div style="font-size: 0.85em;">

# True CATE Distribution Visualization

Visualizing the distribution of true treatment effects helps us understand the heterogeneity in treatment effects across the population. This can reveal:

1. **Average Effect**: The center of the distribution shows the average treatment effect
2. **Effect Heterogeneity**: The spread of the distribution shows how much treatment effects vary
3. **Subgroups**: Multiple peaks might indicate distinct subgroups with different responses
4. **Negative/Positive Effects**: The proportion of individuals with negative vs. positive effects

A narrow distribution suggests homogeneous treatment effects, while a wide distribution suggests high heterogeneity. Skewness in the distribution might indicate that certain types of individuals benefit more or less from the treatment.
</div>


In [None]:
# Create images directory if it doesn't exist
import os
os.makedirs('images', exist_ok=True)

# Plot the distribution of true CATE values
plt.figure(figsize=(12, 7))

# Plot the distribution with KDE
sns.histplot(data=true_cate, kde=True, alpha=0.6, bins=30)

# Add vertical lines for mean and zero
plt.axvline(x=true_cate.mean(), color='red', linestyle='--', linewidth=2,
            label=f'Mean CATE: {true_cate.mean():.4f}')
plt.axvline(x=0, color='black', linestyle='-', linewidth=1,
            label='No Effect')

# Add annotations
plt.title('Distribution of True Conditional Average Treatment Effects', fontsize=14)
plt.xlabel('Treatment Effect', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Show percentage of positive and negative effects
pos_pct = (true_cate > 0).mean() * 100
neg_pct = (true_cate < 0).mean() * 100
plt.annotate(f'Positive Effects: {pos_pct:.1f}%',
             xy=(0.68, 0.90), xycoords='axes fraction', fontsize=11)
plt.annotate(f'Negative Effects: {neg_pct:.1f}%',
             xy=(0.68, 0.85), xycoords='axes fraction', fontsize=11)

plt.tight_layout()

# Save and display plot
plt.savefig('images/true_cate_distribution.png', dpi=300, bbox_inches='tight')
print("True CATE distribution plot saved to images/true_cate_distribution.png")
plt.show()


<div style="font-size: 0.85em;">

# True CATE Distribution Analysis

From the true CATE distribution plot, we can draw several important conclusions:

1. **Heterogeneous Treatment Effects**: The distribution shows significant spread, indicating that treatment effects vary substantially across individuals. This confirms that the treatment does not affect all individuals equally.

2. **Multi-modal Distribution**: The distribution appears to have multiple peaks, suggesting distinct subgroups with different treatment responses. This is a direct result of the threshold-based effects in our data generation process.

3. **Positive and Negative Effects**: There's a substantial proportion of both positive and negative treatment effects, indicating that the treatment benefits some individuals while harming others. This highlights the importance of personalized treatment decisions.

4. **Average Treatment Effect**: The mean CATE (shown by the red dashed line) is positive, suggesting that the treatment is beneficial on average across the entire population. However, this average masks the significant heterogeneity.

These conclusions present several challenges for CATE estimation:

1. **Threshold Detection**: The true effects are generated using threshold-based rules (e.g., X1 > 0.5), which create sharp discontinuities. Standard regression methods might struggle to capture these abrupt changes without prior knowledge of the thresholds.

2. **Subgroup Identification**: The multi-modal nature of the distribution suggests distinct subgroups with different treatment responses. Identifying these subgroups without prior knowledge is challenging and may require specialized methods like clustering or tree-based approaches.

3. **Sign Prediction**: Accurately predicting whether an individual will experience a positive or negative effect is crucial for decision-making but challenging due to the complex interactions between covariates.

4. **Balancing Bias and Variance**: Capturing the complex heterogeneity requires flexible models, but such models risk overfitting. Finding the right balance between model complexity and generalizability is a key challenge.

These challenges highlight why advanced methods for heterogeneous treatment effect estimation are necessary, as simple approaches may fail to capture the complex patterns in the data.
</div>


<div style="font-size: 0.85em;">

# True CATE vs. Covariates Visualization

In this section, we visualize the relationship between each covariate and the true CATE values. These plots help us:

1. **Identify Threshold Effects**: Detect if treatment effects change abruptly at certain covariate values
2. **Discover Interactions**: Understand how covariates influence treatment effects
3. **Visualize Heterogeneity**: See which covariates contribute most to treatment effect variation
4. **Detect Patterns**: Identify linear, non-linear, or step-function relationships

Each plot shows one covariate (X-axis) against the true CATE (Y-axis), with a horizontal line at y=0 separating positive from negative effects.
</div>


In [None]:
# Plot the relationship between covariates and true CATE
fig, axs = plt.subplots(2, 3, figsize=(18, 12))
axs = axs.flatten()

# Plot CATE vs each covariate
for i in range(5):
    axs[i].scatter(X[:, i], true_cate, alpha=0.5)
    axs[i].set_xlabel(f'X{i+1}', fontsize=12)
    axs[i].set_ylabel('True CATE', fontsize=12)
    axs[i].set_title(f'True CATE vs X{i+1}', fontsize=14)
    axs[i].grid(True, alpha=0.3)

    # Add a horizontal line at y=0
    axs[i].axhline(y=0, color='black', linestyle='-', linewidth=1)

# Remove the unused subplot
fig.delaxes(axs[5])

plt.tight_layout()

# Save and display plot
plt.savefig('images/true_cate_vs_covariates.png', dpi=300, bbox_inches='tight')
print("True CATE vs covariates plot saved to images/true_cate_vs_covariates.png")
plt.show()


<div style="font-size: 0.85em;">

# True CATE vs. Covariates Analysis

The plots showing the relationship between each covariate and the true CATE reveal several important patterns:

1. **Threshold Effects**: For X1 and X2, we observe clear threshold effects where the CATE values jump at specific values (X1 ≈ 0.5 and X2 ≈ 0.7). This reflects the threshold-based rules used in the data generation process.

2. **Interaction Effects**: The relationship between X3/X4 and CATE is more complex and doesn't show a clear pattern when viewed individually. This is because the true effect depends on the interaction X3*X4 > 0.5, which isn't visible in single-variable plots.

3. **Step Function for X5**: The plot for X5 shows a clear step function with lower CATE values when X5 < 0.3, directly reflecting the data generation rule.

4. **Heterogeneity Sources**: These plots help identify which covariates contribute most to treatment effect heterogeneity. All five covariates influence the treatment effect, but in different ways and with different patterns.

These observations present several challenges for CATE estimation:

1. **Non-linear Relationships**: The sharp discontinuities and threshold effects are difficult for many standard regression methods to capture accurately. Linear models would completely miss these patterns.

2. **Interaction Detection**: The plots for X3 and X4 individually don't reveal the true pattern because it depends on their interaction. Detecting such interactions automatically is challenging and often requires specialized methods or domain knowledge.

3. **Variable Selection**: While all covariates influence the treatment effect, their importance and patterns differ. Determining which variables to include in a model and how to model their effects is challenging.

4. **Model Specification**: The diverse patterns across covariates suggest that a single functional form (e.g., linear, quadratic) won't adequately capture all relationships. Flexible, non-parametric approaches may be needed.

5. **Extrapolation Risk**: The clear threshold effects mean that extrapolating beyond the observed data range is particularly risky, as the pattern might change dramatically at unobserved thresholds.

These challenges highlight the importance of using flexible, adaptive methods for CATE estimation that can capture complex, non-linear relationships and interactions between covariates and treatment effects.
</div>
