In [None]:
input_n_test = 500

import math
import torch
import torch.distributions as TD
from torch.utils.data import Dataset, DataLoader
from zmq import device
import torch.optim as optim
import numpy as np
from datetime import datetime
import functools
from scipy.linalg import toeplitz
import xgboost as xgb
from sklearn.linear_model import LassoCV
from tqdm import tqdm

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

def get_xzy_randn_nl(n_points, ground_truth='H0', rho=0.3, p1 = 25, p2 = 25, device='cuda:0', a = 0.0, **ignored):
    p = p1 + p2
    cov_vec = rho ** np.arange(p)
    cov_mx = torch.FloatTensor(toeplitz(cov_vec)).to(device)
    X_all_generator = TD.MultivariateNormal(torch.zeros(p).to(device), cov_mx)
    X_all = X_all_generator.sample((n_points,))
    Z = X_all[:, :p1]
    X = X_all[:, p1:]

    beta_z = torch.FloatTensor([1] * 2 + [0] * (p1 - 2)).reshape((-1, 1)).to(device)

    if ground_truth == 'H0':

        beta_x = torch.FloatTensor([0] * 2 + [0] * (p2 - 2)).reshape((-1, 1)).to(device)

    elif ground_truth == 'H1_sparse':

        beta_x = torch.FloatTensor([a/math.sqrt(5)] * 5 + [0] * (p2 - 5)).reshape((-1, 1)).to(device)

    elif ground_truth == 'H1_dense':

        beta_x = torch.FloatTensor([a/math.sqrt(12)] * 12 + [0] * (p2 - 12)).reshape((-1, 1)).to(device)

    else:
        raise NotImplementedError(f'{ground_truth} has to be H0, H1_sparse or H1_dense')

    epsilon = torch.randn(n_points, 1, device=device)*0.5
    Y = torch.matmul(Z, beta_z) + torch.matmul(X, beta_x) ** 2 + epsilon

    return X, Y, Z

def get_xzy_randn_nl_fix(Z, n_points, ground_truth='H0', rho=0.3, p1 = 25, p2 = 25, device='cuda:0', a = 0.0, **ignored):
    p = p1 + p2
    cov_vec = rho ** np.arange(p)
    cov_mx = torch.FloatTensor(toeplitz(cov_vec)).to(device)

    Cov_11 = cov_mx[:p1,:p1]
    Cov_12 = cov_mx[:p1,p1:]
    Cov_21 = cov_mx[p1:,:p1]
    Cov_22 = cov_mx[p1:,p1:]


    if ground_truth == 'H0':

        beta_x = torch.FloatTensor([0] * 2 + [0] * (p2 - 2)).reshape((-1, 1)).to(device)

    elif ground_truth == 'H1_sparse':

        beta_x = torch.FloatTensor([a/math.sqrt(5)] * 5 + [0] * (p2 - 5)).reshape((-1, 1)).to(device)

    elif ground_truth == 'H1_dense':

        beta_x = torch.FloatTensor([a/math.sqrt(12)] * 12 + [0] * (p2 - 12)).reshape((-1, 1)).to(device)

    else:
        raise NotImplementedError(f'{ground_truth} has to be H0, H1_sparse or H1_dense')

    beta_z = torch.FloatTensor([1] * 2 + [0] * (p1 - 2)).reshape((-1, 1)).to(device)

    Cov_22_inv = torch.inverse(Cov_22)
    Condi_Cov = Cov_11 - torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Cov_21)
    Condi_Mean_vec = torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Z.T)
    temp_generator = TD.MultivariateNormal(Condi_Mean_vec, Condi_Cov)
    X = temp_generator.sample((n_points,))

    first_term = torch.matmul(beta_z.T, Z.reshape((-1, 1)))

    var_term = Cov_11 -  torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Cov_21)
    sec_term = torch.matmul(torch.matmul(beta_x.T, var_term), beta_x)

    condi_exp_term = torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Z.reshape((-1, 1)))
    third_term_temp = torch.matmul(condi_exp_term, condi_exp_term.T)
    third_term = torch.matmul(torch.matmul(beta_x.T, third_term_temp), beta_x)

    Y = first_term + sec_term + third_term

    return X, Y.reshape(-1)



def get_p_value_stat(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, boor_rv_type="gaussian"):

    d_y = y_torch.shape[1]
    d_x = x_torch.shape[1]

    w_mx = torch.linalg.vector_norm(z_torch.repeat(n,1,1) - torch.swapaxes(z_torch.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_w = torch.median(w_mx).item()

    u_mx = torch.linalg.vector_norm(x_torch.repeat(n,1,1) - torch.swapaxes(x_torch.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_u = torch.median(u_mx).item()

    # print(sigma_w)
    # print(sigma_u)

    # sigma_w, sigma_u = 1.0, 1.0

    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2)
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.linalg.vector_norm(x_torch.repeat(n, 1, 1) - torch.swapaxes(x_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2) / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_x_all_torch.repeat(n, 1, 1).reshape(n, n, -1, d_x) - x_torch.repeat(1, n).reshape(n, n, 1, d_x), ord=1, dim=3) / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_x_all_torch_rep = gen_x_all_torch.repeat(n, 1, 1).reshape(n, n, -1, d_x)

    u_mx_4 = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_x_all_torch_rep - torch.swapaxes(gen_x_all_torch_rep, 0, 1), ord=1, dim=3) / sigma_u) , dim=2)

    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4
    v_mx_temp = (gen_y_all_torch - y_torch)
    v_mx = torch.matmul(v_mx_temp, v_mx_temp.T)
    FF_mx = u_mx * v_mx * w_mx * (1 - torch.eye(n).to(device))

    stat = 1 / (n - 1) * torch.sum(FF_mx).item()

    boottemp = np.array([])
    if boor_rv_type == "rademacher":
        eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
        eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
        random_mx = torch.matmul(eboot[:, bb].reshape(-1, 1), eboot[:, bb].reshape(-1, 1).T)
        bootmatrix = FF_mx * random_mx
        stat_boot = 1 / (n - 1) * torch.sum(bootmatrix).item()
        boottemp = np.append(boottemp, stat_boot)
    return stat, boottemp


class DatasetSelect(Dataset):
    def __init__(self, X, Y, Z):
        self.X_real = X
        self.Y_real = Y
        self.Z_real = Z
        self.sample_size = X.shape[0]

    def __len__(self):
        return self.sample_size

    def __getitem__(self, index):
        return self.X_real[index], self.Y_real[index], self.Z_real[index]

# Create a DataLoader for given (X, Y)

class DatasetSelect_GAN(torch.utils.data.Dataset):

  def __init__(self, X, Y, Z, batch_size):
    self.X_real = X
    self.Y_real = Y
    self.Z_real = Z
    self.batch_size = batch_size
    self.sample_size = X.shape[0]

  def __len__(self):
    return self.sample_size

  def __getitem__(self, index):
    return self.X_real[index], self.Y_real[index], self.Z_real[index], self.Z_real[(self.batch_size+index) % self.sample_size]

# Create a DataLoader for given (X, Y)

class DatasetSelect_GAN_ver2(torch.utils.data.Dataset):

  def __init__(self, Y, Z, batch_size):
    self.Y_real = Y
    self.Z_real = Z
    self.batch_size = batch_size
    self.sample_size = Z.shape[0]

  def __len__(self):
    return self.sample_size

  def __getitem__(self, index):
    return self.Y_real[index], self.Z_real[index]

##### Auxilliary functions #####



def mGAN(n=500, z_dim=2, simulation='type1error', x_dims=2, y_dims=2, a_x=0.05, M=500, k=2, boot_num=1000,
     boor_rv_type = "gaussian", noise_dimension = 50, noise_type = "normal", input_var = 1.0/3.0):

    sim_x, sim_y, sim_z = get_xzy_randn_nl(n_points = n, ground_truth = simulation, a = a_x)

    x, y, z = sim_x.to(device), sim_y.to(device), sim_z.to(device)

    test_size = int(n/k)
    stat_all = torch.zeros(k, 1)
    boot_temp_all = torch.zeros(k, boot_num)
    cur_k = 0

    for k_fold in range(k):
        k_fold_start = int(n/k * k_fold)
        k_fold_end = int(n/k * (k_fold+1))
        X_test, Y_test, Z_test = x[k_fold_start:k_fold_end], y[k_fold_start:k_fold_end], z[k_fold_start:k_fold_end]
        X_train, Y_train, Z_train = torch.cat((x[0:k_fold_start], x[k_fold_end:])), torch.cat((y[0:k_fold_start], y[k_fold_end:])), torch.cat((z[0:k_fold_start], z[k_fold_end:]))

        gen_x_all = torch.zeros(test_size, M, x_dims)
        gen_y_all = torch.zeros(test_size, y_dims)
        z_all = torch.zeros(test_size, z_dim)
        x_all = torch.zeros(test_size, x_dims)
        y_all = torch.zeros(test_size, y_dims)

        # Generate fake data
        for i in range(test_size):
            gen_x_all[i,:,:], gen_y_all[i,:] = get_xzy_randn_nl_fix(Z = Z_test[i,:], ground_truth = simulation, n_points = M, a = a_x)

        gen_y_all = gen_y_all.reshape(-1,y_dims).to(device)

        gen_x_all = gen_x_all.detach().to(device)
        gen_y_all = gen_y_all.detach().to(device)
        z_all = Z_test.to(device)
        x_all = X_test.to(device)
        y_all = Y_test.to(device)

        standardise = True

        if standardise:
            gen_x_all = (gen_x_all - torch.mean(gen_x_all, dim=0, keepdim=True)) / torch.std(gen_x_all, dim=0, keepdim=True)
            gen_y_all = (gen_y_all - torch.mean(gen_y_all, dim=0, keepdim=True)) / torch.std(gen_y_all, dim=0, keepdim=True)
            x_all = (x_all - torch.mean(x_all, dim=0, keepdim=True)) / torch.std(x_all, dim=0, keepdim=True)
            y_all = (y_all - torch.mean(y_all, dim=0, keepdim=True)) / torch.std(y_all, dim=0, keepdim=True)
            z_all = (z_all - torch.mean(z_all, dim=0, keepdim=True)) / torch.std(z_all, dim=0, keepdim=True)

        cur_stat, cur_boot_temp = get_p_value_stat(boot_num, M, test_size, gen_x_all.to(device), gen_y_all.to(device),
                                x_all.to(device), y_all.to(device), z_all.to(device), boor_rv_type)
        stat_all[cur_k,:] = cur_stat
        boot_temp_all[cur_k,:] = torch.from_numpy(cur_boot_temp)
        cur_k = cur_k + 1

    return np.mean(torch.mean(boot_temp_all, dim = 0).numpy() > torch.mean(stat_all).item() )

def run_experiment(params):
    test = params["test"]
    sample_size = params["sample_size"]
    z_dim = params["z_dim"]
    dx = params["dx"]
    dy = params["dy"]
    n_test = params["n_test"]
    alpha_x = params["alpha_x"]
    m_value = params["m_value"]
    k_value = params["k_value"]
    j_value = params["j_value"]
    alpha = params["alpha"]
    alpha1 = params["alpha1"]
    set_seeds = params["set_seeds"]
    boor_rv_type = params["boor_rv_type"]


    np.random.seed(set_seeds)
    torch.manual_seed(set_seeds)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(set_seeds)

    p_values = np.array([])
    test_count = 0
    if test == 'H0':
        for n in tqdm(range(n_test)):
            # start_time = datetime.now()

            p_value = mGAN(n=sample_size, z_dim=z_dim, simulation=test, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                    k=k_value, boot_num=j_value, boor_rv_type = boor_rv_type)

            # test_count += 1
            # print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha  for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 for pval in p_values]
            final_result1 = np.mean(fp1)

            # print('The stat is {}'.format(p_value))
            # print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            # print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

            final_result_list = np.array([final_result])
            final_result1_list = np.array([final_result1])

        print('Type 1 error: {} with significance level {}'.format(final_result, alpha))
        print('Type 1 error: {} with significance level {}'.format(final_result1, alpha1))

    if test == 'H1_dense' or test == 'H1_sparse':
        for n in tqdm(range(n_test)):
            # start_time = datetime.now()

            p_value = mGAN(n=sample_size, z_dim=z_dim, simulation=test, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                    k=k_value, boot_num=j_value, boor_rv_type = boor_rv_type)

            # test_count += 1
            # print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha  for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 for pval in p_values]
            final_result1 = np.mean(fp1)

            # print('The stat is {}'.format(p_value))
            # print('Power: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            # print('Power: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

            final_result_list = np.array([final_result])
            final_result1_list = np.array([final_result1])

        print('Power: {} with significance level {}'.format(final_result, alpha))
        print('Power: {} with significance level {}'.format(final_result1, alpha1))

    return p_values

In [2]:
for input_sample_size in [800, 600, 400, 200]:
    print("---                     ---")
    print("---                     ---")
    print("--- The n = ",input_sample_size," case ---")

    param = {
      "test": "H0", # ['H0', 'H1_dense', 'H1_sparse']
      "sample_size":  input_sample_size, # [100, 200, 300, 400]
      "z_dim": 25, # [5, 50, 250]
      "dx": 25,
      "dy": 1,
      "n_test": input_n_test, # [500] in the paper
      "alpha_x": 0.50, # only used under alternative ['power_sparse': a = 0.5; 'power_dense': 1/sqrt(2*p2) {1/math.sqrt(2*25)}]
      "m_value": 100, # [500]
      "k_value": 2, # [1, 2, 4]
      "j_value": 1000, # [1000, 2000]
      "alpha": 0.1,
      "alpha1": 0.05,
      "set_seeds": 42,
      "boor_rv_type":  'rademacher' # ['rademacher', 'gaussian']
    }

    print("--- The H0 case ---")
    p_val_list = run_experiment(param)

    import numpy as np
    quantile_5, quantile_10 = np.quantile(p_val_list, 0.05), np.quantile(p_val_list, 0.10)

    param = {
      "test": "H1_dense", # ['H0', 'H1_dense', 'H1_sparse']
      "sample_size":  input_sample_size, # [100, 200, 300, 400]
      "z_dim": 25, # [5, 50, 250]
      "dx": 25,
      "dy": 1,
      "n_test": input_n_test, # [500] in the paper
      "alpha_x": 0.50, # only used under alternative ['power_sparse': a = 0.5; 'power_dense': 1/sqrt(2*p2) {1/math.sqrt(2*25)}]
      "m_value": 100, # [500]
      "k_value": 2, # [1, 2, 4]
      "j_value": 1000, # [1000, 2000]
      "alpha": 0.1,
      "alpha1": 0.05,
      "set_seeds": 42,
      "boor_rv_type":  'rademacher' # ['rademacher', 'gaussian']
    }

    print("--- The H1_dense case Empirical Power ---")
    p_val_list_H1_dense = run_experiment(param)

    fp = [pval < quantile_10  for pval in p_val_list_H1_dense]
    final_result = np.mean(fp)
    fp1 = [pval < quantile_5 for pval in p_val_list_H1_dense]
    final_result1 = np.mean(fp1)


    print('Adjusted Power: {} with significance level {}'.format(final_result, quantile_10))
    print('Adjusted Power: {} with significance level {}'.format(final_result1, quantile_5))

    param = {
      "test": "H1_sparse", # ['H0', 'H1_dense', 'H1_sparse']
      "sample_size":  input_sample_size, # [100, 200, 300, 400]
      "z_dim": 25, # [5, 50, 250]
      "dx": 25,
      "dy": 1,
      "n_test": input_n_test, # [500] in the paper
      "alpha_x": 0.50, # only used under alternative ['power_sparse': a = 0.5; 'power_dense': 1/sqrt(2*p2) {1/math.sqrt(2*25)}]
      "m_value": 100, # [500]
      "k_value": 2, # [1, 2, 4]
      "j_value": 1000, # [1000, 2000]
      "alpha": 0.1,
      "alpha1": 0.05,
      "set_seeds": 42,
      "boor_rv_type":  'rademacher' # ['rademacher', 'gaussian']
    }

    print("--- The H1_sparse case Empirical Power ---")
    p_val_list_H1_sparse = run_experiment(param)

    fp = [pval < quantile_10  for pval in p_val_list_H1_sparse]
    final_result = np.mean(fp)
    fp1 = [pval < quantile_5 for pval in p_val_list_H1_sparse]
    final_result1 = np.mean(fp1)


    print('Adjusted Power: {} with significance level {}'.format(final_result, quantile_10))
    print('Adjusted Power: {} with significance level {}'.format(final_result1, quantile_5))


---                     ---
---                     ---
--- The n =  800  case ---
--- The H0 case ---


  Condi_Mean_vec = torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Z.T)
100%|██████████| 500/500 [16:00<00:00,  1.92s/it]


Type 1 error: 0.138 with significance level 0.1
Type 1 error: 0.072 with significance level 0.05
--- The H1_dense case Empirical Power ---


100%|██████████| 500/500 [15:47<00:00,  1.90s/it]


Power: 1.0 with significance level 0.1
Power: 1.0 with significance level 0.05
Adjusted Power: 1.0 with significance level 0.07250000000000002
Adjusted Power: 0.998 with significance level 0.03380000000000001
--- The H1_sparse case Empirical Power ---


100%|██████████| 500/500 [15:46<00:00,  1.89s/it]


Power: 1.0 with significance level 0.1
Power: 1.0 with significance level 0.05
Adjusted Power: 1.0 with significance level 0.07250000000000002
Adjusted Power: 1.0 with significance level 0.03380000000000001
---                     ---
---                     ---
--- The n =  600  case ---
--- The H0 case ---


100%|██████████| 500/500 [12:29<00:00,  1.50s/it]


Type 1 error: 0.12 with significance level 0.1
Type 1 error: 0.056 with significance level 0.05
--- The H1_dense case Empirical Power ---


100%|██████████| 500/500 [13:33<00:00,  1.63s/it]


Power: 0.968 with significance level 0.1
Power: 0.92 with significance level 0.05
Adjusted Power: 0.962 with significance level 0.0879
Adjusted Power: 0.912 with significance level 0.04385000000000001
--- The H1_sparse case Empirical Power ---


100%|██████████| 500/500 [12:24<00:00,  1.49s/it]


Power: 0.998 with significance level 0.1
Power: 0.994 with significance level 0.05
Adjusted Power: 0.998 with significance level 0.0879
Adjusted Power: 0.992 with significance level 0.04385000000000001
---                     ---
---                     ---
--- The n =  400  case ---
--- The H0 case ---


100%|██████████| 500/500 [11:04<00:00,  1.33s/it]


Type 1 error: 0.128 with significance level 0.1
Type 1 error: 0.068 with significance level 0.05
--- The H1_dense case Empirical Power ---


100%|██████████| 500/500 [09:35<00:00,  1.15s/it]


Power: 0.82 with significance level 0.1
Power: 0.68 with significance level 0.05
Adjusted Power: 0.77 with significance level 0.07390000000000001
Adjusted Power: 0.574 with significance level 0.031950000000000006
--- The H1_sparse case Empirical Power ---


100%|██████████| 500/500 [08:53<00:00,  1.07s/it]


Power: 0.944 with significance level 0.1
Power: 0.862 with significance level 0.05
Adjusted Power: 0.916 with significance level 0.07390000000000001
Adjusted Power: 0.8 with significance level 0.031950000000000006
---                     ---
---                     ---
--- The n =  200  case ---
--- The H0 case ---


100%|██████████| 500/500 [05:32<00:00,  1.51it/s]


Type 1 error: 0.13 with significance level 0.1
Type 1 error: 0.07 with significance level 0.05
--- The H1_dense case Empirical Power ---


100%|██████████| 500/500 [05:31<00:00,  1.51it/s]


Power: 0.44 with significance level 0.1
Power: 0.276 with significance level 0.05
Adjusted Power: 0.4 with significance level 0.07980000000000001
Adjusted Power: 0.238 with significance level 0.039900000000000005
--- The H1_sparse case Empirical Power ---


100%|██████████| 500/500 [05:38<00:00,  1.48it/s]

Power: 0.56 with significance level 0.1
Power: 0.408 with significance level 0.05
Adjusted Power: 0.514 with significance level 0.07980000000000001
Adjusted Power: 0.368 with significance level 0.039900000000000005



