# Example Use Case: Medical Treatment Effect

This notebook demonstrates a complete workflow for estimating causal effects in a medical treatment scenario.

In [None]:
# Import necessary modules
import sys
import os

# Add the root directory to the path to make imports work
root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root_dir not in sys.path:
    sys.path.append(root_dir)

# Import common libraries
import numpy as np
import matplotlib.pyplot as plt

# Import the causal meta-learning library
from causal_meta.graph import Graph, DirectedGraph, CausalGraph
import causal_meta.graph.visualization as viz

# Example Use Case: Simulating a Treatment Effect Study

This notebook demonstrates an end-to-end example of using the causal meta-learning library to simulate a treatment effect study. We'll:

1. Define a causal graph structure
2. Create a structural causal model (SCM)
3. Generate synthetic data
4. Analyze treatment effects
5. Compare naive vs. causal estimation methods

Let's get started!

In [None]:
# Import necessary modules
import sys
import os

# Add the root directory to the path to make imports work
root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root_dir not in sys.path:
    sys.path.append(root_dir)

# Import the necessary modules
from causal_meta.graph import CausalGraph
from causal_meta.environments.scm import StructuralCausalModel
import causal_meta.graph.visualization as viz

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression

## 1. Define Causal Graph Structure

We'll model a simple medical treatment scenario with the following variables:
- `Age`: Patient's age
- `Severity`: Disease severity
- `Treatment`: Whether the patient received treatment
- `Recovery`: Patient's recovery outcome

The causal structure represents a scenario where age affects disease severity, severity influences both treatment decisions and recovery outcomes, and treatment affects recovery.

In [None]:
# Create the causal graph
graph = CausalGraph()

# Add nodes
nodes = ['Age', 'Severity', 'Treatment', 'Recovery']
for node in nodes:
    graph.add_node(node)

# Add edges representing causal relationships
graph.add_edge('Age', 'Severity')        # Age affects disease severity
graph.add_edge('Severity', 'Treatment')  # Severity influences treatment decisions
graph.add_edge('Severity', 'Recovery')   # Severity affects recovery
graph.add_edge('Treatment', 'Recovery')  # Treatment affects recovery

# Visualize the causal graph
plt.figure(figsize=(10, 6))
ax = plt.gca()
viz.plot_causal_graph(graph, ax=ax, title="Causal Graph for Treatment Effect Study")
plt.show()

## 2. Create a Structural Causal Model

Now we'll define the structural equations that govern the relationships between variables.

In [None]:
# Create the SCM
scm = StructuralCausalModel(causal_graph=graph)

# Add variables with appropriate domains
scm.add_variable('Age', domain='continuous')
scm.add_variable('Severity', domain='continuous')
scm.add_variable('Treatment', domain='binary')
scm.add_variable('Recovery', domain='continuous')

# Define structural equations

# Age is exogenous, normally distributed around 50 with std=15
scm.define_linear_gaussian_equation('Age', {}, intercept=50, noise_std=15)

# Severity increases with age (0.03 units per year of age)
scm.define_linear_gaussian_equation('Severity', {'Age': 0.03}, intercept=1, noise_std=0.5)

# Treatment is more likely with higher severity (logistic function)
def treatment_equation(Severity, noise):
    # Logistic function to determine treatment probability
    prob = 1 / (1 + np.exp(-2 * (Severity - 2.5)))
    # Treatment = 1 if noise < prob, else 0
    return 1 if noise < prob else 0

scm.define_probabilistic_equation('Treatment', treatment_equation, lambda rng: rng.uniform(0, 1))

# Recovery depends on severity (negatively) and treatment (positively)
scm.define_linear_gaussian_equation('Recovery', 
                                  {'Severity': -0.7, 'Treatment': 1.5}, 
                                  intercept=5, 
                                  noise_std=0.3)

print(scm)

## 3. Generate Synthetic Data

Let's generate observational data from our SCM to simulate an observational study.

In [None]:
# Sample observational data
obs_data = scm.sample_data(1000, random_seed=42)

# Display the first few rows
print("Observational data:")
print(obs_data.head())

# Basic summary statistics
print("\nSummary statistics:")
print(obs_data.describe().round(2))

# Calculate treatment rate
treatment_rate = obs_data['Treatment'].mean() * 100
print(f"\nTreatment rate: {treatment_rate:.1f}%")

# Visualize data distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()

# Age distribution
sns.histplot(obs_data['Age'], kde=True, ax=axes[0])
axes[0].set_title('Age Distribution')

# Severity distribution
sns.histplot(obs_data['Severity'], kde=True, ax=axes[1])
axes[1].set_title('Severity Distribution')

# Treatment distribution
sns.countplot(x='Treatment', data=obs_data, ax=axes[2])
axes[2].set_title('Treatment Distribution')

# Recovery distribution
sns.histplot(obs_data['Recovery'], kde=True, ax=axes[3])
axes[3].set_title('Recovery Distribution')

plt.tight_layout()
plt.show()

### Examining Relationships Between Variables

In [None]:
# Examine relationship between Severity and Treatment
plt.figure(figsize=(10, 6))
sns.boxplot(x='Treatment', y='Severity', data=obs_data)
plt.title('Severity by Treatment Group')
plt.grid(True, axis='y')
plt.show()

# Summary by treatment group
treatment_summary = obs_data.groupby('Treatment').agg({
    'Age': ['mean', 'std'],
    'Severity': ['mean', 'std'],
    'Recovery': ['mean', 'std', 'count']
}).round(2)

print("Summary by treatment group:")
print(treatment_summary)

## 4. Analyze Treatment Effects

Now let's analyze the treatment effect using both naive methods and causal inference techniques.

In [None]:
# Naive approach: directly compare treatment groups
treated = obs_data[obs_data['Treatment'] == 1]['Recovery'].mean()
untreated = obs_data[obs_data['Treatment'] == 0]['Recovery'].mean()
naive_effect = treated - untreated

print(f"Naive approach (simple difference in means):")
print(f"  Treated group mean recovery: {treated:.3f}")
print(f"  Untreated group mean recovery: {untreated:.3f}")
print(f"  Estimated treatment effect: {naive_effect:.3f}")

# Adjustment approach: control for confounding using regression
X = obs_data[['Treatment', 'Severity']]
y = obs_data['Recovery']
model = LinearRegression().fit(X, y)

# The coefficient for Treatment in this model represents the adjusted effect
adjusted_effect = model.coef_[0]

print(f"\nAdjustment approach (linear regression with severity adjustment):")
print(f"  Regression coefficients: Treatment={model.coef_[0]:.3f}, Severity={model.coef_[1]:.3f}, Intercept={model.intercept_:.3f}")
print(f"  Adjusted treatment effect: {adjusted_effect:.3f}")

# Causal approach: use the SCM to perform interventions
causal_effect = scm.compute_effect(
    treatment='Treatment', 
    outcome='Recovery', 
    treatment_value=1, 
    baseline_value=0,
    sample_size=5000, 
    random_seed=42
)

print(f"\nCausal approach (intervention-based using SCM):")
print(f"  True causal effect: {causal_effect:.3f}")
print(f"  True value defined in SCM: 1.5")

# Compare all approaches
print(f"\nComparison of approaches:")
print(f"  Naive estimate: {naive_effect:.3f}")
print(f"  Adjusted estimate: {adjusted_effect:.3f}")
print(f"  True causal effect: {causal_effect:.3f}")

### Visualizing the Confounding Effect

In [None]:
# Create interventional datasets
scm.do_intervention('Treatment', 1)  # Everyone gets treatment
treated_data = scm.sample_data(1000, random_seed=43)
treated_data['Group'] = 'do(Treatment=1)'

scm.reset()
scm.do_intervention('Treatment', 0)  # No one gets treatment
untreated_data = scm.sample_data(1000, random_seed=43)
untreated_data['Group'] = 'do(Treatment=0)'

# Create observational comparison data
scm.reset()
obs_comparison = obs_data.copy()
obs_comparison['Group'] = 'Observed T=' + obs_comparison['Treatment'].astype(str)

# Combine all data for visualization
combined_data = pd.concat([treated_data, untreated_data, obs_comparison])

# Plot recovery distributions
plt.figure(figsize=(12, 6))
sns.boxplot(x='Group', y='Recovery', data=combined_data, 
            order=['do(Treatment=1)', 'do(Treatment=0)', 'Observed T=1', 'Observed T=0'])
plt.axhline(treated_data['Recovery'].mean(), color='r', linestyle='--', alpha=0.5)
plt.axhline(untreated_data['Recovery'].mean(), color='b', linestyle='--', alpha=0.5)
plt.title('Recovery Distributions: Interventional vs Observational')
plt.grid(True, axis='y')
plt.show()

# Calculate means for each group
means = combined_data.groupby('Group')['Recovery'].mean().reindex(
    ['do(Treatment=1)', 'do(Treatment=0)', 'Observed T=1', 'Observed T=0'])
print("Mean recovery by group:")
print(means)

# Calculate true and naive effects
true_effect = means['do(Treatment=1)'] - means['do(Treatment=0)']
obs_effect = means['Observed T=1'] - means['Observed T=0']
print(f"\nTrue causal effect: {true_effect:.3f}")
print(f"Naive observational effect: {obs_effect:.3f}")
print(f"Bias: {obs_effect - true_effect:.3f}")

### Relationship Between Severity and Treatment Assignment

In [None]:
# Plot the relationship between severity and treatment probability
severities = np.linspace(0, 5, 100)
treatment_probs = 1 / (1 + np.exp(-2 * (severities - 2.5)))

plt.figure(figsize=(10, 6))
plt.plot(severities, treatment_probs, 'b-', linewidth=2)
plt.xlabel('Severity')
plt.ylabel('Probability of Treatment')
plt.title('Treatment Assignment Probability by Disease Severity')
plt.grid(True)
plt.axhline(0.5, color='r', linestyle='--', alpha=0.5)
plt.axvline(2.5, color='r', linestyle='--', alpha=0.5)
plt.text(2.6, 0.52, 'Severity = 2.5', color='r')
plt.ylim(0, 1)
plt.show()

## 5. Exploring Counterfactuals

Finally, let's examine some counterfactual scenarios for individual patients.

In [None]:
# Reset the SCM
scm.reset()

# Sample a few patients with different characteristics
patients = scm.sample_data(3, random_seed=42)
print("Sample patients:")
print(patients)

# Perform counterfactual analysis for each patient
results = []

for i, patient in patients.iterrows():
    patient_df = pd.DataFrame([patient])
    actual_treatment = patient['Treatment']
    counterfactual_treatment = 1 - actual_treatment
    
    # Evaluate counterfactual with opposite treatment
    counterfactual = scm.evaluate_counterfactual(
        factual_data=patient_df,
        interventions={'Treatment': counterfactual_treatment}
    )
    
    # Extract outcomes
    factual_recovery = patient['Recovery']
    counterfactual_recovery = counterfactual['Recovery'].values[0]
    
    # Individual treatment effect
    if actual_treatment == 1:
        ite = factual_recovery - counterfactual_recovery  # Effect of treatment vs. no treatment
    else:
        ite = counterfactual_recovery - factual_recovery  # Effect of treatment vs. no treatment
    
    results.append({
        'Patient': i+1,
        'Age': patient['Age'],
        'Severity': patient['Severity'],
        'Actual Treatment': actual_treatment,
        'Actual Recovery': factual_recovery,
        'Counterfactual Treatment': counterfactual_treatment,
        'Counterfactual Recovery': counterfactual_recovery,
        'Individual Treatment Effect': ite
    })

# Create results dataframe
results_df = pd.DataFrame(results)
print("\nCounterfactual analysis results:")
print(results_df.round(3))

# Visualize the individual treatment effects
plt.figure(figsize=(12, 6))
for i, row in results_df.iterrows():
    plt.plot([row['Actual Treatment'], row['Counterfactual Treatment']], 
             [row['Actual Recovery'], row['Counterfactual Recovery']], 
             'o-', linewidth=2, label=f"Patient {row['Patient']} (Severity={row['Severity']:.2f})")
    
plt.xlabel('Treatment')
plt.ylabel('Recovery')
plt.title('Individual Treatment Effects')
plt.xticks([0, 1], ['No Treatment', 'Treatment'])
plt.grid(True)
plt.legend()
plt.show()

## Summary

In this example, we demonstrated an end-to-end causal inference workflow:

1. We defined a causal graph representing treatment effects with confounding
2. We created a structural causal model with specific mechanisms
3. We generated synthetic observational data from the SCM
4. We analyzed the treatment effect using different approaches:
   - The naive approach (direct comparison) was biased due to confounding
   - The adjusted approach (regression) provided a better estimate
   - The causal approach using interventions gave the true effect
5. We evaluated counterfactual scenarios for individual patients

This example illustrates the importance of causal reasoning in correctly estimating treatment effects, especially in the presence of confounding variables.