In [1]:
 import numpy as np
import jax
import jax.numpy as jnp
import pyensmallen
import time
from joblib import Parallel, delayed
from scipy import stats

# Set random seed
np.random.seed(123)
key = jax.random.PRNGKey(123)

In [2]:
# Simulate data
def generate_data(n_samples=1000, beta_true=None, corr_factor=0.5):
    """Generate data for IV regression with endogenous regressors"""
    # Number of instruments and endogenous variables
    n_instruments = 10
    n_endogenous = 2

    if beta_true is None:
        beta_true = np.array([1.0, -0.5])

    # Generate instruments Z ~ N(0, 1)
    Z = np.random.normal(0, 1, (n_samples, n_instruments))

    # Generate endogenous X related to Z
    X = np.zeros((n_samples, n_endogenous))
    # First stage: X = Z*Pi + V where V is correlated with the error in Y
    Pi = np.random.uniform(-1, 1, (n_instruments, n_endogenous))

    # Error terms - generate with correlation structure
    error_cov = np.eye(n_endogenous + 1)
    # Make first column/row correlate with others to create endogeneity
    error_cov[0, 1:] = corr_factor
    error_cov[1:, 0] = corr_factor

    errors = np.random.multivariate_normal(
        np.zeros(n_endogenous + 1), error_cov, n_samples
    )

    # Structural errors
    u = errors[:, 0]
    # First stage errors
    V = errors[:, 1:]

    # First stage
    X = Z @ Pi + V

    # Outcome equation
    y = X @ beta_true + u

    return y, X, Z, beta_true

In [3]:
# Generate data
n_samples = 1000
true_beta = np.array([1.0, -0.5])
y, X, Z, beta_true = generate_data(n_samples, true_beta)

# Define GMM estimation with JAX 


In [4]:
# Define moment conditions: E[Z'(y - X*beta)] = 0
@jax.jit
def moment_conditions(beta, y, X, Z):
    """Moment conditions for IV regression: Z'(y - X*beta) = 0"""
    residuals = y - X @ beta
    return (Z.T @ residuals) / Z.shape[0]

# Define GMM objective function
@jax.jit
def gmm_objective(beta, W, y, X, Z):
    """GMM objective: g'*W*g where g = Z'(y-X*beta)"""
    g = moment_conditions(beta, y, X, Z)
    return g.T @ W @ g

# Create JAX gradient function for the GMM objective
grad_gmm = jax.grad(gmm_objective)

# Vectorized matrix sandwich formula for asymptotic variance
@jax.jit
def asymptotic_variance(beta, y, X, Z, W):
    """Compute asymptotic variance for GMM estimator"""
    n = Z.shape[0]
    residuals = y - X @ beta

    # Create individual moment conditions (Z_i * e_i)
    g_i = Z.T * residuals.reshape(1, -1)  # Shape: (n_instruments, n_samples)

    # Compute gradient of moment conditions (Jacobian): E[Z'X]
    G = -Z.T @ X / n  # Shape: (n_instruments, n_endogenous)

    # Compute S matrix (covariance of moment conditions)
    S = (g_i @ g_i.T) / n  # Shape: (n_instruments, n_instruments)

    # Compute bread for sandwich formula: (G'WG)^{-1}
    bread = jnp.linalg.inv(G.T @ W @ G)

    # Compute meat for sandwich formula: G'WSW'G
    meat = G.T @ W @ S @ W @ G

    # Asymptotic variance: (G'WG)^{-1} * (G'WSW'G) * (G'WG)^{-1} / n
    avar = bread @ meat @ bread / n

    return avar

In [5]:
# Objective function for pyensmallen
def objective_for_ensmallen(beta, gradient, y, X, Z, W):
    """Objective for pyensmallen with JAX gradient calculation"""
    beta_jax = jnp.array(beta)

    # Calculate objective value
    obj_value = gmm_objective(beta_jax, W, y, X, Z)

    # Calculate gradient
    grad = grad_gmm(beta_jax, W, y, X, Z)

    # Copy gradient to output parameter
    gradient[:] = np.array(grad)

    return float(obj_value)

In [6]:
def two_step_gmm(
    y, X, Z, initial_beta=None, use_joblib_bootstrap=True, n_bootstrap=200
):
    """Two-step GMM estimation for IV regression"""
    n_instruments = Z.shape[1]
    n_endogenous = X.shape[1]
    n_samples = X.shape[0]

    # Convert data to JAX arrays for later use
    y_jax = jnp.array(y)
    X_jax = jnp.array(X)
    Z_jax = jnp.array(Z)

    # Initial weight matrix: identity
    W_1 = jnp.eye(n_instruments)

    # Initial guess if not provided
    if initial_beta is None:
        initial_beta = np.zeros(n_endogenous)

    # First step GMM
    start_time = time.time()
    optimizer = pyensmallen.L_BFGS(10, 1000)  # numBasis, maxIterations

    # Define objective function for first step
    def first_step_objective(beta, gradient):
        return objective_for_ensmallen(beta, gradient, y_jax, X_jax, Z_jax, W_1)

    # Optimize first step
    beta_1 = optimizer.optimize(first_step_objective, initial_beta)

    # Compute optimal weight matrix using first-step residuals
    residuals = y - X @ beta_1
    g_i = Z.T * residuals.reshape(1, -1)  # Individual moment contributions
    S = (g_i @ g_i.T) / n_samples  # Covariance of moment conditions

    # Optimal weight matrix: W = S^{-1}
    W_opt = jnp.linalg.inv(S)

    # Second step GMM with optimal weight matrix
    def second_step_objective(beta, gradient):
        return objective_for_ensmallen(beta, gradient, y_jax, X_jax, Z_jax, W_opt)

    # Optimize second step
    beta_2 = optimizer.optimize(second_step_objective, beta_1)
    estimation_time = time.time() - start_time

    # Compute analytical standard errors
    avar = asymptotic_variance(beta_2, y_jax, X_jax, Z_jax, W_opt)
    se_analytical = np.sqrt(np.diag(avar))

    # Bootstrap standard errors
    if use_joblib_bootstrap:
        # Bootstrap function
        def bootstrap_sample(b):
            # Sample with replacement
            idx = np.random.choice(n_samples, n_samples, replace=True)
            y_b, X_b, Z_b = y[idx], X[idx], Z[idx]

            # Convert to JAX arrays
            y_b_jax = jnp.array(y_b)
            X_b_jax = jnp.array(X_b)
            Z_b_jax = jnp.array(Z_b)

            # First step with identity weight matrix
            def bs_first_objective(beta, gradient):
                return objective_for_ensmallen(
                    beta, gradient, y_b_jax, X_b_jax, Z_b_jax, W_1
                )

            beta_b1 = optimizer.optimize(bs_first_objective, initial_beta)

            # Compute optimal weight matrix for this bootstrap sample
            residuals_b = y_b - X_b @ beta_b1
            g_i_b = Z_b.T * residuals_b.reshape(1, -1)
            S_b = (g_i_b @ g_i_b.T) / n_samples
            W_b = jnp.linalg.inv(S_b)

            # Second step with optimal weight matrix
            def bs_second_objective(beta, gradient):
                return objective_for_ensmallen(
                    beta, gradient, y_b_jax, X_b_jax, Z_b_jax, W_b
                )

            beta_b2 = optimizer.optimize(bs_second_objective, beta_b1)
            return beta_b2

        # Run bootstrap in parallel
        start_bootstrap = time.time()
        bootstrap_results = Parallel(n_jobs=-1)(
            delayed(bootstrap_sample)(b) for b in range(n_bootstrap)
        )
        bootstrap_time = time.time() - start_bootstrap

        # Compute bootstrap standard errors
        bootstrap_estimates = np.array(bootstrap_results)
        se_bootstrap = np.std(bootstrap_estimates, axis=0)

        # Compute bootstrap confidence intervals
        ci_bootstrap = np.percentile(bootstrap_estimates, [2.5, 97.5], axis=0).T
    else:
        se_bootstrap = None
        ci_bootstrap = None
        bootstrap_time = 0

    # Compute t-stats and p-values using analytical standard errors
    t_stats = beta_2 / se_analytical
    p_values = 2 * (1 - stats.t.cdf(np.abs(t_stats), df=n_samples - n_endogenous))

    # Analytical confidence intervals
    ci_analytical = np.column_stack(
        [beta_2 - 1.96 * se_analytical, beta_2 + 1.96 * se_analytical]
    )

    # Combine results
    results = {
        "beta_first_step": beta_1,
        "beta": beta_2,
        "se_analytical": se_analytical,
        "se_bootstrap": se_bootstrap,
        "ci_analytical": ci_analytical,
        "ci_bootstrap": ci_bootstrap,
        "p_values": p_values,
        "t_stats": t_stats,
        "estimation_time": estimation_time,
        "bootstrap_time": bootstrap_time,
        "n_bootstrap": n_bootstrap if use_joblib_bootstrap else 0,
    }

    return results

In [12]:
# Estimate with two-step GMM
results = two_step_gmm(y, X, Z, use_joblib_bootstrap=False, n_bootstrap=500)

# Print results
print("\n--- GMM IV Estimation Results ---")
print(f"Sample size: {n_samples}")
print(f"Number of instruments: {Z.shape[1]}")
print(f"Number of endogenous variables: {X.shape[1]}")
print(f"\nTrue coefficients: {true_beta}")
print(f"First-step estimates: {results['beta_first_step']}")
print(f"Second-step estimates: {results['beta']}")
print("\nAnalytical standard errors:")
for i, se in enumerate(results["se_analytical"]):
    print(f"  Beta_{i+1}: {se:.6f}")

print(results)
# print("\nBootstrap standard errors:")
# for i, se in enumerate(results["se_bootstrap"]):
#     print(f"  Beta_{i+1}: {se:.6f}")

# print("\nAnalytical 95% CI:")
# for i, ci in enumerate(results["ci_analytical"]):
#     print(f"  Beta_{i+1}: [{ci[0]:.6f}, {ci[1]:.6f}]")

# print("\nBootstrap 95% CI:")
# for i, ci in enumerate(results["ci_bootstrap"]):
#     print(f"  Beta_{i+1}: [{ci[0]:.6f}, {ci[1]:.6f}]")

# print("\nP-values:")
# for i, p in enumerate(results["p_values"]):
#     print(f"  Beta_{i+1}: {p:.6f}")

# print(f"\nEstimation time: {results['estimation_time']:.2f} seconds")
# print(
#     f"Bootstrap time ({results['n_bootstrap']} replications): {results['bootstrap_time']:.2f} seconds"
# )


--- GMM IV Estimation Results ---
Sample size: 1000
Number of instruments: 10
Number of endogenous variables: 2

True coefficients: [ 1.  -0.5]
First-step estimates: [ 0.98635808 -0.48494819]
Second-step estimates: [ 0.98251398 -0.48098669]

Analytical standard errors:
  Beta_1: 0.015188
  Beta_2: 0.022259
{'beta_first_step': array([ 0.98635808, -0.48494819]), 'beta': array([ 0.98251398, -0.48098669]), 'se_analytical': array([0.01518814, 0.02225898], dtype=float32), 'se_bootstrap': None, 'ci_analytical': array([[ 0.95274521,  1.01228274],
       [-0.52461429, -0.43735909]]), 'ci_bootstrap': None, 'p_values': array([0., 0.]), 't_stats': array([ 64.68953306, -21.6086579 ]), 'estimation_time': 0.10393977165222168, 'bootstrap_time': 0, 'n_bootstrap': 0}


In [10]:
# ------- Visualization -------
import matplotlib.pyplot as plt

# Plot histogram of bootstrap estimates
bootstrap_estimates = np.array([results["beta"] for _ in range(results["n_bootstrap"])])
if results["se_bootstrap"] is not None:
    fig, axes = plt.subplots(1, X.shape[1], figsize=(12, 4))
    for i in range(X.shape[1]):
        ax = axes[i]
        ax.hist(bootstrap_estimates[:, i], bins=30, alpha=0.7)
        ax.axvline(true_beta[i], color="red", linestyle="--", label="True")
        ax.axvline(results["beta"][i], color="blue", linestyle="-", label="GMM")
        ax.axvline(
            results["ci_bootstrap"][i, 0], color="green", linestyle=":", label="95% CI"
        )
        ax.axvline(results["ci_bootstrap"][i, 1], color="green", linestyle=":")
        ax.set_title(f"Bootstrap Distribution: Beta_{i+1}")
        if i == 0:
            ax.legend()
    plt.tight_layout()
    plt.show()

In [11]:
bootstrap_estimates = np.array([results['beta'] for _ in range(results['n_bootstrap'])])
if results['se_bootstrap'] is not None:
    fig, axes = plt.subplots(1, X.shape[1], figsize=(12, 4))
    for i in range(X.shape[1]):
        ax = axes[i]
        ax.hist(bootstrap_estimates[:, i], bins=30, alpha=0.7)
        ax.axvline(true_beta[i], color='red', linestyle='--', label='True')
        ax.axvline(results['beta'][i], color='blue', linestyle='-', label='GMM')
        ax.axvline(results['ci_bootstrap'][i, 0], color='green', linestyle=':', label='95% CI')
        ax.axvline(results['ci_bootstrap'][i, 1], color='green', linestyle=':')
        ax.set_title(f'Bootstrap Distribution: Beta_{i+1}')
        if i == 0:
            ax.legend()
    plt.tight_layout()
    plt.show()

# Compare analytical vs bootstrap standard errors
if results['se_bootstrap'] is not None:
    fig, ax = plt.subplots(figsize=(8, 5))
    x = np.arange(X.shape[1])
    width = 0.35
    ax.bar(x - width/2, results['se_analytical'], width, label='Analytical')
    ax.bar(x + width/2, results['se_bootstrap'], width, label='Bootstrap')
    ax.set_xticks(x)
    ax.set_xticklabels([f'Beta_{i+1}' for i in range(X.shape[1])])
    ax.set_ylabel('Standard Error')
    ax.set_title('Comparison of Standard Error Estimates')
    ax.legend()
    plt.tight_layout()
    plt.show()