In [1]:
import numpy as np
import jax
import pyensmallen as pe

jax.config.update("jax_enable_x64", True)

## Linear IV

The linear moment condition $z (y - x\beta)$ is attached as a static-method (`iv_moment`) to the class for convenience. This covers OLS and 2SLS.

In [2]:
# Generate synthetic data for IV estimation
def generate_test_data(n=5000, seed=42):
    np.random.seed(seed)

    # Generate instruments
    z1 = np.random.normal(0, 1, n)
    z2 = np.random.normal(0, 1, n)
    Z = np.column_stack([np.ones(n), z1, z2])

    # Generate error terms with correlation
    error = np.random.normal(0, 1, n)
    v = 0.7 * error + np.random.normal(0, 0.5, n)

    # Generate endogenous variable
    x = 0.5 * z1 - 0.2 * z2 + v
    X = np.column_stack([np.ones(n), x])

    # Generate outcome
    true_beta = np.array([-0.5, 1.2])
    y = X @ true_beta + error

    return y, X, Z, true_beta

In [3]:
# Generate test data
y, X, Z, true_beta = generate_test_data()

# Create and fit GMM estimator
gmm = pe.EnsmallenEstimator(pe.EnsmallenEstimator.iv_moment, "optimal")
gmm.fit(Z, y, X, verbose=True)

# Display results
print("\nGMM Results:")
print(f"True parameters: {true_beta}")
print(f"Estimated parameters: {gmm.theta_}")
print(f"Standard errors: {gmm.std_errors_}")


GMM Results:
True parameters: [-0.5  1.2]
Estimated parameters: [-0.48933885  1.19956026]
Standard errors: [0.01412415 0.02603365]


#### Fast bootstrap that bootstraps the score / influence function

In [4]:
%%time

# Get fast bootstrap standard errors
fast_bootstrap_se = gmm.bootstrap_scores(n_bootstrap=1000)

CPU times: user 671 ms, sys: 1.28 ms, total: 672 ms
Wall time: 671 ms


#### Slow bootstrap that bootstraps the whole procedure

In [5]:
%%time
slow_bootstrap_se = gmm.bootstrap_full(n_bootstrap=100, seed=42, verbose=False)

CPU times: user 2min 32s, sys: 2.59 s, total: 2min 34s
Wall time: 1min 39s


In [6]:
import pandas as pd
pd.DataFrame(
    np.c_[true_beta, gmm.theta_, gmm.std_errors_, fast_bootstrap_se, slow_bootstrap_se],
    columns=["True", "Point Estimate", "Analytic", "Fast Bootstrap", "Slow Bootstrap"],
)

Unnamed: 0,True,Point Estimate,Analytic,Fast Bootstrap,Slow Bootstrap
0,-0.5,-0.489339,0.014124,0.014096,0.015408
1,1.2,1.19956,0.026034,0.025503,0.025942


## Nonlinear GMM: Logit

In [7]:
# logit DGP
n = 1000
p = 2
X = np.random.normal(size=(n, p))
X = np.c_[np.ones(n), X]
beta = np.array([0.5, -0.5, 0.5])
y = np.random.binomial(1, 1 / (1 + np.exp(-X @ beta)))
Z = X.copy()

IWLS solution

In [8]:
import statsmodels.api as sm
logit_mod = sm.Logit(y, X)
logit_res = logit_mod.fit(disp=0)
print("Parameters: ", logit_res.params)

Parameters:  [ 0.51176059 -0.4413885   0.42650761]


### nonlinear GMM with ensmallen

define moment condition (in jax-compatible terms)

In [9]:
import jax.numpy as jnp
import jax.scipy.special as jsp

def ψ_logit(z, y, x, beta):
    # Use jax.scipy.special.expit instead of scipy.special.expit
    resid = y - jsp.expit(x @ beta)
    return z * resid[:, jnp.newaxis]

In [10]:
# Create and fit GMM estimator
gmm = pe.EnsmallenEstimator(ψ_logit, "optimal")
gmm.fit(Z, y, X, verbose=True)

# Display results
print(f"True parameters: {beta}")
print(f"Estimated parameters: {gmm.theta_}")
print(f"Standard errors: {gmm.std_errors_}")

True parameters: [ 0.5 -0.5  0.5]
Estimated parameters: [ 0.51176057 -0.44138837  0.42650768]
Standard errors: [0.06817014 0.07033721 0.06615554]


In [14]:
%%time

# Get fast bootstrap standard errors
bootstrap_se = gmm.bootstrap_scores(n_bootstrap=1000)

CPU times: user 2.78 s, sys: 61 ms, total: 2.85 s
Wall time: 2.37 s


In [15]:
pd.DataFrame(
    np.c_[beta, logit_res.params, gmm.theta_, gmm.std_errors_, bootstrap_se],
    columns=["True", "IWLS Sol", "GMM Sol", "Analytic", "Fast Bootstrap"],
)

Unnamed: 0,True,IWLS Sol,GMM Sol,Analytic,Fast Bootstrap
0,0.5,0.511761,0.511761,0.06817,0.067359
1,-0.5,-0.441389,-0.441388,0.070337,0.070655
2,0.5,0.426508,0.426508,0.066156,0.066174
