<a href="https://colab.research.google.com/github/AntoineChapel/metrics1_part2_hw1/blob/main/exo3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp


def wald_test(X, y, R, c):
    X = jnp.array(X)
    y = jnp.array(y)
    R = jnp.array(R)
    c = jnp.array(c)

    n = X.shape[0]
    k = X.shape[1]
    beta_hat = jnp.linalg.solve(X.T@X, X.T@y)
    e_hat = y - X@beta_hat
    D1 = X.T@jnp.diag((e_hat**2).flatten())@X
    v_hc1 = (n/(n-k))*jnp.linalg.inv(X.T@X)@D1@jnp.linalg.inv(X.T@X)
    w_stat = (R.T@beta_hat - c).T@jnp.linalg.inv((R.T@v_hc1@R))@(R.T@beta_hat - c)
    q = R.shape[1]
    pval = (1 - jsp.stats.chi2.cdf(w_stat, q))
    return w_stat[0][0], pval[0][0]


def gen_random_sample(arr, rng):
    n = arr.shape[0]
    indices = jax.random.randint(rng, shape=(n, ), minval = 0, maxval = n)
    sample = arr[indices]
    return sample



def bootstrap_pairs(X, y, R, key):
    subkey = jax.random.split(key, num=2)[1]
    arr = jnp.hstack((y.reshape(-1, 1), X))
    sample_b = fast_b_sample(arr, subkey)
    n = X.shape[0]
    k = X.shape[1]

    beta_hat = jnp.linalg.solve(X.T @ X, X.T @ y)
    c_bootstrap = (R.T @ beta_hat).reshape(-1, 1)

    y_b = jax.lax.slice(sample_b, (0, 0), (n, 1)).reshape(-1, 1)
    X_b = jax.lax.slice(sample_b, (0, 1), (n, k+1)).reshape(n, k)
    w_b = fast_wald_test(X_b, y_b, R, c_bootstrap)[0]
    return w_b

def bootstrap_residuals(X, y, R, key):
    subkey = jax.random.split(key, num=2)[1]
    beta_hat = jnp.linalg.solve(X.T@X, X.T@y)
    e_hat = (y - X@beta_hat).reshape(-1, 1)

    estar_b = fast_b_sample(e_hat, subkey).reshape(-1, 1)
    ystar = X@beta_hat + estar_b

    c_bootstrap = (R.T @ beta_hat).reshape(-1, 1)
    w_b = fast_wald_test(X, ystar, R, c_bootstrap)[0]
    return w_b

def bootstrap_wild(X, y, R, key, wildtype="rademacher"):
    n = X.shape[0]
    if wildtype=="rademacher":
        x = jnp.array([-1, 1])
        xi = jax.random.choice(key, x, shape=(n, 1), p=jnp.array([1/2, 1/2]))
    elif wildtype=="mammen":
        x = jnp.array([(1+jnp.sqrt(5))/2, (1 - jnp.sqrt(5))/2])
        xi = jax.random.choice(key, x, shape=(n, 1), p=jnp.array([(jnp.sqrt(5) - 1)/(2*jnp.sqrt(5)), (jnp.sqrt(5)+1)/(2*jnp.sqrt(5))]))
    else:
        print("ERROR, please specify a valid wild method")
    xi = jnp.array(xi)
    beta_hat = jnp.linalg.solve(X.T@X, X.T@y)
    e_hat = (y - X@beta_hat).reshape(-1, 1)

    estar_b = jnp.multiply(e_hat, xi).reshape(-1, 1)
    ystar = X@beta_hat + estar_b
    c_bootstrap = (R.T @ beta_hat).reshape(-1, 1)
    w_b = fast_wald_test(X, ystar, R, c_bootstrap)[0]
    return w_b

fast_wald_test = jax.jit(wald_test)
fast_b_sample = jax.jit(gen_random_sample)


def bootstrapped_wald_test(X, y, R, c, B, key, boottype="wild", wildtypeval="rademacher"):
    """Runs a Wald test through the pairs bootstrap algorithm"""
    w_stat = fast_wald_test(X, y, R, c)[0]
    keys = jax.random.split(key, num=B)
    if boottype == "pairs":
        w_list = jax.vmap(lambda key: bootstrap_pairs(X, y, R, key))(keys)
    elif boottype == "residuals":
        w_list = jax.vmap(lambda key: bootstrap_residuals(X, y, R, key))(keys)
    elif boottype == "wild":
        w_list = jax.vmap(lambda key: bootstrap_wild(X, y, R, key, wildtype=wildtypeval))(keys)
    pval = jnp.round(jnp.sum(w_list > w_stat)/B, 5)
    return w_stat, pval



### CONFIDENCE INTERVAL FOR BETA
#asymptotic:
def confidence_interval(X, y, alpha):
    X = jnp.array(X)
    y = jnp.array(y).reshape(-1, 1)
    n = X.shape[0]
    k = X.shape[1]
    try:
        Qxx_hat = jnp.linalg.inv(X.T@X)
    except:
        print("WARNING: matrix not invertible")

    beta_hat = Qxx_hat@(X.T@y)

    e_hat = y - X@beta_hat
    D1 = X.T@jnp.diag((e_hat**2).flatten())@X
    v_hc1 = (n/(n-k))*Qxx_hat@D1@Qxx_hat

    q = jsp.stats.norm.ppf(1-alpha/2)

    se_vec = jnp.sqrt(jnp.diag(v_hc1)).reshape(-1, 1)
    ci_beta = jnp.hstack((beta_hat - q*se_vec , beta_hat + q*se_vec))
    return beta_hat, ci_beta


def pairs_t_test(X, y, key):
    n = X.shape[0]
    k = X.shape[1]
    try:
        Qxx_hat = jnp.linalg.inv(X.T@X)
    except:
        print("WARNING: matrix not invertible")
    beta_hat = Qxx_hat@(X.T@y).reshape(-1, 1)

    subkey = jax.random.split(key, num=2)[1]
    arr = jnp.hstack((y, X))
    subsample = fast_b_sample(arr, subkey)

    y_b = jax.lax.slice(subsample, (0, 0), (n, 1)).reshape(-1, 1)
    X_b = jax.lax.slice(subsample, (0, 1), (n, k+1)).reshape(n, k)
    beta_hat_b = jnp.linalg.solve(X_b.T@X_b, X_b.T@y_b).reshape(-1, 1)
    e_hat_b = y_b - X_b@beta_hat_b
    D1_b = X_b.T@jnp.diag((e_hat_b**2).flatten())@X_b
    v_hc1_b = (n/(n-k))*jnp.linalg.inv(X_b.T@X_b)@D1_b@jnp.linalg.inv(X_b.T@X_b)

    T_b = jnp.multiply((beta_hat_b - beta_hat), (1/(jnp.sqrt(jnp.diag(v_hc1_b))).reshape(-1, 1))).flatten()

    return T_b


def residual_t_test(X, y, key):
    n = X.shape[0]
    k = X.shape[1]
    try:
        Qxx_hat = jnp.linalg.inv(X.T@X) #this is Qxx^-1 but ok
    except:
        print("WARNING: matrix not invertible")
    beta_hat = Qxx_hat@(X.T@y).reshape(-1, 1)

    subkey = jax.random.split(key, num=2)[1]
    e_hat = y - X@beta_hat
    e_star = fast_b_sample(e_hat, subkey)
    y_b = X@beta_hat + e_star

    beta_hat_b = jnp.linalg.solve(X.T@X, X.T@y_b).reshape(-1, 1)
    e_hat_b = y_b - X@beta_hat_b
    D1_b = X.T@jnp.diag((e_hat_b**2).flatten())@X
    v_hc1_b = (n/(n-k))*Qxx_hat@D1_b@Qxx_hat

    T_b = jnp.multiply((beta_hat_b - beta_hat), (1/(jnp.sqrt(jnp.diag(v_hc1_b))).reshape(-1, 1))).flatten()

    return T_b



def wild_t_test(X, y, key, wildtype="rademacher"):
    n = X.shape[0]
    k = X.shape[1]
    subkey = jax.random.split(key, num=2)[1]
    if wildtype=="rademacher":
        x = jnp.array([-1, 1])
        xi = jax.random.choice(subkey, x, shape=(n, 1), p=jnp.array([1/2, 1/2]))
    elif wildtype=="mammen":
        x = jnp.array([(1+jnp.sqrt(5))/2, (1 - jnp.sqrt(5))/2])
        xi = jax.random.choice(subkey, x, shape=(n, 1), p=jnp.array([(jnp.sqrt(5) - 1)/(2*jnp.sqrt(5)), (jnp.sqrt(5)+1)/(2*jnp.sqrt(5))]))
    else:
        print("ERROR, please specify a valid wild method")
    xi = jnp.array(xi)

    try:
        Qxx_hat = jnp.linalg.inv(X.T@X) #this is Qxx^-1 but ok
    except:
        print("WARNING: matrix not invertible")
    beta_hat = Qxx_hat@(X.T@y).reshape(-1, 1)


    e_hat = y - X@beta_hat
    e_star = jnp.multiply(e_hat, xi)
    y_b = X@beta_hat + e_star

    beta_hat_b = jnp.linalg.solve(X.T@X, X.T@y_b).reshape(-1, 1)
    e_hat_b = y_b - X@beta_hat_b
    D1_b = X.T@jnp.diag((e_hat_b**2).flatten())@X
    v_hc1_b = (n/(n-k))*Qxx_hat@D1_b@Qxx_hat

    T_b = jnp.multiply((beta_hat_b - beta_hat), (1/(jnp.sqrt(jnp.diag(v_hc1_b))).reshape(-1, 1))).flatten()

    return T_b



def bootstrapped_confidence_interval(X, y, alpha, B, key, boottype="pairs", wildtype_val="rademacher"):
    n = X.shape[0]
    k = X.shape[1]
    try:
        Qxx_hat = jnp.linalg.inv(X.T@X)
    except:
        print("WARNING: matrix not invertible")
    beta_hat = Qxx_hat@(X.T@y).reshape(-1, 1)
    e_hat = y - X@beta_hat
    D1 = X.T@jnp.diag((e_hat**2).flatten())@X
    v_hc1 = (n/(n-k))*Qxx_hat@D1@Qxx_hat

    keys = jax.random.split(key, num=B)

    #plug_in_T_function
    if boottype == "pairs":
        T_list = jax.vmap(lambda key: pairs_t_test(X, y, key))(keys)
    elif boottype == "residuals":
        T_list = jax.vmap(lambda key: residual_t_test(X, y, key))(keys)
    elif boottype == "wild":
        T_list = jax.vmap(lambda key: wild_t_test(X, y, key, wildtype=wildtype_val))(keys)


    t = jnp.percentile(T_list, 100*(1 - alpha/2), axis=0).reshape(-1, 1)
    se_vec = jnp.sqrt(jnp.diag(v_hc1)).reshape(-1, 1)
    ci_beta = jnp.hstack((beta_hat - jnp.multiply(t, se_vec), beta_hat + jnp.multiply(t, se_vec)))
    return beta_hat, ci_beta

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

key = jax.random.PRNGKey(6411)
np.random.seed(6411)

#CONSTANTS
N = 500
K = 3
B = 300

#DGP
beta = np.array([1, 1, 0.5]).reshape(-1, 1)
x = np.random.uniform(-2, 2, size=(N, 1))
epsilon = np.random.normal(0, 1, size=(N, 1))
sigma = np.sqrt(2 + 1/2*(x**2))
X = np.hstack((np.ones((N, 1)), x, x**2))
y = X@beta + np.multiply(sigma, epsilon)

#theta = R'beta
#hypothesis: R'beta = [1, 0].T
R = np.array([[1, 0],
              [0, 1],
              [0, -2]])
c = np.array([1, 0]).reshape(-1, 1)


asy_w, asy_pval = wald_test(X, y, R, c)
pairs_w, pairs_pval = bootstrapped_wald_test(X, y, R, c, B, key, boottype="pairs")
resid_w, resid_pval = bootstrapped_wald_test(X, y, R, c, B, key, boottype="residuals")
wild_w_r, wild_pval_r = bootstrapped_wald_test(X, y, R, c, B, key, boottype="wild", wildtypeval="rademacher")
wild_w_m, wild_pval_m = bootstrapped_wald_test(X, y, R, c, B, key, boottype="pairs", wildtypeval="mammen")

results = np.array([[asy_w, asy_pval],
                    [pairs_w, pairs_pval],
                    [resid_w, resid_pval],
                    [wild_w_r, wild_pval_r],
                    [wild_w_m, wild_pval_m]])

results_df = pd.DataFrame(results)

results_df.columns=["Wald Stat", "PVAL"]
results_df.index = ["Asymptotic", "Pairs Bootstrap", "Residuals Bootstrap", "Wild Bootstrap Rademacher", "Wild Bootstrap Mammen"]

print("Wald test for h0: b0 = 0 and b1 = 2b2")
print(results_df)


beta, ci = confidence_interval(X, y, 0.05)
beta_p, ci_p = bootstrapped_confidence_interval(X, y, 0.05, B, key)
beta_r, ci_r = bootstrapped_confidence_interval(X, y, 0.05, B, key, boottype="residuals")
beta_w_r, ci_w_r = bootstrapped_confidence_interval(X, y, 0.05, B, key, boottype="wild", wildtype_val="rademacher")
beta_w_m, ci_w_m = bootstrapped_confidence_interval(X, y, 0.05, B, key, boottype="wild", wildtype_val="mammen")

results_ci_b1 = np.vstack((ci[1, :], ci_p[1, :], ci_r[1, :], ci_w_r[1, :], ci_w_m[1, :]))
results_ci_df = pd.DataFrame(results_ci_b1)


results_ci_df.index = ["Asymptotic", "Pairs Bootstrap", "Residuals Bootstrap", "Wild Bootstrap Rademacher", "Wild Bootstrap Mammen"]
results_ci_df.columns = ["C 2.5", "C 97.5"]

print("Confidence interval for b1")
print(results_ci_df)

Wald test for h0: b0 = 0 and b1 = 2b2
                           Wald Stat      PVAL
Asymptotic                    1.7329  0.420441
Pairs Bootstrap               1.7329  0.410000
Residuals Bootstrap           1.7329  0.450000
Wild Bootstrap Rademacher     1.7329  0.480000
Wild Bootstrap Mammen         1.7329  0.410000
Confidence interval for b1
                              C 2.5    C 97.5
Asymptotic                 0.839161  1.107434
Pairs Bootstrap            0.834087  1.112508
Residuals Bootstrap        0.825603  1.120992
Wild Bootstrap Rademacher  0.843829  1.102766
Wild Bootstrap Mammen      0.845968  1.100627


In [3]:
import jax
import jax.numpy as jnp
import numpy as np

key = jax.random.PRNGKey(6411)
np.random.seed(6411)

#CONSTANTS
N = 500
K = 3
B = 300

#DGP
beta = np.array([1, 1, 0.5]).reshape(-1, 1)

#theta = R'beta
#hypothesis: R'beta = [1, 0].T
R = np.array([[1, 0],
              [0, 1],
              [0, -2]])
c = np.array([1, 0]).reshape(-1, 1)



def wald_test_by_repetition(N_REP, B, asy=False, boottype="pairs", wildtype="rademacher", verbose=1):
    arr_rep = np.empty(N_REP,)
    for i in range(N_REP):
        x = np.random.uniform(-2, 2, size=(N, 1))
        sigma = np.sqrt(2 + 1/2*(x**2))
        epsilon = np.random.normal(0, 1, size=(N, 1))
        X = np.hstack((np.ones((N, 1)), x, x**2))
        y = X@beta + np.multiply(sigma, epsilon)

        if asy==True:
            w, p = fast_wald_test(X, y, R, c)
        else:
            w, p = bootstrapped_wald_test(X, y, R, c, B, key, boottype, wildtype)

        arr_rep[i] = p
    if verbose >= 1:
        pct_typeI = 100*(np.sum(arr_rep < 0.05))/N_REP
        if asy==True:
            print(f"Type I error in the asymptotic Wald test: {pct_typeI} %")
        elif boottype=="wild":
            print(f"Type I error in the {boottype} {wildtype} bootstrapped wald test: {pct_typeI} %")
        else:
            print(f"Type I error in the {boottype} bootstrapped wald test: {pct_typeI} %")

    return pct_typeI






def ci_test_by_repetition(N_REP, B, asy=False, boottype="pairs", wildtype="rademacher", verbose=1):
    true_b1 = 1
    arr_rep = np.empty((N_REP, 2))
    for i in range(N_REP):
        x = np.random.uniform(-2, 2, size=(N, 1))
        sigma = np.sqrt(2 + 1/2*(x**2))
        epsilon = np.random.normal(0, 1, size=(N, 1))
        X = np.hstack((np.ones((N, 1)), x, x**2))
        y = X@beta + np.multiply(sigma, epsilon)

        if asy==False:
            b, ci = bootstrapped_confidence_interval(X, y, 0.05, B, key, boottype, wildtype)
        else:
            b, ci = confidence_interval(X, y, 0.05)
        ci_b0 = np.array(ci[0, :])
        if verbose >= 2:
            print(f"repetition {i}, {ci_b0}")
        arr_rep[i, :] = ci_b0
    if verbose >= 1:
        pct_typeI = 100*(np.sum(arr_rep[:, 0] > true_b1) + np.sum(arr_rep[:, 1] < true_b1))/N_REP
        if asy==True:
            print(f"Type I error in the asymptotic confidence interval: {pct_typeI} %")
        elif boottype=="wild":
            print(f"Type I error in the {boottype} {wildtype} bootstrapped confidence interval: {pct_typeI} %")
        else:
            print(f"Type I error in the {boottype} bootstrapped confidence interval: {pct_typeI} %")
    return pct_typeI


N_REP = 1000

print("--------------------------------------------------------")
wald_test_by_repetition(N_REP, B, asy=True, verbose=1)
wald_test_by_repetition(N_REP, B, boottype="pairs", verbose=1)
wald_test_by_repetition(N_REP, B, boottype="residuals", verbose=1)
wald_test_by_repetition(N_REP, B, boottype="wild", wildtype="rademacher", verbose=1)
wald_test_by_repetition(N_REP, B, boottype="wild", wildtype="mammen", verbose=1)



ci_test_by_repetition(N_REP, B, asy=True, verbose=1)
ci_test_by_repetition(N_REP, B, boottype="pairs", verbose=1)
ci_test_by_repetition(N_REP, B, boottype="residuals", verbose=1)
ci_test_by_repetition(N_REP, B, boottype="wild", wildtype="rademacher", verbose=1)
ci_test_by_repetition(N_REP, B, boottype="wild", wildtype="mammen", verbose=1)
print("--------------------------------------------------------")

--------------------------------------------------------
Type I error in the asymptotic Wald test: 6.1 %
Type I error in the pairs bootstrapped wald test: 4.1 %
Type I error in the residuals bootstrapped wald test: 4.6 %
Type I error in the wild rademacher bootstrapped wald test: 6.8 %
Type I error in the wild mammen bootstrapped wald test: 4.9 %
Type I error in the asymptotic confidence interval: 5.8 %
Type I error in the pairs bootstrapped confidence interval: 5.8 %
Type I error in the residuals bootstrapped confidence interval: 6.1 %
Type I error in the wild rademacher bootstrapped confidence interval: 5.2 %
Type I error in the wild mammen bootstrapped confidence interval: 5.6 %
--------------------------------------------------------
