This file contains python code to check the hypothesis testing

In [122]:
RUN_PYTHON_SCRIPT = False
#OUTLIER_IDXS = dict(AD=[], ctrl=[])
OUTLIER_IDXS = dict(AD=[49], ctrl=[14, 19, 30, 38])
SAVED_FOLDER = "real_data_nlinear_test"
DATA = ["AD88_PSD89_all_nosm.pkl", "Ctrl92_PSD89_all_nosm.pkl"]

['AD88_PSD89_all_nosm.pkl', 'Ctrl92_PSD89_all_nosm.pkl']

In [123]:
import sys
sys.path.append("../../mypkg")


In [124]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
from numbers import Number
import itertools

from easydict import EasyDict as edict
from tqdm import trange, tqdm
from scipy.io import loadmat
from pprint import pprint
from IPython.display import display
from joblib import Parallel, delayed

In [125]:
# This will reload all imports as soon as the code changes
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [126]:
from constants import DATA_ROOT, RES_ROOT, FIG_ROOT, MIDRES_ROOT
from hdf_utils.data_gen import gen_simu_sinica_dataset
from hdf_utils.SIS import SIS_GLIM
from utils.matrix import col_vec_fn, col_vec2mat_fn, conju_grad, svd_inverse, cholesky_inv
from utils.functions import logit_fn
from utils.misc import save_pkl, load_pkl
from splines import obt_bsp_obasis_Rfn, obt_bsp_basis_Rfn_wrapper
from projection import euclidean_proj_l1ball
from optimization.opt import HDFOpt
from hdf_utils.fns_sinica import  fourier_basis_fn

from joblib import Parallel, delayed

In [127]:
plt.style.use(FIG_ROOT/"base.mplstyle")

In [128]:
torch.set_default_tensor_type(torch.DoubleTensor)
def_dtype = torch.get_default_dtype()

torch.float64

# Load  data and prepare

In [129]:
data_root = DATA_ROOT/"AD_vs_Ctrl_PSD/";
AD_PSD = load_pkl(data_root/DATA[0]);
ctrl_PSD = load_pkl(data_root/DATA[1]);
baseline = pd.read_csv(data_root/"AllDataBaselineOrdered_r_ncpt.csv");
baseline["Gender_binary"] = baseline["Gender"].apply(lambda x: 0 if x=="female" else 1);
baseline["Grp_binary"] = baseline["Grp"].apply(lambda x: 1 if x=="AD" else 0);

Load file /data/rajlab1/user_data/jin/MyResearch/HDF_infer/notebooks/real_data/../../mypkg/../data/AD_vs_Ctrl_PSD/AD88_PSD89_all_nosm.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/HDF_infer/notebooks/real_data/../../mypkg/../data/AD_vs_Ctrl_PSD/Ctrl92_PSD89_all_nosm.pkl


In [141]:
# The outlier idxs to rm
outlier_idxs = np.concatenate([OUTLIER_IDXS["AD"], len(AD_PSD.PSDs)+np.array(OUTLIER_IDXS["ctrl"])])
outlier_idxs = outlier_idxs.astype(int)

# make PSD in dB and std 
raw_X = np.concatenate([AD_PSD.PSDs, ctrl_PSD.PSDs]); #n x d x npts
X_dB = np.log10(raw_X);
X = (X_dB - X_dB.mean(axis=-1, keepdims=1))/X_dB.std(axis=-1, keepdims=1);

Y = np.array(baseline["MMSE"])[:X.shape[0]];
# if logi
#Yb = np.array(baseline["Grp_binary"])[:X.shape[0]];

sel_cov = [ "MEG_Age","Gender_binary"]
Z_raw = np.array(baseline[sel_cov])[:X.shape[0]];

grp_idxs = np.array(baseline["Grp"])[:X.shape[0]];

# remove outliers
X = np.delete(X, outlier_idxs, axis=0)
Y = np.delete(Y, outlier_idxs, axis=0)
Z_raw = np.delete(Z_raw, outlier_idxs, axis=0)
grp_idxs = np.delete(grp_idxs, outlier_idxs, axis=0)


#remove nan
keep_idx = ~np.bitwise_or(np.isnan(Y), np.isnan(Z_raw.sum(axis=1)));
X = X[keep_idx];
Y = Y[keep_idx]
Z_raw = Z_raw[keep_idx]
grp_idxs = grp_idxs[keep_idx]

Z = np.concatenate([np.ones((Z_raw.shape[0], 1)), Z_raw], axis=1); # add intercept


print(X.shape, Y.shape, Z.shape)

all_data = edict()
all_data.X = torch.tensor(X+np.random.randn(*X.shape)*0.1)
all_data.Y = torch.tensor(Y)
all_data.Z = torch.tensor(Z)

freqs = AD_PSD.freqs;

(152, 68, 89) (152,) (152, 3)


In [142]:
# atlas
rois = np.loadtxt(DATA_ROOT/"dk68_utils/ROI_order_DK68.txt", dtype=str);

# Param and fns

## Params

In [143]:
from easydict import EasyDict as edict
from hdf_utils.fns_sinica import coef_fn, fourier_basis_fn
from copy import deepcopy
from scenarios.base_params import get_base_params

base_params = get_base_params("linear") 
base_params.data_params = edict()
base_params.data_params.d = all_data.X.shape[1]
base_params.data_params.n = all_data.X.shape[0]
base_params.data_params.npts = all_data.X.shape[-1]
base_params.data_params.freqs = AD_PSD.freqs

base_params.can_Ns = [4, 6, 8, 10, 12, 14]
base_params.SIS_params = edict({"SIS_pen": 0.02, "SIS_basis_N":8, "SIS_ws":"simpson"})
base_params.opt_params.beta = 10 
base_params.can_lams = [0.60,  0.80,  1,  1.2, 1.4, 1.6, 2.0, 4.0]


setting = edict(deepcopy(base_params))
add_params = edict({})
add_params.setting = "real_data_linear"
add_params.SIS_ratio = 1
setting.update(add_params)

In [144]:
save_dir = RES_ROOT/SAVED_FOLDER
if not save_dir.exists():
    save_dir.mkdir()

In [145]:
bands_cut = edict()
bands_cut.delta = [1, 4]
bands_cut.theta = [4, 8]
bands_cut.alpha = [8, 12]
bands_cut.beta = [12, 35]
bands_cut.pts = [4, 8, 12, 35]

cut_pts = np.abs(freqs.reshape(-1, 1) - bands_cut.pts).argmin(axis=0)

array([ 6, 14, 22, 68])

In [146]:
def _run_fn(roi_idx, lam, N, setting, is_save=False, is_cv=False, verbose=2):
    torch.set_default_dtype(torch.double)
        
    _setting = edict(setting.copy())
    _setting.lam = lam
    _setting.N = N
    _setting.sel_idx = np.delete(np.arange(setting.data_params.d), [roi_idx])
    
    
    f_name = f"roi_{roi_idx:.0f}-lam_{lam*1000:.0f}-N_{N:.0f}_fit.pkl"
    
    
    if not (save_dir/f_name).exists():
        hdf_fit = HDFOpt(lam=_setting.lam, 
                         sel_idx=_setting.sel_idx, 
                         model_type=_setting.model_type,
                         verbose=verbose, 
                         SIS_ratio=_setting.SIS_ratio, 
                         N=_setting.N,
                         is_std_data=True, 
                         cov_types=None, 
                         inits=None,
                         model_params = _setting.model_params, 
                         SIS_params = _setting.SIS_params, 
                         opt_params = _setting.opt_params,
                         bsp_params = _setting.bsp_params, 
                         pen_params = _setting.pen_params
               );
        hdf_fit.add_data(all_data.X, all_data.Y, all_data.Z)
        opt_res = hdf_fit.fit()
        
        if is_cv:
            hdf_fit.get_cv_est(_setting.num_cv_fold)
        if is_save:
            hdf_fit.save(save_dir/f_name, is_compact=False, is_force=True)
    else:
        hdf_fit = load_pkl(save_dir/f_name, verbose>=2);
        
    return hdf_fit


In [147]:
setting.model_params.ws = "sim"
setting.bsp_params.is_orth_basis = True
res = _run_fn(0, lam=1, N=10, setting=setting);

2024-03-05 21:49:06,856 - optimization.opt - INFO - opt params is {'stop_cv': 0.0005, 'max_iter': 2000, 'one_step_verbose': 0, 'alpha': 0.9, 'beta': 10, 'R': 200000.0, 'linear_theta_update': 'cholesky_inv'}.
2024-03-05 21:49:06,857 - optimization.opt - INFO - SIS params is {'SIS_pen': 0.02, 'SIS_basis_N': 8, 'SIS_basis_ord': 4, 'SIS_ratio': 1, 'SIS_ws': 'simpson'}.
2024-03-05 21:49:06,858 - optimization.opt - INFO - model params is {'norminal_sigma2': 1, 'ws': 'sim'}.
2024-03-05 21:49:06,858 - optimization.opt - INFO - penalty params is {'a': 3.7, 'lam': 1}.
2024-03-05 21:49:06,859 - optimization.opt - INFO - bspline params is {'basis_ord': 4, 'is_orth_basis': True, 'N': 10}.
2024-03-05 21:49:06,860 - optimization.opt - INFO - As cov_types is not provided, inferring the continuous covariates.
Main Loop:  48%|██████████████████████████████████████████▏                                            | 969/2000 [00:02<00:02, 482.00it/s, error=0.000505, GamL0=1, CV=0.0005]


In [148]:
res._prepare_hypotest()

In [149]:
res.hypo_test(Cmat=np.eye(1))

2024-03-05 21:49:10,370 - optimization.opt - INFO - hypo params is {'svdinv_eps_Q': 1e-07, 'svdinv_eps_Psi': 1e-07, 'Cmat': array([[1.]])}.


{'pval': 3.9207715768051176e-19, 'T_v': tensor(110.6635)}

In [150]:
torch.svd(res.hypo_utils.Q_mat_part).S

tensor([1.1900e+00, 1.0541e+00, 2.3633e-01, 1.9493e-01, 1.4809e-01, 3.3426e-02,
        1.4293e-02, 7.8611e-03, 7.2049e-03, 3.9487e-03, 2.7440e-03, 2.2244e-03,
        1.0220e-03])

In [151]:
res.get_covmat().beta

tensor([[12.9630,  3.1737, 11.1189,  8.5950,  8.7949,  6.7072,  8.9414,  2.4872,
          3.1607, -2.0657],
        [ 3.1737, 11.9497,  7.6595, 10.8575,  9.1821, 10.0050,  9.6870,  6.5429,
          6.9639,  7.7386],
        [11.1189,  7.6595, 15.1519, 11.7482, 11.7627, 11.7995, 13.1353,  5.6330,
          7.1011,  5.2273],
        [ 8.5950, 10.8575, 11.7482, 16.2650,  9.7389, 12.4699, 16.3653,  7.1542,
         10.7170,  7.2861],
        [ 8.7949,  9.1821, 11.7627,  9.7389, 16.6788,  8.6524, 10.9819, 13.8106,
          8.1629,  3.7891],
        [ 6.7072, 10.0050, 11.7995, 12.4699,  8.6524, 16.7455, 10.7615,  7.6003,
          8.2856, 14.5641],
        [ 8.9414,  9.6870, 13.1353, 16.3653, 10.9819, 10.7615, 25.0305,  5.1681,
         14.6212, 11.8124],
        [ 2.4872,  6.5429,  5.6330,  7.1542, 13.8106,  7.6003,  5.1681, 33.8418,
          6.9636,  3.7851],
        [ 3.1607,  6.9639,  7.1011, 10.7170,  8.1629,  8.2856, 14.6212,  6.9636,
         22.2621,  0.9446],
        [-2.0657,  