In [None]:
param = {
    "test": "type1error",  # ['type1error', 'power']
    "sample_size": 200,  # [200, 400, 600, 800, 1000]
    "batch_size": 256,  # [32, 64, 128, 256]
    "z_dim": 1,  # [5, 50, 250]
    "dx": 1,
    "dy": 1,
    "n_test": 500,  # [200, 2000]
    "epochs_num": 500,  # [1000, 1500]
    "eps_std": 0.5,
    "dist_z": 'gaussian',  # ['laplace', 'gaussian']
    "alpha_x": 0.75,  # only used under alternative
    "m_value": 100,  # [100, 200]
    "k_value": 2,  # [1, 2, 4]
    "j_value": 1000,  # [1000, 2000]
    "noise_dimension": 3,  # [5, 10, 20]
    "hidden_layer_size": 50,  # [64, 128, 256, 512, 1024]
    "normal_ini": False,  # [True, False]
    "preprocess": 'normalize',  # ['normalize',  'scale_Z', 'None' ]
    "G_lr": 5e-2,  # [5e-6, 1e-5, 2e-5， 5e-5]
    "alpha": 0.1,
    "alpha1": 0.05,
    "set_seeds": 42,
    "using_orcale": False,
    "lambda_1": 1,  # loss with Laplace kernel
    "lambda_2": 0,  # loss with Gaussian kernel
    "using_Gen": '1',  # ['1', '2'], types of generator "1" is fully connect, "2" is non fully
    "boor_rv_type": 'gaussian',  # ['rademacher', 'gaussian']
    "wgt_decay": 1e-5,  # weight decay for adam optimizer L2 regularization parameter
    "lambda_3": 1e-5,  # L1 regularization parameter
    "drop_out_p": 0.2,  # probability of an element to be zeroed. Default: 0.5, best 0.2
    "M_train": 10
}

import torch
import torch.distributions as TD
from torch.utils.data import Dataset, DataLoader
from zmq import device
import torch.optim as optim
import numpy as np
from datetime import datetime
import functools
import tensorflow as tf
from scipy.stats import rankdata, ks_2samp, wilcoxon

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')


def generate_samples_random(Ax, Ay, size=1000, sType='CI', dx=1, dy=1, dz=20, nstd=0.05, alpha_x=0.05,
                            preprocess="None", dist_z='gaussian'):
    '''
    Generate CI,I or NI post-nonlinear samples
    1. Z is independent Gaussian or Laplace
    2. X = f1(<a,Z> + b + noise) and Y = f2(<c,Z> + d + noise) in case of CI
    Arguments:
        size : number of samples
        sType: CI, I, or NI
        dx: Dimension of X
        dy: Dimension of Y
        dz: Dimension of Z
        nstd: noise standard deviation
        we set f1 to be sin function and f2 to be cos function.
    Output:
        Samples X, Y, Z
    '''
    num = size

    numbers_z = np.random.multinomial(num, [1 / 2.] * 2, size=1)
    number_z_zeros = numbers_z[0][0]
    number_z_ones = numbers_z[0][1]

    xy_arr = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
    xy_z_zero_index = np.random.choice(a = len(xy_arr), size=number_z_zeros, p=[1/4., 1/4., 1/4., 1/4.])
    xy_z_one_index = np.random.choice(a = len(xy_arr), size=number_z_ones, p=[1/4., 1/4., 1/4., 1/4.])

    xy_z_zero = xy_arr[xy_z_zero_index]
    xy_z_one = xy_arr[xy_z_one_index]

    xy = np.concatenate((xy_z_zero, xy_z_one), axis=0)

    x = xy[:, 0]
    y = xy[:, 1]
    z = np.concatenate((np.zeros(number_z_zeros), np.ones(number_z_ones)), axis=0)

    indices = np.random.permutation(num)

    x, y, z = x[indices], y[indices], z[indices]
    X, Y, Z = x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)

    if preprocess == "normalize":
        Z = (Z - Z.min()) / (Z.max() - Z.min())
        X = (X - X.min()) / (X.max() - X.min())
        Y = (Y - Y.min()) / (Y.max() - Y.min())

    elif preprocess == "scale_Z":
        Z = Z / Z.max()

    elif preprocess == "None":
        X, Y, Z = X, Y, Z

    X, Y, Z = torch.from_numpy(np.array(X)).float(), torch.from_numpy(np.array(Y)).float(), torch.from_numpy(
        np.array(Z)).float()
    return X, Y, Z


def generate_samples_from_fixed_Z_random(Ax, Ay, Z, size=1000, sType='CI', dx=1, dy=1, dz=20, nstd=0.05, alpha_x=0.05,
                                         normalize=True, seed=None, dist_z='gaussian'):
    '''
    Generate CI,I or NI post-nonlinear samples
    1. Z is independent Gaussian or Laplace
    2. X = f1(<a,Z> + b + noise) and Y = f2(<c,Z> + d + noise) in case of CI
    Arguments:
        size : number of samples
        sType: CI, I, or NI
        dx: Dimension of X
        dy: Dimension of Y
        dz: Dimension of Z
        nstd: noise standard deviation
        we set f1 to be sin function and f2 to be cos function.
    Output:
        Samples X, Y, Z
    '''
    num = size

    error_generator_x = TD.MultivariateNormal(
        torch.zeros(dx), 1 * torch.eye(dx))

    error_generator_y = TD.MultivariateNormal(
        torch.zeros(dy), 1 * torch.eye(dy))

    Axy = torch.ones((dx, dy)) * alpha_x

    if sType == 'CI':
        X = torch.sin(torch.matmul(Z, Ax) + nstd * error_generator_x.sample(
            (num,)))  ##variance is 1, not 0.25, as mentioned in the paper
        Y = torch.cos(torch.matmul(Z, Ay) + nstd * error_generator_y.sample((num,)))
    elif sType == 'I':
        X = torch.sin(nstd * error_generator_x.sample((num,)))
        Y = torch.cos(nstd * error_generator_y.sample((num,)))
    else:
        X = torch.sin(torch.matmul(Z, Ax) + nstd * error_generator_x.sample((num,)))
        Y = torch.cos(
            torch.matmul(torch.sin(torch.matmul(Z, Ax) + nstd * error_generator_x.sample((num,))), Axy) + torch.matmul(
                Z, Ay) + nstd * error_generator_y.sample((num,)))

    return X, Y


def get_p_value_stat_1(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, sigma_w, sigma_u=1,
                       sigma_v=1,
                       boor_rv_type="gaussian"):
    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=1,
                                    dim=2)
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.abs(y_torch.repeat(1, n) - y_torch.repeat(1, n).T) / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.abs(gen_y_all_torch.repeat(n, 1, 1) - y_torch.repeat(1, n).reshape(n, n, 1)) / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n, 1, 1)

    temp_mx = gen_y_all_torch_rep[:, :, 0].T
    sum_mx = torch.mean(torch.exp(-torch.abs(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1)) / sigma_u), dim=2)

    v_mx_1 = torch.exp(-torch.abs(x_torch.repeat(1, n) - x_torch.repeat(1, n).T) / sigma_v)
    v_mx_2 = torch.mean(
        torch.exp(-torch.abs(gen_x_all_torch.repeat(n, 1, 1) - x_torch.repeat(1, n).reshape(n, n, 1)) / sigma_v), dim=2)
    v_mx_3 = v_mx_2.T

    gen_x_all_torch_rep = gen_x_all_torch.repeat(n, 1, 1)

    temp2_mx = gen_x_all_torch_rep[:, :, 0].T
    sum2_mx = torch.mean(torch.exp(-torch.abs(gen_x_all_torch_rep - temp2_mx.reshape(n, n, 1)) / sigma_v), dim=2)

    for i in range(1, M):
        temp_mx = gen_y_all_torch_rep[:, :, i].T
        temp_add_mx = torch.mean(torch.exp(-torch.abs(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1)) / sigma_u), dim=2)
        sum_mx = sum_mx + temp_add_mx

        temp2_mx = gen_x_all_torch_rep[:, :, i].T
        temp2_add_mx = torch.mean(torch.exp(-torch.abs(gen_x_all_torch_rep - temp2_mx.reshape(n, n, 1)) / sigma_v),
                                  dim=2)
        sum2_mx = sum2_mx + temp2_add_mx

    u_mx_4 = 1 / M * sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4
    v_mx_4 = 1 / M * sum2_mx
    v_mx = v_mx_1 - v_mx_2 - v_mx_3 + v_mx_4

    FF_mx = u_mx * v_mx * w_mx * (1 - torch.eye(n).to(device))

    stat = 1 / (n - 1) * torch.sum(FF_mx).item()

    # print("U_mx:", u_mx)
    # print("V_mx:", v_mx)
    # print("W_mx:", w_mx)
    boottemp = np.array([])
    if boor_rv_type == "rademacher":
        eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
        eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
        random_mx = torch.matmul(eboot[:, bb].reshape(-1, 1), eboot[:, bb].reshape(-1, 1).T)
        bootmatrix = FF_mx * random_mx
        stat_boot = 1 / (n - 1) * torch.sum(bootmatrix).item()
        boottemp = np.append(boottemp, stat_boot)
    return stat, boottemp


class DatasetSelect(Dataset):
    def __init__(self, X, Y, Z):
        self.X_real = X
        self.Y_real = Y
        self.Z_real = Z
        self.sample_size = X.shape[0]

    def __len__(self):
        return self.sample_size

    def __getitem__(self, index):
        return self.X_real[index], self.Y_real[index], self.Z_real[index]


# Create a DataLoader for given (X, Y)

class DatasetSelect_GAN(torch.utils.data.Dataset):
    """
      Create a DatasetSelect object to generate the DataLoader in the learning process.

      Input:
      - X: PyTorch Tensor of shape (N, input_dimension) giving the training data of X.
      - Y: PyTorch Tensor of shape (N, output_dimension) giving the training data of Y.
      - batch_size: Integer giving the batch size.
    """

    def __init__(self, X, Y, Z, batch_size):
        self.X_real = X
        self.Y_real = Y
        self.Z_real = Z
        self.batch_size = batch_size
        self.sample_size = X.shape[0]

    def __len__(self):
        return self.sample_size

    def __getitem__(self, index):
        return self.X_real[index], self.Y_real[index], self.Z_real[index], self.Z_real[
            (self.batch_size + index) % self.sample_size]


# Create a DataLoader for given (X, Y)

class DatasetSelect_GAN_ver2(torch.utils.data.Dataset):
    """
      Create a DatasetSelect object to generate the DataLoader in the learning process.

      Input:
      - X: PyTorch Tensor of shape (N, input_dimension) giving the training data of X.
      - Y: PyTorch Tensor of shape (N, output_dimension) giving the training data of Y.
      - batch_size: Integer giving the batch size.
    """

    def __init__(self, Y, Z, batch_size):
        self.Y_real = Y
        self.Z_real = Z
        self.batch_size = batch_size
        self.sample_size = Z.shape[0]

    def __len__(self):
        return self.sample_size

    def __getitem__(self, index):
        return self.Y_real[index], self.Z_real[index]


##### Auxilliary functions #####

def sample_noise(sample_size, noise_dimension, noise_type, input_var):
    """
    Generate a PyTorch Tensor of random noise from the specified reference distribution.

    Input:
    - sample_size: the sample size of noise to generate.
    - noise_dimension: the dimension of noise to generate.
    - noise_type: "normal", "unif" or "Cauchy", giving the reference distribution.

    Output:
    - A PyTorch Tensor of shape (sample_size, noise_dimension).
    """

    if (noise_type == "normal"):
        noise_generator = TD.MultivariateNormal(
            torch.zeros(noise_dimension).to(device), input_var * torch.eye(noise_dimension).to(device))

        Z = noise_generator.sample((sample_size,))
    if (noise_type == "unif"):
        Z = torch.rand(sample_size, noise_dimension)
    if (noise_type == "Cauchy"):
        Z = TD.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])).sample((sample_size, noise_dimension)).squeeze(2)

    return Z


##### GAN architecture #####

class Generator(torch.nn.Module):
    """
    Specify the neural network architecture of the Generator.

    Here, we consider a FNN with a fully connected hidden layer with a width of 50,
    which is followed by a Leaky ReLU activation. The coefficient of Leaky ReLU needs to be
    specified. Batch normalization may be added prior to the activation function.
    The output layer a fully connected layer without activation.

    Inputs:
    - input_dimension: Integer giving the dimension of input X.
    - output_dimension: Integer giving the dimension of output Y.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Scalar giving the coefficient of the Leaky ReLU layer.

    Returns:
    - x: PyTorch Tensor containing the (output_dimension,) output of the discriminator.
    """

    def __init__(self, input_dimension, output_dimension, noise_dimension, hidden_layer_size, BN_type, ReLU_coef,
                 drop_out_p,
                 drop_input=False):
        super(Generator, self).__init__()
        self.BN_type = BN_type
        self.ReLU_coef = ReLU_coef
        self.fc1 = torch.nn.Linear(input_dimension + noise_dimension, hidden_layer_size, bias=True)
        if BN_type:
            self.BN1 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
            self.BN2 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
            self.BN3 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.leakyReLU1 = torch.nn.LeakyReLU(ReLU_coef)
        self.fc2 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
        self.fc3 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
        self.fc_last = torch.nn.Linear(hidden_layer_size, output_dimension, bias=True)
        self.sigmoid = torch.nn.Sigmoid()
        self.drop_out0 = torch.nn.Dropout(p=drop_out_p)
        self.drop_out1 = torch.nn.Dropout(p=drop_out_p)
        self.drop_out2 = torch.nn.Dropout(p=drop_out_p)
        self.drop_out3 = torch.nn.Dropout(p=drop_out_p)
        self.drop_input = drop_input

    def forward(self, x):
        if self.BN_type:
            if self.drop_input:
                x = self.drop_out0(x)
            x = self.drop_out1(self.leakyReLU1(self.BN1(self.fc1(x))))
            x = self.drop_out2(self.leakyReLU1(self.BN2(self.fc2(x))))
            # x = self.drop_out3(self.leakyReLU1(self.BN3(self.fc3(x))))
            x = self.fc_last(x)
        else:
            if self.drop_input:
                x = self.drop_out0(x)
            x = self.drop_out1(self.leakyReLU1(self.fc1(x)))
            x = self.drop_out2(self.leakyReLU1(self.fc2(x)))
            # x = self.drop_out3(self.leakyReLU1(self.fc3(x)))
            x = self.fc_last(x)
            x = self.sigmoid(x)
            # x = (x > 0.5).float()
        return x


class NonFullyConnected_1(torch.nn.Module):

    def __init__(self, size_in, size_out, m, bias=True):
        super(NonFullyConnected_1, self).__init__()
        self.linear = torch.nn.Linear(m * size_in, m * size_out, bias=bias).to(device)
        self.mask = functools.reduce(torch.block_diag, [torch.ones(size_out, size_in) for i in range(m)]).to(device)

    def forward(self, x):
        self.linear.weight.data *= self.mask
        return self.linear(x)


class Generator_2(torch.nn.Module):
    def __init__(
            self,
            input_dimension,
            output_dimension,
            noise_dimension,
            hidden_layer_size,
            BN_type,
            ReLU_coef,
            hidden_layer_depth=1,
            ntargets_k=5):
        super(Generator_2, self).__init__()
        self.input_dimension = input_dimension + noise_dimension
        self.output_dimension = output_dimension
        self.ntargets_k = ntargets_k
        self.hidden_layer_sizes = [hidden_layer_size] * hidden_layer_depth
        self.BN_type = BN_type
        self.leakyrelu = torch.nn.LeakyReLU(ReLU_coef)
        self.linear_layers_from_input = torch.nn.Linear(self.input_dimension, ntargets_k * self.hidden_layer_sizes[0])

        self.linear_layers_between = torch.nn.ModuleList([
            NonFullyConnected_1(self.hidden_layer_sizes[0], self.hidden_layer_sizes[0], ntargets_k)
            for i in range(len(self.hidden_layer_sizes))
        ])
        # self.linear8 = torch.nn.Linear(self.hidden_layer_sizes[0]*ntargets_k, self.hidden_layer_sizes[0]*ntargets_k)
        # self.linear8.weight = torch.nn.Parameter(torch.eye(self.hidden_layer_sizes[0]*ntargets_k), requires_grad=False)
        self.linear8 = torch.nn.Linear(self.hidden_layer_sizes[0] * ntargets_k, self.output_dimension)
        if BN_type:
            self.BN1 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)

    def forward(self, input):
        if self.BN_type:
            output = self.linear_layers_from_input(input)
            output = self.leakyrelu(self.BN1(output))
            for linear_layers_between in self.linear_layers_between:
                output = linear_layers_between(output)
                output = self.leakyrelu(self.BN1(output))
        else:
            output = self.linear_layers_from_input(input)
            output = self.leakyrelu(output)
            for linear_layers_between in self.linear_layers_between:
                output = linear_layers_between(output)
                output = self.leakyrelu(output)

        return self.linear8(output)  # torch.mean(self.linear8(output)).reshape(1)


##### Training procedures #####


def find_loss(y_torch, gen_y_all_torch, z_torch, sigma_w, sigma_u, M):
    n = z_torch.shape[0]
    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=1,
                                    dim=2)
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.abs(y_torch.repeat(1, n) - y_torch.repeat(1, n).T) / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.abs(gen_y_all_torch.repeat(n, 1, 1) - y_torch.repeat(1, n).reshape(n, n, 1)) / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n, 1, 1)

    temp_mx = gen_y_all_torch_rep[:, :, 0].T
    sum_mx = torch.mean(torch.exp(-torch.abs(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1)) / sigma_u), dim=2)

    for i in range(1, M):
        temp_mx = gen_y_all_torch_rep[:, :, i].T
        temp_add_mx = torch.mean(torch.exp(-torch.abs(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1)) / sigma_u), dim=2)
        sum_mx = sum_mx + temp_add_mx

    u_mx_4 = 1 / M * sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4

    FF_mx = u_mx * w_mx * (1 - torch.eye(n).to(device))

    loss = 1 / (n) * torch.sum(FF_mx)
    return loss


def train_ver3(X, Y, Z, X_test, Y_test, Z_test, M,
               noise_dimension, noise_type, G_lr, hidden_layer_size,
               DataLoader, BN_type, ReLU_coef,
               epochs_num=10, sigma_z=1, sigma_x=1, sigma_y=1,
               normal_ini=False,
               lambda_1=1, lambda_2=1, using_Gen='1', wgt_decay=0,
               lambda_3=0, drop_out_p=0.2, M_train=3):
    """
    Train loop for GAN.

    Inputs:
    - X: PyTorch Tensor (sample_size, input_dimension) of training input.
    - Y: PyTorch Tensor (sample_size, output_dimension) of training output.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - noise_type: "normal", "unif" or "Cauchy", giving the reference distribution.
    - D_lr, G_lr: Float giving the learning rate of the discriminator and
      the generator.
    - discriminator_type, generator_type: ("KL", "JS", "WS", "LS"), giving the loss criterion
      of the discriminator and generator, respectively.
    - discriminator_loss, generator_loss: Functions to use for computing the
      generator and discriminator loss, respectively.
    - DataLoader: DataLoader object used to generate training batches.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Scalar giving the coefficient of the Leaky ReLU layer.
    - batch_size: Integer giving the size of batches for each epoch.
    - epochs_num: Number of epochs over the training dataset to use for training.
    - lambda_gp: Float giving the coefficient of gradient penalty for WS.

    Outputs:
    - D: PyTorch Net giving the trained discriminator.
    - G: PyTorch Net giving the trained generator.
    - Output the trained D and G at 250, 500, 750, 1000 epochs.
    """

    input_dimension = Z.shape[1]
    output_dimension_y = Y.shape[1]
    output_dimension_x = X.shape[1]

    if using_Gen == '1':

        G_zy = Generator(input_dimension, output_dimension_y, noise_dimension, hidden_layer_size, BN_type, ReLU_coef,
                         drop_out_p).to(device)
        G_zy_solver = optim.Adam(G_zy.parameters(), lr=G_lr, betas=(0.5, 0.999), weight_decay=wgt_decay)

        G_zx = Generator(input_dimension, output_dimension_x, noise_dimension, hidden_layer_size, BN_type, ReLU_coef,
                         drop_out_p).to(device)
        G_zx_solver = optim.Adam(G_zx.parameters(), lr=G_lr, betas=(0.5, 0.999), weight_decay=wgt_decay)

    elif using_Gen == '2':
        G_zy = Generator_2(input_dimension, output_dimension_y, noise_dimension, hidden_layer_size, BN_type,
                           ReLU_coef).to(device)
        G_zy_solver = optim.Adam(G_zy.parameters(), lr=G_lr, betas=(0.5, 0.999), weight_decay=wgt_decay)

        G_zx = Generator_2(input_dimension, output_dimension_x, noise_dimension, hidden_layer_size, BN_type,
                           ReLU_coef).to(device)
        G_zx_solver = optim.Adam(G_zx.parameters(), lr=G_lr, betas=(0.5, 0.999), weight_decay=wgt_decay)

    iter_count = 0
    G_zy = G_zy.train()
    G_zx = G_zx.train()

    if normal_ini:
        for p in G_zy.parameters():
            p.data = torch.randn(
                p.shape, device=device,
                dtype=torch.float32) / np.sqrt(float(hidden_layer_size * 2))

        for p in G_zx.parameters():
            p.data = torch.randn(
                p.shape, device=device,
                dtype=torch.float32) / np.sqrt(float(hidden_layer_size * 2))

    for epoch in range(epochs_num):
        # print('EPOCH: ', (epoch+1))
        batch_count = 0
        G_zy = G_zy.train()
        G_zx = G_zx.train()
        for X_real, Y_real, Z_real, Z_fake in DataLoader:
            X_real = X_real.to(device)
            Y_real = Y_real.to(device)
            Z_real = Z_real.to(device)
            Z_fake = Z_fake.to(device)

            batch_size = Z_real.shape[0]
            Z_real_repeat = Z_real.repeat(M_train, 1)

            # Generate fake data
            Noise_fake = sample_noise(Z_real_repeat.shape[0], noise_dimension, noise_type, input_var=1.0 / 3.0).to(
                device)
            X_fake = G_zx(torch.cat((Z_real_repeat, Noise_fake), dim=1)).to(device)

            X_fake = X_fake.reshape(batch_size, M_train)

            standardise = False

            if standardise:
                Y_fake = (Y_fake - torch.mean(Y_fake, dim=0, keepdim=True)) / torch.std(Y_fake, dim=0, keepdim=True)
                X_fake = (X_fake - torch.mean(X_fake, dim=0, keepdim=True)) / torch.std(X_fake, dim=0, keepdim=True)
                X_real = (X_real - torch.mean(X_real, dim=0, keepdim=True)) / torch.std(X_real, dim=0, keepdim=True)
                Y_real = (Y_real - torch.mean(Y_real, dim=0, keepdim=True)) / torch.std(Y_real, dim=0, keepdim=True)
                Z_real = (Z_real - torch.mean(Z_real, dim=0, keepdim=True)) / torch.std(Z_real, dim=0, keepdim=True)

            # Generator step
            g_zx_error = None
            G_zx_solver.zero_grad()

            l1_regularization = 0

            for param in G_zx.parameters():
                l1_regularization += torch.linalg.vector_norm(param, ord=1)

            g_zx_error = lambda_1 * find_loss(X_real, X_fake, Z_real, sigma_z, sigma_x,
                                              M_train) + lambda_3 * l1_regularization

            g_zx_error.backward()
            torch.nn.utils.clip_grad_norm_(G_zx.parameters(), max_norm=0.5)
            G_zx_solver.step()

            iter_count += 1
            batch_count += 1

        # G_zx = G_zx.eval()
        # G_zy = G_zy.eval()
        # if ((epoch + 1) % 100 == 0):
        #     dataset_test = DatasetSelect(X_test, Y_test, Z_test)
        #     dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=True)

        #     test_size = Z_test.shape[0]
        #     gen_x_all = torch.zeros(test_size, M)
        #     gen_y_all = torch.zeros(test_size, M)
        #     z_all = torch.zeros(test_size, input_dimension)
        #     x_all = torch.zeros(test_size, output_dimension_y)
        #     y_all = torch.zeros(test_size, output_dimension_x)

        #     cur_itr = 0

        #     for i, (x_test, y_test, z_test) in enumerate(dataloader_test):
        #         z_test_temp = z_test.repeat(M, 1).to(device)
        #         Noise_fake = sample_noise(z_test_temp.size()[0], noise_dimension, noise_type, input_var=1.0 / 3.0).to(
        #             device)
        #         fake_x = G_zx(torch.cat((z_test_temp, Noise_fake), dim=1)).reshape(1, -1)

        #         Noise_fake = sample_noise(z_test_temp.size()[0], noise_dimension, noise_type, input_var=1.0 / 3.0).to(
        #             device)
        #         fake_y = G_zy(torch.cat((z_test_temp, Noise_fake), dim=1)).reshape(1, -1)

        #         gen_x_all[cur_itr, :] = fake_x.detach().reshape(-1)
        #         gen_y_all[cur_itr, :] = fake_y.detach().reshape(-1)
        #         x_all[cur_itr, :] = x_test
        #         y_all[cur_itr, :] = y_test
        #         z_all[cur_itr, :] = z_test
        #         cur_itr = cur_itr + 1
        #     gen_x_mean = torch.mean(gen_x_all, dim=1).reshape(-1, 1)
        #     gen_y_mean = torch.mean(gen_y_all, dim=1).reshape(-1, 1)

        #     mse_x = torch.mean((gen_x_mean - x_all) ** 2).item()
        #     mse_y = torch.mean((gen_y_mean - y_all) ** 2).item()

        #     print(f'Epoch [{epoch + 1}/{epochs_num}], test MSE x [{mse_x}], MSE y [{mse_y}]')

        #     dataset_test = DatasetSelect(X, Y, Z)
        #     dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=True)

        #     test_size = Z_test.shape[0]
        #     gen_x_all = torch.zeros(test_size, M)
        #     gen_y_all = torch.zeros(test_size, M)
        #     z_all = torch.zeros(test_size, input_dimension)
        #     x_all = torch.zeros(test_size, output_dimension_y)
        #     y_all = torch.zeros(test_size, output_dimension_x)

        #     cur_itr = 0

        #     for i, (x_test, y_test, z_test) in enumerate(dataloader_test):
        #         z_test_temp = z_test.repeat(M, 1).to(device)
        #         Noise_fake = sample_noise(z_test_temp.size()[0], noise_dimension, noise_type, input_var=1.0 / 3.0).to(
        #             device)
        #         fake_x = G_zx(torch.cat((z_test_temp, Noise_fake), dim=1)).reshape(1, -1)

        #         Noise_fake = sample_noise(z_test_temp.size()[0], noise_dimension, noise_type, input_var=1.0 / 3.0).to(
        #             device)
        #         fake_y = G_zy(torch.cat((z_test_temp, Noise_fake), dim=1)).reshape(1, -1)

        #         gen_x_all[cur_itr, :] = fake_x.detach().reshape(-1)
        #         gen_y_all[cur_itr, :] = fake_y.detach().reshape(-1)
        #         x_all[cur_itr, :] = x_test
        #         y_all[cur_itr, :] = y_test
        #         z_all[cur_itr, :] = z_test
        #         cur_itr = cur_itr + 1
        #     gen_x_mean = torch.mean(gen_x_all, dim=1).reshape(-1, 1)
        #     gen_y_mean = torch.mean(gen_y_all, dim=1).reshape(-1, 1)

        #     mse_x = torch.mean((gen_x_mean - x_all) ** 2).item()
        #     mse_y = torch.mean((gen_y_mean - y_all) ** 2).item()

        #     print(f'Epoch [{epoch + 1}/{epochs_num}], train MSE x [{mse_x}], MSE y [{mse_y}]')
    return G_zy, G_zx

def rdc(x, y, f=np.sin, k=20, s=1 / 6., n=1):
    """
    Computes the Randomized Dependence Coefficient
    x,y: numpy arrays 1-D or 2-D
         If 1-D, size (samples,)
         If 2-D, size (samples, variables)
    f:   function to use for random projection
    k:   number of random projections to use
    s:   scale parameter
    n:   number of times to compute the RDC and
         return the median (for stability)
    According to the paper, the coefficient should be relatively insensitive to
    the settings of the f, k, and s parameters.

    Source: https://github.com/garydoranjr/rdc
    """
    x = tf.reshape(x, shape=(x.shape[0], ))
    y = tf.reshape(y, shape=(y.shape[0], ))

    if n > 1:
        values = []
        for i in range(n):
            try:
                values.append(rdc(x, y, f, k, s, 1))
            except np.linalg.linalg.LinAlgError:
                pass
        return np.median(values)

    if len(x.shape) == 1: x = tf.reshape(x, shape=(-1, 1))
    if len(y.shape) == 1: y = tf.reshape(y, shape=(-1, 1))

    # Copula Transformation
    cx = np.column_stack([rankdata(xc, method='ordinal') for xc in np.transpose(x)]) / float(x.shape[0])
    cy = np.column_stack([rankdata(yc, method='ordinal') for yc in np.transpose(y)]) / float(y.shape[0])

    # Add a vector of ones so that w.x + b is just a dot product
    O = np.ones(cx.shape[0])
    X = np.column_stack([cx, O])
    Y = np.column_stack([cy, O])

    # Random linear projections
    Rx = (s / X.shape[1]) * np.random.randn(X.shape[1], k)
    Ry = (s / Y.shape[1]) * np.random.randn(Y.shape[1], k)
    X = np.dot(X, Rx)
    Y = np.dot(Y, Ry)

    # Apply non-linear function to random projections
    fX = f(X)
    fY = f(Y)

    # Compute full covariance matrix
    C = np.cov(np.hstack([fX, fY]).T)

    # Due to numerical issues, if k is too large,
    # then rank(fX) < k or rank(fY) < k, so we need
    # to find the largest k such that the eigenvalues
    # (canonical correlations) are real-valued
    k0 = k
    lb = 1
    ub = k
    while True:
        # Compute canonical correlations
        Cxx = C[:k, :k]
        Cyy = C[k0:k0 + k, k0:k0 + k]
        Cxy = C[:k, k0:k0 + k]
        Cyx = C[k0:k0 + k, :k]

        eigs = np.linalg.eigvals(np.dot(np.dot(np.linalg.pinv(Cxx), Cxy),
                                        np.dot(np.linalg.pinv(Cyy), Cyx)))

        # Binary search if k is too large
        if not (np.all(np.isreal(eigs)) and
                0 <= np.min(eigs) and
                np.max(eigs) <= 1):
            ub -= 1
            k = (ub + lb) // 2
            continue
        if lb == ub: break
        lb = k
        if ub == lb + 1:
            k = ub
        else:
            k = (ub + lb) // 2

    return np.sqrt(np.max(eigs))



def GCIT(Ax, Ay, n=500, z_dim=100, simulation='type1error', batch_size=64, epochs_num=1000,
         nstd=1.0, z_dist='gaussian', x_dims=1, y_dims=1, a_x=0.05, M=500, k=2, boot_num=1000,
         noise_dimension=10, hidden_layer_size=512, normal_ini=False, preprocess='normalize',
         G_lr=1e-5, using_orcale=False, lambda_1=1, lambda_2=1, using_Gen='1',
         boor_rv_type="gaussian", wgt_decay=0, lambda_3=1, drop_out_p=0.2, exp_num=0, M_train=3):
    if simulation == 'type1error':
        # generate samples x, y, z under null hypothesis - x and y are conditional independent
        sim_x, sim_y, sim_z = generate_samples_random(Ax, Ay, size=n, sType='CI', dx=x_dims, dy=y_dims, dz=z_dim,
                                                      nstd=nstd, alpha_x=a_x,
                                                      dist_z=z_dist, preprocess=preprocess)

    elif simulation == 'power':
        # generate samples x, y, z under alternative hypothesis - x and y are dependent
        sim_x, sim_y, sim_z = generate_samples_random(Ax, Ay, size=n, sType='dependent', dx=x_dims, dy=y_dims, dz=z_dim,
                                                      nstd=nstd,
                                                      alpha_x=a_x, dist_z=z_dist, preprocess=preprocess)
    else:
        raise ValueError('Test does not exist.')

    x, y, z = sim_x, sim_y, sim_z

    # w_mx = torch.linalg.vector_norm(z.repeat(n,1,1) - torch.swapaxes(z.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_w_train = 1.0  # torch.median(w_mx).item()

    # u_mx = torch.abs(y.repeat(1, n) - y.repeat(1, n).T)
    sigma_u_train = 1.0  # torch.median(u_mx).item()

    # v_mx = torch.abs(x.repeat(1, n) - x.repeat(1, n).T)
    sigma_v_train = 1.0  # torch.median(v_mx).item()

    split_factor = 2/3.

    X_test, Y_test, Z_test = x[int( n * split_factor ):int(n)], y[int(n * split_factor ):int(n)], z[int(n * split_factor ):int(n)]
    X_train, Y_train, Z_train = x[0:int( n * split_factor )], y[0:int(n * split_factor )], z[0:int(n * split_factor )]

    if (k == 1):
        X_train, Y_train, Z_train = X_test, Y_test, Z_test

    train_xyz = DatasetSelect_GAN(X_train, Y_train, Z_train, batch_size)
    DataLoader_xyz = torch.utils.data.DataLoader(train_xyz, batch_size=batch_size, shuffle=True)
    G_zy, G_zx = train_ver3(X=X_train, Y=Y_train, Z=Z_train, M=M,
                            X_test=X_test, Y_test=Y_test, Z_test=Z_test,
                            noise_dimension=noise_dimension, noise_type="normal",
                            G_lr=G_lr, hidden_layer_size=hidden_layer_size,
                            DataLoader=DataLoader_xyz, BN_type=False, ReLU_coef=0.1,
                            epochs_num=epochs_num, sigma_z=sigma_w_train, sigma_x=sigma_v_train,
                            sigma_y=sigma_u_train,
                            normal_ini=normal_ini, lambda_1=lambda_1, lambda_2=lambda_2,
                            using_Gen=using_Gen, wgt_decay=wgt_decay, lambda_3=lambda_3,
                            drop_out_p=drop_out_p, M_train=M_train)

    G_zx = G_zx.eval()
    G_zy = G_zy.eval()

    test_samples = 1000
    rho = []
    test_size = Z_test.shape[0]
    y_test = Y_test.numpy()
    x_test = X_test.numpy()

    for i in range(test_samples):
        Z_test = Z_test.to(device)
        Noise_fake = sample_noise(Z_test.shape[0], noise_dimension, "normal", input_var=1.0 / 3.0).to(device)
        fake_data = G_zx(torch.cat((Z_test, Noise_fake), dim=1)).to(device)
        fake_data = fake_data.reshape(test_size, 1).cpu().detach().numpy()
        rho.append(rdc(fake_data, y_test))

    rho = tf.stack(rho)
    stat_real = rdc(x_test, y_test)
    # p-value computation as a two-sided test
    p_value = min(tf.reduce_sum(tf.cast(rho < stat_real, tf.float32)) / test_samples,
                  tf.reduce_sum(tf.cast(rho > stat_real, tf.float32)) / test_samples)

    return p_value


def run_experiment(params):
    test = params["test"]
    sample_size = params["sample_size"]
    batch_size = params["batch_size"]
    z_dim = params["z_dim"]
    dx = params["dx"]
    dy = params["dy"]
    n_test = params["n_test"]
    epochs_num = params["epochs_num"]
    eps_std = params["eps_std"]
    dist_z = params["dist_z"]
    alpha_x = params["alpha_x"]
    m_value = params["m_value"]
    k_value = params["k_value"]
    j_value = params["j_value"]
    noise_dimension = params["noise_dimension"]
    hidden_layer_size = params["hidden_layer_size"]
    normal_ini = params["normal_ini"]
    preprocess = params["preprocess"]
    G_lr = params["G_lr"]
    alpha = params["alpha"]
    alpha1 = params["alpha1"]
    set_seeds = params["set_seeds"]
    using_orcale = params["using_orcale"]
    lambda_1 = params["lambda_1"]
    lambda_2 = params["lambda_2"]
    using_Gen = params["using_Gen"]
    boor_rv_type = params["boor_rv_type"]
    wgt_decay = params["wgt_decay"]
    lambda_3 = params["lambda_3"]
    drop_out_p = params["drop_out_p"]
    M_train = params["M_train"]

    np.random.seed(set_seeds)
    tf.random.set_seed(set_seeds)
    torch.manual_seed(set_seeds)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(set_seeds)

    a_f = torch.rand((z_dim, dx))
    l1_norm_a_f = torch.linalg.vector_norm(a_f, ord=1)
    Ax = a_f / l1_norm_a_f
    a_g = torch.rand((z_dim, dy))
    l1_norm_a_g = torch.linalg.vector_norm(a_g, ord=1)
    Ay = a_g / l1_norm_a_g

    p_values = np.array([])
    test_count = 0
    if test == 'type1error':
        for n in range(n_test):
            start_time = datetime.now()

            p_value = GCIT(Ax=Ax, Ay=Ay, n=sample_size, z_dim=z_dim, simulation=test, batch_size=batch_size,
                           epochs_num=epochs_num,
                           nstd=eps_std, z_dist=dist_z, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                           k=k_value, boot_num=j_value,
                           noise_dimension=noise_dimension, hidden_layer_size=hidden_layer_size, normal_ini=normal_ini,
                           preprocess=preprocess, G_lr=G_lr, using_orcale=using_orcale,
                           lambda_1=lambda_1, lambda_2=lambda_2, using_Gen=using_Gen, boor_rv_type=boor_rv_type,
                           wgt_decay=wgt_decay, lambda_3=lambda_3, drop_out_p=drop_out_p, exp_num=n + 1,
                           M_train=M_train)
            test_count += 1
            print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha / 2.0 for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 / 2.0 for pval in p_values]
            final_result1 = np.mean(fp1)

            print('The stat is {}'.format(p_value))
            print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

            final_result_list = np.array([final_result])
            final_result1_list = np.array([final_result1])
    if test == 'power':
        for n in range(n_test):
            start_time = datetime.now()

            p_value = GCIT(Ax=Ax, Ay=Ay, n=sample_size, z_dim=z_dim, simulation=test, batch_size=batch_size,
                           epochs_num=epochs_num,
                           nstd=eps_std, z_dist=dist_z, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                           k=k_value, boot_num=j_value,
                           noise_dimension=noise_dimension, hidden_layer_size=hidden_layer_size, normal_ini=normal_ini,
                           preprocess=preprocess, G_lr=G_lr, using_orcale=using_orcale,
                           lambda_1=lambda_1, lambda_2=lambda_2, using_Gen=using_Gen, boor_rv_type=boor_rv_type,
                           wgt_decay=wgt_decay, lambda_3=lambda_3, drop_out_p=drop_out_p, exp_num=n + 1,
                           M_train=M_train)

            test_count += 1
            print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha / 2.0 for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 / 2.0 for pval in p_values]
            final_result1 = np.mean(fp1)

            print('The stat is {}'.format(p_value))
            print('Power: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            print('Power: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))

            final_result_list = np.array([final_result])
            final_result1_list = np.array([final_result1])
    return p_values


run_experiment(param)


--- The 1'th iteration take 0:00:14.521534 seconds ---
The stat is 0.2809999883174896
Type 1 error: 0.0 for z dimension 1 with significance level 0.1
Type 1 error: 0.0 for z dimension 1 with significance level 0.05
--- The 2'th iteration take 0:00:11.209522 seconds ---
The stat is 0.22300000488758087
Type 1 error: 0.0 for z dimension 1 with significance level 0.1
Type 1 error: 0.0 for z dimension 1 with significance level 0.05
--- The 3'th iteration take 0:00:11.259131 seconds ---
The stat is 0.49000000953674316
Type 1 error: 0.0 for z dimension 1 with significance level 0.1
Type 1 error: 0.0 for z dimension 1 with significance level 0.05
--- The 4'th iteration take 0:00:11.272632 seconds ---
The stat is 0.019999999552965164
Type 1 error: 0.25 for z dimension 1 with significance level 0.1
Type 1 error: 0.25 for z dimension 1 with significance level 0.05
--- The 5'th iteration take 0:00:11.493549 seconds ---
The stat is 0.26100000739097595
Type 1 error: 0.2 for z dimension 1 with signif

array([0.28099999, 0.223     , 0.49000001, 0.02      , 0.26100001,
       0.34      , 0.48500001, 0.079     , 0.28299999, 0.156     ,
       0.12800001, 0.41299999, 0.233     , 0.435     , 0.31299999,
       0.051     , 0.059     , 0.032     , 0.22      , 0.43000001,
       0.48100001, 0.30500001, 0.062     , 0.069     , 0.226     ,
       0.36000001, 0.36000001, 0.43900001, 0.132     , 0.296     ,
       0.31600001, 0.147     , 0.117     , 0.012     , 0.13      ,
       0.454     , 0.29300001, 0.39899999, 0.074     , 0.32800001,
       0.22      , 0.461     , 0.49000001, 0.06      , 0.498     ,
       0.20100001, 0.373     , 0.11      , 0.257     , 0.101     ])