In [1]:
import numpy as np
import jax
import pyensmallen as pe
jax.config.update("jax_enable_x64", True)


## Linear IV

The linear moment condition $z (y - x\beta)$ is attached as a static-method (`iv_moment`) to the class for convenience. This covers OLS and 2SLS.

In [2]:
# Generate synthetic data for IV estimation
def generate_test_data(n=5000, seed=42):
    np.random.seed(seed)

    # Generate instruments
    z1 = np.random.normal(0, 1, n)
    z2 = np.random.normal(0, 1, n)
    Z = np.column_stack([np.ones(n), z1, z2])

    # Generate error terms with correlation
    error = np.random.normal(0, 1, n)
    v = 0.7 * error + np.random.normal(0, 0.5, n)

    # Generate endogenous variable
    x = 0.5 * z1 - 0.2 * z2 + v
    X = np.column_stack([np.ones(n), x])

    # Generate outcome
    true_beta = np.array([-0.5, 1.2])
    y = X @ true_beta + error

    return y, X, Z, true_beta

In [3]:
# Generate test data
y, X, Z, true_beta = generate_test_data()

# Create and fit GMM estimator
gmm = pe.EnsmallenEstimator(pe.EnsmallenEstimator.iv_moment, "optimal")
gmm.fit(Z, y, X, verbose=True)

# Display results
print("\nGMM Results:")
print(f"True parameters: {true_beta}")
print(f"Estimated parameters: {gmm.theta_}")
print(f"Standard errors: {gmm.std_errors_}")


GMM Results:
True parameters: [-0.5  1.2]
Estimated parameters: [-0.48933885  1.19956026]
Standard errors: [0.01412415 0.02603365]


#### slow bootstrap that bootstraps the whole procedure

In [4]:
%%time

# Get fast bootstrap standard errors
bootstrap_se = gmm.bootstrap_scores(n_bootstrap=1000)

# View results
gmm.summary()

CPU times: user 2.27 s, sys: 24.9 ms, total: 2.3 s
Wall time: 443 ms


Unnamed: 0,parameter,coef,std err,t,p-value,[0.025,0.975],boot_se,ratio
0,θ_0,-0.4893,0.0141,-34.6456,0.0,-0.517,-0.4617,0.0141,0.998
1,θ_1,1.1996,0.026,46.0773,0.0,1.1485,1.2506,0.0255,0.9796


#### fast bootstrap that bootstraps the influence function

In [5]:
%%time
bootstrap_se = gmm.bootstrap_full(n_bootstrap=100, seed=42, verbose=False)

gmm.summary()

CPU times: user 51.8 s, sys: 962 ms, total: 52.8 s
Wall time: 22.8 s


Unnamed: 0,parameter,coef,std err,t,p-value,[0.025,0.975],boot_se,ratio
0,θ_0,-0.4893,0.0141,-34.6456,0.0,-0.517,-0.4617,0.0154,1.0909
1,θ_1,1.1996,0.026,46.0773,0.0,1.1485,1.2506,0.0259,0.9965


## Logit

In [6]:
# logit DGP
n = 1000
p = 2
X = np.random.normal(size=(n, p))
X = np.c_[np.ones(n), X]
beta = np.array([0.5, -0.5, 0.5])
y = np.random.binomial(1, 1 / (1 + np.exp(-X @ beta)))
Z = X.copy()

In [7]:
import statsmodels.api as sm
logit_mod = sm.Logit(y, X)
logit_res = logit_mod.fit(disp=0)
print("Parameters: ", logit_res.params)

Parameters:  [ 0.51176059 -0.4413885   0.42650761]


### nonlinear GMM with ensmallen

define moment condition (in jax-compatible terms)

In [8]:
import jax.numpy as jnp
import jax.scipy.special as jsp

def ψ_logit(z, y, x, beta):
    # Use jax.scipy.special.expit instead of scipy.special.expit
    resid = y - jsp.expit(x @ beta)
    return z * resid[:, jnp.newaxis]

In [9]:
# Create and fit GMM estimator
gmm = pe.EnsmallenEstimator(ψ_logit, "optimal")
gmm.fit(Z, y, X, verbose=True)

# Display results
print(f"True parameters: {beta}")
print(f"Estimated parameters: {gmm.theta_}")
print(f"Standard errors: {gmm.std_errors_}")

True parameters: [ 0.5 -0.5  0.5]
Estimated parameters: [ 0.51176057 -0.44138837  0.42650768]
Standard errors: [0.06817014 0.07033721 0.06615554]


In [10]:
%%time

# Get fast bootstrap standard errors
bootstrap_se = gmm.bootstrap_scores(n_bootstrap=1000)

# View results
gmm.summary()

CPU times: user 802 ms, sys: 28 ms, total: 830 ms
Wall time: 692 ms


Unnamed: 0,parameter,coef,std err,t,p-value,[0.025,0.975],boot_se,ratio
0,θ_0,0.5118,0.0682,7.5071,0.0,0.3781,0.6454,0.0707,1.0366
1,θ_1,-0.4414,0.0703,-6.2753,0.0,-0.5792,-0.3035,0.07,0.9949
2,θ_2,0.4265,0.0662,6.447,0.0,0.2968,0.5562,0.0665,1.0048
