# Multi-Level Causal Inference in Large Language Models: From Prompt Engineering to Attention Mechanisms

## Abstract

This notebook provides a comprehensive guide to applying causal inference methods to Large Language Models (LLMs), bridging the gap between traditional causal analysis and modern deep learning. Readers will learn how to identify and estimate causal effects in LLMs through two complementary approaches: external causality (prompt-level interventions using propensity score matching) and internal causality (attention mechanisms as mediators using mediation analysis). The role of causal inference in ML is critical for understanding *why* models behave in specific ways, moving beyond correlation to establish causal claims about prompt effectiveness, feature importance, and model behavior. Practical applications include optimizing prompt engineering strategies, interpreting attention patterns in reasoning, and improving model interpretability through rigorous causal methods. This work demonstrates how causal frameworks can systematically analyze LLM behavior, providing methodological tools for researchers and practitioners seeking to understand and improve language model performance.

**Key Learning Outcomes:**
- Foundational causal inference concepts with LLM applications
- Propensity score matching for prompt-level causal analysis
- Causal mediation analysis for understanding attention mechanisms
- Data preparation techniques for causal inference in ML contexts
- Practical implementation using DoWhy, CausalML, and modern LLM frameworks

---

# Part 1: Theory of Causal Inference in Machine Learning

## 1.1 Causality Fundamentals

### Correlation vs Causation in LLMs

In machine learning, we frequently observe correlations: "Models trained on more data tend to perform better" or "Longer prompts yield better completions." However, correlation does not imply causation. For example, longer prompts might correlate with better completions simply because longer prompts provide more context, not because length itself causes improvement. This distinction is critical when designing experiments and interpreting model behavior.

### The Potential Outcomes Framework (Rubin Causal Model)

The **potential outcomes framework**, developed by Donald Rubin, provides a formal language for causal inference. For each unit (e.g., a prompt-task pair), we define:

- **Y_i(1)**: Potential outcome if unit receives treatment
- **Y_i(0)**: Potential outcome if unit does not receive treatment
- **T_i**: Treatment indicator (1 if treated, 0 if control)

The fundamental problem of causal inference is that we can only observe one outcome per unit (the realized outcome), never both potential outcomes simultaneously. This motivates the need for causal methods that can estimate counterfactuals.

**Individual Treatment Effect (ITE):** 
$$\tau_i = Y_i(1) - Y_i(0)$$

**Average Treatment Effect (ATE):**
$$\tau = \mathbb{E}[Y(1) - Y(0)] = \mathbb{E}[Y(1)] - \mathbb{E}[Y(0)]$$

In LLM contexts, examples include:
- Treatment: Using chain-of-thought prompting vs. direct prompting
- Outcome: Task performance (accuracy, completion quality)
- Causal question: Does chain-of-thought *cause* improved performance, or correlate with task types where it's naturally more effective?

### Causal Graphs (Directed Acyclic Graphs - DAGs)

Causal graphs, formalized by Judea Pearl, represent causal relationships visually and mathematically:

- **Nodes**: Variables (treatment, outcome, confounders)
- **Edges**: Direct causal relationships
- **Paths**: Sequences of edges connecting variables

Key path types:
1. **Causal path**: T → O (treatment causes outcome)
2. **Backdoor path**: T ← X → O (confounding through X)
3. **Front-door path**: T → M → O (mediation through M)

### Key Assumptions

1. **Unconfoundedness**: Given observed covariates X, treatment assignment is independent of potential outcomes:
   $$\{Y(1), Y(0)\} \perp\perp T \mid X$$
   
2. **Stable Unit Treatment Value Assumption (SUTVA)**: One unit's treatment doesn't affect another's outcome, and there's no hidden variation in treatment effects.

3. **Positivity (Overlap)**: Every unit has a non-zero probability of receiving each treatment level:
   $$0 < P(T=1 \mid X) < 1 \quad \forall X$$

These assumptions are crucial for causal identification and will be explicitly checked in our examples.

In [None]:
# Install required packages for theory section
import sys
import subprocess

packages = [
    'matplotlib', 'seaborn', 'networkx', 'numpy', 'pandas',
    'scikit-learn', 'dowhy', 'causalml'
]

for package in packages:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '--quiet'])

print("All packages installed successfully!")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
import warnings
warnings.filterwarnings('ignore')

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("Libraries imported successfully!")

### DAG 1: Confounding in LLM Experiments

This DAG illustrates a common confounding scenario: the relationship between instruction format (treatment) and task completion quality (outcome) is confounded by task difficulty.

In [None]:
# Create DAG for Example 1: Confounding in Instruction Format
G1 = nx.DiGraph()

# Add nodes
G1.add_node('Task_Difficulty', pos=(0, 2), color='lightblue', style='filled')
G1.add_node('Prompt_Length', pos=(2, 2), color='lightblue', style='filled')
G1.add_node('Instruction_Format', pos=(1, 0), color='lightgreen', style='filled')
G1.add_node('Task_Completion', pos=(1, -2), color='salmon', style='filled')

# Add edges
G1.add_edge('Task_Difficulty', 'Instruction_Format')
G1.add_edge('Task_Difficulty', 'Task_Completion')
G1.add_edge('Prompt_Length', 'Instruction_Format')
G1.add_edge('Prompt_Length', 'Task_Completion')
G1.add_edge('Instruction_Format', 'Task_Completion')

# Draw
fig, ax = plt.subplots(figsize=(10, 8))
pos = nx.get_node_attributes(G1, 'pos')
colors = [G1.nodes[n]['color'] for n in G1.nodes()]

nx.draw(G1, pos, ax=ax, with_labels=True, node_color=colors, 
        node_size=4000, font_size=10, font_weight='bold',
        arrowsize=20, edge_color='gray', width=2)

# Add annotations
ax.text(0.5, 1.5, 'Confounders', ha='center', fontsize=12, fontweight='bold', color='blue')
ax.text(1, -0.5, 'Treatment', ha='center', fontsize=12, fontweight='bold', color='green')
ax.text(1, -2.7, 'Outcome', ha='center', fontsize=12, fontweight='bold', color='red')

plt.title('Figure 1: Confounding in LLM Instruction Format Experiment', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

print("""\nKey Insights from DAG:
-------------------------
• Task Difficulty affects BOTH instruction format choice and completion quality
  → Creates spurious correlation between format and completion
• Prompt Length is another confounder
• Backdoor path: Format ← Difficulty → Completion (must be blocked)
• Causal path: Format → Completion (true effect we want to estimate)
""")

## 1.2 Propensity Score Matching

### Why Matching?

Propensity Score Matching (PSM) creates balanced treatment and control groups by matching units with similar probabilities of receiving treatment, given their covariates. This addresses confounding by ensuring that treated and control units are comparable on observed characteristics.

**Intuition**: If we match tasks that have equal probability of being assigned Format B vs Format A, any difference in outcomes can be attributed to the format itself, not to underlying differences in task difficulty or other confounders.

### Propensity Score Definition

The propensity score is the probability of treatment assignment conditional on observed covariates:

$$e(X) = P(T=1 \mid X)$$

where X represents observed confounders (task difficulty, prompt length, etc.).

**Key Property (Rosenbaum & Rubin, 1983):**
If treatment assignment is strongly ignorable given X, then it is also strongly ignorable given the propensity score e(X):

$$\{Y(1), Y(0)\} \perp\perp T \mid e(X)$$

This allows us to reduce multidimensional covariates to a single scalar while preserving balancing properties.

### Matching Algorithms

1. **Nearest Neighbor Matching**: Match each treated unit to the control unit with the closest propensity score
2. **Caliper Matching**: Only match if scores are within a specified tolerance (e.g., 0.1 standard deviations)
3. **Kernel Matching**: Use weighted averages of all controls, with weights inversely proportional to distance

### Balance Checking

Before and after matching, we must check balance:

- **Standardized Mean Difference (SMD)**: $\frac{\bar{X}_T - \bar{X}_C}{\sqrt{(s_T^2 + s_C^2)/2}}$
  - SMD < 0.1 indicates good balance
- **Variance Ratio**: Close to 1 indicates similar variances
- **Visual inspection**: Density plots of covariates by treatment status

### Treatment Effect Estimation

After matching, we estimate:

- **ATE (Average Treatment Effect)**: $\frac{1}{N} \sum_{i=1}^N (2T_i - 1) \cdot (Y_i - Y_{m(i)})$
- **ATT (Average Treatment Effect on the Treated)**: More common in practice

**Confidence Intervals**: Use bootstrapping or analytical methods for inference.

### DAG 2: Mediation in Attention Mechanisms

This DAG illustrates causal mediation: the effect of prompt intervention on reasoning quality is mediated through attention patterns.

In [None]:
# Create DAG for Example 2: Mediation in Attention Mechanisms
G2 = nx.DiGraph()

# Add nodes
G2.add_node('Task_Complexity', pos=(0, 2), color='lightblue', style='filled')
G2.add_node('Prompt_Intervention', pos=(1, 0), color='lightgreen', style='filled')
G2.add_node('Attention_Patterns', pos=(2, 0), color='lightyellow', style='filled')
G2.add_node('Reasoning_Quality', pos=(1, -2), color='salmon', style='filled')

# Add edges
G2.add_edge('Task_Complexity', 'Prompt_Intervention')
G2.add_edge('Task_Complexity', 'Attention_Patterns')
G2.add_edge('Prompt_Intervention', 'Attention_Patterns')
G2.add_edge('Prompt_Intervention', 'Reasoning_Quality')
G2.add_edge('Attention_Patterns', 'Reasoning_Quality')

# Draw
fig, ax = plt.subplots(figsize=(10, 8))
pos = nx.get_node_attributes(G2, 'pos')
colors = [G2.nodes[n]['color'] for n in G2.nodes()]

nx.draw(G2, pos, ax=ax, with_labels=True, node_color=colors,
        node_size=4500, font_size=9, font_weight='bold',
        arrowsize=20, edge_color='gray', width=2)

# Add path annotations
ax.arrow(1, -0.2, 0.9, 0, head_width=0.1, head_length=0.1, 
         fc='red', ec='red', alpha=0.5, linestyle='--', linewidth=2)
ax.arrow(2, -0.2, -0.9, -1.5, head_width=0.1, head_length=0.1,
         fc='red', ec='red', alpha=0.5, linestyle='--', linewidth=2)

ax.text(1.5, -0.5, 'Mediator', ha='center', fontsize=11, fontweight='bold', color='orange')
ax.text(1.5, -1.3, 'Indirect Effect', ha='center', fontsize=9, color='red', rotation=-30)

plt.title('Figure 2: Causal Mediation in LLM Attention Mechanisms', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

print("""\nKey Insights from DAG:
-------------------------
• Prompt Intervention affects Reasoning Quality through TWO paths:
  1. Direct Effect: Prompt → Reasoning (unmediated)
  2. Indirect Effect: Prompt → Attention → Reasoning (mediated)
• Task Complexity confounds the relationship
• Attention Patterns act as a mediator (explanation mechanism)
• Total Effect = Direct Effect + Indirect Effect
""")

## 1.3 Causal Mediation Analysis

### Direct vs Indirect Effects

Causal mediation analysis decomposes the total treatment effect into:

- **Direct Effect (DE)**: Effect of treatment on outcome *not* through the mediator
- **Indirect Effect (IE)**: Effect of treatment on outcome *through* the mediator

Mathematically:
$$\tau_{\text{total}} = \tau_{\text{direct}} + \tau_{\text{indirect}}$$

### Mediator Definition and Role

A **mediator** is a variable that lies on the causal path between treatment and outcome:

Treatment (T) → Mediator (M) → Outcome (Y)

Key requirements:
1. M is causally affected by T
2. M causally affects Y
3. M is not a collider

### Mediation in LLM Context

In our Example 2:
- **Treatment**: Chain-of-thought prompting vs. direct prompting
- **Mediator**: Attention patterns in specific heads/layers
- **Outcome**: Reasoning quality (accuracy, logical coherence)

**Question**: Does CoT improve reasoning *because* it changes how the model attends to information (indirect effect), or does it have other mechanisms (direct effect)?

### Path Decomposition

Using the counterfactual framework (Imai et al., 2010):

**Total Effect**:
$$TE = \mathbb{E}[Y(1, M(1)) - Y(0, M(0))]$$

**Direct Effect** (natural):
$$NDE = \mathbb{E}[Y(1, M(0)) - Y(0, M(0))]$$

**Indirect Effect** (natural):
$$NIE = \mathbb{E}[Y(1, M(1)) - Y(1, M(0))]$$

where $M(t)$ is the counterfactual mediator value under treatment $t$.

### Sequential Ignorability Assumption

For mediation analysis, we require:

1. **No unmeasured confounding of T → Y**: $\{Y(1,m), Y(0,m)\} \perp\perp T \mid X$
2. **No unmeasured confounding of M → Y**: $\{Y(1,m), Y(0,m)\} \perp\perp M \mid T, X$

This is a strong assumption and motivates sensitivity analyses.

## 1.4 Data Preparation for Causal Inference

### Feature Selection: Confounders vs Colliders

In causal analysis, feature selection is not just about predictive power:

- **Confounders**: Variables that affect BOTH treatment and outcome. **Must be included** to avoid bias.
- **Mediators**: Variables on the causal path. **Exclude** when estimating total effect, **include** when doing mediation analysis.
- **Colliders**: Variables affected by BOTH treatment and outcome. **Must be excluded** to avoid introducing bias.
- **Instrumental Variables**: Variables that affect treatment but not outcome (except through treatment). Useful in IV methods but require strong assumptions.

**LLM Example**: In analyzing prompt effects:
- Include: Task difficulty, prompt length (confounders)
- Exclude: Completion tokens (mediator/outcome)
- Exclude: User satisfaction rating (collider: affected by both prompt quality and actual quality)

### Handling Missing Data in Causal Contexts

Missing data is particularly problematic in causal inference:

- **Missing Completely at Random (MCAR)**: Missingness independent of all variables. Safe to use simple imputation.
- **Missing at Random (MAR)**: Missingness depends on observed data. Use model-based imputation (e.g., MICE).
- **Missing Not at Random (MNAR)**: Missingness depends on unobserved data. Requires sensitivity analysis or specialized methods.

**Caution**: Don't drop observations randomly! This can violate the positivity assumption and introduce selection bias.

### Encoding Categorical Variables

- **Binary treatment**: Use 0/1 encoding
- **Multi-level treatment**: Use one-hot encoding (n-1 dummy variables)
- **Ordinal variables**: Consider whether natural ordering is meaningful for causal interpretation
- **Nominal confounders**: One-hot encode, drop one category as reference

**Important**: Maintain consistent encoding across treatment and control groups.

### Common Support and Overlap

The **common support assumption** requires that propensity scores have substantial overlap between treatment groups:

```
Treatment:     ████████████████████████
Control:           ████████████████████████████
Overlap:            ████████████████████
```

Units outside the common support region should be excluded (trimmed) or analyzed separately.

### Sensitivity to Preprocessing Choices

Causal estimates can be sensitive to:
- Scaling of continuous variables
- Binning strategies for continuous confounders
- Outlier treatment
- Feature engineering choices

**Best Practice**: Conduct sensitivity analyses across different preprocessing choices to assess robustness.

## 1.5 LLM-Specific Considerations

### Prompts as Interventions/Treatments

In LLM research, prompts function as **treatment variables**:

- **Binary treatment**: CoT vs. no-CoT
- **Continuous treatment**: Temperature, top-k sampling parameters
- **Multi-level treatment**: Different prompting strategies (few-shot, zero-shot, chain-of-thought)

**Key Challenge**: Prompt effects are often **highly context-dependent**. A prompt that works for mathematical reasoning may not work for creative writing. This motivates careful covariate control (task type, difficulty, domain).

### Model Outputs as Outcomes

LLM outcomes can be measured in various ways:

- **Binary**: Correct/incorrect (classification, QA)
- **Ordinal**: Quality score (human evaluation, automated metrics)
- **Continuous**: Token likelihood, perplexity
- **Multivariate**: Multiple dimensions (accuracy, coherence, creativity)

**Consideration**: Different outcome types require different causal estimators. Binary outcomes may use logistic regression; continuous outcomes may use linear models.

### Confounders in LLM Experiments

Common confounders in LLM research:

1. **Task Difficulty**: Harder tasks may receive different prompting strategies
2. **Prompt Length**: Longer prompts may correlate with better performance
3. **Task Type**: Some prompts work better for certain task categories
4. **Token Count**: Longer completions may appear higher quality
5. **Domain Knowledge**: Models perform better on topics they've seen more in training

**Solution**: Explicitly measure and control for these confounders using domain-specific metrics.

### Attention Mechanisms as Causal Mediators

Attention weights offer a unique opportunity for **internal causal analysis**:

- They provide a mechanistic view of model processing
- They can be intervened upon (attention intervention studies)
- They vary systematically with prompting strategies
- They can be analyzed at different layers and heads

**Key Insight**: By analyzing how prompt interventions affect attention patterns, and how those patterns affect final outputs, we can understand *why* prompts work (or don't work).

### Reproducibility Challenges

LLM experiments face unique reproducibility challenges:

1. **Stochasticity**: Random sampling, temperature, seed dependence
2. **Hardware**: GPU/CPU differences can affect floating-point computations
3. **Versioning**: Model checkpoints, library versions, API changes
4. **Evaluation**: Human evaluation variance, automated metric drift

**Best Practices**:
- Set random seeds (Python, NumPy, PyTorch, CUDA)
- Document all versions (model, libraries, CUDA)
- Use multiple random seeds and report variance
- Provide complete configuration files
- Use standardized evaluation benchmarks

---

# Part 2: Dataset Preparation and GPT-2 Testing

In this section, we'll:
1. Download the two datasets needed for our examples
2. Perform exploratory analysis
3. Test basic GPT-2 functionality

## 2.1 Install Required Libraries

In [None]:
# Install Hugging Face libraries
!pip install transformers torch datasets --quiet

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from datasets import load_dataset
import os

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Libraries installed successfully!")

## 2.2 GPT-2 Model Testing

Let's verify that GPT-2 loads and generates text correctly on your system.

In [None]:
# Load GPT-2 model and tokenizer
print("Loading GPT-2 model...")
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Set pad token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded successfully!")
print(f"Model parameters: {model.num_parameters:,}")
print(f"Vocabulary size: {len(tokenizer):,}")

In [None]:
# Test basic text generation
test_prompts = [
    "The capital of France is",
    "To solve this math problem,",
    "Translate this sentence:"
]

print("Testing GPT-2 text generation...\n")
print("=" * 60)

for prompt in test_prompts:
    inputs = tokenizer(prompt, return_tensors='pt')
    outputs = model.generate(
        **inputs,
        max_length=30,
        num_return_sequences=1,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"\nPrompt: {prompt}")
    print(f"Output: {generated_text}")
    print("-" * 60)

print("\n✓ GPT-2 is working correctly!")

## 2.3 Example 1 Dataset: Instruction Format Experiment

We'll use SuperNaturalInstructions (or a similar instruction-following dataset) to study how different instruction formats affect task completion.

If SuperNaturalInstructions is too large or slow, we'll use a smaller subset or an alternative like the FLAN dataset.

In [None]:
try:
    # Try loading SuperNaturalInstructions
    print("Attempting to load SuperNaturalInstructions dataset...")
    dataset1 = load_dataset("super_natural_instructions", split="train[:1000]")
    print(f"✓ Successfully loaded {len(dataset1)} samples from SuperNaturalInstructions")
except Exception as e:
    print(f"Could not load SuperNaturalInstructions: {e}")
    print("\nTrying alternative: P3 (Public Pool of Prompts) dataset...")
    try:
        dataset1 = load_dataset("bigscience/P3", split="train[:1000]")
        print(f"✓ Successfully loaded {len(dataset1)} samples from P3 dataset")
    except Exception as e2:
        print(f"Could not load P3: {e2}")
        print("\nCreating synthetic instruction format dataset for demonstration...")
        
        # Create a synthetic dataset with instruction formats
        from datasets import Dataset
        
        task_types = ['translation', 'qa', 'summarization', 'classification', 'reasoning']
        difficulties = ['easy', 'medium', 'hard']
        
        synthetic_data = []
        for i in range(1000):
            task = {
                'task_id': i,
                'task_type': np.random.choice(task_types),
                'difficulty': np.random.choice(difficulties),
                'instruction': f"Perform {np.random.choice(task_types)} on this input",
                'input': f"Sample input text {i}",
            }
            synthetic_data.append(task)
        
        dataset1 = Dataset.from_list(synthetic_data)
        print(f"✓ Created synthetic dataset with {len(dataset1)} samples")

# Display sample data
print("\n" + "=" * 60)
print("Example 1 Dataset Sample")
print("=" * 60)
print(dataset1[:2])

In [None]:
# Exploratory Data Analysis for Example 1
import pandas as pd

# Convert to DataFrame for easier analysis
df1 = pd.DataFrame(dataset1)

print("Dataset Statistics:")
print("=" * 60)
print(f"Total samples: {len(df1)}")
print(f"\nColumns: {list(df1.columns)}")

# Check for categorical variables
categorical_cols = df1.select_dtypes(include=['object']).columns
print(f"\nCategorical variables ({len(categorical_cols)}):")
for col in categorical_cols[:5]:
    unique_vals = df1[col].nunique()
    print(f"  - {col}: {unique_vals} unique values")

## 2.4 Example 2 Dataset: Attention and Reasoning

We'll use GSM8K (Grade School Math) for studying attention patterns in mathematical reasoning.

In [None]:
# Load GSM8K dataset for math reasoning
print("Loading GSM8K dataset...")
dataset2 = load_dataset("gsm8k", "main", split="train[:500]")
print(f"✓ Successfully loaded {len(dataset2)} math problems from GSM8K")

# Display sample data
print("\n" + "=" * 60)
print("Example 2 Dataset Sample (GSM8K)")
print("=" * 60)
for i in range(2):
    print(f"\n--- Question {i+1} ---")
    print(f"Question: {dataset2[i]['question'][:100]}...")
    print(f"Answer: {dataset2[i]['answer'][:100]}...")

In [None]:
# Exploratory Data Analysis for Example 2
df2 = pd.DataFrame(dataset2)

print("Dataset Statistics:")
print("=" * 60)
print(f"Total samples: {len(df2)}")
print(f"\nColumns: {list(df2.columns)}")

# Analyze question and answer lengths
df2['question_length'] = df2['question'].str.len()
df2['answer_length'] = df2['answer'].str.len()

print(f"\nQuestion length: mean={df2['question_length'].mean():.1f}, std={df2['question_length'].std():.1f}")
print(f"Answer length: mean={df2['answer_length'].mean():.1f}, std={df2['answer_length'].std():.1f}")

## 2.5 Save Processed Datasets

We'll save the datasets for use in our example notebooks.

In [None]:
# Save datasets to disk
output_dir1 = '../Example1_Dataset'
output_dir2 = '../Example2_Dataset'

# Ensure directories exist
os.makedirs(output_dir1, exist_ok=True)
os.makedirs(output_dir2, exist_ok=True)

# Save datasets
df1.to_csv(f'{output_dir1}/instruction_format_data.csv', index=False)
df2.to_csv(f'{output_dir2}/attention_reasoning_data.csv', index=False)

print(f"✓ Saved Example 1 dataset to {output_dir1}/instruction_format_data.csv")
print(f"✓ Saved Example 2 dataset to {output_dir2}/attention_reasoning_data.csv")
print(f"\nFile sizes:")
print(f"  - Example 1: {os.path.getsize(f'{output_dir1}/instruction_format_data.csv') / 1024:.1f} KB")
print(f"  - Example 2: {os.path.getsize(f'{output_dir2}/attention_reasoning_data.csv') / 1024:.1f} KB")

---

# Part 3: Conclusion and Next Steps

## Summary of Phase 0

In this notebook, we've successfully:

1. ✅ **Established theoretical foundations** in causal inference
   - Covered correlation vs causation in LLMs
   - Explained the potential outcomes framework
   - Introduced causal graphs (DAGs) with LLM examples
   - Detailed propensity score matching methodology
   - Explored causal mediation analysis
   - Discussed data preparation for causal inference
   - Addressed LLM-specific considerations

2. ✅ **Created visual representations** of causal relationships
   - DAG showing confounding in instruction format experiments
   - DAG showing mediation in attention mechanisms

3. ✅ **Set up computational environment**
   - Installed all required libraries
   - Tested GPT-2 model functionality

4. ✅ **Downloaded and processed datasets**
   - Example 1: Instruction format dataset (1000 samples)
   - Example 2: GSM8K math reasoning dataset (500 samples)
   - Saved processed datasets for efficient reuse

## Importance of Data Preparation in Causal Analysis

Data preparation is **not** merely a preprocessing step in causal inference—it is fundamental to valid causal claims. Unlike predictive modeling, where feature engineering focuses on maximizing predictive power, causal data preparation must:

- **Identify and control for confounders** to avoid biased effect estimates
- **Exclude colliders and mediators** (when inappropriate) to prevent spurious associations
- **Ensure common support** to make valid within-sample comparisons
- **Address missing data** carefully to avoid violating the positivity assumption

Poor data preparation can lead to completely incorrect causal conclusions, even with sophisticated estimation methods. This underscores the importance of careful causal diagramming, domain knowledge integration, and rigorous sensitivity analyses.

## Impact on Model Development

Causal inference transforms model development from "what works" to "why it works":

- **Interpretability**: Understanding causal mechanisms makes models more interpretable and trustworthy
- **Optimization**: Causal insights can guide feature engineering and model architecture decisions
- **Robustness**: Causal frameworks identify sensitive dependencies and potential failure modes
- **Decision-making**: Causal effects support better policy and design decisions in AI systems

## Next Steps

The foundation is now set for the two main examples:

1. **Example 1 Notebook** (`01_example1_psm.ipynb`): Propensity Score Matching for instruction format causality
   - Generate instruction format variations
   - Measure task completion quality with GPT-2
   - Estimate propensity scores
   - Perform matching and balance checks
   - Estimate and interpret treatment effects

2. **Example 2 Notebook** (`02_example2_mediation.ipynb`): Mediation analysis for attention mechanisms
   - Generate reasoning tasks with and without CoT
   - Extract attention patterns from GPT-2
   - Perform mediation analysis
   - Visualize indirect effects through attention
   - Interpret causal mechanisms

## References

### Causal Inference Foundations
- Pearl, J. (2009). *Causality: Models, Reasoning, and Inference*. Cambridge University Press.
- Rubin, D. B. (1974). "Estimating causal effects of treatments in randomized and nonrandomized studies." *Journal of Educational Psychology*, 66(5), 688-701.
- Imbens, G. W., & Rubin, D. B. (2015). *Causal Inference for Statistics, Social, and Biomedical Sciences*. Cambridge University Press.

### Propensity Score Matching
- Rosenbaum, P. R., & Rubin, D. B. (1983). "The central role of the propensity score in observational studies for causal effects." *Biometrika*, 70(1), 41-55.
- Stuart, E. A. (2010). "Matching methods for causal inference: A review and a look forward." *Statistical Science*, 25(1), 1-21.

### Mediation Analysis
- Imai, K., Keele, L., & Tingley, D. (2010). "A general approach to causal mediation analysis." *Psychological Methods*, 15(4), 309-334.
- Baron, R. M., & Kenny, D. A. (1986). "The moderator-mediator variable distinction in social psychological research." *Journal of Personality and Social Psychology*, 51(6), 1173-1182.

### Causal Inference in Machine Learning
- Pearl, J., & Mackenzie, D. (2018). *The Book of Why: The New Science of Cause and Effect*. Basic Books.
- Peters, J., Janzing, D., & Schölkopf, B. (2017). *Elements of Causal Inference: Foundations and Learning Algorithms*. MIT Press.

### LLM and Attention Mechanisms
- Vaswani, A., et al. (2017). "Attention is all you need." *Advances in Neural Information Processing Systems*, 30.
- Wei, J., et al. (2022). "Chain-of-thought prompting elicits reasoning in large language models." *Advances in Neural Information Processing Systems*, 35.
- Vig, J. (2019). "A multiscale visualization of attention in the transformer model." *ACL 2019 Workshop on BlackboxNLP*.

---

**End of Phase 0: Setup and Theory**

This notebook provides the theoretical foundation and computational setup for the causal analysis examples. The next notebooks will demonstrate practical implementation of these concepts using real LLM data.