# GMM in `pyensmallen+jax`

In [1]:
import numpy as np
import jax
import pyensmallen as pe
jax.config.update("jax_enable_x64", True)

## Linear IV

The linear moment condition $z (y - x\beta)$ is attached as a static-method (`iv_moment`) to the class for convenience. This covers OLS and 2SLS.

In [2]:
# Generate synthetic data for IV estimation
def generate_test_data(n=5000, seed=42):
    np.random.seed(seed)

    # Generate instruments
    z1 = np.random.normal(0, 1, n)
    z2 = np.random.normal(0, 1, n)
    Z = np.column_stack([np.ones(n), z1, z2])

    # Generate error terms with correlation
    error = np.random.normal(0, 1, n)
    v = 0.7 * error + np.random.normal(0, 0.5, n)

    # Generate endogenous variable
    x = 0.5 * z1 - 0.2 * z2 + v
    X = np.column_stack([np.ones(n), x])

    # Generate outcome
    true_beta = np.array([-0.5, 1.2])
    y = X @ true_beta + error

    return y, X, Z, true_beta

In [3]:
# Generate test data
y, X, Z, true_beta = generate_test_data()

# Create and fit GMM estimator
gmm = pe.EnsmallenEstimator(pe.EnsmallenEstimator.iv_moment, "optimal")
gmm.fit(Z, y, X, verbose=True)

# Display results
print("\nGMM Results:")
print(f"True parameters: {true_beta}")
print(f"Estimated parameters: {gmm.theta_}")
print(f"Standard errors: {gmm.std_errors_}")


GMM Results:
True parameters: [-0.5  1.2]
Estimated parameters: [-0.48933885  1.19956026]
Standard errors: [0.01412415 0.02603365]


### The Fast Score Bootstrap

The "slow" or "full" non-parametric bootstrap is conceptually simple:
1.  Pretend the sample is the true population.
2.  Draw a new sample (with replacement) from this "population".
3.  Re-calculate your estimator $\hat{\theta}$ from scratch on this new sample.
4.  Repeat 1000s of times to build a distribution of $\hat{\theta}^*$. The standard deviation of this distribution is your standard error.

This is robust but computationally expensive because of Step 3. The "fast" score bootstrap is a clever shortcut that avoids re-running the full optimization. It's based on a first-order Taylor approximation.

**The Logic:**

1.  **The Starting Point:** The M-estimator $\hat{\theta}$ is defined by the first-order condition (FOC) that the sample average of the moment conditions is (approximately) zero:
    
$$
(1/n) \sum g(z_i, \hat{\theta}) \approx 0
$$

2.  **The Bootstrap Sample:** Now, imagine we have a bootstrap sample. The bootstrap estimate $\hat{\theta}^*$ would be the one that solves the FOC for *that* sample:
    
$$
(1/n) \sum g(z_i^*, \hat{\theta}^*) \approx 0 
$$
    
where $z_i^*$ are draws from the original sample.

3.  **The Taylor Expansion Trick:** We don't want to actually *solve* for $\hat{\theta}^*$. Instead, we can approximate it. Let's do a first-order Taylor expansion of the bootstrap FOC around the *original* estimate $\hat{\theta}$:
    
$$
(1/n) \sum g(z_i^*, \hat{\theta}^*) \approx (1/n) \sum g(z_i^*, \hat{\theta}) + [ (1/n) \sum (\partial g(z_i^*, \theta)/\partial \theta) \mid_{\hat{\theta}} ]  (\hat{\theta}^* - \hat{\theta})
$$
    
Since the left side is zero, we can rearrange to solve for the difference $(\hat{\theta}^* - \hat{\theta})$ 

$$
(\hat{\theta}^* - \hat{\theta}) \approx
- [ (1/n) \sum (\partial g(z_i^*, \theta)/\partial \theta) \mid_{\hat{\theta}} ]^{-1} 
\; \;  [ (1/n) \sum g(z_i^*, \hat{\theta}) ]
$$
    

4.  **Simplifying the Approximation:**
*   The first term $[ (1/n) \sum (\partial g(z_i^*, \theta)/\partial \theta) \mid_{\hat{\theta}}]$  is the average Jacobian of the moments, evaluated on the bootstrap sample. By the Law of Large Numbers, this is approximately the same as the Jacobian on the original sample, which we call $.
*   The second term $(1/n) \sum g(z_i^*, \hat{\theta})$  is simply the average of the *original* moment residuals $g_i(\hat{\theta})$ over the bootstrap sample. Let's call this $\bar{g}^*$.

This simplifies our approximation to:


$$
(\hat{\theta}^* - \hat{\theta}) \approx - G^{-1} \cdot \bar{g}^*
$$

(This is for the just-identified case. For GMM, the full term is $-(G'WG)^{-1} G'W \; \cdot \; \bar{g}^*$ 

**The Fast Bootstrap Algorithm:**

This approximation is the heart of the fast bootstrap:

1.  Calculate the original estimate $\hat{\theta}$ once.
2.  Calculate the $n$ individual moment residuals $g_i(\hat{\theta})$ once.
3.  Calculate the matrix $M = -(G'WG)^{-1}G'W$ once.
4.  Loop $B$ times:
        (a) Create a bootstrap sample of the *residuals* $g_i$ (not the data).
        (b) Calculate their mean, $\bar{g}^*$.
        (c) Calculate the bootstrap estimate directly: $\hat{\theta}^* = \hat{\theta} + M * \bar{g}^*$.
5.  Compute the standard deviation of the $B$ values of $\hat{\theta}^*$.

This is orders of magnitude faster because the loop only involves simple matrix-vector products, not a full re-optimization.

### Connection to Newey (1994) Efficiency Framework

The connection is the **influence function**, and is central to Newey (1994).

*   **What is the Influence Function?** The influence function $\varphi(z_i)$  for an estimator $\hat{\theta}$ tells you how much the estimate changes when you add a single observation $z_i$ to your sample. It's the building block of the estimator's asymptotic distribution:
    
$$
\sqrt{n} (\hat{\theta} - \theta_0) = \frac{1}{\sqrt{n}} \sum \varphi(z_i) + o_p(1)
$$

*   **The Influence Function for GMM:** the influence function for a GMM estimator is:

$$
\varphi(z_i) = - (G'WG)^{-1} G'W \cdot g(z_i, \theta_0)
$$


The core of the fast bootstrap approximation is a sample-based version of the influence function. The matrix $M = -(G'WG)^{-1}G'W$ is the key component that transforms the moment conditions $g$ into their effect on the parameter estimate $\theta$.


**How it relates to Newey's Efficiency Analysis :**

Newey's framework is all about comparing the influence functions of different estimators.

1.  **A Common Structure:** Newey  classifies estimators by writing their influence functions in a general form:
    
$$
\varphi(z, \tau) = D(\tau)^{-1} m(z, \tau)
$$
    
where $\tau$ is some parameter indexing the class of estimators (e.g., for GMM, $\tau$ could be the weighting matrix $W$). $D$ is a non-stochastic matrix (like $G'WG$) and $m$ is a zero-mean function (like $G'Wg$).

2.  **The Efficiency Condition:** An estimator indexed by $\bar{\tau}$ is efficient within its class if its influence function $\varphi(z, \bar{\tau})$ is "as small as possible". The key condition is that for an efficient estimator, its $D$ matrix must equal the variance of its $m$ function:

$$
D(\bar{\tau}) = E[m(z, \bar{\tau})  m(z, \bar{\tau})']
$$

3.  **Optimal GMM:** For GMM, $D(W) = G'WG$ and $m(z, W) = G'Wg(z)$. The variance of $m$ is $E[(G'Wg)(G'Wg)'] = G'W \Omega WG$. The efficiency condition becomes:
    
$$
G'WG = G'W \Omega WG
$$
    
This holds if $W = \Omega^{-1}$. This proves that the GMM estimator using the inverse variance of the moments as the weighting matrix is efficient *within the class of GMM estimators defined by those moments*.


In [4]:
%%time
# Get fast bootstrap standard errors
fast_bootstrap_se = gmm.bootstrap_scores(n_bootstrap=1000)

CPU times: user 177 ms, sys: 1.91 ms, total: 179 ms
Wall time: 178 ms


In [5]:
%%time
slow_bootstrap_se = gmm.bootstrap_full(n_bootstrap=100)

CPU times: user 48.5 s, sys: 968 ms, total: 49.5 s
Wall time: 21.9 s


In [6]:
import pandas as pd
pd.DataFrame(
    np.c_[true_beta, gmm.theta_, gmm.std_errors_, fast_bootstrap_se, slow_bootstrap_se],
    columns=["True", "GMM", "Analytic", "Fast Bootstrap", "Full Bootstrap"],
)

Unnamed: 0,True,GMM,Analytic,Fast Bootstrap,Full Bootstrap
0,-0.5,-0.489339,0.014124,0.014096,0.013001
1,1.2,1.19956,0.026034,0.025503,0.024544


## Nonlinear GMM: Logit

In [7]:
# logit DGP
n = 1000
p = 2
X = np.random.normal(size=(n, p))
X = np.c_[np.ones(n), X]
beta = np.array([0.5, -0.5, 0.5])
y = np.random.binomial(1, 1 / (1 + np.exp(-X @ beta)))
Z = X.copy()

IWLS solution

In [8]:
import statsmodels.api as sm
logit_mod = sm.Logit(y, X)
logit_res = logit_mod.fit(disp=0)
print("Parameters: ", logit_res.params)

Parameters:  [ 0.45130964 -0.44964216  0.53112315]


### nonlinear GMM with ensmallen

define moment condition (in jax-compatible terms)

In [9]:
import jax.numpy as jnp
import jax.scipy.special as jsp

def ψ_logit(z, y, x, beta):
    # Use jax.scipy.special.expit instead of scipy.special.expit
    resid = y - jsp.expit(x @ beta)
    return z * resid[:, jnp.newaxis]

In [10]:
# Create and fit GMM estimator
gmm = pe.EnsmallenEstimator(ψ_logit, "optimal")
gmm.fit(Z, y, X, verbose=True)

# Display results
print(f"True parameters: {beta}")
print(f"Estimated parameters: {gmm.theta_}")
print(f"Standard errors: {gmm.std_errors_}")

True parameters: [ 0.5 -0.5  0.5]
Estimated parameters: [ 0.45130965 -0.44964226  0.53112299]
Standard errors: [0.06921393 0.07012597 0.06711118]


In [11]:
%%time

# Get fast bootstrap standard errors
bootstrap_se = gmm.bootstrap_scores(n_bootstrap=1000)

CPU times: user 857 ms, sys: 21.4 ms, total: 878 ms
Wall time: 696 ms


In [12]:
np.c_[gmm.theta_, gmm.std_errors_, bootstrap_se]

array([[ 0.45130965,  0.06921393,  0.07022184],
       [-0.44964226,  0.07012597,  0.07125831],
       [ 0.53112299,  0.06711118,  0.06535547]])