## 1. Setup and Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

print("✓ Libraries imported")

In [None]:
# Import KRL Multi-Unit SCM
try:
    from krl_models.causal import (
        MultiUnitSCM,
        MultiUnitSCMConfig,
        MultiUnitSCMResult,
        WeightMethod,
        InferenceMethod,
        multi_unit_scm,
    )
    print("✓ KRL Model Zoo Multi-Unit SCM imported")
    USE_KRL = True
except ImportError:
    print("⚠ KRL Model Zoo not available, using demonstration mode")
    USE_KRL = False

## 2. Generate Simulated Multi-State Data

We'll simulate a scenario where **5 states** adopt a policy in 2010, with **25 control states**.

In [None]:
np.random.seed(42)

# Simulation parameters
n_treated = 5
n_control = 25
n_periods = 30  # Years 1995-2024
treatment_year = 2010
treatment_period = 15  # Index in the array
years = list(range(1995, 2025))

# State names
treated_states = ['California', 'New York', 'Illinois', 'Massachusetts', 'Washington']
control_states = [
    'Texas', 'Florida', 'Ohio', 'Georgia', 'North Carolina',
    'Michigan', 'New Jersey', 'Virginia', 'Arizona', 'Tennessee',
    'Indiana', 'Missouri', 'Maryland', 'Wisconsin', 'Colorado',
    'Minnesota', 'South Carolina', 'Alabama', 'Louisiana', 'Kentucky',
    'Oregon', 'Oklahoma', 'Connecticut', 'Utah', 'Iowa'
]

# True treatment effects (state-specific)
true_effects = {
    'California': -12,
    'New York': -8,
    'Illinois': -10,
    'Massachusetts': -15,
    'Washington': -6,
}

# Generate control state outcomes (unemployment rate)
base_trend = 5 + 0.5 * np.sin(np.linspace(0, 4*np.pi, n_periods))  # Cyclical

Y_control = np.zeros((n_control, n_periods))
for j, state in enumerate(control_states):
    state_effect = np.random.normal(0, 1.5)
    noise = np.random.normal(0, 0.3, n_periods)
    Y_control[j] = base_trend + state_effect + noise

# Generate treated state outcomes
Y_treated = np.zeros((n_treated, n_periods))
for m, state in enumerate(treated_states):
    # Similar to weighted average of controls pre-treatment
    weights = np.random.dirichlet(np.ones(n_control) * 2)
    Y_treated[m] = Y_control.T @ weights
    
    # Add treatment effect post-treatment
    effect = true_effects[state]
    Y_treated[m, treatment_period:] += effect / 10  # Scale to unemployment rate

print(f"✓ Data generated:")
print(f"  - {n_treated} treated states")
print(f"  - {n_control} control states")
print(f"  - {n_periods} years ({years[0]}-{years[-1]})")
print(f"  - Treatment year: {treatment_year}")

In [None]:
# Visualize the data
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Panel A: All states
ax1 = axes[0]
for j in range(n_control):
    ax1.plot(years, Y_control[j], color='gray', alpha=0.3, linewidth=0.8)

colors = plt.cm.Set1(np.linspace(0, 1, n_treated))
for m, state in enumerate(treated_states):
    ax1.plot(years, Y_treated[m], color=colors[m], linewidth=2, label=state)

ax1.axvline(x=treatment_year, color='black', linestyle='--', linewidth=1.5, alpha=0.7)
ax1.set_xlabel('Year', fontsize=12)
ax1.set_ylabel('Outcome (e.g., Unemployment Rate)', fontsize=12)
ax1.set_title('(A) Treated States vs Control States', fontsize=13, fontweight='bold')
ax1.legend(loc='upper right', fontsize=9)

# Panel B: Average trends
ax2 = axes[1]
ax2.plot(years, Y_control.mean(axis=0), 'gray', linewidth=2, label='Control Average')
ax2.plot(years, Y_treated.mean(axis=0), 'blue', linewidth=2, label='Treated Average')
ax2.fill_between(years, Y_control.min(axis=0), Y_control.max(axis=0), 
                 alpha=0.2, color='gray', label='Control Range')
ax2.axvline(x=treatment_year, color='black', linestyle='--', linewidth=1.5, alpha=0.7)

ax2.set_xlabel('Year', fontsize=12)
ax2.set_ylabel('Outcome', fontsize=12)
ax2.set_title('(B) Average Trends', fontsize=13, fontweight='bold')
ax2.legend(loc='upper right', fontsize=10)

plt.tight_layout()
plt.show()

## 3. Multi-Unit SCM Estimation

### 3.1 Individual Weights Method

Each treated state gets its own set of donor weights.

In [None]:
# Configure Multi-Unit SCM with individual weights
config_individual = MultiUnitSCMConfig(
    weight_method=WeightMethod.INDIVIDUAL,
    inference_method=InferenceMethod.BOOTSTRAP,
    n_bootstrap=100,
    confidence_level=0.95,
    random_state=42,
) if USE_KRL else None

# Fit model
if USE_KRL:
    model_individual = MultiUnitSCM(
        treatment_period=treatment_period,
        config=config_individual,
    )
    
    result_individual = model_individual.fit(
        Y_treated=Y_treated,
        Y_control=Y_control,
        treated_names=treated_states,
        control_names=control_states,
    )
    
    print("═" * 60)
    print("INDIVIDUAL WEIGHTS METHOD RESULTS")
    print("═" * 60)
    print(f"\nAverage Treatment Effect (ATT): {result_individual.att.mean():.4f}")
    print(f"Cumulative ATT: {result_individual.att_cumulative:.4f}")
    print(f"\n95% CI: [{result_individual.ci_lower.mean():.4f}, {result_individual.ci_upper.mean():.4f}]")
    print(f"\nUnit-specific effects:")
    for m, state in enumerate(treated_states):
        effect = result_individual.treatment_effects[m].mean()
        print(f"  {state}: {effect:.4f}")
else:
    print("Demonstration mode - install krl_models to run")

### 3.2 Hierarchical Weights Method

Shrinkage between unit-specific and pooled weights for more stable estimates.

In [None]:
if USE_KRL:
    # Configure with hierarchical weights
    config_hierarchical = MultiUnitSCMConfig(
        weight_method=WeightMethod.HIERARCHICAL,
        shrinkage_alpha=0.5,  # 50% pooling
        inference_method=InferenceMethod.BOOTSTRAP,
        n_bootstrap=100,
        random_state=42,
    )
    
    model_hierarchical = MultiUnitSCM(
        treatment_period=treatment_period,
        config=config_hierarchical,
    )
    
    result_hierarchical = model_hierarchical.fit(
        Y_treated=Y_treated,
        Y_control=Y_control,
        treated_names=treated_states,
        control_names=control_states,
    )
    
    print("═" * 60)
    print("HIERARCHICAL WEIGHTS METHOD RESULTS (α = 0.5)")
    print("═" * 60)
    print(f"\nAverage Treatment Effect (ATT): {result_hierarchical.att.mean():.4f}")
    print(f"95% CI: [{result_hierarchical.ci_lower.mean():.4f}, {result_hierarchical.ci_upper.mean():.4f}]")
    print(f"\nStandard Error: {result_hierarchical.se.mean():.4f}")
else:
    print("Demonstration mode")

### 3.3 Pooled Weights Method

Single set of weights for all treated units (assumes homogeneous donor relevance).

In [None]:
if USE_KRL:
    result_pooled = multi_unit_scm(
        Y_treated=Y_treated,
        Y_control=Y_control,
        treatment_period=treatment_period,
        method='pooled',
        inference='bootstrap',
        n_bootstrap=100,
    )
    
    print("═" * 60)
    print("POOLED WEIGHTS METHOD RESULTS")
    print("═" * 60)
    print(f"\nAverage Treatment Effect (ATT): {result_pooled.att.mean():.4f}")
    print(f"95% CI: [{result_pooled.ci_lower.mean():.4f}, {result_pooled.ci_upper.mean():.4f}]")
else:
    print("Demonstration mode")

## 4. Visualize Results

In [None]:
if USE_KRL:
    result = result_individual  # Use individual weights for visualization
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Top row: Treated vs Synthetic for each state
    for m, state in enumerate(treated_states):
        ax = axes[0, m % 3] if m < 3 else None
        if ax is None:
            continue
            
        ax.plot(years, Y_treated[m], 'b-', linewidth=2, label=state)
        ax.plot(years, result.synthetic_controls[m], 'r--', linewidth=2, label='Synthetic')
        ax.axvline(x=treatment_year, color='gray', linestyle=':', linewidth=1.5)
        ax.set_xlabel('Year')
        ax.set_ylabel('Outcome')
        ax.set_title(state, fontsize=12, fontweight='bold')
        ax.legend(loc='upper right', fontsize=9)
    
    # Bottom left: ATT over time with CI
    ax = axes[1, 0]
    post_years = years[treatment_period:]
    ax.plot(post_years, result.att, 'b-', linewidth=2, label='ATT')
    ax.fill_between(post_years, result.ci_lower, result.ci_upper, 
                   alpha=0.3, color='blue', label='95% CI')
    ax.axhline(y=0, color='gray', linestyle='-', linewidth=1)
    ax.set_xlabel('Year')
    ax.set_ylabel('Average Treatment Effect')
    ax.set_title('ATT with Confidence Interval', fontsize=12, fontweight='bold')
    ax.legend()
    
    # Bottom center: Unit-specific effects
    ax = axes[1, 1]
    unit_effects = [result.treatment_effects[m].mean() for m in range(n_treated)]
    colors = plt.cm.RdYlGn(np.linspace(0, 1, n_treated))
    bars = ax.barh(treated_states, unit_effects, color=colors)
    ax.axvline(x=0, color='gray', linestyle='-', linewidth=1)
    ax.set_xlabel('Average Treatment Effect')
    ax.set_title('Unit-Specific Effects', fontsize=12, fontweight='bold')
    
    # Bottom right: Weight distribution
    ax = axes[1, 2]
    avg_weights = result.weights.mean(axis=0)
    top_k = 10
    top_idx = np.argsort(avg_weights)[-top_k:][::-1]
    top_states = [control_states[i] for i in top_idx]
    top_weights = avg_weights[top_idx]
    
    ax.barh(top_states, top_weights, color='steelblue')
    ax.set_xlabel('Average Weight')
    ax.set_title('Top Donor States', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('multi_unit_scm_results.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\n✓ Figure saved: multi_unit_scm_results.png")
else:
    print("Demonstration mode - install krl_models to generate figures")

## 5. Compare Inference Methods

In [None]:
if USE_KRL:
    inference_methods = ['bootstrap', 'jackknife', 'placebo', 'conformal']
    results = {}
    
    for method in inference_methods:
        try:
            result = multi_unit_scm(
                Y_treated=Y_treated,
                Y_control=Y_control,
                treatment_period=treatment_period,
                method='individual',
                inference=method,
                n_bootstrap=50 if method == 'bootstrap' else None,
                n_placebo=10 if method == 'placebo' else None,
            )
            results[method] = result
        except Exception as e:
            print(f"  {method}: Error - {e}")
    
    print("\n" + "=" * 70)
    print("INFERENCE METHOD COMPARISON")
    print("=" * 70)
    print(f"\n{'Method':<15} {'ATT':>10} {'SE':>10} {'CI Width':>12} {'P-value':>10}")
    print("-" * 70)
    
    for method, result in results.items():
        att = result.att.mean()
        se = result.se.mean() if result.se is not None else np.nan
        ci_width = (result.ci_upper - result.ci_lower).mean() if result.ci_lower is not None else np.nan
        pval = result.p_values.mean() if result.p_values is not None else np.nan
        
        print(f"{method:<15} {att:>10.4f} {se:>10.4f} {ci_width:>12.4f} {pval:>10.4f}")
    
    print("=" * 70)
else:
    print("Demonstration mode")

## 6. Model Diagnostics

In [None]:
if USE_KRL:
    result = result_individual
    
    print("\n" + "=" * 60)
    print("MODEL DIAGNOSTICS")
    print("=" * 60)
    
    print("\nPre-treatment Fit (RMSPE):")
    for m, state in enumerate(treated_states):
        print(f"  {state}: {result.pre_rmspe[m]:.4f}")
    print(f"  Average: {result.pre_rmspe.mean():.4f}")
    
    print("\nPost-treatment RMSPE:")
    for m, state in enumerate(treated_states):
        print(f"  {state}: {result.post_rmspe[m]:.4f}")
    print(f"  Average: {result.post_rmspe.mean():.4f}")
    
    print("\nRMSPE Ratio (Post/Pre):")
    for m, state in enumerate(treated_states):
        ratio = result.post_rmspe[m] / result.pre_rmspe[m]
        print(f"  {state}: {ratio:.2f}x")
    
    avg_ratio = result.post_rmspe.mean() / result.pre_rmspe.mean()
    print(f"\n  Average ratio: {avg_ratio:.2f}x")
    
    if avg_ratio > 2:
        print("  ✓ Large divergence post-treatment suggests real effect")
    else:
        print("  ⚠ Small divergence - treatment effect may be modest")
else:
    print("Demonstration mode")

In [None]:
if USE_KRL:
    # Print full summary
    print(model_individual.summary())
else:
    print("Demonstration mode")

## 7. Comparison with True Effects

In [None]:
if USE_KRL:
    print("\n" + "=" * 60)
    print("VALIDATION: ESTIMATED vs TRUE EFFECTS")
    print("=" * 60)
    print(f"\n{'State':<20} {'True':>10} {'Estimated':>12} {'Error':>10}")
    print("-" * 60)
    
    for m, state in enumerate(treated_states):
        true = true_effects[state] / 10  # Scaled
        est = result_individual.treatment_effects[m].mean()
        error = est - true
        print(f"{state:<20} {true:>10.4f} {est:>12.4f} {error:>10.4f}")
    
    print("-" * 60)
    true_avg = np.mean([true_effects[s]/10 for s in treated_states])
    est_avg = result_individual.att.mean()
    print(f"{'Average':<20} {true_avg:>10.4f} {est_avg:>12.4f} {est_avg-true_avg:>10.4f}")
    print("=" * 60)
else:
    print("Demonstration mode")

## 8. Key Takeaways

### When to Use Multi-Unit SCM

1. **Multiple units adopt a policy simultaneously**
   - Pooling information improves efficiency
   - Can estimate average and heterogeneous effects

2. **Limited donor pool for individual SCM**
   - Hierarchical weights stabilize estimates
   - Shrinkage reduces variance

3. **Need for proper uncertainty quantification**
   - Multiple inference methods available
   - Bootstrap for general use
   - Placebo for permutation-based inference

### Method Selection Guide

| Scenario | Recommended Method |
|----------|-------------------|
| Heterogeneous treated units | Individual weights |
| Similar treated units | Pooled weights |
| Moderate heterogeneity | Hierarchical (α=0.5) |
| Small donor pool | Hierarchical with high α |

### Inference Method Selection

| Method | Best For |
|--------|----------|
| Bootstrap | General purpose, confidence intervals |
| Jackknife | Bias assessment, leave-one-out stability |
| Placebo | Permutation-based p-values |
| Conformal | Prediction intervals, distribution-free |

In [None]:
print("\n" + "=" * 60)
print("NB25: MULTI-UNIT SCM DEMONSTRATION COMPLETE")
print("=" * 60)
print("""
This notebook demonstrated:
  ✓ Multi-Unit SCM for multiple treated units
  ✓ Three weight methods: Individual, Pooled, Hierarchical
  ✓ Four inference methods: Bootstrap, Jackknife, Placebo, Conformal
  ✓ Model diagnostics and validation

Key classes from krl_models.causal:
  • MultiUnitSCM - Main estimation class
  • MultiUnitSCMConfig - Configuration options
  • MultiUnitSCMResult - Results container
  • WeightMethod - Enum for weight estimation
  • InferenceMethod - Enum for inference
  • multi_unit_scm() - Convenience function
""")