In [4]:
import numpy as np 
from numpy.linalg import inv 
from scipy.stats import distributions as iid 
from scipy.optimize import minimize

#From here: https://stackoverflow.com/questions/4740172/how-do-you-a-double-factorial-in-python
def doublefactorial(n):
     if n <= 0:
         return 1
     else:
         return n * doublefactorial(n-2)

    
def gj(b, x, k): 
    '''
    b: [mu, sigma], parameters for normal dist.
    x: a single observaton
    k: number of moments
    '''
    (mu, sigma) = b
    res = []
    for i in range(1,k+1):
        if (i % 2) == 0:
            xk = (x - mu) ** i  - (sigma**i) * doublefactorial(i-1)
        else: 
            xk = (x - mu) ** i 
        res.append(xk)
    return res


def gN(b, x_lst, k):
    '''
    Average of gj across all observations
    b: [mu, sigma], parameters for normal dist.
    x_lst: list of all observations
    k: number of moments
    '''
    return np.mean([gj(b, x_lst[j], k) for j in range(len(x_lst))], axis=0)


def Omegahat(b, x_lst, k):
    e = np.array([gj(b, x_lst[j], k) for j in range(len(x_lst))])

    # Recenter! We have Eu=0 under null.
    # Important to use this information.
    e = e - e.mean(axis=0)
    
    return e.T@e/e.shape[0]


def J(b, W, x_lst, k): 
    m = gN(b, x_lst, k) # Sample moments @ b
    N = len(x_lst)

    return (N*m.T@W@m) # Scale by sample size


def two_step_gmm(x_lst, k):
    # First step uses identity weighting matrix; use mean and variance as initial guess 
    W1 = np.eye(len(gj([0, 1], x_lst[0], k)))
    b1 = minimize(lambda b: J(b, W1, x_lst, k), [np.mean(x_lst), np.var(x_lst)]).x 

    # Construct 2nd step weighting matrix using first step estimate of beta
    W2 = inv(Omegahat(b1, x_lst, k))

    return minimize(lambda b: J(b, W2, x_lst, k), b1)

In [5]:
X = iid.norm.rvs(loc=2, scale=2, size=(1000, )) 
k = 4
two_step_gmm(X, k)

  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: 0.749096006795324
        x: [ 1.904e+00  1.923e+00]
      nit: 4
      jac: [ 4.545e-07  6.832e-06]
 hess_inv: [[ 1.860e-03  2.407e-05]
            [ 2.407e-05  8.571e-04]]
     nfev: 21
     njev: 7