This file contains python code for convergence test under linear setting

In [1]:
import sys
sys.path.append("../mypkg")

In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from scipy.stats import pearsonr
from numbers import Number
import multiprocessing as mp

from easydict import EasyDict as edict
from tqdm import trange
from scipy.io import loadmat
from pprint import pprint

plt.style.use('ggplot')
plt.rcParams["savefig.bbox"] = "tight"

In [3]:
import importlib
import hdf_utils.data_gen
importlib.reload(hdf_utils.data_gen)

<module 'hdf_utils.data_gen' from '/data/rajlab1/user_data/jin/MyResearch/HDF_infer/notebooks/../mypkg/hdf_utils/data_gen.py'>

In [4]:
from constants import DATA_ROOT, RES_ROOT, FIG_ROOT

from hdf_utils.data_gen import gen_covs, gen_simu_psd
from hdf_utils.fns import fn1, fn2, fn3, fn4, fn5, zero_fn
from utils.matrix import col_vec_fn, col_vec2mat_fn, conju_grad, svd_inverse
from utils.functions import logit_fn
from utils.misc import save_pkl
from splines import obt_bsp_basis_Rfn
from projection import euclidean_proj_l1ball
from optimization.one_step_opt import OneStepOpt

from penalties.scad_pen import SCAD
from penalties.base_pen import PenaltyBase
from models.logistic_model import LogisticModel
from models.linear_model import LinearModel


In [5]:
torch.set_default_tensor_type(torch.DoubleTensor)

## Param and fns

### Params

In [6]:
# freqs
ind_freq = np.linspace(1, 40, 40)

In [17]:
np.random.seed(0)
paras = edict()
paras.model = "linear"
paras.num_rep = 20
paras.n = 1000 # num of data obs to be genareted
paras.ns = [100, 300, 900, 2700, 8100, 24300]

paras.npts = 40 # num of pts to evaluate X(s)
paras.d = 10 # num of ROIs
paras.q = 5 # num of other covariates
paras.sigma2 = 1 # variance of the error
paras.sel_idx = np.arange(2, paras.d) # M^c set
paras.stop_cv = 5e-4 # stop cv for convergence
paras.max_iter = 2000
paras.can_lams = [1e-1, 3e-1, 1e0, 3e0, 9e0, 3e1]
paras.can_Rfcts = [1,  2]
paras.can_Ns = [5, 10, 15, 20]


paras.bsp = edict()
paras.bsp.ord = 4
paras.bsp.N = int(8*paras.n**(1/paras.bsp.ord/2)) # num of basis for bsp
paras.bsp.x = np.linspace(0, 1, paras.npts)
paras.bsp.aknots_raw = np.linspace(0, 1, paras.bsp.N-2)
paras.bsp.iknots = paras.bsp.aknots_raw[1:-1]
paras.bsp.bknots = np.array([0, 1])
paras.bsp.basis_mat = obt_bsp_basis_Rfn(paras.bsp.x, 
                                        paras.bsp.iknots, 
                                        paras.bsp.bknots, 
                                        paras.bsp.ord)
assert paras.bsp.N == paras.bsp.basis_mat.shape[1]
print(f"The number of B-spline basis is {paras.bsp.N:.0f}.")

paras.types_ = ["int", 2, 2, "c", "c"]
paras.alp_GT = np.array([5, 1, -2, 3, -4])

beta_type_GT = [fn1, fn2, fn3] + [zero_fn]*(paras.d-3-2) + [fn4, fn5]
paras.beta_GT = np.array([_fn(paras.bsp.x) for _fn in beta_type_GT]).T
paras.Gam_est = (np.linalg.inv(paras.bsp.basis_mat.T 
                               @ paras.bsp.basis_mat) 
                 @ paras.bsp.basis_mat.T 
                 @ paras.beta_GT)

paras.Rmin = 2*(np.linalg.norm(paras.Gam_est/np.sqrt(paras.bsp.N), axis=0).sum() + np.abs(paras.alp_GT).sum())

#paras.basis_mat = torch.DoubleTensor(paras.bsp.basis_mat) # npts x N

The number of B-spline basis is 18.


In [18]:
paras.save_dir = RES_ROOT/"linear_test"
if not paras.save_dir.exists():
    paras.save_dir.mkdir()
save_pkl(paras.save_dir/f"paras_{paras.n}.pkl", paras, is_force=True)

Save to /data/rajlab1/user_data/jin/MyResearch/HDF_infer/notebooks/../mypkg/../results/linear_test/paras_1000.pkl


In [19]:
pprint(paras)

{'Gam_est': matrix([[ 4.16614805e-03,  1.00005563e+01, -1.00000000e+01,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  5.09066801e-18,
         -6.00000072e+00],
        [ 4.16983155e+00,  1.27892804e+01, -9.95555556e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  4.44444444e-01,
         -6.00003083e+00],
        [ 1.23927172e+01,  1.83719557e+01, -9.87555556e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  1.33333333e+00,
         -5.99973900e+00],
        [ 7.68649742e+00,  2.12725504e+01, -9.78044444e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  2.66666667e+00,
         -6.00331056e+00],
        [-7.68006987e+00,  1.68164317e+01, -9.70844444e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+0

### Fns

In [20]:
def gen_simu_data_all(seed, paras):
    np.random.seed(seed)
    # simulated PSD
    assert len(paras.types_) == paras.q
    assert len(paras.alp_GT) == paras.q
    
    simu_psd = gen_simu_psd(paras.n, paras.d, 10)
    simu_covs = gen_covs(paras.n, paras.types_)
    
    # linear term and Y
    int_part = np.sum(paras.beta_GT.T* simu_psd[:, :, :], axis=1).mean(axis=1)
    cov_part = simu_covs @ paras.alp_GT 
    
    # linear term
    lin_term = cov_part + int_part
    
    # Y 
    rvs = np.random.rand(paras.n)
    Y = lin_term + np.random.randn(paras.n)*np.sqrt(paras.sigma2)
    
    # To torch
    X = torch.tensor(simu_psd) # n x d x npts
    Z = torch.tensor(simu_covs) # n x q
    Y = torch.tensor(Y)
    
    all_data = edict()
    all_data.X = X
    all_data.Y = Y
    all_data.Z = Z
    return all_data

## Simu

In [21]:
def run_fn(seed, lam, N, Rfct=2, is_small=False):
    print((seed, lam, N, Rfct), "\n")
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    paras.bsp = edict()
    paras.bsp.ord = 4
    paras.bsp.N = N
    paras.bsp.x = np.linspace(0, 1, paras.npts)
    paras.bsp.aknots_raw = np.linspace(0, 1, paras.bsp.N-2)
    paras.bsp.iknots = paras.bsp.aknots_raw[1:-1]
    paras.bsp.bknots = np.array([0, 1])
    paras.bsp.basis_mat = obt_bsp_basis_Rfn(paras.bsp.x, 
                                        paras.bsp.iknots, 
                                        paras.bsp.bknots, 
                                        paras.bsp.ord)
    assert paras.bsp.N == paras.bsp.basis_mat.shape[1]
    
    paras.Gam_est = (np.linalg.inv(paras.bsp.basis_mat.T 
                                   @ paras.bsp.basis_mat) 
                     @ paras.bsp.basis_mat.T 
                     @ paras.beta_GT)
    
    paras.Rmin = 2*(np.linalg.norm(paras.Gam_est/np.sqrt(paras.bsp.N), axis=0).sum() + np.abs(paras.alp_GT).sum())
    
    alp_init = torch.tensor(paras.alp_GT) + torch.randn(paras.q)*0
    Gam_init = torch.tensor(paras.Gam_est) + torch.randn(paras.bsp.N, paras.d)*0
    theta_init = torch.cat([alp_init, col_vec_fn(Gam_init)/np.sqrt(paras.bsp.N)])
    rhok_init = torch.randn(paras.d*paras.bsp.N)
    last_Gamk = 0
    last_rhok = 0
    last_thetak = 0
    
    cur_data = gen_simu_data_all(seed, paras)
    model = LinearModel(Y=cur_data.Y, X=cur_data.X, Z=cur_data.Z, 
                        basis_mat=torch.DoubleTensor(paras.bsp.basis_mat), 
                        sigma2=paras.sigma2)
    # 3e0
    pen = SCAD(lams=lam, a=3.7,  sel_idx=paras.sel_idx)
    
    for ix in range(paras.max_iter):
        opt = OneStepOpt(Gamk=Gam_init, 
                      rhok=rhok_init, 
                      theta_init=theta_init, 
                      alpha=0.9, beta=1, model=model, penalty=pen, 
                      q=paras.q, NR_eps=1e-5, NR_maxit=100, R=paras.Rmin*Rfct)
        opt()
        Gam_init = opt.Gamk
        rhok_init = opt.rhok
        theta_init = opt.thetak
        
        
        # converge cv
        Gam_diff = opt.Gamk- last_Gamk
        Gam_diff_norm = torch.norm(Gam_diff)/torch.norm(opt.Gamk)
        
        theta_diff = opt.thetak - last_thetak
        theta_diff_norm = torch.norm(theta_diff)/torch.norm(opt.thetak)
        
        Gam_theta_diff = opt.Gamk - col_vec2mat_fn(opt.thetak[paras.q:], nrow=paras.bsp.N)*np.sqrt(paras.bsp.N)
        Gam_theta_diff_norm = torch.norm(Gam_theta_diff)/torch.norm(opt.Gamk)
        
        stop_v = np.max([Gam_diff_norm.item(), theta_diff_norm.item(), Gam_theta_diff_norm.item()])
        if stop_v < paras.stop_cv:
            break
            
        last_Gamk = opt.Gamk
        last_rhok = opt.rhok
        last_thetak = opt.thetak
    
    if ix == (paras.max_iter-1):
        print(f"The optimization under seed {seed} may not converge with stop value {stop_v:.3f}")
    if is_small:
        opt.model = None
    return opt, ix == (paras.max_iter-1), seed

In [23]:
run_fn(0, 2, 5, 2, 1)

(0, 2, 5, 2) 



(<optimization.one_step_opt.OneStepOpt at 0x7f5d8004fe80>, False, 0)

In [None]:
num_core = 20
if __name__ == "__main__":
    for cur_lam in paras.can_lams:
        for cur_N in paras.can_Ns:
            fil_name = f"result_lam1-{cur_lam*100:.0f}_Rfct-20_N-{cur_N:.0f}_n-{paras.n:.0f}.pkl"
            with mp.Pool(num_core) as pool:
                res_proc = []
                for seed in range(paras.num_rep):
                    res_proc.append(pool.apply_async(run_fn, (seed, cur_lam, cur_N, 2, True)))
                opt_results = [ix.get() for ix in res_proc]
            pool.join()
            save_pkl(paras.save_dir/fil_name, opt_results)


(0, 0.1, 5, 2)(4, 0.1, 5, 2)(3, 0.1, 5, 2)(2, 0.1, 5, 2)(1, 0.1, 5, 2)(10, 0.1, 5, 2)(5, 0.1, 5, 2)(9, 0.1, 5, 2)(8, 0.1, 5, 2)(11, 0.1, 5, 2)(12, 0.1, 5, 2)(14, 0.1, 5, 2)(13, 0.1, 5, 2) (17, 0.1, 5, 2)  (16, 0.1, 5, 2)(15, 0.1, 5, 2)  (18, 0.1, 5, 2)  (19, 0.1, 5, 2)  
   
  


  

  










(7, 0.1, 5, 2)








(6, 0.1, 5, 2)

 



 






Save to /data/rajlab1/user_data/jin/MyResearch/HDF_infer/notebooks/../mypkg/../results/linear_test/result_lam1-10_Rfct-20_N-5_n-1000.pkl
(5, 0.1, 10, 2)(0, 0.1, 10, 2)(1, 0.1, 10, 2)(3, 0.1, 10, 2)(2, 0.1, 10, 2)(7, 0.1, 10, 2) (8, 0.1, 10, 2)(12, 0.1, 10, 2)(9, 0.1, 10, 2)(11, 0.1, 10, 2) (6, 0.1, 10, 2)(4, 0.1, 10, 2)(10, 0.1, 10, 2) (14, 0.1, 10, 2) 
(13, 0.1, 10, 2) (17, 0.1, 10, 2)(15, 0.1, 10, 2)(18, 0.1, 10, 2)(16, 0.1, 10, 2)   
(19, 0.1, 10, 2)    
  


  
   




 



























Save to /data/rajlab1/user_data/jin/MyResearch/HDF_infer/notebooks/../mypkg/../results/linear_test/result_lam1-10_Rfct-20_N-10_n-1000.pkl
(5, 0