In [None]:
param = {
  "test": "power", # ['type1error', 'power']
  "sample_size": 600, # for power, sample_size = 600
  "batch_size": 128, # [32, 64, 128, 256]
  "z_dim": 2, # [5, 50, 250]
  "dx": 2,
  "dy": 2,
  "n_test": 1000, # [200, 2000]
  "epochs_num": 400, # [1000, 1500]
  "eps_std" : 0.5,
  "dist_z" : 'gaussian', # ['laplace', 'gaussian']
  "alpha_x": 0.05, # only used under alternative [0.05, 0.10, 0.15, 0.20, 0.25]
  "m_value": 100, # [100, 200]
  "k_value": 2, # [1, 2, 4]
  "j_value": 1000, # [1000, 2000]
  "noise_dimension": 50, # [5, 10, 20]
  "hidden_layer_size": 50, # [64, 128, 256, 512, 1024]
  "normal_ini": False, # [True, False]
  "preprocess": 'normalize', # ['normalize',  'scale_Z', 'None' ]
  "G_lr": 5e-4, # [5e-6, 1e-5, 2e-5， 5e-5]
  "alpha": 0.1,
  "alpha1": 0.05,
  "set_seeds": 0,
  "using_orcale": False,
  "lambda_1": 1, # loss with Laplace kernel
  "lambda_2": 0,  # loss with Gaussian kernel
  "using_Gen": '1',  # ['1', '2'], types of generator "1" is fully connect, "2" is non fully
  "boor_rv_type":  'gaussian', # ['rademacher', 'gaussian']
  "wgt_decay": 1e-5, # weight decay for adam optimizer L2 regularization parameter
  "lambda_3": 1e-5, # L1 regularization parameter
  "drop_out_p": 0.1, #  probability of an element to be zeroed. Default: 0.5, best 0.2
  "M_train": 10
}


import torch
import torch.distributions as TD
from torch.utils.data import Dataset, DataLoader
from zmq import device
import torch.optim as optim
import numpy as np
from datetime import datetime
import functools
from tqdm import tqdm

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

def get_xzy_randn_nl(n_points, ground_truth='H0', dim=2, device='cuda:0', p = 0.0, **ignored):
    '''
    Generate CI,I  samples
    1. Z is independent Gaussian or Laplace (in the following function is y).
    2. X = Z + noise (in the following function is x)
    and Y = Z + noise (in the following function is z) in case of CI
    Arguments:
        n_points : number of samples
        ground_truth: 'H0'or 'H1'
        dim: dimension of X, Y, Z
        device: device to run on
        p: probability of H1. only used in H1
        ignored: ignored
    Output:
        Samples X, Y, Z
    '''
    y = torch.randn(n_points, dim, device=device) / np.sqrt(dim)
    m = TD.Bernoulli(torch.tensor([p]))
    delta = m.sample((n_points,)).to(device)

    noise_z = 0.1 * torch.randn(n_points, dim, device=device) / np.sqrt(dim)
    z = y + noise_z
    x = y.clone()

    if ground_truth == 'H1':
        x = x + (1 - delta) * (0.1* torch.randn_like(x) / np.sqrt(dim)) + delta * noise_z
    elif ground_truth == 'H0':
        x = x + (0.1* torch.randn_like(x) / np.sqrt(dim))
    else:
        raise NotImplementedError(f'{ground_truth} has to be H0 or H1')



    return x, z, y

def get_p_value_stat_1(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, sigma_w, sigma_u=1,
                       sigma_v=1,
                       boor_rv_type="gaussian"):
  """
    Compute the p-value

    Input:
    - boot_num: Integer giving the number of bootstrap samples.
    - M: Integer giving the number of training samples per batch.
    - n: Integer giving the number of training samples.
    - gen_x_all_torch: PyTorch Tensor (batch_size, M) of generated data of X.
    - gen_y_all_torch: PyTorch Tensor (batch_size, M) of generated data of Y.
    - x_torch: PyTorch Tensor (batch_size) of training input X.
    - y_torch: PyTorch Tensor (batch_size) of training input Y.
    - z_torch: PyTorch Tensor (batch_size, dimension_Z) of training input Z
    - sigma_w: Float of the bandwith of the Laplace kernel.
    - sigma_u: Float of the bandwith of the Laplace kernel.
    - sigma_v: Float of the bandwith of the Laplace kernel.
    - boor_rv_type: "rademacher" or "gaussian", specifying the reference distribution.

    Output:
    - p_value: Float giving the p-value.
  """

    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2)**2
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.linalg.vector_norm(y_torch.repeat(n, 1, 1) - torch.swapaxes(y_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2)**2 / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, 2) - y_torch.repeat(1, n).reshape(n, n, 1, 2), ord=1, dim=3)**2 / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, 2)

    temp_mx = torch.swapaxes(gen_y_all_torch_rep[:, :, 0, :], 0, 1)
    sum_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1, 2), ord=1, dim=3)**2 / sigma_u), dim=2)

    v_mx_1 = torch.exp(-torch.linalg.vector_norm(x_torch.repeat(n, 1, 1) - torch.swapaxes(x_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2)**2 / sigma_v)
    v_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_x_all_torch.repeat(n, 1, 1).reshape(n, n, -1, 2) - x_torch.repeat(1, n).reshape(n, n, 1, 2), ord=1, dim=3)**2 / sigma_v), dim=2)
    v_mx_3 = v_mx_2.T

    gen_x_all_torch_rep = gen_x_all_torch.repeat(n, 1, 1).reshape(n, n, -1, 2)

    temp2_mx = torch.swapaxes(gen_x_all_torch_rep[:, :, 0, :], 0, 1)
    sum2_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_x_all_torch_rep - temp2_mx.reshape(n, n, 1, 2), ord=1, dim=3)**2 / sigma_v), dim=2)

    for i in range(1, M):
        temp_mx = torch.swapaxes(gen_y_all_torch_rep[:, :, i, :], 0, 1)
        temp_add_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1, 2), ord=1, dim=3)**2 / sigma_u), dim=2)
        sum_mx = sum_mx + temp_add_mx

        temp2_mx = torch.swapaxes(gen_x_all_torch_rep[:, :, i, :], 0, 1)
        temp2_add_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_x_all_torch_rep - temp2_mx.reshape(n, n, 1, 2), ord=1, dim=3)**2 / sigma_v), dim=2)
        sum2_mx = sum2_mx + temp2_add_mx

    u_mx_4 = 1 / M * sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4
    v_mx_4 = 1 / M * sum2_mx
    v_mx = v_mx_1 - v_mx_2 - v_mx_3 + v_mx_4

    FF_mx = u_mx * v_mx * w_mx * (1 - torch.eye(n).to(device))

    stat = 1 / (n - 1) * torch.sum(FF_mx).item()

    # print("U_mx:", u_mx)
    # print("V_mx:", v_mx)
    # print("W_mx:", w_mx)
    boottemp = np.array([])
    if boor_rv_type == "rademacher":
        eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
        eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
        random_mx = torch.matmul(eboot[:, bb].reshape(-1, 1), eboot[:, bb].reshape(-1, 1).T)
        bootmatrix = FF_mx * random_mx
        stat_boot = 1 / (n - 1) * torch.sum(bootmatrix).item()
        boottemp = np.append(boottemp, stat_boot)
    return stat, boottemp

class DatasetSelect(Dataset):
  """
    Create a DatasetSelect object to generate the DataLoader in the learning process.

    Input:
    - X: PyTorch Tensor of shape (N, input_dimension) giving the training data of X.
    - Y: PyTorch Tensor of shape (N, output_dimension) giving the training data of Y.
    - Z: PyTorch Tensor of shape (N, output_dimension) giving the training data of Z.
  """
    def __init__(self, X, Y, Z):
        self.X_real = X
        self.Y_real = Y
        self.Z_real = Z
        self.sample_size = X.shape[0]

    def __len__(self):
        return self.sample_size

    def __getitem__(self, index):
        return self.X_real[index], self.Y_real[index], self.Z_real[index]

# Create a DataLoader for given (X, Y)

class DatasetSelect_GAN(torch.utils.data.Dataset):
  """
    Create a DatasetSelect object to generate the DataLoader in the learning process.

    Input:
    - X: PyTorch Tensor of shape (N, input_dimension) giving the training data of X.
    - Y: PyTorch Tensor of shape (N, output_dimension) giving the training data of Y.
    - batch_size: Integer giving the batch size.
  """

  def __init__(self, X, Y, Z, batch_size):
    self.X_real = X
    self.Y_real = Y
    self.Z_real = Z
    self.batch_size = batch_size
    self.sample_size = X.shape[0]

  def __len__(self):
    return self.sample_size

  def __getitem__(self, index):
    return self.X_real[index], self.Y_real[index], self.Z_real[index], self.Z_real[(self.batch_size+index) % self.sample_size]

# Create a DataLoader for given (X, Y)

class DatasetSelect_GAN_ver2(torch.utils.data.Dataset):
  """
    Create a DatasetSelect object to generate the DataLoader in the learning process.

    Input:
    - X: PyTorch Tensor of shape (N, input_dimension) giving the training data of X.
    - Y: PyTorch Tensor of shape (N, output_dimension) giving the training data of Y.
    - batch_size: Integer giving the batch size.
  """

  def __init__(self, Y, Z, batch_size):
    self.Y_real = Y
    self.Z_real = Z
    self.batch_size = batch_size
    self.sample_size = Z.shape[0]

  def __len__(self):
    return self.sample_size

  def __getitem__(self, index):
    return self.Y_real[index], self.Z_real[index]

##### Auxilliary functions #####

def sample_noise(sample_size, noise_dimension, noise_type, input_var):
    """
    Generate a PyTorch Tensor of random noise from the specified reference distribution.

    Input:
    - sample_size: the sample size of noise to generate.
    - noise_dimension: the dimension of noise to generate.
    - noise_type: "normal", "unif" or "Cauchy", giving the reference distribution.

    Output:
    - A PyTorch Tensor of shape (sample_size, noise_dimension).
    """

    if (noise_type == "normal"):
      noise_generator = TD.MultivariateNormal(
        torch.zeros(noise_dimension).to(device), input_var * torch.eye(noise_dimension).to(device))

      Z = noise_generator.sample((sample_size,))
    if (noise_type == "unif"):
      Z = torch.rand(sample_size, noise_dimension)
    if (noise_type == "Cauchy"):
      Z = TD.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])).sample((sample_size, noise_dimension)).squeeze(2)

    return Z

##### GAN architecture #####

class Generator(torch.nn.Module):
    """
    Specify the neural network architecture of the Generator.

    Here, we consider a FNN with a fully connected hidden layer with a width of 50,
    which is followed by a Leaky ReLU activation. The coefficient of Leaky ReLU needs to be
    specified. Batch normalization may be added prior to the activation function.
    The output layer a fully connected layer without activation.

    Inputs:
    - input_dimension: Integer giving the dimension of input Z.
    - output_dimension: Integer giving the dimension of output X or Y.
    - noise_dimension: Integer giving the dimension of random noise.
    - hidden_layer_size: Integer giving the size of the hidden layer of the generator.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Scalar giving the coefficient of the Leaky ReLU layer.
    - drop_out_p: Float giving the dropout probability.
    - drop_input: Boolean specifying whether to add dropout to the input layer.

    Returns:
    - x: PyTorch Tensor containing the (output_dimension,) output of the generator.
    """

    def __init__(self, input_dimension, output_dimension, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p,
                 drop_input = False):
      super(Generator, self).__init__()
      self.BN_type = BN_type
      self.ReLU_coef = ReLU_coef
      self.fc1 = torch.nn.Linear(input_dimension + noise_dimension, hidden_layer_size, bias=True)
      if BN_type:
        self.BN1 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN2 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN3 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
      self.leakyReLU1 = torch.nn.LeakyReLU(ReLU_coef)
      self.fc2 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc3 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc_last = torch.nn.Linear(hidden_layer_size, output_dimension, bias=True)
      self.fc_temp = torch.nn.Linear(input_dimension + noise_dimension, output_dimension, bias=True)
      self.sigmoid = torch.nn.Sigmoid()
      self.drop_out0 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out1 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out2 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out3 = torch.nn.Dropout(p=drop_out_p)
      self.drop_input = drop_input

    def forward(self, x):
      if self.BN_type:
        x = self.drop_out1(self.leakyReLU1(self.BN1(self.fc1(x))))
        x = self.drop_out2(self.leakyReLU1(self.BN2(self.fc2(x))))
        # x = self.drop_out3(self.leakyReLU1(self.BN3(self.fc3(x))))
        x = self.fc_last(x)

      else:
        x = self.drop_out1(self.leakyReLU1(self.fc1(x)))
        x = self.drop_out2(self.leakyReLU1(self.fc2(x)))
        # x = self.drop_out3(self.leakyReLU1(self.fc3(x)))
        x = self.fc_last(x)
        # x = self.sigmoid(x)
      return x

##### Training procedures #####


def find_loss(y_torch, gen_y_all_torch, z_torch, sigma_w, sigma_u, M):
    """
    Compute the MMD loss via Laplace kernel.

    Inputs:
    - y_torch: PyTorch Tensor (batch_size) of training input. (X or Y)
    - gen_y_all_torch: PyTorch Tensor (batch_size, M) of generated data.
    - z_torch: PyTorch Tensor (batch_size, dimension_Z) of training input Z.
    - sigma_w: Float of the bandwith of the kernel.
    - sigma_u: Float of the bandwith of the kernel.
    - M: Number of training samples per batch.

    Outputs:
    - loss: PyTorch Tensor containing the MMD loss.
    """
    n = z_torch.shape[0]
    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=2, dim=2)**2
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.linalg.vector_norm(y_torch.repeat(n, 1, 1) - torch.swapaxes(y_torch.repeat(n, 1, 1), 0, 1), ord=2, dim=2)**2 / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, 2) - y_torch.repeat(1, n).reshape(n, n, 1, 2), ord=2, dim=3)**2 / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, 2)

    temp_mx = torch.swapaxes(gen_y_all_torch_rep[:, :, 0, :], 0, 1)

    sum_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1, 2), ord=2, dim=3)**2 / sigma_u), dim=2)

    for i in range(1, M):
        temp_mx = torch.swapaxes(gen_y_all_torch_rep[:, :, i, :], 0, 1)
        temp_add_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1, 2), ord=2, dim=3)**2 / sigma_u), dim=2)
        sum_mx = sum_mx + temp_add_mx

    u_mx_4 = 1 / M * sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4

    FF_mx = u_mx * w_mx * (1 - torch.eye(n).to(device))

    loss = 1 / (n) * torch.sum(FF_mx)
    return loss


def train_ver3(X, Y, Z, X_test, Y_test, Z_test, M,
      noise_dimension, noise_type, G_lr, hidden_layer_size,
      DataLoader, BN_type, ReLU_coef,
      epochs_num=10,  sigma_z = 1, sigma_x = 1, sigma_y = 1,
      normal_ini = False,
      lambda_1 = 1, lambda_2 = 1, using_Gen = '1', wgt_decay = 0,
      lambda_3 = 0, drop_out_p = 0.2, M_train = 3):
    """
    Train loop for GAN.

    Inputs:
    - X: PyTorch Tensor (sample_size, dimension_X) of training input.
    - Y: PyTorch Tensor (sample_size, dimension_Y) of training input.
    - Z: PyTorch Tensor (sample_size, dimension_Z) of training input.
    - X_test: PyTorch Tensor (sample_size, dimension_X) of test input.
    - Y_test: PyTorch Tensor (sample_size, dimension_Y) of test input.
    - Z_test: PyTorch Tensor (sample_size, dimension_Z) of test input.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - noise_type: "normal", "unif" or "Cauchy", giving the reference distribution.
    - G_lr: Float giving the learning rate of the generator.
    - hidden_layer_size: Integer giving the size of the hidden layer of the generator.
    - DataLoader: DataLoader object used to generate training batches.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Float giving the coefficient of the Leaky ReLU layer.
    - epochs_num: Number of epochs over the training dataset to use for training.
    - sigma_z: Float of the bandwith of the kernel.
    - sigma_x: Float of the bandwith of the kernel.
    - sigma_y: Float of the bandwith of the kernel.
    - normal_ini: Boolean specifying whether to initialize the generator with normal initialization.
    - lambda_1: Float giving the coefficient of the MMD loss using Laplace kernel.
    - lambda_2: Float giving the coefficient of the MMD loss using Gaussian kernel. (not using)
    - using_Gen: '1' or '2' specifying whether to use the first or second generator.(not using)
    - wgt_decay: Float giving the weight decay. (L2 regularization)
    - lambda_3: Scalar giving the coefficient of the L1 regularization.
    - drop_out_p: Float giving the dropout probability.
    - M_train: Number of training samples per batch used in the Laplace or Gaussian kernel.

    Outputs:
    - G_zy: PyTorch Net giving the trained generator.
    - G_zx: PyTorch Net giving the trained generator.
    """
    input_dimension = Z.shape[1]
    output_dimension_y = Y.shape[1]
    output_dimension_x = X.shape[1]
    M_eval = 50

    G_zy = Generator(input_dimension, output_dimension_y, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p).to(device)
    G_zy_solver = optim.Adam(G_zy.parameters(), lr=G_lr, betas=(0.5, 0.999), weight_decay=wgt_decay)

    G_zx = Generator(input_dimension, output_dimension_x, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p).to(device)
    G_zx_solver = optim.Adam(G_zx.parameters(), lr=G_lr, betas=(0.5, 0.999), weight_decay=wgt_decay)


    iter_count = 0
    G_zy = G_zy.train()
    G_zx = G_zx.train()

    for epoch in range(epochs_num):
        # print('EPOCH: ', (epoch+1))
        batch_count = 0
        G_zy = G_zy.train()
        G_zx = G_zx.train()
        for X_real, Y_real, Z_real, Z_fake in DataLoader:
            X_real = X_real.to(device)
            Y_real = Y_real.to(device)
            Z_real = Z_real.to(device)
            Z_fake = Z_fake.to(device)

            batch_size = Z_real.shape[0]
            Z_real_repeat = Z_real.repeat(M_train,1)


            # Generate fake data
            Noise_fake = sample_noise(Z_real_repeat.shape[0], noise_dimension, noise_type, input_var = 1.0/3.0).to(device)
            Y_fake = G_zy(torch.cat((Z_real_repeat,Noise_fake),dim=1)).to(device)

            Noise_fake = sample_noise(Z_real_repeat.shape[0], noise_dimension, noise_type, input_var = 1.0/3.0).to(device)
            X_fake = G_zx(torch.cat((Z_real_repeat,Noise_fake),dim=1)).to(device)

            # Y_fake = Y_fake.reshape(batch_size, M_train, output_dimension_y)
            Y_fake = Y_fake.reshape(M_train, batch_size, output_dimension_y)
            # X_fake = X_fake.reshape(batch_size, M_train, output_dimension_x)
            X_fake = X_fake.reshape(M_train, batch_size, output_dimension_x)

            # Generator step
            g_zy_error = None
            G_zy_solver.zero_grad()

            g_zx_error = None
            G_zx_solver.zero_grad()

            # g_zy_error = find_loss(Y_real, Y_fake, Z_real, sigma_z, sigma_y, M_train)
            Y_fake = torch.swapaxes(Y_fake, 0, 1)
            # g_zy_error = torch.mean(torch.linalg.vector_norm(torch.mean(Y_fake, dim = 1) - Y_real, ord = 2, dim = 1) ** 2)
            g_zy_error = find_loss(Y_real, Y_fake, Z_real, sigma_z, sigma_y, M_train)

            g_zy_error.backward()
            torch.nn.utils.clip_grad_norm_(G_zy.parameters(), max_norm=0.5)
            G_zy_solver.step()

            g_zy_error = None
            G_zy_solver.zero_grad()

            g_zx_error = None
            G_zx_solver.zero_grad()

            X_fake = torch.swapaxes(X_fake, 0, 1)
            g_zx_error = find_loss(X_real, X_fake, Z_real, sigma_z, sigma_x, M_train)


            g_zx_error.backward()
            torch.nn.utils.clip_grad_norm_(G_zx.parameters(), max_norm=0.5)
            G_zx_solver.step()

            iter_count += 1
            batch_count += 1

    return G_zy, G_zx

def mGAN(n=500, z_dim=2, simulation='type1error', batch_size=64, epochs_num=1000,
    nstd=1.0, z_dist='gaussian', x_dims=2, y_dims=2, a_x=0.05, M=500, k=2, boot_num=1000,
    noise_dimension = 1, hidden_layer_size = 512, normal_ini = False, preprocess = 'normalize',
    G_lr = 1e-5, using_orcale = False, lambda_1 = 1, lambda_2 = 1, using_Gen = '1',
    boor_rv_type = "gaussian", wgt_decay = 0, lambda_3 = 1, drop_out_p = 0.2, exp_num = 0, M_train = 3):
    """
    Compute the test statistics

    Inputs:
    - Ax: Torch Tensor of shape (sample_size, dimension_X) giving the matrix to generate training data of X.
    - Ay: Torch Tensor of shape (sample_size, dimension_Y) giving the matrix to generate training data of Y.
    - n: Integer giving the number of samples to generate.
    - z_dim: Integer giving the dimension of Z.
    - simulation: 'type1error' or 'power'.
    - batch_size: Integer giving the batch size.
    - epochs_num: Number of epochs over the training dataset to use for training.
    - nstd: Float. standard deviation of the noise in the simulated data.
    - z_dist: 'gaussian' or 'uniform'.
    - x_dims: Integer giving the dimension of X.
    - y_dims: Integer giving the dimension of Y.
    - a_x: Float using in the alternative case. alpha_x.
    - M: Integer giving the number of training samples per batch used in the Laplace or Gaussian kernel.
    - k: Integer giving the number of cross-validation folds.
    - boot_num: Integer of the number of wild bootstrap when computing p-value.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - hidden_layer_size: Integer giving the size of the hidden layer of the generator.
    - normal_ini: Boolean specifying whether to initialize the generator with normal initialization.
    - G_lr: Float giving the learning rate of the generator.
    - using_orcale: 'True' or 'False' specifying whether to use the orcale method.
    - lambda_1: Float giving the coefficient of the MMD loss using Laplace kernel.
    - lambda_2: Float giving the coefficient of the MMD loss using Gaussian kernel.
    - using_Gen: '1' or '2' specifying whether to use the first or second generator.
    - boor_rv_type: 'rademacher', 'gaussian'. type of the bootstrap random variable.
    - wgt_decay: Float giving the weight decay. (L2 regularization)
    - lambda_3: Scalar giving the coefficient of the L1 regularization.
    - drop_out_p: Float giving the dropout probability
    - exp_num: not using
    - M_train: Number of training samples per batch used in the Laplace or Gaussian kernel.

    Outputs:
    - p-value: Float of computed p-value.

    """
    if simulation == 'type1error':
        sim_x, sim_y, sim_z = get_xzy_randn_nl(n_points = n, ground_truth = 'H0')

    elif simulation == 'power':
        sim_x, sim_y, sim_z = get_xzy_randn_nl(n_points = n, ground_truth = 'H1', p = a_x)

    else:
        raise ValueError('Test does not exist.')

    x, y, z = sim_x.to(device), sim_y.to(device), sim_z.to(device)

    sigma_w_train, sigma_u_train, sigma_v_train =  1.0, 1.0, 1.0

    test_size = int(n/k)
    stat_all = torch.zeros(k, 1)
    boot_temp_all = torch.zeros(k, boot_num)
    cur_k = 0

    for k_fold in range(k):
        k_fold_start = int(n/k * k_fold)
        k_fold_end = int(n/k * (k_fold+1))
        X_test, Y_test, Z_test = x[k_fold_start:k_fold_end], y[k_fold_start:k_fold_end], z[k_fold_start:k_fold_end]
        X_train, Y_train, Z_train = torch.cat((x[0:k_fold_start], x[k_fold_end:])), torch.cat((y[0:k_fold_start], y[k_fold_end:])), torch.cat((z[0:k_fold_start], z[k_fold_end:]))

        if (k == 1):
            X_train, Y_train, Z_train = X_test, Y_test, Z_test

        train_xyz = DatasetSelect_GAN(X_train, Y_train, Z_train, batch_size)
        DataLoader_xyz = torch.utils.data.DataLoader(train_xyz, batch_size=batch_size, shuffle=True)

        G_zy, G_zx = train_ver3(X = X_train, Y = Y_train, Z = Z_train, M = M,
                X_test = X_test, Y_test = Y_test, Z_test = Z_test,
                noise_dimension = noise_dimension, noise_type = "normal",
                G_lr = G_lr, hidden_layer_size = hidden_layer_size,
                DataLoader = DataLoader_xyz, BN_type = True, ReLU_coef = 0.2,
                epochs_num=epochs_num, sigma_z = sigma_w_train, sigma_x = sigma_v_train, sigma_y = sigma_u_train,
                normal_ini = normal_ini, lambda_1 = lambda_1, lambda_2 = lambda_2,
                using_Gen = using_Gen, wgt_decay = wgt_decay, lambda_3 = lambda_3,
                drop_out_p = drop_out_p, M_train = M_train)

        gen_x_all = torch.zeros(test_size, M)
        gen_y_all = torch.zeros(test_size, M)
        z_all = torch.zeros(test_size, z_dim)
        x_all = torch.zeros(test_size, x_dims)
        y_all = torch.zeros(test_size, y_dims)

        G_zx = G_zx.eval()
        G_zy = G_zy.eval()

        Z_test_repeat = Z_test.repeat(M,1).to(device)


        # Generate fake data
        Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension, "normal", input_var = 1.0/3.0).to(device)
        with torch.no_grad():
            gen_y_all = G_zy(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device)

        Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension, "normal", input_var = 1.0/3.0).to(device)
        with torch.no_grad():
            gen_x_all = G_zx(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device)

        gen_x_all = gen_x_all.reshape(M, test_size, x_dims).detach().to(device)
        gen_y_all = gen_y_all.reshape(M, test_size, y_dims).detach().to(device)

        gen_x_all = torch.swapaxes(gen_x_all, 0, 1)
        gen_y_all = torch.swapaxes(gen_y_all, 0, 1)

        z_all = Z_test.to(device)
        x_all = X_test.to(device)
        y_all = Y_test.to(device)

        sigma_w, sigma_u, sigma_v = 1.0, 1.0, 1.0

        standardise = True

        if standardise:
            gen_x_all = (gen_x_all - torch.mean(gen_x_all, dim=0, keepdim=True)) / torch.std(gen_x_all, dim=0, keepdim=True)
            gen_y_all = (gen_y_all - torch.mean(gen_y_all, dim=0, keepdim=True)) / torch.std(gen_y_all, dim=0, keepdim=True)
            x_all = (x_all - torch.mean(x_all, dim=0, keepdim=True)) / torch.std(x_all, dim=0, keepdim=True)
            y_all = (y_all - torch.mean(y_all, dim=0, keepdim=True)) / torch.std(y_all, dim=0, keepdim=True)
            z_all = (z_all - torch.mean(z_all, dim=0, keepdim=True)) / torch.std(z_all, dim=0, keepdim=True)

        cur_stat, cur_boot_temp = get_p_value_stat_1(boot_num, M, test_size, gen_x_all.to(device), gen_y_all.to(device),
                                x_all.to(device), y_all.to(device), z_all.to(device), sigma_w, sigma_u, sigma_v,
                                boor_rv_type)
        stat_all[cur_k,:] = cur_stat
        boot_temp_all[cur_k,:] = torch.from_numpy(cur_boot_temp)
        cur_k = cur_k + 1

    return np.mean(torch.mean(boot_temp_all, dim = 0).numpy() > torch.mean(stat_all).item() )

def run_experiment(params):
    test = params["test"]
    sample_size = params["sample_size"]
    batch_size = params["batch_size"]
    z_dim = params["z_dim"]
    dx = params["dx"]
    dy = params["dy"]
    n_test = params["n_test"]
    epochs_num = params["epochs_num"]
    eps_std = params["eps_std"]
    dist_z = params["dist_z"]
    alpha_x = params["alpha_x"]
    m_value = params["m_value"]
    k_value = params["k_value"]
    j_value = params["j_value"]
    noise_dimension = params["noise_dimension"]
    hidden_layer_size = params["hidden_layer_size"]
    normal_ini = params["normal_ini"]
    preprocess = params["preprocess"]
    G_lr = params["G_lr"]
    alpha = params["alpha"]
    alpha1 = params["alpha1"]
    set_seeds = params["set_seeds"]
    using_orcale = params["using_orcale"]
    lambda_1 = params["lambda_1"]
    lambda_2 = params["lambda_2"]
    using_Gen = params["using_Gen"]
    boor_rv_type = params["boor_rv_type"]
    wgt_decay = params["wgt_decay"]
    lambda_3 = params["lambda_3"]
    drop_out_p = params["drop_out_p"]
    M_train = params["M_train"]

    np.random.seed(set_seeds)
    torch.manual_seed(set_seeds)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(set_seeds)

    p_values = np.array([])
    test_count = 0
    if test == 'type1error':
        for n in tqdm(range(n_test)):
            start_time = datetime.now()

            p_value = mGAN(n=sample_size, z_dim=z_dim, simulation=test, batch_size=batch_size,
                    epochs_num=epochs_num,
                    nstd=eps_std, z_dist=dist_z, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                    k=k_value, boot_num=j_value,
                    noise_dimension = noise_dimension, hidden_layer_size = hidden_layer_size, normal_ini = normal_ini,
                    preprocess = preprocess, G_lr = G_lr, using_orcale = using_orcale,
                    lambda_1 = lambda_1, lambda_2 = lambda_2, using_Gen = using_Gen, boor_rv_type = boor_rv_type,
                    wgt_decay = wgt_decay, lambda_3 = lambda_3, drop_out_p = drop_out_p, exp_num = n + 1,
                    M_train = M_train)
            test_count += 1
            # print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha  for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 for pval in p_values]
            final_result1 = np.mean(fp1)

            # print('The stat is {}'.format(p_value))
            # print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            # print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

            final_result_list = np.array([final_result])
            final_result1_list = np.array([final_result1])

    if test == 'power':
        for n in tqdm(range(n_test)):
            start_time = datetime.now()

            p_value = mGAN(n=sample_size, z_dim=z_dim, simulation=test, batch_size=batch_size,
                    epochs_num=epochs_num,
                    nstd=eps_std, z_dist=dist_z, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                    k=k_value, boot_num=j_value,
                    noise_dimension = noise_dimension, hidden_layer_size = hidden_layer_size, normal_ini = normal_ini,
                    preprocess = preprocess, G_lr = G_lr, using_orcale = using_orcale,
                    lambda_1 = lambda_1, lambda_2 = lambda_2, using_Gen = using_Gen, boor_rv_type = boor_rv_type,
                    wgt_decay = wgt_decay, lambda_3 = lambda_3, drop_out_p = drop_out_p, exp_num = n + 1,
                    M_train = M_train)

            test_count += 1
            # print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha  for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 for pval in p_values]
            final_result1 = np.mean(fp1)

            # print('The stat is {}'.format(p_value))
            # print('Power: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            # print('Power: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

            final_result_list = np.array([final_result])
            final_result1_list = np.array([final_result1])

    print('Emp Rej Rate: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
    print('Emp Rej Rate: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

    return p_values


# p_val_list = run_experiment(param)
# p_val_list

In [None]:
# @title code to get Size Adjusted Power for Figure 6 (b) hat T_2


param["alpha"] = 0.039  # 5% quantile of the p_val_list from ours_size.ipynb when param["test"] = "type1error" param["sample_size"] = 600
param["alpha1"] = 0.083 # 10% quantile of the p_val_list from ours_size.ipynb when param["test"] = "type1error" param["sample_size"] = 600


for alpha_x in [0.05, 0.10, 0.15, 0.20, 0.25]:
    param["alpha_x"] = alpha_x
    p_val_list = run_experiment(param)