In [None]:
import numpy as np
import pandas as pd
from scipy import stats
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from plotnine import *
import glob

import os
import sys
import time
sys.path.insert(0, "../mosaicperm/")
import mosaicperm as mp
from mosaicperm.utilities import vrange, elapsed

def calc_mean_sem(data, group_vals, meas, trunc_zero=True):
    """
    Groups data by group_vals and then calculates mean, standard error
    for each column.
    """
    agg_df = data.groupby(group_vals)[meas].agg(['mean', 'std', 'sem']).reset_index()
    for m in meas:
        agg_df[f'{m}_mean'] = agg_df[m]['mean']
        agg_df[f'{m}_std'] = agg_df[m]['std']
        agg_df[f'{m}_sem'] = agg_df[m]['sem']
        agg_df[f'{m}_ymin'] =  agg_df[f'{m}_mean'] - 2*agg_df[f'{m}_sem']
        if trunc_zero:
            agg_df[f'{m}_ymin'] = np.maximum(0, agg_df[f'{m}_ymin'])
        agg_df[f'{m}_ymax'] =  agg_df[f'{m}_mean'] + 2*agg_df[f'{m}_sem']
    
    agg_df = agg_df.loc[:, agg_df.columns.get_level_values(1) == '']
    agg_df.columns = agg_df.columns.get_level_values(0)
    return agg_df

In [None]:
def naive_plugin(hateps, test_stat, R=100):
    """
    Perform naive permutation test on hateps OLS.
    """
    n, p = hateps.shape
    S = test_stat(hateps)
    hatepsr = hateps.copy()
    inds = np.arange(n)
    S0s = np.zeros(R)
    for r in vrange(R, verbose=False):
        for j in range(p):
            np.random.shuffle(inds)
            hatepsr[:, j] = hateps[:, j][inds]
        S0s[r] = test_stat(hatepsr)
    pval = (1 + np.sum(S <= S0s)) / (R + 1)
    return pval, S, S0s

def resid_bootstrap(hateps, test_stat, R=100):
    """
    Residual bootstrap from paper.
    """
    n, p = hateps.shape
    S = test_stat(hateps)
    hatepsr = hateps.copy()
    inds = np.arange(n)
    S0s = np.zeros(R)
    for r in vrange(R, verbose=False):
        #for j in range(p):
        #    hatepsr[:, j] = np.random.choice(hateps[:,j], size=n, replace=True)
        S0s[r] = test_stat(hateps[np.random.choice(inds, size=n, replace=True)])
    pval = (1 + np.sum(S <= S0s)) / (R + 1)
    return pval, S, S0s

def bootstrap(Y, H, test_stat, R=100):
    """
    Nonparametric bootstrap (identical to residual bootstrap when exposures are constant).
    """
    n, p = Y.shape
    # OLS
    hateps = Y @ H
    S = test_stat(hateps)
    # Bootstrap
    inds = np.arange(n)
    S0s = np.zeros(R)
    for r in vrange(R, verbose=False):
        Ynew = Y[np.random.choice(inds, size=n, replace=True)]
        S0s[r] = test_stat(Ynew @ H)
    pval = (1 + np.sum(S <= S0s)) / (R + 1)
    return pval, S, S0s

## 1. Invalidity of bootstrap (version 2)

In [None]:
# Load factor exposures
L_PATH = "data/bfre_cache/simulation_exposures.npy"
L_PLACEHOLDER_PATH = "data/bfre_placeholder/simulation_exposures.npy"
if os.path.exists(L_PATH):
    L = np.load(L_PATH)
else:
    print("True exposures are not publicly available, using placeholder data instead.")
    L = np.load(L_PLACEHOLDER_PATH)
p, k = L.shape

# Sample Y
n = 500
p, k = L.shape
X = np.random.randn(n, k)
eps = np.random.randn(n, p)
Y = X @ L.T + eps
# Compute hateps
Q, R = np.linalg.qr(L)
H = np.eye(p) - Q @ Q.T
hateps = Y @ H

In [None]:
# bootstrap distribution
reps = 1000
bs_reps = 50
nboot_reps = reps
Ts = []
Ts_bs = np.zeros((reps, bs_reps))
#Ts_resid_bs = [] ## bootstrap and resid bootstrap are the same
Ts_naive = []
Ts_mosaic_null = []
Ts_mosaic = []

# test statistic
def test_stat(x):
    return mp.statistics.mean_maxcorr_stat(x)

t0 = time.time()
for rep in tqdm(list(range(reps))):
    #print(f"At rep={rep} at time={elapsed(t0)}.")
    for n in [300]:
        # Sample Y
        p, k = L.shape
        np.random.seed(rep)
        X = np.random.randn(n, k)
        eps = np.random.randn(n, p)
        Y = X @ L.T + eps
        # Compute hateps
        hateps = Y @ H
        # Compute p-values
        Ts.append(test_stat(hateps))
        if rep <= nboot_reps:
            # bootstrap
            _, _, S0s = bootstrap(Y, H, test_stat, R=bs_reps)
            Ts_bs[rep] = S0s
            # naive perm
            _, _, S0s = naive_plugin(hateps, test_stat, R=1)
            Ts_naive.append(S0s[0])
            
        # mosaic
        mptest = mp.factor.MosaicFactorTest(outcomes=Y, exposures=L, test_stat=test_stat)
        mptest.fit(nrand=1, verbose=False)
        Ts_mosaic.append(mptest.statistic)
        Ts_mosaic_null.append(mptest.null_statistics.item())
        
Ts_mosaic = np.array(Ts_mosaic)
Ts_mosaic_null = np.array(Ts_mosaic_null)
Ts_bs = np.array(Ts_bs)
#Ts_resid_bs = np.array(Ts_resid_bs)
Ts_naive = np.array(Ts_naive)
Ts = np.array(Ts)

In [None]:
bs_bias = np.mean(Ts_bs, axis=1) - Ts
Zs_bs = (Ts - bs_bias) / np.std(Ts_bs, axis=1)
Zs_gaussian = np.random.randn(reps)

In [None]:
mondmin = min(Ts_mosaic.min(), Ts_mosaic_null.min())
mondmax = max(Ts_mosaic.max(), Ts_mosaic_null.max())
mondbins = np.linspace(mondmin, mondmax, 13)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4.1))
olslab = 'OLS statistic, S($\hat\epsilon^{OLS}$)'

for axnum, color, stats, label in zip(
    [1, 1, 0, 0, 2, 2],
    ['blue', 'green', 'blue', 'dimgrey', 'orangered', 'cornflowerblue'],
    [Zs_gaussian, Zs_bs, Ts, Ts_naive, Ts_mosaic_null, Ts_mosaic],
    [
        "N(0,1)", 'Bootstrap z-statistics', olslab, 'Naive perm. test',
        'Mosaic permutations', r'Mosaic statistic, S($\hat\epsilon$)'
    ],
):
    ax = axes[axnum]
    sns.histplot(
        stats, 
        color=color, 
        alpha=0.5, 
        ax=ax, 
        label=label,
        linewidth=0.2,
        bins=mondbins if 'Mosaic' in label else 12,
    )
    if axnum != 0:
        axes[axnum].set(ylabel='')
    axes[axnum].set_ylim(0, 300)
    
for axchar, ax in zip(['a', 'b', 'c'], axes):
    ax.set(title=f"({axchar})")
for ax in axes:
    ax.legend()
    
#plt.savefig("plots/naive_methods.png", dpi=500, bbox_inches='tight')
plt.show()

## 2. Power comparisons with OLS

In [None]:
jobid = 44876868
fnames = glob.glob(f"sim_data/main_sims/*/*/*{jobid}*.csv")
df = []
for fname in fnames:
    df.append(pd.read_csv(fname))
df = pd.concat(df, axis='index')
#df = df.loc[df['n'] == 150]
alpha = 0.05

In [None]:
## OLS thresholds
ols_sub = df.loc[df['method'].str.contains("OLS")].copy()
group_vals = ['sparsity', 'test_stat', 'test_stat_index', 'method']
thresh = ols_sub.loc[ols_sub['rho'] == 0].groupby(group_vals)['T'].quantile(1-alpha)
thresh = thresh.reset_index().rename(columns={"T":"threshold"})
ols_sub = pd.merge(
    ols_sub, thresh, on=group_vals, how='left'
)
ols_sub['disc'] = ols_sub['T'] > ols_sub['threshold']
## Mondrian results
mpt_sub = df.loc[~df['method'].str.contains("OLS")].copy()
mpt_sub['disc'] = (mpt_sub['pval'] <= alpha).astype(float)
## final df
fdf = pd.concat([mpt_sub, ols_sub], axis='index').drop("threshold", axis='columns')
## aggregate statistics
agg = calc_mean_sem(
    fdf,
    group_vals=group_vals+['rho'],
    meas=['disc']
)

In [None]:
mpt_adaptive = agg.loc[
    (agg['method'] == 'MPT') 
    & (agg['test_stat_index'] == 'adaptive')
]
# A-priori maximum
nonadapt = agg.loc[agg['test_stat_index'] != 'adaptive']
ids = nonadapt.groupby(
    list(set(group_vals+['rho']) - set(['test_stat_index']))
)['disc_mean'].idxmax().values
oracle_index_stats = nonadapt.loc[ids]
oracle_index_stats['test_stat_index'] = 'oracle'
## to plot
df4plot = pd.concat([oracle_index_stats, mpt_adaptive], axis='index')

In [None]:
df4plot['Method'] = (df4plot['method'] + df4plot['test_stat_index']).map({
    "MPTadaptive":"MPT (adaptive)",
    "MPToracle":"MPT (oracle)",
    "OLS oracleoracle":r"OLS (double oracle)"
})

In [None]:
meas = 'disc'
g = (
    ggplot(
        df4plot.loc[(df4plot['test_stat'] == 'quant_corr')],
        aes(x='rho', y=f'{meas}_mean', color='Method')
    ) 
    + geom_point()
    + geom_line()
    + geom_errorbar(aes(ymin=meas+"_ymin", ymax=meas+"_ymax"), width=0.01)
    + facet_wrap("~sparsity", labeller=lambda x: rf"$s_0$={x}")
    + theme_bw()
    + theme(figure_size=(8,3))
    + geom_hline(yintercept=alpha, color='black', linetype='dotted')
    + scale_color_manual(['blue', 'red', 'black'])
    + labs(
        x=r'Signal size ($\rho$)', y='Power', 
    )
)
#g.save("final_plots/main_sim_results.png", dpi=500)
print(g)