# James-Stein Estimator
Estimate mean of correlated gaussian variables. JS estimator dominates the LSQ in mean squared error terms. I.e. the JS estimator of the mean yields lower mse than the simple sample mean.

https://en.wikipedia.org/wiki/James%E2%80%93Stein_estimator

In [1]:
def james_stein(est, n, mcov, guess=None):
    ''' Specify guess for location other than at origin. '''
    guess = np.zeros(mcov.shape[0]) if guess is None else guess
    p = mcov.shape[0]
    shrinkage = 1-(p-2)/((est-guess) @ np.linalg.inv(mcov) @ est)/n
    shrinkage = max(shrinkage, 0)
    return guess + shrinkage * (est - guess)

def sqloss(est, target, mcov=None):
    ''' Specify mcov for cov invariant squared loss.'''
    mcov = np.eye(len(est)) if mcov is None else mcov
    return (est - target) @ np.linalg.pinv(mcov) @ (est - target)

def simulate_js_loss(dist, nsamples, loss_func, nsim):
    loss = np.empty(nsim)
    for i in range(nsim):
        X = dist.rvs(nsamples)
        mu_est = np.mean(X, axis=0)
        mcov_est = np.cov(X, rowvar=False, ddof=1)
        js = james_stein(mu_est, nsamples, mcov_est)
        loss[i] = loss_func(js)
    return loss

In [2]:
import numpy as np
import scipy.stats as stats
from matplotlib import pyplot as plt

# ---------------------------
# --- Define distribution ---
# ---------------------------
# Number of dimensions.
p = 3
# Samples size.
n = 10

nsim= 5000

# Arbitrary cov matrix and mean.
mcov = np.diag(1+np.arange(p))
mcov[1,1] *= 2
mmu = np.arange(p)*0
dist = stats.multivariate_normal(mean=mmu, cov=mcov)
    
# ---------------------------
# ------- Simulate  ---------
# ---------------------------
# MSE of JS estimator.
loss = lambda x: sqloss(x, mmu)
mse_js = simulate_js_loss(dist, nsamples = n, loss_func=loss, nsim = nsim)
mse_js_mu = (np.mean(mse_js), np.var(mse_js, ddof=1)/nsim )

# Exact MSE of sample mean (MLE) estimator.
mse_mle_mu = np.trace(mcov)/n

# Histogram of diffed mse observations (negative values => JS mse lower).
plt.hist(mse_js - mse_mle_mu, bins='auto', density=True)
plt.grid()

# Hypothesis test that JS is better.
print('H0: MSE_js_mu >= MSE_mle_mu')
print('H1: MSE_js_mu < MSE_mle_mu')
zscore = (mse_js_mu[0] - mse_mle_mu) / np.sqrt(mse_js_mu[1])
pval = stats.norm.cdf(zscore)
print('p-value {:.4f}'.format(pval))
alpha = 0.05
print(f'====> JS likely better at {alpha} level <====') if pval < 0.05 else print(f'MLE likely better at {alpha} level')

H0: MSE_js_mu >= MSE_mle_mu
H1: MSE_js_mu < MSE_mle_mu
p-value 0.0000
====> JS likely better at 0.05 level <====
