In [None]:
input_n_test = 500
input_sample_size = 400 # 200, 400, 600, 800
input_seed = 42

import math
import torch
import torch.distributions as TD
from torch.utils.data import Dataset, DataLoader
from zmq import device
import torch.optim as optim
import numpy as np
from datetime import datetime
import functools
from scipy.linalg import toeplitz
import xgboost as xgb
from sklearn.linear_model import LassoCV
from tqdm import tqdm

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

def get_xzy_randn_nl(n_points, ground_truth='H0', rho=0.3, p1 = 25, p2 = 25, device='cuda:0', a = 0.0, **ignored):
    p = p1 + p2
    cov_vec = rho ** np.arange(p)
    cov_mx = torch.FloatTensor(toeplitz(cov_vec)).to(device)
    X_all_generator = TD.MultivariateNormal(torch.zeros(p).to(device), cov_mx)
    X_all = X_all_generator.sample((n_points,))
    Z = X_all[:, :p1]
    X = X_all[:, p1:]

    beta_z = torch.FloatTensor([1] * 2 + [0] * (p1 - 2)).reshape((-1, 1)).to(device)

    if ground_truth == 'H0':

        beta_x = torch.FloatTensor([0] * 2 + [0] * (p2 - 2)).reshape((-1, 1)).to(device)

    elif ground_truth == 'H1_sparse':

        beta_x = torch.FloatTensor([a/math.sqrt(5)] * 5 + [0] * (p2 - 5)).reshape((-1, 1)).to(device)

    elif ground_truth == 'H1_dense':

        beta_x = torch.FloatTensor([a/math.sqrt(12)] * 12 + [0] * (p2 - 12)).reshape((-1, 1)).to(device)

    else:
        raise NotImplementedError(f'{ground_truth} has to be H0, H1_sparse or H1_dense')

    epsilon = torch.randn(n_points, 1, device=device)*0.5
    Y = torch.matmul(Z, beta_z) + torch.matmul(X, beta_x) ** 2 + epsilon

    return X, Y, Z

def get_xzy_randn_nl_fix(Z, n_points, ground_truth='H0', rho=0.3, p1 = 25, p2 = 25, device='cuda:0', a = 0.0, **ignored):
    p = p1 + p2
    cov_vec = rho ** np.arange(p)
    cov_mx = torch.FloatTensor(toeplitz(cov_vec)).to(device)

    Cov_11 = cov_mx[:p1,:p1]
    Cov_12 = cov_mx[:p1,p1:]
    Cov_21 = cov_mx[p1:,:p1]
    Cov_22 = cov_mx[p1:,p1:]


    if ground_truth == 'H0':

        beta_x = torch.FloatTensor([0] * 2 + [0] * (p2 - 2)).reshape((-1, 1)).to(device)

    elif ground_truth == 'H1_sparse':

        beta_x = torch.FloatTensor([a/math.sqrt(5)] * 5 + [0] * (p2 - 5)).reshape((-1, 1)).to(device)

    elif ground_truth == 'H1_dense':

        beta_x = torch.FloatTensor([a/math.sqrt(12)] * 12 + [0] * (p2 - 12)).reshape((-1, 1)).to(device)

    else:
        raise NotImplementedError(f'{ground_truth} has to be H0, H1_sparse or H1_dense')

    beta_z = torch.FloatTensor([1] * 2 + [0] * (p1 - 2)).reshape((-1, 1)).to(device)

    Cov_22_inv = torch.inverse(Cov_22)
    Condi_Cov = Cov_11 - torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Cov_21)
    Condi_Mean_vec = torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Z.T)
    temp_generator = TD.MultivariateNormal(Condi_Mean_vec, Condi_Cov)
    X = temp_generator.sample((n_points,))

    first_term = torch.matmul(beta_z.T, Z.reshape((-1, 1)))

    var_term = Cov_11 -  torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Cov_21)
    sec_term = torch.matmul(torch.matmul(beta_x.T, var_term), beta_x)

    condi_exp_term = torch.matmul(torch.matmul(Cov_12, Cov_22_inv), Z.reshape((-1, 1)))
    third_term_temp = torch.matmul(condi_exp_term, condi_exp_term.T)
    third_term = torch.matmul(torch.matmul(beta_x.T, third_term_temp), beta_x)

    Y = first_term + sec_term + third_term

    return X, Y.reshape(-1)


def get_p_value_stat(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, boor_rv_type="gaussian"):

    d_y = y_torch.shape[1]
    d_x = x_torch.shape[1]

    w_mx = torch.linalg.vector_norm(z_torch.repeat(n,1,1) - torch.swapaxes(z_torch.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_w = torch.median(w_mx).item()

    u_mx = torch.linalg.vector_norm(x_torch.repeat(n,1,1) - torch.swapaxes(x_torch.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_u = torch.median(u_mx).item()

    # print(sigma_w)
    # print(sigma_u)

    # sigma_w, sigma_u = 1.0, 1.0

    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2)
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.linalg.vector_norm(x_torch.repeat(n, 1, 1) - torch.swapaxes(x_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2) / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_x_all_torch.repeat(n, 1, 1).reshape(n, n, -1, d_x) - x_torch.repeat(1, n).reshape(n, n, 1, d_x), ord=1, dim=3) / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_x_all_torch_rep = gen_x_all_torch.repeat(n, 1, 1).reshape(n, n, -1, d_x)

    u_mx_4 = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_x_all_torch_rep - torch.swapaxes(gen_x_all_torch_rep, 0, 1), ord=1, dim=3) / sigma_u) , dim=2)

    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4
    v_mx_temp = (gen_y_all_torch - y_torch)
    v_mx = torch.matmul(v_mx_temp, v_mx_temp.T)
    FF_mx = u_mx * v_mx * w_mx * (1 - torch.eye(n).to(device))

    stat = 1 / (n - 1) * torch.sum(FF_mx).item()

    boottemp = np.array([])
    if boor_rv_type == "rademacher":
        eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
        eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
        random_mx = torch.matmul(eboot[:, bb].reshape(-1, 1), eboot[:, bb].reshape(-1, 1).T)
        bootmatrix = FF_mx * random_mx
        stat_boot = 1 / (n - 1) * torch.sum(bootmatrix).item()
        boottemp = np.append(boottemp, stat_boot)
    return stat, boottemp


class DatasetSelect(Dataset):
    def __init__(self, X, Y, Z):
        self.X_real = X
        self.Y_real = Y
        self.Z_real = Z
        self.sample_size = X.shape[0]

    def __len__(self):
        return self.sample_size

    def __getitem__(self, index):
        return self.X_real[index], self.Y_real[index], self.Z_real[index]

# Create a DataLoader for given (X, Y)

class DatasetSelect_GAN(torch.utils.data.Dataset):

  def __init__(self, X, Y, Z, batch_size):
    self.X_real = X
    self.Y_real = Y
    self.Z_real = Z
    self.batch_size = batch_size
    self.sample_size = X.shape[0]

  def __len__(self):
    return self.sample_size

  def __getitem__(self, index):
    return self.X_real[index], self.Y_real[index], self.Z_real[index], self.Z_real[(self.batch_size+index) % self.sample_size]

# Create a DataLoader for given (X, Y)

class DatasetSelect_GAN_ver2(torch.utils.data.Dataset):

  def __init__(self, Y, Z, batch_size):
    self.Y_real = Y
    self.Z_real = Z
    self.batch_size = batch_size
    self.sample_size = Z.shape[0]

  def __len__(self):
    return self.sample_size

  def __getitem__(self, index):
    return self.Y_real[index], self.Z_real[index]

##### Auxilliary functions #####

def sample_noise(sample_size, noise_dimension, noise_type, input_var):

    if (noise_type == "normal"):
      noise_generator = TD.MultivariateNormal(
        torch.zeros(noise_dimension).to(device), input_var * torch.eye(noise_dimension).to(device))

      Z = noise_generator.sample((sample_size,))
    if (noise_type == "unif"):
      Z = torch.rand(sample_size, noise_dimension)
    if (noise_type == "Cauchy"):
      Z = TD.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])).sample((sample_size, noise_dimension)).squeeze(2)

    return Z

##### GAN architecture #####

class Generator(torch.nn.Module):

    def __init__(self, input_dimension, output_dimension, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p,
                 drop_input = False):
      super(Generator, self).__init__()
      self.BN_type = BN_type
      self.ReLU_coef = ReLU_coef
      self.fc1 = torch.nn.Linear(input_dimension + noise_dimension, hidden_layer_size, bias=True)
      if BN_type:
        self.BN1 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN2 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN3 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
      self.leakyReLU1 = torch.nn.LeakyReLU(ReLU_coef)
      self.fc2 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc3 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc_last = torch.nn.Linear(hidden_layer_size, output_dimension, bias=True)
      self.fc_temp = torch.nn.Linear(input_dimension + noise_dimension, output_dimension, bias=True)
      self.sigmoid = torch.nn.Sigmoid()
      self.drop_out0 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out1 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out2 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out3 = torch.nn.Dropout(p=drop_out_p)
      self.drop_input = drop_input

    def forward(self, x):
      if self.BN_type:
        x = self.drop_out1(self.leakyReLU1(self.BN1(self.fc1(x))))
        # x = self.drop_out2(self.leakyReLU1(self.BN2(self.fc2(x))))
        # x = self.drop_out3(self.leakyReLU1(self.BN3(self.fc3(x))))
        x = self.fc_last(x)

      else:
        x = self.drop_out1(self.leakyReLU1(self.fc1(x)))
        # x = self.drop_out2(self.leakyReLU1(self.fc2(x)))
        # x = self.drop_out3(self.leakyReLU1(self.fc3(x)))
        x = self.fc_last(x)
        # x = self.sigmoid(x)
      return x

##### Training procedures #####


def find_loss_l(y_torch, gen_y_all_torch, z_torch, sigma_w, sigma_u, M):
    n = z_torch.shape[0]
    d_y = y_torch.shape[1]
    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2)
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.linalg.vector_norm(y_torch.repeat(n, 1, 1) - torch.swapaxes(y_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2) / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, d_y) - y_torch.repeat(1, n).reshape(n, n, 1, d_y), ord=1, dim=3) / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, d_y)

    u_mx_4 = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - torch.swapaxes(gen_y_all_torch_rep, 0, 1), ord=1, dim=3) / sigma_u), dim=2)

    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4

    FF_mx = u_mx * w_mx * (1 - torch.eye(n).to(device))

    loss = 1 / (n) * torch.sum(FF_mx)
    return loss

def find_loss_g(y_torch, gen_y_all_torch, z_torch, sigma_w, sigma_u, M):
    n = z_torch.shape[0]
    d_y = y_torch.shape[1]
    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=2, dim=2) ** 2
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.linalg.vector_norm(y_torch.repeat(n, 1, 1) - torch.swapaxes(y_torch.repeat(n, 1, 1), 0, 1), ord=2, dim=2) ** 2 / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, d_y) - y_torch.repeat(1, n).reshape(n, n, 1, d_y), ord=2, dim=3) ** 2 / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, d_y)

    u_mx_4 = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - torch.swapaxes(gen_y_all_torch_rep, 0, 1), ord=2, dim=3) ** 2 / sigma_u), dim=2)

    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4

    FF_mx = u_mx * w_mx * (1 - torch.eye(n).to(device))

    loss = 1 / (n) * torch.sum(FF_mx)
    return loss


def get_Gzx(X, Y, Z, noise_dimension = 50, noise_type = "normal", input_var = 1.0/3.0,
      sigma_z_l = 1, sigma_x_l = 1,
      sigma_z_g = 1, sigma_x_g = 1):

    param = {
      "hidden_layer_size": 128,
      "ReLU_coef": 0.8,
      "drop_out_p": 0.05,
      "lambda_2": 0.25,
      "lambda_3": 0.001,
      "wgt_decay": 1e-05,
      "G_lr": 0.004950377967576449
    }

    hidden_layer_size = param['hidden_layer_size']
    ReLU_coef = param['ReLU_coef']
    drop_out_p = param['drop_out_p']
    lambda_2 = param['lambda_2']
    lambda_3 = param['lambda_3']
    wgt_decay = param['wgt_decay']
    G_lr = param['G_lr']

    # noise_dimension = 50
    BN_type = False
    lambda_1 = 1
    M_train = 10
    batch_size = 128
    epochs_num = 1000
    # noise_type = "normal"

    input_dimension = Z.shape[1]
    output_dimension_y = Y.shape[1]
    output_dimension_x = X.shape[1]

    G_zx = Generator(input_dimension, output_dimension_x, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p).to(device)
    G_zx_solver = optim.Adam(G_zx.parameters(), lr=G_lr, betas=(0.5, 0.999), weight_decay=wgt_decay)

    train_xyz = DatasetSelect_GAN(X, Y, Z, batch_size)
    DataLoader = torch.utils.data.DataLoader(train_xyz, batch_size=batch_size, shuffle=True)


    G_zx = G_zx.train()

    for epoch in range(epochs_num):
        # print('EPOCH: ', (epoch+1))
        batch_count = 0
        G_zx = G_zx.train()
        for X_real, Y_real, Z_real, Z_fake in DataLoader:
            X_real = X_real.to(device)
            Y_real = Y_real.to(device)
            Z_real = Z_real.to(device)
            Z_fake = Z_fake.to(device)

            batch_size = Z_real.shape[0]
            Z_real_repeat = Z_real.repeat(M_train,1)

            # Generate fake data
            Noise_fake = sample_noise(Z_real_repeat.shape[0], noise_dimension, noise_type, input_var = input_var).to(device)
            X_fake = G_zx(torch.cat((Z_real_repeat,Noise_fake),dim=1)).to(device)

            # X_fake = X_fake.reshape(batch_size, M_train, output_dimension_x)
            X_fake = X_fake.reshape(M_train, batch_size, output_dimension_x).swapaxes(0, 1)

            # Generator step
            g_zx_error = None
            G_zx_solver.zero_grad()

            l1_regularization = 0

            for param in G_zx.parameters():
                l1_regularization += torch.linalg.vector_norm(param, ord = 1)

            g_zx_error = lambda_1 * find_loss_g(X_real, X_fake, Z_real, sigma_z_g, sigma_x_g, M_train) + \
                    lambda_2 * find_loss_l(X_real, X_fake, Z_real, sigma_z_l, sigma_x_l, M_train) + \
                    lambda_3 * l1_regularization

            g_zx_error.backward()
            torch.nn.utils.clip_grad_norm_(G_zx.parameters(), max_norm=0.5)
            G_zx_solver.step()
    return G_zx

def get_Gzy(Y, Z):

    param = {
      "epochs_num_zy": 1000,
      "hidden_layer_size_zy": 256,
      "BN_type_zy": False,
      "ReLU_coef_zy": 0.6957593034943316,
      "drop_out_p_zy": 0.15,
      "G_lr_zy": 0.00021957327899192503,
      "weight_decay_zy": 0.00010530996833286099,
      "lambda_1_zy" : 1,
      "lambda_3_zy" : 0.020062972725244245
    }
    MSE_loss = torch.nn.MSELoss()

    epochs_num_zy = param['epochs_num_zy']
    hidden_layer_size_zy = param['hidden_layer_size_zy']
    BN_type_zy = param['BN_type_zy']
    ReLU_coef_zy = param['ReLU_coef_zy']
    drop_out_p_zy = param['drop_out_p_zy']
    G_lr_zy = param['G_lr_zy']
    weight_decay_zy = param['weight_decay_zy']
    lambda_1_zy = param['lambda_1_zy']
    lambda_3_zy = param['lambda_3_zy']

    G_zy = Generator(25, 1, 0, hidden_layer_size = hidden_layer_size_zy, BN_type = BN_type_zy, ReLU_coef = ReLU_coef_zy, drop_out_p = drop_out_p_zy).to(device)
    G_zy_solver = optim.Adam(G_zy.parameters(), lr=G_lr_zy, betas=(0.5, 0.999), weight_decay=weight_decay_zy)

    for epoch in range(epochs_num_zy):
        # print('EPOCH: ', (epoch+1))
        batch_count = 0
        G_zy = G_zy.train()

        batch_size = Z.shape[0]

        # Generate fake data
        Y_fake = G_zy(Z).to(device)

        # Generator step
        g_zy_error = None
        G_zy_solver.zero_grad()
        l1_regularization = 0

        for param in G_zy.parameters():
            l1_regularization += torch.linalg.vector_norm(param, ord = 1)

        g_zy_error = lambda_1_zy * MSE_loss(Y_fake, Y) + lambda_3_zy * l1_regularization

        g_zy_error.backward()
        torch.nn.utils.clip_grad_norm_(G_zy.parameters(), max_norm=0.5)
        G_zy_solver.step()

    return G_zy

def mGAN(n=500, z_dim=2, simulation='type1error', x_dims=2, y_dims=2, a_x=0.05, M=500, k=2, boot_num=1000,
     boor_rv_type = "gaussian", noise_dimension = 50, noise_type = "normal", input_var = 1.0/3.0):

    sim_x, sim_y, sim_z = get_xzy_randn_nl(n_points = n, ground_truth = simulation, a = a_x)

    x, y, z = sim_x.to(device), sim_y.to(device), sim_z.to(device)

    w_mx = torch.linalg.vector_norm(z.repeat(n,1,1) - torch.swapaxes(z.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_z_l = torch.median(w_mx).item()

    v_mx = torch.linalg.vector_norm(x.repeat(n,1,1) - torch.swapaxes(x.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_x_l = torch.median(v_mx).item()

    w_mx = torch.linalg.vector_norm(z.repeat(n,1,1) - torch.swapaxes(z.repeat(n,1,1), 0, 1), ord = 2, dim = 2) ** 2
    sigma_z_g = torch.median(w_mx).item() * 2

    v_mx = torch.linalg.vector_norm(x.repeat(n,1,1) - torch.swapaxes(x.repeat(n,1,1), 0, 1), ord = 2, dim = 2) ** 2
    sigma_x_g = torch.median(v_mx).item() * 2

    test_size = int(n/k)
    stat_all = torch.zeros(k, 1)
    boot_temp_all = torch.zeros(k, boot_num)
    cur_k = 0

    for k_fold in range(k):
        k_fold_start = int(n/k * k_fold)
        k_fold_end = int(n/k * (k_fold+1))
        X_test, Y_test, Z_test = x[k_fold_start:k_fold_end], y[k_fold_start:k_fold_end], z[k_fold_start:k_fold_end]
        X_train, Y_train, Z_train = torch.cat((x[0:k_fold_start], x[k_fold_end:])), torch.cat((y[0:k_fold_start], y[k_fold_end:])), torch.cat((z[0:k_fold_start], z[k_fold_end:]))

        G_zx = get_Gzx(X = X_train, Y = Y_train, Z = Z_train, noise_dimension = noise_dimension,
                noise_type = noise_type, input_var = input_var,
                sigma_z_l = sigma_z_l, sigma_x_l = sigma_x_l,
                sigma_z_g = sigma_z_g, sigma_x_g = sigma_x_g)

        G_zy = get_Gzy(Y = Y_train, Z = Z_train)

        gen_x_all = torch.zeros(test_size, M, x_dims)
        gen_y_all = torch.zeros(test_size, y_dims)
        z_all = torch.zeros(test_size, z_dim)
        x_all = torch.zeros(test_size, x_dims)
        y_all = torch.zeros(test_size, y_dims)

        G_zx = G_zx.eval()

        Z_test_repeat = Z_test.repeat(M,1).to(device)


        # Generate fake data
        Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension, noise_type, input_var = input_var).to(device)
        with torch.no_grad():
            gen_x_all = G_zx(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device)

        # gen_x_all = gen_x_all.reshape(test_size, M, x_dims).detach().to(device)
        gen_x_all = gen_x_all.reshape(M, test_size, x_dims).swapaxes(0, 1).detach().to(device)

        # Generate fake data
        # for i in range(test_size):
        #     _, gen_y_all[i,:] = get_xzy_randn_nl_fix(Z = Z_test[i,:], ground_truth = simulation, n_points = M, a = a_x)

        # gen_y_all = gen_y_all.reshape(-1,y_dims).to(device)

        G_zy = G_zy.eval()
        with torch.no_grad():
            gen_y_all = G_zy(Z_test.to(device)).to(device)
        gen_y_all = gen_y_all.reshape(-1,y_dims).to(device)

        gen_x_all = gen_x_all.detach().to(device)
        gen_y_all = gen_y_all.detach().to(device)
        z_all = Z_test.to(device)
        x_all = X_test.to(device)
        y_all = Y_test.to(device)

        standardise = True

        if standardise:
            gen_x_all = (gen_x_all - torch.mean(gen_x_all, dim=0, keepdim=True)) / torch.std(gen_x_all, dim=0, keepdim=True)
            gen_y_all = (gen_y_all - torch.mean(gen_y_all, dim=0, keepdim=True)) / torch.std(gen_y_all, dim=0, keepdim=True)
            x_all = (x_all - torch.mean(x_all, dim=0, keepdim=True)) / torch.std(x_all, dim=0, keepdim=True)
            y_all = (y_all - torch.mean(y_all, dim=0, keepdim=True)) / torch.std(y_all, dim=0, keepdim=True)
            z_all = (z_all - torch.mean(z_all, dim=0, keepdim=True)) / torch.std(z_all, dim=0, keepdim=True)

        cur_stat, cur_boot_temp = get_p_value_stat(boot_num, M, test_size, gen_x_all.to(device), gen_y_all.to(device),
                                x_all.to(device), y_all.to(device), z_all.to(device), boor_rv_type)
        stat_all[cur_k,:] = cur_stat
        boot_temp_all[cur_k,:] = torch.from_numpy(cur_boot_temp)
        cur_k = cur_k + 1

    return np.mean(torch.mean(boot_temp_all, dim = 0).numpy() > torch.mean(stat_all).item() )

def run_experiment(params):
    test = params["test"]
    sample_size = params["sample_size"]
    z_dim = params["z_dim"]
    dx = params["dx"]
    dy = params["dy"]
    n_test = params["n_test"]
    alpha_x = params["alpha_x"]
    m_value = params["m_value"]
    k_value = params["k_value"]
    j_value = params["j_value"]
    alpha = params["alpha"]
    alpha1 = params["alpha1"]
    set_seeds = params["set_seeds"]
    boor_rv_type = params["boor_rv_type"]


    np.random.seed(set_seeds)
    torch.manual_seed(set_seeds)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(set_seeds)

    p_values = np.array([])
    test_count = 0
    if test == 'H0':
        for n in tqdm(range(n_test)):
            # start_time = datetime.now()

            p_value = mGAN(n=sample_size, z_dim=z_dim, simulation=test, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                    k=k_value, boot_num=j_value, boor_rv_type = boor_rv_type)

            # test_count += 1
            # print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha  for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 for pval in p_values]
            final_result1 = np.mean(fp1)

            # print('The stat is {}'.format(p_value))
            # print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            # print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

            final_result_list = np.array([final_result])
            final_result1_list = np.array([final_result1])

        print('Type 1 error: {} with significance level {}'.format(final_result, alpha))
        print('Type 1 error: {} with significance level {}'.format(final_result1, alpha1))

    if test == 'H1_dense' or test == 'H1_sparse':
        for n in tqdm(range(n_test)):
            # start_time = datetime.now()

            p_value = mGAN(n=sample_size, z_dim=z_dim, simulation=test, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                    k=k_value, boot_num=j_value, boor_rv_type = boor_rv_type)

            # test_count += 1
            # print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha  for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 for pval in p_values]
            final_result1 = np.mean(fp1)

            # print('The stat is {}'.format(p_value))
            # print('Power: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            # print('Power: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

            final_result_list = np.array([final_result])
            final_result1_list = np.array([final_result1])

        print('Power: {} with significance level {}'.format(final_result, alpha))
        print('Power: {} with significance level {}'.format(final_result1, alpha1))

    return p_values

In [None]:
print("---                     ---")
print("---                     ---")
print("--- The n = ",input_sample_size," case ---")
param = {
  "test": "H0", # ['H0', 'H1_dense', 'H1_sparse']
  "sample_size":  input_sample_size, # [100, 200, 300, 400]
  "z_dim": 25, # [5, 50, 250]
  "dx": 25,
  "dy": 1,
  "n_test": input_n_test, # [500] in the paper
  "alpha_x": 0.50, # only used under alternative ['power_sparse': a = 0.5; 'power_dense': 1/sqrt(2*p2) {1/math.sqrt(2*25)}]
  "m_value": 100, # [500]
  "k_value": 2, # [1, 2, 4]
  "j_value": 1000, # [1000, 2000]
  "alpha": 0.1,
  "alpha1": 0.05,
  "set_seeds": input_seed,
  "boor_rv_type":  'rademacher' # ['rademacher', 'gaussian']
}

print("--- The H0 case ---")
p_val_list = run_experiment(param)

import numpy as np
quantile_5, quantile_10 = np.quantile(p_val_list, 0.05), np.quantile(p_val_list, 0.10)

param = {
  "test": "H1_dense", # ['H0', 'H1_dense', 'H1_sparse']
  "sample_size":  input_sample_size, # [100, 200, 300, 400]
  "z_dim": 25, # [5, 50, 250]
  "dx": 25,
  "dy": 1,
  "n_test": input_n_test, # [500] in the paper
  "alpha_x": 0.50, # only used under alternative ['power_sparse': a = 0.5; 'power_dense': 1/sqrt(2*p2) {1/math.sqrt(2*25)}]
  "m_value": 100, # [500]
  "k_value": 2, # [1, 2, 4]
  "j_value": 1000, # [1000, 2000]
  "alpha": 0.1,
  "alpha1": 0.05,
  "set_seeds": input_seed,
  "boor_rv_type":  'rademacher' # ['rademacher', 'gaussian']
}

print("--- The H1_dense case Empirical Power ---")
p_val_list_H1_dense = run_experiment(param)

fp = [pval < quantile_10  for pval in p_val_list_H1_dense]
final_result = np.mean(fp)
fp1 = [pval < quantile_5 for pval in p_val_list_H1_dense]
final_result1 = np.mean(fp1)


print('Adjusted Power: {} with significance level {}'.format(final_result, quantile_10))
print('Adjusted Power: {} with significance level {}'.format(final_result1, quantile_5))

param = {
  "test": "H1_sparse", # ['H0', 'H1_dense', 'H1_sparse']
  "sample_size":  input_sample_size, # [100, 200, 300, 400]
  "z_dim": 25, # [5, 50, 250]
  "dx": 25,
  "dy": 1,
  "n_test": input_n_test, # [500] in the paper
  "alpha_x": 0.50, # only used under alternative ['power_sparse': a = 0.5; 'power_dense': 1/sqrt(2*p2) {1/math.sqrt(2*25)}]
  "m_value": 100, # [500]
  "k_value": 2, # [1, 2, 4]
  "j_value": 1000, # [1000, 2000]
  "alpha": 0.1,
  "alpha1": 0.05,
  "set_seeds": input_seed,
  "boor_rv_type":  'rademacher' # ['rademacher', 'gaussian']
}

print("--- The H1_sparse case Empirical Power ---")
p_val_list_H1_sparse = run_experiment(param)

fp = [pval < quantile_10  for pval in p_val_list_H1_sparse]
final_result = np.mean(fp)
fp1 = [pval < quantile_5 for pval in p_val_list_H1_sparse]
final_result1 = np.mean(fp1)


print('Adjusted Power: {} with significance level {}'.format(final_result, quantile_10))
print('Adjusted Power: {} with significance level {}'.format(final_result1, quantile_5))

---                     ---
---                     ---
--- The n =  400  case ---
--- The H0 case ---


100%|██████████| 500/500 [5:45:08<00:00, 41.42s/it]


Type 1 error: 0.13 with significance level 0.1
Type 1 error: 0.07 with significance level 0.05
--- The H1_dense case Empirical Power ---


100%|██████████| 500/500 [5:44:02<00:00, 41.29s/it]


Power: 0.784 with significance level 0.1
Power: 0.648 with significance level 0.05
Adjusted Power: 0.744 with significance level 0.084
Adjusted Power: 0.578 with significance level 0.03495000000000001
--- The H1_sparse case Empirical Power ---


100%|██████████| 500/500 [5:42:20<00:00, 41.08s/it]

Power: 0.932 with significance level 0.1
Power: 0.812 with significance level 0.05
Adjusted Power: 0.906 with significance level 0.084
Adjusted Power: 0.762 with significance level 0.03495000000000001



