In [None]:
d_l = 40 # dl = [40, 20, 16, 13, 12, 10, 9, 8, 7, 6, 5, 3, 1]
gen_layer_size = 64 # 64

In [None]:
import torch
import numpy as np
import random
import os

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

In [None]:
import torch
import torch.nn as nn
from torch.optim import SGD
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

import torch.distributions as TD
from zmq import device
import torch.optim as optim
from datetime import datetime
import functools
from tqdm import tqdm
import cv2

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

# Adding Noise to the dataset
def add_gaussian_noise(images, mean=0.0, std=1.0):
    noise = torch.randn_like(images) * std + mean
    noisy_images = images + noise
    noisy_images = torch.clamp(noisy_images, 0.0, 1.0)  # Ensure the values are still between 0 and 1
    return noisy_images

def motion_blur_kernel(length, angle):
    kernel = np.zeros((length, length))
    xs, ys = np.meshgrid(np.arange(length), np.arange(length))
    xs = xs - length // 2
    ys = ys - length // 2
    coords = np.stack([xs, ys], axis=-1)
    angle_rad = np.deg2rad(angle)
    direction = np.array([np.cos(angle_rad), np.sin(angle_rad)])
    dot_product = np.dot(coords, direction)
    kernel[np.abs(dot_product) < 0.5] = 1
    kernel /= kernel.sum() # normalize
    return kernel

def apply_motion_blur(image, length, angle):
    kernel = motion_blur_kernel(length, angle)
    kernel = kernel.astype(np.float32)
    image_np = image.squeeze().cpu().numpy()  # Convert to numpy array
    blurred_image_np = cv2.filter2D(image_np, -1, kernel, borderType=cv2.BORDER_REPLICATE)
    blurred_image = torch.tensor(blurred_image_np, dtype=torch.float32).unsqueeze(0).to(image.device)
    return blurred_image

def add_motion_blur_to_dataset(dataset):
    blurred_dataset = []
    for image in dataset:
        length = np.random.randint(6, 9)  # Random length between 4 and 6
        angle = np.random.uniform(0, 360)  # Random angle between 0 and 360 degrees
        blurred_image = apply_motion_blur(image, length, angle)
        blurred_dataset.append(blurred_image)

    return torch.stack(blurred_dataset)

def add_combine(dataset):
    motion_blur_dataset = add_motion_blur_to_dataset(dataset)
    add_guassian_noise_dataset = add_gaussian_noise(motion_blur_dataset)
    return add_guassian_noise_dataset

class CTDataset_image(Dataset):
    def __init__(self, filepath):
        self.flatten = nn.Flatten()
        self.x, _ = torch.load(filepath, weights_only=False)
        self.x = self.x / 255.
        self.x = self.x.reshape(-1, 1, 28, 28).cuda().detach()

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, ix):
        return self.x[ix]

def get_plot(x):
    x_temp = x.clone()
    x_np = x_temp.cpu().detach().numpy()
    # Create a figure with 10 rows and 4 columns
    fig, axes = plt.subplots(2, 4, figsize=(10, 5))

    # Plot the first 40 images
    for i in range(8):
        ax = axes[i // 4, i % 4]
        ax.imshow(x_np[i, 0], cmap='gray')
        ax.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
setup_seed(42)  # You can choose any seed value
train_ds = CTDataset_image('./training.pt')
test_ds = CTDataset_image('./test.pt')

setup_seed(42)  # You can choose any seed value
train_AE_set, train_cond_gen_set = torch.utils.data.random_split(train_ds, [30000, 30000])

In [None]:
ys_train_AE_full_image = train_AE_set[:]
ys_train_gen_full_image = train_cond_gen_set[:]
ys_test_full_image = test_ds[:]

setup_seed(42)  # You can choose any seed value
xs_train_AE = add_combine(ys_train_AE_full_image)

setup_seed(42)  # You can choose any seed value
xs_train_gen = add_combine(ys_train_gen_full_image)

setup_seed(42)  # You can choose any seed value
xs_test = add_combine(ys_test_full_image)

ys_train_AE = ys_train_AE_full_image.to(device)
ys_train_gen = ys_train_gen_full_image.to(device)
ys_test = ys_test_full_image.to(device)

In [None]:
get_plot(ys_test)

In [None]:
get_plot(xs_test)

In [None]:
import torch
import torch.nn as nn

class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)

class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()

    def forward(self, x):
        return x[:, :, :28, :28]

class AutoEncoder3(nn.Module):
    def __init__(self, d_l=d_l):
        super().__init__()

        self.encoder = nn.Sequential( #784
                nn.Conv2d(1, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 128, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(128, 128, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(128, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.Flatten(),
                nn.Linear(3136, d_l),
                nn.Tanh()
                # nn.Flatten(),
                # nn.Identity()
        )

        self.decoder = nn.Sequential(
                torch.nn.Linear(d_l, 3136),
                Reshape(-1, 64, 7, 7),
                nn.ConvTranspose2d(64, 128, stride=(1, 1), kernel_size=(3, 3), padding=1), # 64x7x7 -> 64x7x7
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(128, 128, stride=(2, 2), kernel_size=(3, 3), padding=1), # 64x7x7 -> 64x13x13
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(128, 64, stride=(2, 2), kernel_size=(3, 3), padding=0), # 64x13x13 -> 32x27x27
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 1, stride=(1, 1), kernel_size=(3, 3), padding=0), # 32x27x27 -> 1x29x29
                Trim(),  # 1x29x29 -> 1x28x28
                nn.Sigmoid()
        )

    def forward(self, x):
        x = F.interpolate(x, size=(28, 28), mode='nearest')
        x = self.encoder(x)
        # print(x.shape)
        x = self.decoder(x)
        return x

    def get_latent(self, x):
        x = F.interpolate(x, size=(28, 28), mode='nearest')
        x = self.encoder(x)
        return x

In [None]:
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.utils.data as Data
def get_AE_model(model, x_train, y_train, x_test, y_test, param):

    set_seed = param['set_seed']
    wgt_decay = param['wgt_decay']
    G_lr = param['G_lr']

    setup_seed(set_seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Move the model to the device
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=G_lr, weight_decay=wgt_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=10, min_lr=1e-9)

    batch_size = 64
    epochs_num = 1500 # 1500

    train_data = Data.TensorDataset(x_train.to(device), y_train.to(device))
    train_loader = Data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

    test_data = Data.TensorDataset(x_test.to(device), y_test.to(device))
    test_loader = Data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

    # Early stopping parameters
    early_stopping_patience = 50
    best_val_loss = float('inf')
    patience_counter = 0

    criterion = nn.MSELoss()
    # criterion2 = nn.L1Loss()
    # Lists to store loss values for plotting
    train_losses = []
    val_losses = []

    # Initial evaluation
    total_test_loss = 0.0
    model.eval()
    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            total_test_loss += loss.item()

    avg_val_loss = total_test_loss / len(test_loader)

    total_train_loss = 0.0
    model.eval()
    with torch.no_grad():
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    current_lr = scheduler.optimizer.param_groups[0]['lr']
    print(f'Epoch [{0}/{epochs_num}], Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Learning Rate: {current_lr:.10f}')

    # Save the best model initially
    torch.save(model.state_dict(), 'best_model.pth')

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

    for epoch in tqdm(range(epochs_num)):
        model.train()
        total_train_loss = 0.0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)

            optimizer.zero_grad()
            loss = criterion(outputs, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        total_loss = 0.0
        model.eval()
        with torch.no_grad():
            for batch_x, batch_y in test_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                total_loss += loss.item()

        avg_val_loss = total_loss / len(test_loader)

        # Step the scheduler based on the Validation loss
        scheduler.step(avg_val_loss)

        # Check for early stopping and save the best model based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1

        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

        # Append losses for plotting
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)

        if (epoch + 1) % 25 == 0:
            current_lr = scheduler.optimizer.param_groups[0]['lr']
            print(f'Epoch [{epoch + 1}/{epochs_num}], Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Learning Rate: {current_lr:.10f}')

    # Load the best model after training is complete
    model.load_state_dict(torch.load('best_model.pth', weights_only=False))
    print("Best model loaded with validation loss:", best_val_loss)

    import os
    file_path = './best_model.pth'
    if os.path.exists(file_path):
        os.remove(file_path)
        print(f"{file_path} has been deleted.")
    else:
        print(f"{file_path} does not exist.")

In [None]:
# @title Train AE model
param = {
  "set_seed": 42,
  "wgt_decay": 1e-04,
  "G_lr": 1e-6
}

model_AE = AutoEncoder3().to(device)
get_AE_model(model=model_AE, x_train=xs_train_AE, y_train=ys_train_AE, x_test=xs_test, y_test=ys_test, param=param)

In [None]:
# @title Get MSE for compute Test PSNR

total_test_loss_MSE = 0.0
criterion = nn.nn.MSELoss()

test_data = Data.TensorDataset(xs_test.to(device), ys_test.to(device))
test_loader = Data.DataLoader(dataset=test_data, batch_size=64, shuffle=True)

model_AE.eval()
with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = model_AE(batch_x)
        loss = criterion(outputs, batch_y)
        total_test_loss_MSE += loss.item()

avg_val_loss_MSE = total_test_loss_MSE / len(test_loader)

print(f'Validation MSE Loss: {avg_val_loss_MSE:.4f}')

In [None]:
# Initial evaluation
total_test_loss_MAE = 0.0
criterion = nn.L1Loss()

test_data = Data.TensorDataset(xs_test.to(device), ys_test.to(device))
test_loader = Data.DataLoader(dataset=test_data, batch_size=64, shuffle=True)

model_AE.eval()
with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = model_AE(batch_x)
        loss = criterion(outputs, batch_y)
        total_test_loss_MAE += loss.item()

avg_val_loss_MAE = total_test_loss_MAE / len(test_loader)

print(f'Validation MAE Loss: {avg_val_loss_MAE:.4f}')

In [None]:
def plot_images_from_tensor(x):
    # Convert tensor to numpy array
    x_temp = x.clone()
    x_np = x_temp.cpu().detach().numpy()

    # Create a figure with a larger size
    fig, axes = plt.subplots(2, 8, figsize=(40, 10))

    # Plot each image
    for i in range(x_np.shape[0]):
        ax = axes[i // 8, i % 8]

        ax.imshow(x_np[i, 0], cmap='gray')
        plt.axis('off')


x_demo = xs_test[:16,:,:,:].clone()
input_demo = ys_test[:16,:,:,:].clone()

xs_test_temp = x_demo

model_AE.eval()
with torch.no_grad():
    outputs = model_AE(xs_test_temp.to(device))

plot_images_from_tensor(x_demo)
plot_images_from_tensor(input_demo)
plot_images_from_tensor(outputs)

In [None]:
# Function to count the number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

encoder_params = count_parameters(model_AE.encoder)
print(f'Number of parameters in the encoder: {encoder_params}')

# Count the parameters in the decoder
decoder_params = count_parameters(model_AE.decoder)
print(f'Number of parameters in the decoder: {decoder_params}')

In [None]:
torch.save(model_AE.state_dict(), 'path_to_trained_autoencoder.pth')

In [None]:
del xs_train_AE, ys_train_AE

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import torch.distributions as TD
import numpy as np
import random
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import cv2
import matplotlib.pyplot as plt


# Function to save images from tensor
def save_images_from_tensor(x, file_path):
    # Convert tensor to numpy array
    x_temp = x.clone()
    x_np = x_temp.cpu().detach().numpy()

    # Create a figure with a larger size
    fig, axes = plt.subplots(4, 8, figsize=(10, 5))

    # Plot each image
    for i in range(x_np.shape[0]):
        ax = axes[i // 8, i % 8]

        ax.imshow(x_np[i, 0], cmap='gray')
        plt.axis('off')

    # Save the figure to a file
    plt.savefig(file_path)
    plt.close()

def sample_noise(sample_size, noise_dimension, noise_type, input_var):

    if (noise_type == "normal"):
      noise_generator = TD.MultivariateNormal(
        torch.zeros(noise_dimension).to(device), input_var * torch.eye(noise_dimension).to(device))

      Z = noise_generator.sample((sample_size,))
    if (noise_type == "unif"):
      Z = torch.rand(sample_size, noise_dimension)
    if (noise_type == "Cauchy"):
      Z = TD.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])).sample((sample_size, noise_dimension)).squeeze(2)

    return Z

def get_distance_matrix(X, Y, p_in = 1):
    return torch.cdist(X, Y, p=p_in)

def find_loss_l(y_torch, gen_y_all_torch, z_torch, sigma_w, sigma_u, M):
    n = z_torch.shape[0]
    d_y = y_torch.shape[1]

    w_mx = get_distance_matrix(z_torch, z_torch)
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-get_distance_matrix(y_torch, y_torch) / sigma_u)
    u_mx_2 = torch.exp(-get_distance_matrix(gen_y_all_torch[:,0,:], y_torch) / sigma_u)
    for i in range(1, M):
        u_mx_2 = u_mx_2 + torch.exp(-get_distance_matrix(gen_y_all_torch[:,i,:], y_torch) / sigma_u)
    u_mx_2 = u_mx_2 / M
    u_mx_3 = u_mx_2.T


    sum_mx_temp = torch.zeros(n, n, M).to(device)

    for i in range(n):
        sum_mx_temp[i,:,:] = torch.linalg.vector_norm(gen_y_all_torch.reshape(n*M,d_y) - gen_y_all_torch[i,0,:].reshape(1,d_y).expand(n*M, -1), ord = 1, dim = 1).reshape(n, M)

    sum_mx = torch.mean(torch.exp(-sum_mx_temp/ sigma_u), dim=2)

    for k in range(1, M):
        sum_mx_temp = torch.zeros(n, n, M).to(device)
        for i in range(n):
            sum_mx_temp[i,:,:] = torch.linalg.vector_norm(gen_y_all_torch.reshape(n*M,d_y) - gen_y_all_torch[i,k,:].reshape(1,d_y).expand(n*M, -1), ord = 1, dim = 1).reshape(n, M)

        temp_add_mx = torch.mean(torch.exp(-sum_mx_temp/ sigma_u), dim=2)
        sum_mx = sum_mx + temp_add_mx

    u_mx_4 = 1 / M * sum_mx

    # u_mx_4 = torch.exp(-get_distance_matrix(gen_y_all_torch[:,0,:], gen_y_all_torch[:,0,:]) / sigma_u)
    # for i in range(1, M):
    #     u_mx_4 = u_mx_4 + torch.exp(-get_distance_matrix(gen_y_all_torch[:,i,:], gen_y_all_torch[:,i,:]) / sigma_u)
    # u_mx_4 = u_mx_4 / M

    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4

    FF_mx = u_mx * w_mx * (1 - torch.eye(n).to(device))

    loss = 1 / (n) * torch.sum(FF_mx)
    return loss


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)

class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()

    def forward(self, x):
        return x[:, :, :28, :28]


class generator_y(nn.Module):
    def __init__(self, input_dimension, noise_dimension, gen_layer_size):
        super().__init__()

        self.decoder = nn.Sequential(
            torch.nn.Linear(input_dimension + noise_dimension, 3136),
            Reshape(-1, 64, 7, 7),
            nn.ConvTranspose2d(64, gen_layer_size, stride=(1, 1), kernel_size=(3, 3), padding=1),
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(gen_layer_size, gen_layer_size, stride=(2, 2), kernel_size=(3, 3), padding=1),
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(gen_layer_size, gen_layer_size, stride=(2, 2), kernel_size=(3, 3), padding=0),
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(gen_layer_size, 1, stride=(1, 1), kernel_size=(3, 3), padding=0),
            Trim(),  # 1x29x29 -> 1x28x28
            nn.Sigmoid()
        )

    def forward(self, x, noise):
        x = torch.cat((x, noise), dim=2)
        x = self.decoder(x)
        return x

class generator_x(nn.Module):
    def __init__(self, input_dimension, noise_dimension, gen_layer_size):
        super().__init__()

        self.decoder = nn.Sequential(
            torch.nn.Linear(input_dimension + noise_dimension, 3136),
            Reshape(-1, 64, 7, 7),
            nn.ConvTranspose2d(64, gen_layer_size, stride=(1, 1), kernel_size=(3, 3), padding=1),
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(gen_layer_size, gen_layer_size, stride=(2, 2), kernel_size=(3, 3), padding=1),
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(gen_layer_size, gen_layer_size, stride=(2, 2), kernel_size=(3, 3), padding=0),
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(gen_layer_size, 1, stride=(1, 1), kernel_size=(3, 3), padding=0),
            Trim(),  # 1x29x29 -> 1x28x28
            nn.Sigmoid()
        )

    def forward(self, x, noise):
        x = torch.cat((x, noise), dim=2)
        x = self.decoder(x)
        return x

def get_generator(model, z_train, z_test, x_train, x_test, param):


    set_seed = param['set_seed']
    noise_dimension = param['noise_dimension']
    noise_type = param['noise_type']
    input_var = param['input_var']
    lambda_3 = param['lambda_3']
    wgt_decay = param['wgt_decay']
    G_lr = param['G_lr']
    label = param['label']

    setup_seed(set_seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    optimizer = optim.Adam(model.parameters(), lr=G_lr, weight_decay=wgt_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, min_lr=1e-7)

    N = z_train.shape[0]

    z_temp = z_train.clone().reshape(N, -1).to(device)
    z_temp = z_temp.detach()
    z_temp = z_temp[:10000]
    w_mx = get_distance_matrix(z_temp, z_temp)
    sigma_z_l = torch.median(w_mx).item()


    x_sub_all_temp = x_train.clone().reshape(N, -1).to(device)

    x_sub_all_temp = x_sub_all_temp.detach()
    x_sub_all_temp = x_sub_all_temp[:10000]
    u_mx = get_distance_matrix(x_sub_all_temp, x_sub_all_temp)
    sigma_x_l = torch.median(u_mx).item()

    print("sigma_z_l: ", sigma_z_l, "sigma_x_l: ", sigma_x_l)

    M_train = 10
    batch_size = 128 # 16
    epochs_num = 600 # 300

    train_data = Data.TensorDataset(z_train.to(device), x_train.to(device))
    train_loader = Data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

    test_data = Data.TensorDataset(z_test.to(device), x_test.to(device))
    test_loader = Data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

    # Early stopping parameters
    early_stopping_patience = 50
    best_val_loss = float('inf')
    patience_counter = 0

    # Lists to store loss values for plotting
    train_losses = []
    val_losses = []
    import math

    # eval
    total_test_loss = 0.0
    batch_count = 0

    model.eval()
    with torch.no_grad():
        for batch_z, batch_x in test_loader:
            batch_size = batch_z.shape[0]

            batch_x_sub = batch_x.clone()

            X_real = batch_x_sub

            repeat_dims = (M_train, 1, 1)
            Z_real_repeat = batch_z.repeat(*repeat_dims).to(device)

            Noise_fake = sample_noise(Z_real_repeat.shape[0]*Z_real_repeat.shape[1], noise_dimension, noise_type, input_var = input_var).to(device)
            Noise_fake = Noise_fake.reshape(Z_real_repeat.shape[0], Z_real_repeat.shape[1], -1)

            output1= model(Z_real_repeat.to(device), Noise_fake.to(device))
            output1 = output1.reshape(M_train, batch_size, output1.shape[1], output1.shape[2], output1.shape[3]).swapaxes(0, 1)

            X_fake = output1.reshape(batch_size, M_train, -1).to(device)
            X_real = X_real.reshape(batch_size, -1).to(device)
            Z_real = batch_z.reshape(batch_size, -1).to(device)

            mmd_l_test_loss = find_loss_l(X_real, X_fake, Z_real, sigma_z_l, sigma_x_l, M_train)
            total_test_loss = total_test_loss +  mmd_l_test_loss.item()

    avg_val_loss = total_test_loss / len(test_loader)

    total_train_loss = 0.0
    batch_count = 0

    model.eval()
    with torch.no_grad():
        for batch_z, batch_x in train_loader:
            batch_size = batch_z.shape[0]

            batch_x_sub = batch_x.clone()

            X_real = batch_x_sub

            repeat_dims = (M_train, 1, 1)
            Z_real_repeat = batch_z.repeat(*repeat_dims).to(device)

            Noise_fake = sample_noise(Z_real_repeat.shape[0]*Z_real_repeat.shape[1], noise_dimension, noise_type, input_var = input_var).to(device)
            Noise_fake = Noise_fake.reshape(Z_real_repeat.shape[0], Z_real_repeat.shape[1], -1)


            output1= model(Z_real_repeat.to(device), Noise_fake.to(device))
            output1 = output1.reshape(M_train, batch_size, output1.shape[1], output1.shape[2], output1.shape[3]).swapaxes(0, 1)

            X_fake = output1.reshape(batch_size, M_train, -1).to(device)
            X_real = X_real.reshape(batch_size, -1).to(device)
            Z_real = batch_z.reshape(batch_size, -1).to(device)

            mmd_l_train_loss = find_loss_l(X_real, X_fake, Z_real, sigma_z_l, sigma_x_l, M_train)
            total_train_loss = total_train_loss +  mmd_l_train_loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    current_lr = scheduler.optimizer.param_groups[0]['lr']
    print(f'Epoch [{0}/{epochs_num}], Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Learning Rate: {current_lr:.6f}')

    # Save the best model initially
    torch.save(model.state_dict(), 'best_model.pth')

    train_losses.append(math.log(avg_train_loss))
    val_losses.append(math.log(avg_val_loss))


    for epoch in tqdm(range(epochs_num)):
        model.train()
        total_train_loss = 0.0
        batch_count = 0
        for batch_z, batch_x in train_loader:
            batch_size = batch_z.shape[0]

            batch_x_sub = batch_x.clone()

            X_real = batch_x_sub

            repeat_dims = (M_train, 1, 1)
            Z_real_repeat = batch_z.repeat(*repeat_dims).to(device)

            Noise_fake = sample_noise(Z_real_repeat.shape[0]*Z_real_repeat.shape[1], noise_dimension, noise_type, input_var = input_var).to(device)
            Noise_fake = Noise_fake.reshape(Z_real_repeat.shape[0], Z_real_repeat.shape[1], -1)

            output1= model(Z_real_repeat.to(device), Noise_fake.to(device))
            output1 = output1.reshape(M_train, batch_size, output1.shape[1], output1.shape[2], output1.shape[3]).swapaxes(0, 1)

            X_fake = output1.reshape(batch_size, M_train, -1).to(device)
            X_real = X_real.reshape(batch_size, -1).to(device)
            Z_real = batch_z.reshape(batch_size, -1).to(device)

            # Generator step
            g_zx_error = None
            optimizer.zero_grad()

            # l1_regularization = 0

            # for param in model.parameters():
            #     l1_regularization += torch.linalg.vector_norm(param, ord = 1)

            g_zx_error = find_loss_l(X_real, X_fake, Z_real, sigma_z_l, sigma_x_l, M_train) # +  lambda_3 * l1_regularization

            g_zx_error.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            total_train_loss = total_train_loss +   g_zx_error.item()

        avg_train_loss = total_train_loss / len(train_loader)


        total_loss = 0.0
        model.eval()
        with torch.no_grad():
            for batch_z, batch_x in test_loader:
                batch_size = batch_z.shape[0]

                batch_x_sub = batch_x.clone()

                X_real = batch_x_sub

                repeat_dims = (M_train, 1, 1)
                Z_real_repeat = batch_z.repeat(*repeat_dims).to(device)

                Noise_fake = sample_noise(Z_real_repeat.shape[0]*Z_real_repeat.shape[1], noise_dimension, noise_type, input_var = input_var).to(device)
                Noise_fake = Noise_fake.reshape(Z_real_repeat.shape[0], Z_real_repeat.shape[1], -1)


                output1= model(Z_real_repeat.to(device), Noise_fake.to(device))
                output1 = output1.reshape(M_train, batch_size, output1.shape[1], output1.shape[2], output1.shape[3]).swapaxes(0, 1)

                X_fake = output1.reshape(batch_size, M_train, -1).to(device)
                X_real = X_real.reshape(batch_size, -1).to(device)
                Z_real = batch_z.reshape(batch_size, -1).to(device)

                mmd_l_test_loss = find_loss_l(X_real, X_fake, Z_real, sigma_z_l, sigma_x_l, M_train)
                total_loss = total_loss +  mmd_l_test_loss.item()

        avg_val_loss = total_loss / len(test_loader)


        # Step the scheduler based on the Validation loss
        scheduler.step(avg_val_loss)

        # Check for early stopping and save the best model based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1

        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

        # Append losses for plotting
        train_losses.append(math.log(avg_train_loss))
        val_losses.append(math.log(avg_val_loss))


        # Plotting the training and validation losses without showing the plot
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Log Training Loss')
        plt.plot(val_losses, label='Log Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Log oss')
        plt.title('Training and Validation Loss vs Epochs')
        plt.legend()
        plt.savefig('loss_vs_epoch.png')
        plt.close()  # Close the figure

        if (epoch + 1) % 15 == 0:

            # eval
            total_loss = 0.0
            model.eval()
            with torch.no_grad():
                for batch_z, batch_x in test_loader:
                    batch_size = batch_z.shape[0]

                    batch_x_sub = batch_x.clone()

                    X_real = batch_x_sub

                    repeat_dims = (M_train, 1, 1)
                    Z_real_repeat = batch_z.repeat(*repeat_dims).to(device)

                    Noise_fake = sample_noise(Z_real_repeat.shape[0]*Z_real_repeat.shape[1], noise_dimension, noise_type, input_var = input_var).to(device)
                    Noise_fake = Noise_fake.reshape(Z_real_repeat.shape[0], Z_real_repeat.shape[1], -1)

                    output1= model(Z_real_repeat.to(device), Noise_fake.to(device))
                    output1 = output1.reshape(M_train, batch_size, output1.shape[1], output1.shape[2], output1.shape[3]).swapaxes(0, 1)

                    X_fake = output1.reshape(batch_size, M_train, -1).to(device)
                    X_real = X_real.reshape(batch_size, -1).to(device)
                    Z_real = batch_z.reshape(batch_size, -1).to(device)

                    mmd_l_test_loss = find_loss_l(X_real, X_fake, Z_real, sigma_z_l, sigma_x_l, M_train)
                    total_loss = total_loss +  mmd_l_test_loss.item()

            avg_val_loss = total_loss / len(test_loader)

            current_lr = scheduler.optimizer.param_groups[0]['lr']
            print(f'Epoch [{epoch + 1}/{epochs_num}], Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Learning Rate: {current_lr:.6f}, pat: {patience_counter:.1f}')
            # To see the image during the training process

            z_demo = z_test[:32,:].clone()
            x_demo = x_test[:32,:,:,:].clone()


            Noise_fake = sample_noise(z_demo.shape[0], noise_dimension, noise_type, input_var = input_var).to(device)
            Noise_fake = Noise_fake.reshape(z_demo.shape[0], 1, -1)
            z_demo = z_demo.reshape(z_demo.shape[0],1,z_demo.shape[1]).to(device)

            model.eval()
            with torch.no_grad():
                z_demo_temp = model(z_demo, Noise_fake.to(device))

            save_images_from_tensor(z_demo_temp.cpu().detach(), './generated_image.jpg')
            save_images_from_tensor(x_demo.cpu().detach(), './original_image.jpg')

    # Load the best model after training is complete
    model.load_state_dict(torch.load('best_model.pth', weights_only=False))
    print('The seed is ' + str(seed))
    print("Best model loaded with validation loss:", best_val_loss)
    torch.save(model.state_dict(), 'best_model_'+label+'_k1_'+ str(set_seed) +'.pth')

    import os

    file_path = './best_model.pth'

    # Check if the file exists before deleting
    if os.path.exists(file_path):
        os.remove(file_path)
        print(f"{file_path} has been deleted.")
    else:
        print(f"{file_path} does not exist.")

    return best_val_loss


In [None]:
x_train_input = xs_train_gen.clone()
x_test_input = xs_test.clone()

y_train_input = ys_train_gen.clone()
y_test_input = ys_test.clone()


import torch
from torch.utils.data import DataLoader, TensorDataset

# Assuming xs_train_AE and xs_test_AE are your input tensors
batch_size = 64  # Adjust the batch size based on your memory capacity

# Create DataLoaders
train_dataset = TensorDataset(x_train_input)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

test_dataset = TensorDataset(x_test_input)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Function to compute latent representations
def compute_latent_representations(model, data_loader, device):
    latent_representations = []
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            batch = batch[0].to(device)  # Get the input tensor from the batch
            z_batch = model.get_latent(batch).detach()
            latent_representations.append(z_batch.cpu())
    return torch.cat(latent_representations, dim=0)
model_AE = AutoEncoder3().to(device)
model_AE.load_state_dict(torch.load('path_to_trained_autoencoder.pth', weights_only=False))
# Compute latent representations for training and test data
z_train_input = compute_latent_representations(model_AE, train_loader, device).reshape(x_train_input.shape[0], -1)
z_test_input = compute_latent_representations(model_AE, test_loader, device).reshape(x_test_input.shape[0], -1)

print(z_train_input.shape)  # Check the shape of the latent representations
print(z_test_input.shape)

print(x_train_input.shape)
print(x_test_input.shape)

print(y_train_input.shape)
print(y_test_input.shape)

In [None]:
param_x = {
  "set_seed": 42,
  "noise_dimension": 20,
  "noise_type": "normal",
  "input_var": 1.0/9.0,
  "lambda_3": 0,
  "wgt_decay": 1e-05,
  "G_lr": 1e-3,
  "label": "x"
}
input_noise_dim_x = param_x["noise_dimension"]

seed_list = np.array([0, 1, 2, 3, 42, 114514, 1919810])

test_mmd_list = np.array([])

for seed in seed_list:
    print("seed: ", seed)
    param_x["set_seed"] = int(seed)

    model_x_k1 = generator_x(input_dimension = d_l, noise_dimension = input_noise_dim_x, gen_layer_size = gen_layer_size).to(device)
    temp_mmd = get_generator( model=model_x_k1, z_train=z_train_input, z_test=z_test_input, x_train=x_train_input, x_test=x_test_input, param=param_x)
    test_mmd_list = np.append(test_mmd_list, temp_mmd)

In [None]:
# @title Get Gen model X


del model_x_k1

# Find the index of the minimum test MMD
min_index = np.argmin(test_mmd_list)

# Get the minimum test MMD and the corresponding seed
min_test_mmd = test_mmd_list[min_index]
corresponding_seed = seed_list[min_index]

print(f"The minimum test MMD is {min_test_mmd} and the corresponding seed is {corresponding_seed}.")

for seed in seed_list:
    if seed != corresponding_seed:
        print("del .pth with seed: ", seed)
        os.remove('best_model_x_k1_'+ str(seed) +'.pth')

model_x_k1 = generator_x(input_dimension = d_l, noise_dimension = input_noise_dim_x, gen_layer_size = gen_layer_size).to(device)

model_x_k1.load_state_dict(torch.load('best_model_x_k1_'+ str(corresponding_seed) +'.pth', weights_only=False))

torch.save(model_x_k1.state_dict(), 'best_model_x_k1.pth')

In [None]:
def plot_images_from_tensor(x):
    # Convert tensor to numpy array
    x_temp = x.clone()
    x_np = x_temp.cpu().detach().numpy()

    # Create a figure with a larger size
    fig, axes = plt.subplots(4, 8, figsize=(10, 5))

    # Plot each image
    for i in range(x_np.shape[0]):
        ax = axes[i // 8, i % 8]

        ax.imshow(x_np[i, 0], cmap='gray')
        plt.axis('off')


z_demo = z_test_input[:32,:].clone()
input_demo = x_test_input[:32,:,:,:].clone()

noise_dimension_x = param_x['noise_dimension']
noise_type_x = param_x['noise_type']
input_var_x = param_x['input_var']

Noise_fake = sample_noise(z_demo.shape[0], noise_dimension_x, noise_type_x, input_var = input_var_x).to(device)
Noise_fake = Noise_fake.reshape(z_demo.shape[0], 1, -1)
z_demo = z_demo.reshape(z_demo.shape[0],1,z_demo.shape[1]).to(device)

model_x_k1.eval()
with torch.no_grad():
    z_demo_temp = model_x_k1(z_demo, Noise_fake.to(device))

plot_images_from_tensor(z_demo_temp)
plot_images_from_tensor(input_demo)

In [None]:
z_demo_final = torch.zeros(32, 1, 28, 28)

for i in range(32):

    z_demo = z_test_input[i,:].unsqueeze(0).clone()

    M = 200

    Noise_fake = sample_noise(z_demo.shape[0]*M, noise_dimension_x, noise_type_x, input_var = input_var_x).to(device)
    Noise_fake = Noise_fake.reshape(z_demo.shape[0], M, -1)
    z_demo = z_demo.reshape(z_demo.shape[0],1,z_demo.shape[1]).to(device)
    z_demo = z_demo.expand(-1, M, -1).to(device)
    z_demo = z_demo.reshape(z_demo.shape[0],M,d_l).to(device)

    model_x_k1.eval()
    with torch.no_grad():
        z_demo_temp = model_x_k1(z_demo, Noise_fake.to(device))

    z_demo_temp_mean = torch.mean(z_demo_temp, dim = 0)
    z_demo_final[i,:,:,:] = z_demo_temp_mean

plot_images_from_tensor(z_demo_final)
plot_images_from_tensor(input_demo)

In [None]:
param_y = {
  "set_seed": 42,
  "noise_dimension": 20,
  "noise_type": "normal",
  "input_var": 1.0/9.0,
  "lambda_3": 0,
  "wgt_decay": 1e-05,
  "G_lr": 1e-3,
  "label" : "y"
}
input_noise_dim_y = param_y["noise_dimension"]

seed_list = np.array([0, 1, 2, 3, 42, 114514, 1919810])

test_mmd_list = np.array([])

for seed in seed_list:
    print("seed: ", seed)
    param_y["set_seed"] = int(seed)

    model_y_k1 = generator_y(input_dimension = d_l, noise_dimension = input_noise_dim_y, gen_layer_size = gen_layer_size).to(device)
    temp_mmd = get_generator(model=model_y_k1, z_train=z_train_input, z_test=z_test_input, x_train=y_train_input, x_test=y_test_input, param=param_y)
    test_mmd_list = np.append(test_mmd_list, temp_mmd)

In [None]:
# @title Get Gen model Y
del model_y_k1

# Find the index of the minimum test MMD
min_index = np.argmin(test_mmd_list)

# Get the minimum test MMD and the corresponding seed
min_test_mmd = test_mmd_list[min_index]
corresponding_seed = seed_list[min_index]

print(f"The minimum test MMD is {min_test_mmd} and the corresponding seed is {corresponding_seed}.")

for seed in seed_list:
    if seed != corresponding_seed:
        print("del .pth with seed: ", seed)
        os.remove('best_model_y_k1_'+ str(seed) +'.pth')

model_y_k1 = generator_x(input_dimension = d_l, noise_dimension = input_noise_dim_y, gen_layer_size = gen_layer_size).to(device)

model_y_k1.load_state_dict(torch.load('best_model_y_k1_'+ str(corresponding_seed) +'.pth', weights_only=False))

torch.save(model_y_k1.state_dict(), 'best_model_y_k1.pth')

In [None]:
z_demo = z_test_input[:32,:].clone()
input_demo = y_test_input[:32,:,:,:].clone()

noise_dimension_y = param_y['noise_dimension']
noise_type_y = param_y['noise_type']
input_var_y = param_y['input_var']

Noise_fake = sample_noise(z_demo.shape[0], noise_dimension_y, noise_type_y, input_var = input_var_y).to(device)
Noise_fake = Noise_fake.reshape(z_demo.shape[0], 1, -1)
z_demo = z_demo.reshape(z_demo.shape[0],1,z_demo.shape[1]).to(device)

model_x_k1.eval()
with torch.no_grad():
    z_demo_temp = model_y_k1(z_demo, Noise_fake.to(device))

plot_images_from_tensor(z_demo_temp)
plot_images_from_tensor(input_demo)

In [None]:
z_demo_final = torch.zeros(32, 1, 28, 28)

for i in range(32):

    z_demo = z_test_input[i,:].unsqueeze(0).clone()

    M = 200

    Noise_fake = sample_noise(z_demo.shape[0]*M, noise_dimension_y, noise_type_y, input_var = input_var_y).to(device)
    Noise_fake = Noise_fake.reshape(z_demo.shape[0], M, -1)
    z_demo = z_demo.reshape(z_demo.shape[0],1,z_demo.shape[1]).to(device)
    z_demo = z_demo.expand(-1, M, -1).to(device)
    z_demo = z_demo.reshape(z_demo.shape[0],M,d_l).to(device)

    model_x_k1.eval()
    with torch.no_grad():
        z_demo_temp = model_y_k1(z_demo, Noise_fake.to(device))

    z_demo_temp_mean = torch.mean(z_demo_temp, dim = 0)
    z_demo_final[i,:,:,:] = z_demo_temp_mean

plot_images_from_tensor(z_demo_final)
plot_images_from_tensor(input_demo)

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# Parameters
M = 100
test_size = 10000
latent_space_dim = d_l

x_test_input = xs_test.clone()
y_test_input = ys_test.clone()

# Initialize tensors
gen_x_all = torch.zeros(test_size, M, 28*28)
gen_y_all = torch.zeros(test_size, M, 28*28)
z_all = torch.zeros(test_size, latent_space_dim)
x_all = torch.zeros(test_size, 28*28)
y_all = torch.zeros(test_size, 28*28)

# Create DataLoader
test_data = TensorDataset(z_test_input.to(device), y_test_input.to(device), x_test_input.to(device))
DataLoader_test = DataLoader(test_data, batch_size=1, shuffle=False, drop_last=False)

# Set models to evaluation mode
model_y_k1.eval()
model_x_k1.eval()

with torch.no_grad():
    for i, (z_test, y_test, x_test) in tqdm(enumerate(DataLoader_test)):
        repeat_dims = (M, 1, 1)
        Z_real_repeat = z_test.repeat(*repeat_dims).to(device)

        # Generate noise for y
        Noise_fake_y = sample_noise(Z_real_repeat.shape[0]*Z_real_repeat.shape[1], noise_dimension_y, noise_type_y, input_var=input_var_y).to(device)
        Noise_fake_y = Noise_fake_y.reshape(Z_real_repeat.shape[0], Z_real_repeat.shape[1], -1)

        # Generate fake y
        output_y = model_y_k1(Z_real_repeat, Noise_fake_y)
        output_y = output_y.reshape(M, 1, output_y.shape[1], output_y.shape[2], output_y.shape[3]).swapaxes(0, 1)
        Y_fake = output_y.reshape(1, M, -1).detach()

        # Generate noise for x
        Noise_fake_x = sample_noise(Z_real_repeat.shape[0]*Z_real_repeat.shape[1], noise_dimension_x, noise_type_x, input_var=input_var_x).to(device)
        Noise_fake_x = Noise_fake_x.reshape(Z_real_repeat.shape[0], Z_real_repeat.shape[1], -1)

        # Generate fake x
        output_x = model_x_k1(Z_real_repeat, Noise_fake_x) #100, 14*14
        output_x = output_x.reshape(M, 1, output_x.shape[1], output_x.shape[2], output_x.shape[3]).swapaxes(0, 1) #
        X_fake = output_x.reshape(1, M, -1).detach()

        # Store results
        x_all[i, :] = x_test.reshape(1, -1).detach()
        y_all[i, :] = y_test.reshape(1, -1).detach()
        z_all[i, :] = z_test.reshape(1, -1).detach()
        gen_x_all[i, :] = X_fake
        gen_y_all[i, :] = Y_fake


In [None]:
print(x_all.shape)
print(y_all.shape)
print(z_all.shape)
print(gen_x_all.shape)
print(gen_y_all.shape)

In [None]:
import matplotlib.pyplot as plt

def get_four_plots(x_all, y_all, gen_x_all, gen_y_all):

    # Prepare the data
    x_np1 = x_all[100].reshape(1, 28, 28).cpu().detach().numpy()
    y_np1 = y_all[100].reshape(1, 28, 28).cpu().detach().numpy()
    x_np2 = gen_x_all[100, 0, :].reshape(1, 28, 28).cpu().detach().numpy()
    y_np2 = gen_y_all[100, 0, :].reshape(1, 28, 28).cpu().detach().numpy()

    # Create subplots
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))

    # Plot each image
    axs[0].imshow(x_np1[0], cmap='gray')
    axs[0].set_title('x')
    axs[1].imshow(y_np1[0], cmap='gray')
    axs[1].set_title('y')
    axs[2].imshow(x_np2[0], cmap='gray')
    axs[2].set_title('x gen')
    axs[3].imshow(y_np2[0], cmap='gray')
    axs[3].set_title('y gen')

    # Remove axis labels
    for ax in axs:
        ax.axis('off')

    # Display the plots
    plt.show()

get_four_plots(x_all, y_all, gen_x_all, gen_y_all)

In [None]:
def get_p_value_stat_2(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, boor_rv_type="gaussian", sigma_z=1.0, sigma_x=1.0, sigma_y=1.0):

    # w_mx = get_distance_matrix(z_torch, z_torch)
    # sigma_z = torch.median(w_mx).item()

    # u_mx = get_distance_matrix(x_torch, x_torch)
    # sigma_x = torch.median(u_mx).item()

    # v_mx = get_distance_matrix(y_torch, y_torch)
    # sigma_y = torch.median(v_mx).item()

    d_y = y_torch.shape[1]
    d_x = x_torch.shape[1]


    w_mx = torch.exp(-get_distance_matrix(z_torch, z_torch) / sigma_z)
    u_mx_1 = torch.exp(-get_distance_matrix(x_torch, x_torch) / sigma_x)
    u_mx_2 = torch.exp(-get_distance_matrix(gen_x_all_torch[:,0,:], x_torch) / sigma_x)
    for i in range(1, M):
        u_mx_2 = u_mx_2 + torch.exp(-get_distance_matrix(gen_x_all_torch[:,i,:], x_torch) / sigma_x)
    u_mx_2 = u_mx_2 / M

    u_mx_3 = u_mx_2.T

    v_mx_1 = torch.exp(-get_distance_matrix(y_torch, y_torch) / sigma_y)
    v_mx_2 = torch.exp(-get_distance_matrix(gen_y_all_torch[:,0,:], y_torch) / sigma_y)
    for i in range(1, M):
        v_mx_2 = v_mx_2 + torch.exp(-get_distance_matrix(gen_y_all_torch[:,i,:], y_torch) / sigma_y)
    v_mx_2 = v_mx_2 / M

    v_mx_3 = v_mx_2.T

    sum_mx_temp = torch.zeros(n, n, M).to(device)

    for i in range(n):
        sum_mx_temp[i,:,:] = torch.linalg.vector_norm(gen_y_all_torch.reshape(n*M,d_y) - gen_y_all_torch[i,0,:].reshape(1,d_y).expand(n*M, -1), ord = 1, dim = 1).reshape(n, M)

    sum_mx = torch.mean(torch.exp(-sum_mx_temp/ sigma_y), dim=2)

    sum2_mx_temp = torch.zeros(n, n, M).to(device)
    for i in range(n):
        sum2_mx_temp[i,:,:] = torch.linalg.vector_norm(gen_x_all_torch.reshape(n*M,d_x) - gen_x_all_torch[i,0,:].reshape(1,d_x).expand(n*M, -1), ord = 1, dim = 1).reshape(n, M)

    sum2_mx = torch.mean(torch.exp(-sum2_mx_temp/ sigma_x), dim=2)

    for k in range(1, M):
        sum_mx_temp = torch.zeros(n, n, M).to(device)
        sum2_mx_temp = torch.zeros(n, n, M).to(device)
        for i in range(n):
            sum_mx_temp[i,:,:] = torch.linalg.vector_norm(gen_y_all_torch.reshape(n*M,d_y) - gen_y_all_torch[i,k,:].reshape(1,d_y).expand(n*M, -1), ord = 1, dim = 1).reshape(n, M)
            sum2_mx_temp[i,:,:] = torch.linalg.vector_norm(gen_x_all_torch.reshape(n*M,d_x) - gen_x_all_torch[i,k,:].reshape(1,d_x).expand(n*M, -1), ord = 1, dim = 1).reshape(n, M)

        temp_add_mx = torch.mean(torch.exp(-sum_mx_temp/ sigma_y), dim=2)
        temp2_add_mx = torch.mean(torch.exp(-sum2_mx_temp/ sigma_x), dim=2)
        sum_mx = sum_mx + temp_add_mx
        sum2_mx = sum2_mx + temp2_add_mx

    u_mx_4 = 1 / M * sum2_mx
    v_mx_4 = 1 / M * sum_mx

    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4

    v_mx = v_mx_1 - v_mx_2 - v_mx_3 + v_mx_4

    FF_mx = u_mx * v_mx * w_mx * (1 - torch.eye(n).to(device))

    stat = 1 / (n - 1) * torch.sum(FF_mx).item()

    boottemp = np.array([])
    if boor_rv_type == "rademacher":
        eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
        eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
        random_mx = torch.matmul(eboot[:, bb].reshape(-1, 1), eboot[:, bb].reshape(-1, 1).T)
        bootmatrix = FF_mx * random_mx
        stat_boot = 1 / (n - 1) * torch.sum(bootmatrix).item()
        boottemp = np.append(boottemp, stat_boot)
    return stat, boottemp, u_mx, v_mx, w_mx


In [None]:
z_all = z_all
y_all_true = y_all.clone()
gen_y_all_true = gen_y_all.clone()
x_all_true = x_all.clone()
gen_x_all_true = gen_x_all.clone()

In [None]:
def get_folder_p_vals2_r(gen_x_all, gen_y_all, x_all, y_all, z_all, M, run_all=False, Total_num_p_val=100, boor_rv_type='rademacher'):

    w_mx = get_distance_matrix(z_all, z_all)
    sigma_z = torch.median(w_mx).item()

    u_mx = get_distance_matrix(x_all, x_all)
    sigma_x = torch.median(u_mx).item()

    v_mx = get_distance_matrix(y_all, y_all)
    sigma_y = torch.median(v_mx).item()
    # Total_num_p_val = 100
    n_length_input = int(test_size/Total_num_p_val)
    p_val_list = []
    count = 0
    total_run = Total_num_p_val if run_all else 5
    for i in tqdm(range(0, total_run)):

        boot_num = 10000
        # boor_rv_type = 'rademacher' # 'rademacher' 'gaussian'

        n_length = n_length_input
        start_index = n_length_input*(i)
        end_index = start_index + n_length

        gen_x_all_in = gen_x_all[start_index:end_index,].to(device).detach()
        gen_y_all_in = gen_y_all[start_index:end_index,].to(device).detach()
        x_all_in = x_all[start_index:end_index,].to(device).detach()
        y_all_in = y_all[start_index:end_index,].to(device).detach()
        z_all_in = z_all[start_index:end_index,].to(device).detach()

        cur_stat, cur_boot_temp, u_mx, v_mx, w_mx = get_p_value_stat_2(boot_num=boot_num, M=M, n=n_length,
                      gen_x_all_torch=gen_x_all_in, gen_y_all_torch=gen_y_all_in,
                      x_torch=x_all_in, y_torch=y_all_in, z_torch=z_all_in,
                      boor_rv_type=boor_rv_type, sigma_z=sigma_z, sigma_x=sigma_x, sigma_y=sigma_y)
        p_val = np.mean( cur_boot_temp > cur_stat )
        if count < 5:
            print("the ",start_index," has p value: ",p_val)

        count += 1

        p_val_list.append(p_val)
    print("the mean is ", np.mean(p_val_list), "the median is ", np.median(p_val_list))
    print("the 25% is ", np.percentile(p_val_list, 25), "the 75% is ", np.percentile(p_val_list, 75))


In [None]:
y_all = y_all_true
gen_y_all = gen_y_all_true

x_all = x_all_true
gen_x_all = gen_x_all_true

In [None]:
get_folder_p_vals2_r(gen_x_all, gen_y_all, x_all, y_all, z_all, M, run_all=True, Total_num_p_val=50, boor_rv_type='rademacher')