In [None]:
import numpy as np
import matplotlib.pyplot as plt

def soft_threshold(z, t):
    return np.sign(z) * np.maximum(np.abs(z) - t, 0.0)

def sure_soft_threshold(z, t, sigma=1.0):
    z_hat = soft_threshold(z, t)
    rss = np.sum((z_hat - z)**2)
    df = np.sum(np.abs(z) > t)
    return (rss + 2.0 * sigma**2 * df) / len(z)

def find_best_threshold_sure(z, sigma=1.0):
    best_t, best_s = 0.0, float('inf')
    for t in np.unique(np.abs(z)):              
        s = sure_soft_threshold(z, t, sigma)
        if s < best_s:
            best_s, best_t = s, t
    return best_t, best_s

def one_lambda_sure(n, sigma, p, mu_mean, mu_std, rng):
    is_zero = rng.random(n) < p
    mu_true = np.zeros(n)
    mu_true[~is_zero] = rng.normal(mu_mean, mu_std, size=(~is_zero).sum())
    X = mu_true + rng.normal(0.0, sigma, size=n)
    lam_sure, _ = find_best_threshold_sure(X, sigma)
    return lam_sure

n = 6000
sigma = 1.0
mu_mean = 0.0
mu_std = 4.70
n_repeats = 1000
rng = np.random.default_rng()

lam_sure_090 = np.array([one_lambda_sure(n, sigma, 0.90, mu_mean, mu_std, rng)
                         for _ in range(n_repeats)])
lam_sure_095 = np.array([one_lambda_sure(n, sigma, 0.95, mu_mean, mu_std, rng)
                         for _ in range(n_repeats)])
lam_sure_098 = np.array([one_lambda_sure(n, sigma, 0.98, mu_mean, mu_std, rng)
                         for _ in range(n_repeats)])


plt.style.use('ggplot')
fig, ax = plt.subplots(1, 3, figsize=(14, 4))

hist_specs = [
    (lam_sure_090, 0, r'$p = 0.90$', 'steelblue'),
    (lam_sure_095, 1, r'$p = 0.95$', 'darkorange'),
    (lam_sure_098, 2, r'$p = 0.98$', 'seagreen')
]

for data, idx, title, color in hist_specs:
    ax[idx].hist(data, bins=30, color=color, alpha=0.8)
    ax[idx].set_title(r'$\lambda_{\mathrm{SURE}}$ distribution, ' + title)
    ax[idx].set_xlabel(r'$\lambda_{\mathrm{SURE}}$')
    ax[idx].set_ylabel('frequency')

plt.tight_layout()
plt.show()

plt.figure(figsize=(6,4))
plt.boxplot([lam_sure_090, lam_sure_095, lam_sure_098],
            labels=[r'$p=0.90$', r'$p=0.95$', r'$p=0.98$'],
            patch_artist=True,
            boxprops=dict(facecolor='lightgray'))

plt.ylabel(r'$\lambda_{\mathrm{SURE}}$')
plt.title(r'Box-whisker comparison of $\lambda_{\mathrm{SURE}}$')
plt.tight_layout()
plt.show()

print("Values of Î»_SURE (mean with sd):")
print(f"p = 0.90: {lam_sure_090.mean():.3f} +- {lam_sure_090.std(ddof=1):.3f}")
print(f"p = 0.95: {lam_sure_095.mean():.3f} +- {lam_sure_095.std(ddof=1):.3f}")
print(f"p = 0.98: {lam_sure_098.mean():.3f} +- {lam_sure_098.std(ddof=1):.3f}")