# Structural Causal Models

This notebook explores structural causal models (SCMs) and their applications in causal inference.

In [None]:
# Import necessary modules
import sys
import os

# Add the root directory to the path to make imports work
root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root_dir not in sys.path:
    sys.path.append(root_dir)

# Import common libraries
import numpy as np
import matplotlib.pyplot as plt

# Import the causal meta-learning library
from causal_meta.graph import Graph, DirectedGraph, CausalGraph
import causal_meta.graph.visualization as viz

# Structural Causal Models

This notebook demonstrates how to work with Structural Causal Models (SCMs) in the causal meta-learning library. We'll cover:

1. Creating and defining SCMs
2. Defining structural equations and mechanisms
3. Sampling data from SCMs
4. Performing interventions
5. Calculating causal effects

Let's get started!

In [None]:
# Import necessary modules
import sys
import os

# Add the root directory to the path to make imports work
root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root_dir not in sys.path:
    sys.path.append(root_dir)

# Import the necessary modules
from causal_meta.graph import CausalGraph
from causal_meta.environments.scm import StructuralCausalModel
import causal_meta.graph.visualization as viz

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

## 1. Creating and Defining SCMs

A Structural Causal Model (SCM) represents a system of causal relationships where each variable is determined by a function of its parent variables and some independent noise. Let's create a simple SCM.

In [None]:
# Create a causal graph for our SCM
graph = CausalGraph()

# Add nodes
graph.add_node('X')
graph.add_node('Y')
graph.add_node('Z')

# Add edges to define the causal structure
graph.add_edge('X', 'Y')  # X causes Y
graph.add_edge('Z', 'X')  # Z causes X
graph.add_edge('Z', 'Y')  # Z causes Y (directly)

# Visualize the graph
plt.figure(figsize=(8, 6))
ax = plt.gca()
viz.plot_causal_graph(graph, ax=ax, title="Causal Graph for our SCM")
plt.show()

# Create the SCM with this causal graph
scm = StructuralCausalModel(causal_graph=graph)

# Add variables with domains
scm.add_variable('X', domain='continuous')
scm.add_variable('Y', domain='continuous')
scm.add_variable('Z', domain='continuous')

## 2. Defining Structural Equations and Mechanisms

Now we need to define the structural equations that determine how each variable is influenced by its parents.

In [None]:
# Define linear Gaussian structural equations

# Z is an exogenous variable (no parents)
scm.define_linear_gaussian_equation('Z', {}, intercept=0, noise_std=1.0)

# X depends on Z
scm.define_linear_gaussian_equation('X', {'Z': 0.7}, intercept=0, noise_std=0.5)

# Y depends on both X and Z
scm.define_linear_gaussian_equation('Y', {'X': 0.6, 'Z': 0.3}, intercept=0, noise_std=0.3)

# Inspect the SCM
print(scm)

### Using Custom Structural Equations

We can also define custom non-linear equations for more complex relationships.

In [None]:
# Create a new SCM with non-linear relationships
nonlinear_graph = CausalGraph()
for node in ['X', 'Y']:
    nonlinear_graph.add_node(node)
nonlinear_graph.add_edge('X', 'Y')

nonlinear_scm = StructuralCausalModel(causal_graph=nonlinear_graph)
nonlinear_scm.add_variable('X', domain='continuous')
nonlinear_scm.add_variable('Y', domain='continuous')

# X is exogenous with standard normal distribution
def x_equation(noise):
    return noise

# Y is a non-linear function of X (quadratic with noise)
def y_equation(X, noise):
    return 0.5 * (X ** 2) - 0.3 * X + noise

# Define noise distributions
def normal_noise(rng):
    return rng.normal(0, 1)

# Define the structural equations
nonlinear_scm.define_probabilistic_equation('X', x_equation, normal_noise)
nonlinear_scm.define_probabilistic_equation('Y', y_equation, lambda rng: rng.normal(0, 0.5))

print(nonlinear_scm)

## 3. Sampling Data from SCMs

Once we've defined our structural equations, we can sample data from the SCM.

In [None]:
# Sample data from our linear SCM
data = scm.sample_data(1000, random_seed=42)

# Display the first few rows
print("Data from linear SCM:")
print(data.head())

# Calculate the correlations
print("\nCorrelation matrix:")
print(data.corr().round(3))

# Visualize the relationships with a pairplot
plt.figure(figsize=(12, 10))
sns.pairplot(data)
plt.suptitle("Relationships between variables", y=1.02)
plt.show()

In [None]:
# Sample data from the non-linear SCM
nonlinear_data = nonlinear_scm.sample_data(1000, random_seed=42)

# Display the first few rows
print("Data from non-linear SCM:")
print(nonlinear_data.head())

# Visualize the non-linear relationship
plt.figure(figsize=(10, 6))
plt.scatter(nonlinear_data['X'], nonlinear_data['Y'], alpha=0.5)
plt.title("Non-linear Relationship: Y = 0.5X² - 0.3X + ε")
plt.xlabel("X")
plt.ylabel("Y")
plt.grid(True)

# Add the true functional relationship line
x_range = np.linspace(-3, 3, 100)
plt.plot(x_range, 0.5 * (x_range ** 2) - 0.3 * x_range, 'r-', label="True function")
plt.legend()
plt.show()

## 4. Performing Interventions

One of the key features of SCMs is the ability to perform interventions, which allow us to model counterfactual scenarios.

In [None]:
# Let's perform an intervention on X in our linear SCM
scm.do_intervention('X', 2.0)  # Set X to a fixed value

# Sample data under the intervention
interventional_data = scm.sample_data(1000, random_seed=42)

# Display the first few rows
print("Interventional data with do(X=2.0):")
print(interventional_data.head())

# Verify X has been fixed to 2.0
print(f"\nMean of X: {interventional_data['X'].mean():.5f}")
print(f"Standard deviation of X: {interventional_data['X'].std():.5f}")

# Reset the SCM to remove the intervention
scm.reset()

# Compare with observational data
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Observational data
axes[0].scatter(data['X'], data['Y'], alpha=0.5)
axes[0].set_title("Observational Data")
axes[0].set_xlabel("X")
axes[0].set_ylabel("Y")
axes[0].grid(True)

# Interventional data
axes[1].scatter(interventional_data['X'], interventional_data['Y'], alpha=0.5)
axes[1].set_title("Interventional Data with do(X=2.0)")
axes[1].set_xlabel("X")
axes[1].set_ylabel("Y")
axes[1].grid(True)

plt.tight_layout()
plt.show()

### Multiple Interventions

We can also perform multiple interventions simultaneously.

In [None]:
# Perform multiple interventions
scm.multiple_interventions({'X': 1.0, 'Z': -1.0})  # Set X=1.0 and Z=-1.0

# Sample data under the multiple interventions
multi_interventional_data = scm.sample_data(1000, random_seed=42)

# Display the first few rows
print("Data with do(X=1.0, Z=-1.0):")
print(multi_interventional_data.head())

# Verify X and Z have been fixed
print(f"\nMean of X: {multi_interventional_data['X'].mean():.5f}")
print(f"Mean of Z: {multi_interventional_data['Z'].mean():.5f}")

# Reset the SCM to remove the interventions
scm.reset()

## 5. Calculating Causal Effects

SCMs allow us to estimate causal effects by measuring how changes in one variable affect another.

In [None]:
# Calculate the average causal effect of X on Y
effect = scm.compute_effect(treatment='X', outcome='Y', 
                          treatment_value=1.0, baseline_value=0.0,
                          sample_size=5000, random_seed=42)

print(f"Average Causal Effect of X on Y (when X changes from 0 to 1): {effect:.4f}")

# This should be close to the coefficient 0.6 that we used in the structural equation
print(f"True coefficient in the structural equation: 0.6")

### Direct and Indirect Effects

We can also decompose the total causal effect into direct and indirect components.

In [None]:
# Calculate the direct effect of Z on Y
direct_effect = scm.compute_direct_effect(
    treatment='Z', outcome='Y', 
    treatment_value=1.0, baseline_value=0.0,
    sample_size=5000, random_seed=42
)

print(f"Direct Effect of Z on Y: {direct_effect:.4f}")
print(f"True direct effect in the structural equation: 0.3")

# Calculate the indirect effect of Z on Y through X
indirect_effect = scm.compute_indirect_effect(
    treatment='Z', outcome='Y', 
    treatment_value=1.0, baseline_value=0.0,
    mediators=['X'],  # X is the mediator 
    sample_size=5000, random_seed=42
)

print(f"\nIndirect Effect of Z on Y through X: {indirect_effect:.4f}")
print(f"Expected indirect effect: Z→X coefficient × X→Y coefficient = 0.7 × 0.6 = 0.42")

# Calculate total effect
total_effect = scm.compute_effect(
    treatment='Z', outcome='Y', 
    treatment_value=1.0, baseline_value=0.0,
    sample_size=5000, random_seed=42
)

print(f"\nTotal Effect of Z on Y: {total_effect:.4f}")
print(f"Direct + Indirect Effect: {direct_effect + indirect_effect:.4f}")

## Visualizing Intervention Effects

Let's visualize how interventions on a variable affect the distribution of an outcome.

In [None]:
# Get intervention effects for a range of intervention values
treatment = 'X'
outcome = 'Y'
intervention_values = np.linspace(-2, 2, 9)  # Intervention values from -2 to 2

# Get the distribution of Y under each intervention
all_data = []
distributions = []

for value in intervention_values:
    # Perform intervention
    scm.do_intervention(treatment, value)
    
    # Sample data
    int_data = scm.sample_data(1000, random_seed=42)
    
    # Store the distribution of Y
    distributions.append(int_data[outcome].values)
    
    # Add intervention value for plotting
    int_data['Intervention'] = f"{treatment}={value:.1f}"
    all_data.append(int_data)
    
    # Reset SCM
    scm.reset()

# Combine all data
combined_data = pd.concat(all_data)

# Create a boxplot to compare Y distributions across interventions
plt.figure(figsize=(12, 6))
sns.boxplot(x='Intervention', y=outcome, data=combined_data)
plt.title(f"Distribution of {outcome} under different interventions on {treatment}")
plt.xlabel(f"Intervention on {treatment}")
plt.ylabel(outcome)
plt.grid(True, axis='y')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## Counterfactual Analysis

SCMs also allow us to reason about counterfactuals - what would have happened in a specific case if we had intervened differently.

In [None]:
# Generate a "factual" observation
scm.reset()
factual_data = scm.sample_data(1, random_seed=42)
print("Factual observation:")
print(factual_data)

# Perform counterfactual reasoning: what would Y have been if X was set to 2.0?
counterfactual_data = scm.evaluate_counterfactual(
    factual_data=factual_data, 
    interventions={'X': 2.0}
)

print("\nCounterfactual scenario (do(X=2.0)):")
print(counterfactual_data)

# Calculate the difference in Y
y_factual = factual_data['Y'].values[0]
y_counterfactual = counterfactual_data['Y'].values[0]
print(f"\nFactual Y: {y_factual:.4f}")
print(f"Counterfactual Y: {y_counterfactual:.4f}")
print(f"Difference: {y_counterfactual - y_factual:.4f}")

## Summary

In this notebook, we've explored Structural Causal Models (SCMs) in the causal meta-learning library:

1. We created SCMs with both linear and non-linear structural equations
2. We sampled observational data from these models
3. We performed interventions to simulate experimental data
4. We calculated direct, indirect, and total causal effects
5. We visualized how interventions affect outcome distributions
6. We performed counterfactual analysis on specific observations

SCMs provide a powerful framework for causal modeling, enabling not just observational inference but also interventional and counterfactual reasoning.