In [None]:
latent_space_dim = 15 # [3, 4, 5, 6, 7, 8, 10, 12, 15, 16, 17, 20, 30, 40, 100]

import torch
import torch.nn as nn
from torch.optim import SGD
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

import torch.distributions as TD
from zmq import device
import torch.optim as optim
from datetime import datetime
import functools
from tqdm import tqdm

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

import torch.nn as nn

set_seed = 42

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Utilites related to Sinkhorn computations and training for TensorFlow 2.0
import tensorflow as tf
import logging
import tensorflow_probability as tfp
from sklearn.metrics.pairwise import rbf_kernel
from scipy.stats import rankdata, ks_2samp, wilcoxon
from sklearn.model_selection import KFold
from datetime import datetime
import decimal
import torch
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
import gc # Garbage Collector

tf.random.set_seed(set_seed)
np.random.seed(set_seed)
torch.manual_seed(set_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(set_seed)

logging.getLogger('tensorflow').disabled = True
tf.keras.backend.set_floatx('float32')


tf.random.set_seed(set_seed)
np.random.seed(set_seed)
torch.manual_seed(set_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(set_seed)

class DatasetSelect_GAN(torch.utils.data.Dataset):
    def __init__(self, X, Y, Z, batch_size):
        self.X_real = X
        self.Y_real = Y
        self.Z_real = Z
        self.batch_size = batch_size
        self.sample_size = X.shape[0]

    def __len__(self):
        return self.sample_size

    def __getitem__(self, index):
        return self.X_real[index], self.Y_real[index], self.Z_real[index], self.Z_real[(self.batch_size+index) % self.sample_size]

class CharacteristicFunction:
    '''
    class to construct a function that represents the characteristic function
    '''

    def __init__(self, size, x_dims, z_dims, test_size):
        self.n_samples = size
        self.hidden_dims = 20 # default: 20
        self.test_size = test_size

        self.input_dim = z_dims + x_dims
        self.z_dims = z_dims
        self.x_dims = x_dims
        self.input_shape1x = [self.x_dims, self.hidden_dims]
        self.input_shape1z = [self.z_dims, self.hidden_dims]
        self.input_shape1 = [self.input_dim, self.hidden_dims]
        self.input_shape2 = [self.hidden_dims, 1]

        self.w1x = self.xavier_var_creator(self.input_shape1x)
        self.b1 = tf.squeeze(self.xavier_var_creator([self.hidden_dims, 1]))

        self.w2 = self.xavier_var_creator(self.input_shape2)
        self.b2 = tf.Variable(tf.zeros(self.input_shape2[1], tf.float64))

    def xavier_var_creator(self, input_shape):
        xavier_stddev = tf.sqrt(2.0 / (input_shape[0]))
        init = tf.random.normal(shape=input_shape, mean=0.0, stddev=xavier_stddev)
        init = tf.cast(init, tf.float64)
        var = tf.Variable(init, shape=tf.TensorShape(input_shape), trainable=True)
        return var

    def update(self):
        self.w1x = self.xavier_var_creator(self.input_shape1x)
        self.b1 = tf.squeeze(self.xavier_var_creator([self.hidden_dims, 1]))
        self.w2 = self.xavier_var_creator(self.input_shape2)

    def call(self, x, z):
        # inputs are concatenations of z and v
        x = tf.reshape(tensor=x, shape=[self.test_size, -1, self.x_dims])
        z = tf.reshape(tensor=z, shape=[self.test_size, -1, self.z_dims])
        # we asssume parameter b for z to be 0
        h1 = tf.nn.sigmoid(tf.matmul(x, self.w1x) + self.b1)
        out = tf.nn.sigmoid(tf.matmul(h1, self.w2))
        return out

tf.random.set_seed(set_seed)
np.random.seed(set_seed)
torch.manual_seed(set_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(set_seed)

#
# test statistics for DGCIT
#


def t_and_sigma(psy_x_i, psy_y_i, phi_x_i, phi_y_i):
    b, n = psy_x_i.shape
    x_mtx = phi_x_i - psy_x_i
    y_mtx = phi_y_i - psy_y_i
    matrix = tf.reshape(x_mtx[None, :, :] * y_mtx[:, None, :], [-1, n])
    t_b = tf.reduce_sum(matrix, axis=1) / tf.cast(n, tf.float64)
    t_b = tf.expand_dims(t_b, axis=1)

    crit_matrix = matrix - t_b
    std_b = tf.sqrt(tf.reduce_sum(crit_matrix**2, axis=1) / tf.cast(n-1, tf.float64))
    return t_b, std_b


def test_statistics(psy_x_i, psy_y_i, phi_x_i, phi_y_i, t_b, std_b, j):
    b, n = psy_x_i.shape
    x_mtx = phi_x_i - psy_x_i
    y_mtx = phi_y_i - psy_y_i
    matrix = tf.reshape(x_mtx[None, :, :] * y_mtx[:, None, :], [-1, n])
    crit_matrix = matrix - t_b
    test_stat = tf.reduce_max(tf.abs(tf.sqrt(tf.cast(n, tf.float64)) * tf.squeeze(t_b) / std_b))

    sig = tf.reduce_sum(crit_matrix[None, :, :] * crit_matrix[:, None, :], axis=2)
    coef = std_b[None, :] * std_b[:, None] * tf.cast(n-1, tf.float64)
    sig_xy = sig / coef

    eigenvalues, eigenvectors = tf.linalg.eigh(sig_xy)
    base = tf.zeros_like(eigenvectors)
    eig_vals = tf.sqrt(eigenvalues + 1e-12)
    lamda = tf.linalg.set_diag(base, eig_vals)
    sig_sqrt = tf.matmul(tf.matmul(eigenvectors, lamda), tf.linalg.inv(eigenvectors))

    z_dist = tfp.distributions.Normal(0.0, scale=1.0)
    z_samples = z_dist.sample([b*b, j])
    z_samples = tf.cast(z_samples, tf.float64)
    vals = tf.matmul(sig_sqrt, z_samples)
    t_j = tf.reduce_max(vals, axis=0)
    return test_stat, t_j

#
# Training algorithm for DGCIT
#

def dgcit(x, y, z, generator_x, generator_y, test_size=500, z_dim=100,
      x_dims=1, y_dims=1, M=100, k=1,
      b=30, j=1000):

    noise_dimension_image = 50
    noise_dimension_label = 1
    input_noise_type = "normal"

    psy_x_all = []
    phi_x_all = []
    psy_y_all = []
    phi_y_all = []
    test_samples = b

    psy_x_b = []
    phi_x_b = []
    psy_y_b = []
    phi_y_b = []

    x_samples = []
    y_samples = []
    z_input = []
    x_input = []
    y_input = []

    test_xyz = DatasetSelect_GAN(x, y, z, 1)
    testing_dataset = torch.utils.data.DataLoader(test_xyz, batch_size=1, shuffle=False)

    G_image = generator_x.eval()
    G_label = generator_y.eval()

    for test_x, test_y, test_z, Z_fake in testing_dataset:

        Z_test_repeat = test_z.repeat(M,1).to(device).detach()

        # Generate fake data
        Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension_label, input_noise_type, input_var = 1.0/3.0).to(device)
        with torch.no_grad():
            gen_y = G_label(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device).detach()

        Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension_image, input_noise_type, input_var = 1.0/3.0).to(device)
        with torch.no_grad():
            gen_x = G_image(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device).detach()

        gen_x_all = gen_x.reshape(M, x_dims).cpu().detach().numpy()
        gen_y_all = gen_y.reshape(M, y_dims).cpu().detach().numpy()

        fake_x = tf.convert_to_tensor(gen_x_all)
        fake_y = tf.convert_to_tensor(gen_y_all)

        test_z = tf.convert_to_tensor(test_z.cpu().detach().numpy())
        test_x = tf.convert_to_tensor(test_x.cpu().detach().numpy())
        test_y = tf.convert_to_tensor(test_y.cpu().detach().numpy())

        test_z = tf.reshape(test_z, (1, z_dim))
        test_y = tf.reshape(test_y, (1, y_dims))
        test_x = tf.reshape(test_x, (1, x_dims))

        fake_x = tf.cast(fake_x, tf.float64)
        fake_y = tf.cast(fake_y, tf.float64)
        test_z = tf.cast(test_z, tf.float64)
        test_x = tf.cast(test_x, tf.float64)
        test_y = tf.cast(test_y, tf.float64)

        x_samples.append(fake_x)
        y_samples.append(fake_y)
        z_input.append(test_z)
        x_input.append(test_x)
        y_input.append(test_y)

    # give the five variables: x_samples, y_samples, z_input, x_input, y_input
    # they are lists with length = test_size
    # x_samples: [x_sampled1, x_sampled2, ... , x_sampled_test_size] x_sampled1 has shape [M, dx]
    # y_samples: [y_sampled1, y_sampled2, ... , y_sampled_test_size] y_sampled1 has shape [M, dy]
    # z_input = [z1, z2, ... , z_test_size]
    # x_input = [x1, x2, ... , x_test_size]
    # y_input = [y1, y2, ... , y_test_size]


    standardise = True

    if standardise:
        x_samples = (x_samples - tf.reduce_mean(x_samples)) / tf.math.reduce_std(x_samples)
        y_samples = (y_samples - tf.reduce_mean(y_samples)) / tf.math.reduce_std(y_samples)
        x_input = (x_input - tf.reduce_mean(x_input)) / tf.math.reduce_std(x_input)
        y_input = (y_input - tf.reduce_mean(y_input)) / tf.math.reduce_std(y_input)
        z_input = (z_input - tf.reduce_mean(z_input)) / tf.math.reduce_std(z_input)

    f1 = CharacteristicFunction(M, x_dims, z_dim, test_size)
    f2 = CharacteristicFunction(M, y_dims, z_dim, test_size)
    for i in range(test_samples):
        phi_x = tf.reduce_mean(f1.call(x_samples, z_input), axis=1)
        phi_y = tf.reduce_mean(f2.call(y_samples, z_input), axis=1)
        psy_x = tf.squeeze(f1.call(x_input, z_input))
        psy_y = tf.squeeze(f2.call(y_input, z_input))

        psy_x_b.append(psy_x)
        phi_x_b.append(phi_x)
        psy_y_b.append(psy_y)
        phi_y_b.append(phi_y)
        f1.update()
        f2.update()

    psy_x_all.append(psy_x_b)
    phi_x_all.append(phi_x_b)
    psy_y_all.append(psy_y_b)
    phi_y_all.append(phi_y_b)

    # reshape
    psy_x_all = tf.reshape(psy_x_all, [k, test_samples, test_size])
    psy_y_all = tf.reshape(psy_y_all, [k, test_samples, test_size])
    phi_x_all = tf.reshape(phi_x_all, [k, test_samples, test_size])
    phi_y_all = tf.reshape(phi_y_all, [k, test_samples, test_size])

    t_b = 0.0
    std_b = 0.0
    for n in range(k):
        t, std = t_and_sigma(psy_x_all[n], psy_y_all[n], phi_x_all[n], phi_y_all[n])
        t_b += t
        std_b += std
    t_b = t_b / tf.cast(k, tf.float64)
    std_b = std_b / tf.cast(k, tf.float64)

    psy_x_all = tf.transpose(psy_x_all, (1, 0, 2))
    psy_y_all = tf.transpose(psy_y_all, (1, 0, 2))
    phi_x_all = tf.transpose(phi_x_all, (1, 0, 2))
    phi_y_all = tf.transpose(phi_y_all, (1, 0, 2))

    psy_x_all = tf.reshape(psy_x_all, [test_samples, test_size*k])
    psy_y_all = tf.reshape(psy_y_all, [test_samples, test_size*k])
    phi_x_all = tf.reshape(phi_x_all, [test_samples, test_size*k])
    phi_y_all = tf.reshape(phi_y_all, [test_samples, test_size*k])

    stat, critical_vals = test_statistics(psy_x_all, psy_y_all, phi_x_all, phi_y_all, t_b, std_b, j)
    comparison = [c > stat or c == stat for c in critical_vals]
    comparison = np.reshape(comparison, (-1,))
    p_value = np.sum(comparison.astype(np.float32)) / j

    return p_value

class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)


class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()

    def forward(self, x):
        return x[:, :, :28, :28]


class AutoEncoder(nn.Module):
    def __init__(self, d_l):
        super().__init__()

        self.encoder = nn.Sequential( #784
                nn.Conv2d(1, 32, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.Flatten(),
                nn.Linear(3136, d_l)
        )
        self.decoder = nn.Sequential(
                torch.nn.Linear(d_l, 3136),
                Reshape(-1, 64, 7, 7),
                nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1), # 64x7x7 -> 64x7x7
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1), # 64x7x7 -> 64x13x13
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0), # 64x13x13 -> 32x27x27
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0), # 32x27x27 -> 1x29x29
                Trim(),  # 1x29x29 -> 1x28x28
                nn.Sigmoid()
                )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    def get_latent_space(self, x):
        return self.encoder(x)

    def get_decoded_images(self, x):
        return self.decoder(x)

class Generator_image(torch.nn.Module):
    """
    Specify the neural network architecture of the Generator.

    Here, we consider a FNN with a fully connected hidden layer with a width of 50,
    which is followed by a Leaky ReLU activation. The coefficient of Leaky ReLU needs to be
    specified. Batch normalization may be added prior to the activation function.
    The output layer a fully connected layer without activation.

    Inputs:
    - input_dimension: Integer giving the dimension of input X.
    - output_dimension: Integer giving the dimension of output Y.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Scalar giving the coefficient of the Leaky ReLU layer.

    Returns:
    - x: PyTorch Tensor containing the (output_dimension,) output of the discriminator.
    """

    def __init__(self, input_dimension, noise_dimension):
      super(Generator_image, self).__init__()
      self.flatten = nn.Flatten()
      self.decoder = nn.Sequential(
              torch.nn.Linear(input_dimension + noise_dimension, 3136),
              Reshape(-1, 64, 7, 7),
              nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
              nn.LeakyReLU(0.01),
              nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
              nn.LeakyReLU(0.01),
              nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0),
              nn.LeakyReLU(0.01),
              nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0),
              Trim(),  # 1x29x29 -> 1x28x28
              nn.Sigmoid()
              )

    def forward(self, x):
      x = self.decoder(x)
      x = self.flatten(x)# 1x28x28 -> 1x784
      return x

class Generator(torch.nn.Module):
    """
    Specify the neural network architecture of the Generator.

    Here, we consider a FNN with a fully connected hidden layer with a width of 50,
    which is followed by a Leaky ReLU activation. The coefficient of Leaky ReLU needs to be
    specified. Batch normalization may be added prior to the activation function.
    The output layer a fully connected layer without activation.

    Inputs:
    - input_dimension: Integer giving the dimension of input X.
    - output_dimension: Integer giving the dimension of output Y.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Scalar giving the coefficient of the Leaky ReLU layer.

    Returns:
    - x: PyTorch Tensor containing the (output_dimension,) output of the discriminator.
    """

    def __init__(self, input_dimension, output_dimension, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p,
                drop_input = False):
      super(Generator, self).__init__()
      self.BN_type = BN_type
      self.ReLU_coef = ReLU_coef
      self.fc1 = torch.nn.Linear(input_dimension + noise_dimension, hidden_layer_size, bias=True)
      if BN_type:
        self.BN1 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN2 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN3 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
      self.leakyReLU1 = torch.nn.LeakyReLU(ReLU_coef)
      self.fc2 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc3 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc_last = torch.nn.Linear(hidden_layer_size, output_dimension, bias=True)
      self.sigmoid = torch.nn.Sigmoid()
      self.drop_out0 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out1 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out2 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out3 = torch.nn.Dropout(p=drop_out_p)
      self.drop_input = drop_input
      self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
      if self.BN_type:
        if self.drop_input:
            x = self.drop_out0(x)
        x = self.drop_out1(self.leakyReLU1(self.BN1(self.fc1(x))))
        x = self.drop_out2(self.leakyReLU1(self.BN2(self.fc2(x))))
        # x = self.drop_out3(self.leakyReLU1(self.BN3(self.fc3(x))))
        x = self.fc_last(x)
        x = self.softmax(x)
      else:
        if self.drop_input:
            x = self.drop_out0(x)
        x = self.drop_out1(self.leakyReLU1(self.fc1(x)))
        x = self.drop_out2(self.leakyReLU1(self.fc2(x)))
        # x = self.drop_out3(self.leakyReLU1(self.fc3(x)))
        x = self.fc_last(x)
        # x = self.sigmoid(x)
        x = self.softmax(x)

      return x


##### Auxilliary functions #####

def sample_noise(sample_size, noise_dimension, noise_type, input_var):
    """
    Generate a PyTorch Tensor of random noise from the specified reference distribution.

    Input:
    - sample_size: the sample size of noise to generate.
    - noise_dimension: the dimension of noise to generate.
    - noise_type: "normal", "unif" or "Cauchy", giving the reference distribution.

    Output:
    - A PyTorch Tensor of shape (sample_size, noise_dimension).
    """

    if (noise_type == "normal"):
      noise_generator = TD.MultivariateNormal(
        torch.zeros(noise_dimension).to(device), input_var * torch.eye(noise_dimension).to(device))

      Z = noise_generator.sample((sample_size,))
    if (noise_type == "unif"):
      Z = torch.rand(sample_size, noise_dimension)
    if (noise_type == "Cauchy"):
      Z = TD.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])).sample((sample_size, noise_dimension)).squeeze(2)

    return Z

def get_p_value_stat_1(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, sigma_w, sigma_u=1,
            sigma_v=1, boor_rv_type="gaussian", dy_g=10, dx_g = 28*28):

    w_mx = torch.zeros(n, n).to(device)

    for i in range(n):
        w_mx[i,:] = torch.linalg.vector_norm(z_torch[i].reshape(1,-1) - z_torch, ord = 1, dim = 1)

    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_temp = torch.zeros(n, n).to(device)

    for i in range(n):
        u_mx_temp[i,:] = torch.linalg.vector_norm(y_torch[i].reshape(1,-1) - y_torch, ord = 1, dim = 1)

    u_mx_1 = torch.exp(-u_mx_temp / sigma_u)

    u_mx_temp_2 = torch.zeros(n, n, M).to(device)
    for i in range(n):
        for j in range(n):
            u_mx_temp_2[i,j,:] = torch.linalg.vector_norm(y_torch[i].reshape(1,-1) - gen_y_all_torch[j,], ord = 1, dim = 1)

    u_mx_2 = torch.mean( torch.exp(-u_mx_temp_2 / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    sum_mx_temp = torch.zeros(n, n, M).to(device)
    for i in range(n):
        for j in range(n):
            sum_mx_temp[i,j,:] = torch.linalg.vector_norm(gen_y_all_torch[j,:,:].reshape(1,M,dy_g) - gen_y_all_torch[i,0,:].reshape(1,1,dy_g), ord = 1, dim = 2)

    sum_mx = torch.mean(torch.exp(-sum_mx_temp/ sigma_u), dim=2)

    v_mx_temp = torch.zeros(n, n).to(device)

    for i in range(n):
        v_mx_temp[i,:] = torch.linalg.vector_norm(x_torch[i].reshape(1,-1) - x_torch, ord = 1, dim = 1)

    v_mx_1 = torch.exp(-v_mx_temp / sigma_v)

    v_mx_temp_2 = torch.zeros(n, n, M).to(device)
    for i in range(n):
        for j in range(n):
            v_mx_temp_2[i,j,:] = torch.linalg.vector_norm(x_torch[i].reshape(1,-1) - gen_x_all_torch[j,], ord = 1, dim = 1)

    v_mx_2 = torch.mean( torch.exp(-v_mx_temp_2 / sigma_v), dim=2)
    v_mx_3 = v_mx_2.T

    sum2_mx_temp = torch.zeros(n, n, M).to(device)
    for i in range(n):
        for j in range(n):
            sum2_mx_temp[i,j,:] = torch.linalg.vector_norm(gen_x_all_torch[j,:,:].reshape(1,M,dx_g) - gen_x_all_torch[i,0,:].reshape(1,1,dx_g), ord = 1, dim = 2)

    sum2_mx = torch.mean(torch.exp(-sum2_mx_temp/ sigma_v), dim=2)

    for k in tqdm(range(1, M)):
        sum_mx_temp = torch.zeros(n, n, M).to(device)
        sum2_mx_temp = torch.zeros(n, n, M).to(device)
        for i in range(n):
            for j in range(n):
                sum_mx_temp[i,j,:] = torch.linalg.vector_norm(gen_y_all_torch[j,:,:].reshape(1,M,dy_g) - gen_y_all_torch[i,k,:].reshape(1,1,dy_g), ord = 1, dim = 2)
                sum2_mx_temp[i,j,:] = torch.linalg.vector_norm(gen_x_all_torch[j,:,:].reshape(1,M,dx_g) - gen_x_all_torch[i,k,:].reshape(1,1,dx_g), ord = 1, dim = 2)

        temp_add_mx = torch.mean(torch.exp(-sum_mx_temp/ sigma_u), dim=2)
        temp2_add_mx = torch.mean(torch.exp(-sum2_mx_temp/ sigma_v), dim=2)
        sum_mx = sum_mx + temp_add_mx
        sum2_mx = sum2_mx + temp2_add_mx

    u_mx_4 = 1 / M * sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4
    v_mx_4 = 1 / M * sum2_mx
    v_mx = v_mx_1 - v_mx_2 - v_mx_3 + v_mx_4

    FF_mx = u_mx * v_mx * w_mx * (1 - torch.eye(n).to(device))

    stat = 1 / (n - 1) * torch.sum(FF_mx).item()

    boottemp = np.array([])
    torch.manual_seed(42)
    if boor_rv_type == "rademacher":
        eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
        eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
        random_mx = torch.matmul(eboot[:, bb].reshape(-1, 1), eboot[:, bb].reshape(-1, 1).T)
        bootmatrix = FF_mx * random_mx
        stat_boot = 1 / (n - 1) * torch.sum(bootmatrix).item()
        boottemp = np.append(boottemp, stat_boot)
    return stat, boottemp

class CTDataset_all(Dataset):
    def __init__(self, filepath, AE_model):
        self.flatten = nn.Flatten()
        self.x, self.y = torch.load(filepath, weights_only=False)
        self.x = self.x / 255.
        self.z = self.flatten(self.x)
        self.x = self.x.reshape(-1, 1, 28, 28).cuda().detach()
        AE_model.eval()
        with torch.no_grad():
            self.x = AE_model.get_latent_space(self.x)
        self.x = self.x.detach()
        self.y = F.one_hot(self.y, num_classes=10).to(float)
        # self.y = self.y.to(float)
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, ix):
        return self.x[ix], self.y[ix], self.z[ix]

class CTDataset(Dataset):
    def __init__(self, filepath, AE_model):
        self.x, self.y = torch.load(filepath, weights_only=False)
        self.x = self.x / 255.
        self.x = self.x.reshape(-1, 1, 28, 28).cuda().detach()
        AE_model.eval()
        with torch.no_grad():
            self.x = AE_model.get_latent_space(self.x)
        self.x = self.x.detach()
        self.y = F.one_hot(self.y, num_classes=10).to(float)
        # self.x_max, _ = torch.max(self.x, dim=0, keepdim=True)
        # self.x_min, _ = torch.min(self.x, dim=0, keepdim=True)
        # self.x = (self.x - self.x_min) / (self.x_max - self.x_min)
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, ix):
        return self.x[ix], self.y[ix]

AE_model = AutoEncoder(d_l = latent_space_dim)
AE_model.load_state_dict(torch.load('./AE_'+ str(latent_space_dim) +'.pth', weights_only=True))
AE_model.to(device)

noise_dimension_image = 50
noise_dimension_label = 1
input_noise_type = "normal"

torch.manual_seed(42)
train_ds = CTDataset('./training.pt', AE_model)

torch.manual_seed(42)
train_AE_set, train_cond_gen_set = torch.utils.data.random_split(train_ds, [30000, 30000])
train_ds = train_cond_gen_set
DataLoader_train = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True, drop_last= False)

xs, ys = train_ds[0:10000]

torch.manual_seed(42)

test_ds = CTDataset_all('./test.pt', AE_model)

DataLoader_test = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=True, drop_last= False, )

G_image = Generator_image(latent_space_dim,  noise_dimension_image).to(device)
G_image.load_state_dict(torch.load('./AE'+str(latent_space_dim)+'_image.pth', weights_only=True))

G_label = Generator(input_dimension = latent_space_dim, output_dimension = 10, noise_dimension = noise_dimension_label,
          hidden_layer_size = 512, BN_type = True, ReLU_coef = 0.5, drop_out_p= 0.2).to(device)
G_label.load_state_dict(torch.load('./AE'+str(latent_space_dim)+'_label.pth', weights_only=True))

M = 100
test_size = 10000
Total_num_p_val = 40

z_all = torch.zeros(test_size, latent_space_dim)
x_all = torch.zeros(test_size, 28*28)
y_all = torch.zeros(test_size, 10)


for i, (z_test, y_test, x_test) in tqdm(enumerate(DataLoader_test)):
    x_all[i,:] = x_test
    y_all[i,:] = y_test
    z_all[i,:] = z_test

n_length_input = int(test_size/Total_num_p_val)
p_val_list = []

for i in tqdm(range(0, Total_num_p_val)):
    n_length = n_length_input
    start_index = n_length_input*(i)
    end_index = start_index + n_length

    x_all_in = x_all[start_index:end_index,].to(device).detach()
    y_all_in = y_all[start_index:end_index,].to(device).detach()
    z_all_in = z_all[start_index:end_index,].to(device).detach()

    p_val = dgcit(x_all_in, y_all_in, z_all_in, G_image, G_label, test_size=n_length_input, z_dim=latent_space_dim,
      x_dims=28*28, y_dims=10, M=100, k=1, b=30, j=1000)

    # print("the ",start_index," has p value: ",p_val)
    p_val_list.append(p_val)
p_lower, p_med, p_higher = np.quantile(p_val_list, 0.25), np.median(p_val_list), np.quantile(p_val_list, 0.75)
print("the latent_space is: ", latent_space_dim ,"\n")
print("the lower bound is: ", p_lower, "the median is: ", p_med, "the higher bound is: ", p_higher,"\n")

10000it [00:01, 5044.97it/s]
100%|██████████| 40/40 [01:44<00:00,  2.62s/it]

the latent_space is:  7 

the lower bound is:  0.28099999999999997 the median is:  0.3845 the higher bound is:  0.48724999999999996 




