In [None]:
latent_space_dim = 64

import torch
import torch.nn as nn
from torch.optim import SGD
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

import torch.distributions as TD
from zmq import device
import torch.optim as optim
from datetime import datetime
import functools
from tqdm import tqdm

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

import torch.nn as nn

class avg_pooling_cov_net(nn.Module):
    def __init__(self, dl = 100):
        super(avg_pooling_cov_net, self).__init__()
        if dl == 100:
            self.layer1 = nn.Sequential(
                nn.AdaptiveAvgPool2d((10, 10))
            )
        if dl == 81:
            self.layer1 = nn.Sequential(
                nn.AdaptiveAvgPool2d((9, 9))
            )
        if dl == 64:
            self.layer1 = nn.Sequential(
                nn.AdaptiveAvgPool2d((8, 8))
            )
        if dl == 49:
            self.layer1 = nn.Sequential(
                nn.AdaptiveAvgPool2d((7, 7)))
        if dl == 36:
            self.layer1 = nn.Sequential(
                nn.AdaptiveAvgPool2d((6, 6)))
        if dl == 25:
            self.layer1 = nn.Sequential(
                nn.AdaptiveAvgPool2d((5, 5)))
        if dl == 16:
            self.layer1 = nn.Sequential(
                nn.AdaptiveAvgPool2d((4, 4)))
        if dl == 9:
            self.layer1 = nn.Sequential(
                nn.AdaptiveAvgPool2d((3, 3)))
        self.flatten = nn.Flatten()

    def forward(self, x):
        out = self.layer1(x)
        out = self.flatten(out)
        return out
class CTDataset_all(Dataset):
    def __init__(self, filepath, avg_pooling_model):
        self.flatten = nn.Flatten()
        self.x, self.y = torch.load(filepath)
        self.x = self.x / 255.
        self.z = self.flatten(self.x)
        avg_pooling_model.eval()
        with torch.no_grad():
            self.x = avg_pooling_model(self.x)
        self.y = F.one_hot(self.y, num_classes=10).to(float)
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, ix):
        return self.x[ix], self.y[ix], self.z[ix]

class CTDataset(Dataset):
    def __init__(self, filepath, avg_pooling_model):
        self.x, self.y = torch.load(filepath)
        self.x = self.x / 255.
        avg_pooling_model.eval()
        with torch.no_grad():
            self.x = avg_pooling_model(self.x)
        self.y = F.one_hot(self.y, num_classes=10).to(float)
        # self.y = self.y.to(float)
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, ix):
        return self.x[ix], self.y[ix]

class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)


class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()

    def forward(self, x):
        return x[:, :, :28, :28]


class Generator_image(torch.nn.Module):
    """
    Specify the neural network architecture of the Generator.

    Here, we consider a FNN with a fully connected hidden layer with a width of 50,
    which is followed by a Leaky ReLU activation. The coefficient of Leaky ReLU needs to be
    specified. Batch normalization may be added prior to the activation function.
    The output layer a fully connected layer without activation.

    Inputs:
    - input_dimension: Integer giving the dimension of input X.
    - output_dimension: Integer giving the dimension of output Y.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Scalar giving the coefficient of the Leaky ReLU layer.

    Returns:
    - x: PyTorch Tensor containing the (output_dimension,) output of the discriminator.
    """

    def __init__(self, input_dimension, noise_dimension):
      super(Generator_image, self).__init__()
      self.flatten = nn.Flatten()
      self.decoder = nn.Sequential(
              torch.nn.Linear(input_dimension + noise_dimension, 3136),
              Reshape(-1, 64, 7, 7),
              nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
              nn.LeakyReLU(0.01),
              nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
              nn.LeakyReLU(0.01),
              nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0),
              nn.LeakyReLU(0.01),
              nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0),
              Trim(),  # 1x29x29 -> 1x28x28
              nn.Sigmoid()
              )

    def forward(self, x):
      x = self.decoder(x)
      x = self.flatten(x)# 1x28x28 -> 1x784
      return x

class Generator(torch.nn.Module):
    """
    Specify the neural network architecture of the Generator.

    Here, we consider a FNN with a fully connected hidden layer with a width of 50,
    which is followed by a Leaky ReLU activation. The coefficient of Leaky ReLU needs to be
    specified. Batch normalization may be added prior to the activation function.
    The output layer a fully connected layer without activation.

    Inputs:
    - input_dimension: Integer giving the dimension of input X.
    - output_dimension: Integer giving the dimension of output Y.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Scalar giving the coefficient of the Leaky ReLU layer.

    Returns:
    - x: PyTorch Tensor containing the (output_dimension,) output of the discriminator.
    """

    def __init__(self, input_dimension, output_dimension, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p,
                 drop_input = False):
      super(Generator, self).__init__()
      self.BN_type = BN_type
      self.ReLU_coef = ReLU_coef
      self.fc1 = torch.nn.Linear(input_dimension + noise_dimension, hidden_layer_size, bias=True)
      if BN_type:
        self.BN1 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN2 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN3 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
      self.leakyReLU1 = torch.nn.LeakyReLU(ReLU_coef)
      self.fc2 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc3 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc_last = torch.nn.Linear(hidden_layer_size, output_dimension, bias=True)
      self.sigmoid = torch.nn.Sigmoid()
      self.drop_out0 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out1 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out2 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out3 = torch.nn.Dropout(p=drop_out_p)
      self.drop_input = drop_input
      self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
      if self.BN_type:
        if self.drop_input:
            x = self.drop_out0(x)
        x = self.drop_out1(self.leakyReLU1(self.BN1(self.fc1(x))))
        x = self.drop_out2(self.leakyReLU1(self.BN2(self.fc2(x))))
        # x = self.drop_out3(self.leakyReLU1(self.BN3(self.fc3(x))))
        x = self.fc_last(x)
        x = self.softmax(x)
      else:
        if self.drop_input:
            x = self.drop_out0(x)
        x = self.drop_out1(self.leakyReLU1(self.fc1(x)))
        x = self.drop_out2(self.leakyReLU1(self.fc2(x)))
        # x = self.drop_out3(self.leakyReLU1(self.fc3(x)))
        x = self.fc_last(x)
        # x = self.sigmoid(x)
        x = self.softmax(x)

      return x


##### Auxilliary functions #####

def sample_noise(sample_size, noise_dimension, noise_type, input_var):
    """
    Generate a PyTorch Tensor of random noise from the specified reference distribution.

    Input:
    - sample_size: the sample size of noise to generate.
    - noise_dimension: the dimension of noise to generate.
    - noise_type: "normal", "unif" or "Cauchy", giving the reference distribution.

    Output:
    - A PyTorch Tensor of shape (sample_size, noise_dimension).
    """

    if (noise_type == "normal"):
      noise_generator = TD.MultivariateNormal(
        torch.zeros(noise_dimension).to(device), input_var * torch.eye(noise_dimension).to(device))

      Z = noise_generator.sample((sample_size,))
    if (noise_type == "unif"):
      Z = torch.rand(sample_size, noise_dimension)
    if (noise_type == "Cauchy"):
      Z = TD.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])).sample((sample_size, noise_dimension)).squeeze(2)

    return Z

def get_p_value_stat_1(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, sigma_w, sigma_u=1,
            sigma_v=1, boor_rv_type="gaussian", dy_g=10, dx_g = 28*28):

    w_mx = torch.zeros(n, n).to(device)

    for i in range(n):
        w_mx[i,:] = torch.linalg.vector_norm(z_torch[i].reshape(1,-1) - z_torch, ord = 1, dim = 1)

    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_temp = torch.zeros(n, n).to(device)

    for i in range(n):
        u_mx_temp[i,:] = torch.linalg.vector_norm(y_torch[i].reshape(1,-1) - y_torch, ord = 1, dim = 1)

    u_mx_1 = torch.exp(-u_mx_temp / sigma_u)

    u_mx_temp_2 = torch.zeros(n, n, M).to(device)
    for i in range(n):
        for j in range(n):
            u_mx_temp_2[i,j,:] = torch.linalg.vector_norm(y_torch[i].reshape(1,-1) - gen_y_all_torch[j,], ord = 1, dim = 1)

    u_mx_2 = torch.mean( torch.exp(-u_mx_temp_2 / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    sum_mx_temp = torch.zeros(n, n, M).to(device)
    for i in range(n):
        for j in range(n):
            sum_mx_temp[i,j,:] = torch.linalg.vector_norm(gen_y_all_torch[j,:,:].reshape(1,M,dy_g) - gen_y_all_torch[i,0,:].reshape(1,1,dy_g), ord = 1, dim = 2)

    sum_mx = torch.mean(torch.exp(-sum_mx_temp/ sigma_u), dim=2)

    v_mx_temp = torch.zeros(n, n).to(device)

    for i in range(n):
        v_mx_temp[i,:] = torch.linalg.vector_norm(x_torch[i].reshape(1,-1) - x_torch, ord = 1, dim = 1)

    v_mx_1 = torch.exp(-v_mx_temp / sigma_v)

    v_mx_temp_2 = torch.zeros(n, n, M).to(device)
    for i in range(n):
        for j in range(n):
            v_mx_temp_2[i,j,:] = torch.linalg.vector_norm(x_torch[i].reshape(1,-1) - gen_x_all_torch[j,], ord = 1, dim = 1)

    v_mx_2 = torch.mean( torch.exp(-v_mx_temp_2 / sigma_v), dim=2)
    v_mx_3 = v_mx_2.T

    sum2_mx_temp = torch.zeros(n, n, M).to(device)
    for i in range(n):
        for j in range(n):
            sum2_mx_temp[i,j,:] = torch.linalg.vector_norm(gen_x_all_torch[j,:,:].reshape(1,M,dx_g) - gen_x_all_torch[i,0,:].reshape(1,1,dx_g), ord = 1, dim = 2)

    sum2_mx = torch.mean(torch.exp(-sum2_mx_temp/ sigma_v), dim=2)

    for k in tqdm(range(1, M)):
        sum_mx_temp = torch.zeros(n, n, M).to(device)
        sum2_mx_temp = torch.zeros(n, n, M).to(device)
        for i in range(n):
            for j in range(n):
                sum_mx_temp[i,j,:] = torch.linalg.vector_norm(gen_y_all_torch[j,:,:].reshape(1,M,dy_g) - gen_y_all_torch[i,k,:].reshape(1,1,dy_g), ord = 1, dim = 2)
                sum2_mx_temp[i,j,:] = torch.linalg.vector_norm(gen_x_all_torch[j,:,:].reshape(1,M,dx_g) - gen_x_all_torch[i,k,:].reshape(1,1,dx_g), ord = 1, dim = 2)

        temp_add_mx = torch.mean(torch.exp(-sum_mx_temp/ sigma_u), dim=2)
        temp2_add_mx = torch.mean(torch.exp(-sum2_mx_temp/ sigma_v), dim=2)
        sum_mx = sum_mx + temp_add_mx
        sum2_mx = sum2_mx + temp2_add_mx

    u_mx_4 = 1 / M * sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4
    v_mx_4 = 1 / M * sum2_mx
    v_mx = v_mx_1 - v_mx_2 - v_mx_3 + v_mx_4

    FF_mx = u_mx * v_mx * w_mx * (1 - torch.eye(n).to(device))

    stat = 1 / (n - 1) * torch.sum(FF_mx).item()

    boottemp = np.array([])
    torch.manual_seed(42)
    if boor_rv_type == "rademacher":
        eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
        eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
        random_mx = torch.matmul(eboot[:, bb].reshape(-1, 1), eboot[:, bb].reshape(-1, 1).T)
        bootmatrix = FF_mx * random_mx
        stat_boot = 1 / (n - 1) * torch.sum(bootmatrix).item()
        boottemp = np.append(boottemp, stat_boot)
    return stat, boottemp

noise_dimension_image = 1
noise_dimension_label = 1
input_noise_type = "normal"
avg_pooling_model = avg_pooling_cov_net(dl = latent_space_dim).to(device)
torch.manual_seed(42)
train_ds = CTDataset('./training.pt', avg_pooling_model)

torch.manual_seed(42)
train_AE_set, train_cond_gen_set = torch.utils.data.random_split(train_ds, [30000, 30000])
train_ds = train_cond_gen_set
DataLoader_train = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True, drop_last= False)

xs, ys = train_ds[0:10000]
pairwise_distance_x = torch.zeros(xs.shape[0], xs.shape[0])

for i in tqdm(range(xs.shape[0])):
    pairwise_distance_x[i,:] = torch.linalg.vector_norm(xs[i].reshape(1,-1) - xs, ord = 1, dim = 1)

sigma_w_train = torch.median(pairwise_distance_x).item()
print(f"[sigma_w_train {sigma_w_train}]")


sigma_w_train = sigma_w_train # for z
sigma_v_train = 130.29019165039062 # for x
sigma_u_train = 2.0 # for y

torch.manual_seed(42)

test_ds = CTDataset_all('./test.pt', avg_pooling_model)

DataLoader_test = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=True, drop_last= False, )

G_image = Generator_image(latent_space_dim,  noise_dimension_image).to(device)
G_image.load_state_dict(torch.load('./AE'+str(latent_space_dim)+'_image.pth'))

G_label = Generator(input_dimension = latent_space_dim, output_dimension = 10, noise_dimension = noise_dimension_label,
           hidden_layer_size = 512, BN_type = True, ReLU_coef = 0.5, drop_out_p= 0.2).to(device)
G_label.load_state_dict(torch.load('./AE'+str(latent_space_dim)+'_label.pth'))

M = 100
test_size = 10000
Total_num_p_val = 40

gen_x_all = torch.zeros(test_size, M, 28*28)
gen_y_all = torch.zeros(test_size, M, 10)
z_all = torch.zeros(test_size, latent_space_dim)
x_all = torch.zeros(test_size, 28*28)
y_all = torch.zeros(test_size, 10)


G_label = G_label.eval()
G_image = G_image.eval()
for i, (z_test, y_test, x_test) in tqdm(enumerate(DataLoader_test)):
    Z_test_repeat = z_test.repeat(M,1).to(device).detach()
    Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension_label, input_noise_type, input_var = 1.0/3.0).to(device)
    with torch.no_grad():
        gen_y = G_label(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device).detach()

    Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension_image, input_noise_type, input_var = 1.0/3.0).to(device)
    with torch.no_grad():
        gen_x = G_image(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device).detach()

    gen_x = gen_x.reshape(1, M, 28*28).detach().to(device)
    gen_y = gen_y.reshape(1, M, 10).detach().to(device)

    x_all[i,:] = x_test
    y_all[i,:] = y_test
    z_all[i,:] = z_test

    gen_x_all[i,:] = gen_x
    gen_y_all[i,:] = gen_y

100%|██████████| 10000/10000 [00:34<00:00, 289.86it/s]


[sigma_w_train 8.039461135864258]


10000it [00:38, 262.91it/s]


In [None]:
n_length_input = int(test_size/Total_num_p_val)
p_val_list = []

for i in range(0, Total_num_p_val):
    sigma_w = sigma_w_train
    sigma_u = sigma_u_train
    sigma_v = sigma_v_train

    boot_num = 1000
    boor_rv_type = 'gaussian'

    n_length = n_length_input
    start_index = n_length_input*(i)
    end_index = start_index + n_length

    gen_x_all_in = gen_x_all[start_index:end_index,].to(device).detach()
    gen_y_all_in = gen_y_all[start_index:end_index,].to(device).detach()
    x_all_in = x_all[start_index:end_index,].to(device).detach()
    y_all_in = y_all[start_index:end_index,].to(device).detach()
    z_all_in = z_all[start_index:end_index,].to(device).detach()

    cur_stat, cur_boot_temp = get_p_value_stat_1(boot_num, M, n_length, gen_x_all_in, gen_y_all_in,
                            x_all_in, y_all_in, z_all_in, sigma_w, sigma_u, sigma_v,
                            boor_rv_type)
    p_val = np.mean( cur_boot_temp > cur_stat )
    print("the ",start_index," has p value: ",p_val)
    p_val_list.append(p_val)

100%|██████████| 99/99 [12:57<00:00,  7.86s/it]


the  0  has p value:  0.032


100%|██████████| 99/99 [12:59<00:00,  7.88s/it]


the  250  has p value:  0.074


100%|██████████| 99/99 [13:02<00:00,  7.90s/it]


the  500  has p value:  0.597


100%|██████████| 99/99 [13:05<00:00,  7.94s/it]


the  750  has p value:  0.367


100%|██████████| 99/99 [13:04<00:00,  7.93s/it]


the  1000  has p value:  0.018


100%|██████████| 99/99 [13:10<00:00,  7.98s/it]


the  1250  has p value:  0.027


100%|██████████| 99/99 [13:09<00:00,  7.98s/it]


the  1500  has p value:  0.225


100%|██████████| 99/99 [13:14<00:00,  8.03s/it]


the  1750  has p value:  0.005


100%|██████████| 99/99 [13:16<00:00,  8.05s/it]


the  2000  has p value:  0.016


100%|██████████| 99/99 [13:08<00:00,  7.96s/it]


the  2250  has p value:  0.051


100%|██████████| 99/99 [13:06<00:00,  7.94s/it]


the  2500  has p value:  0.864


100%|██████████| 99/99 [13:01<00:00,  7.89s/it]


the  2750  has p value:  0.279


100%|██████████| 99/99 [12:58<00:00,  7.87s/it]


the  3000  has p value:  0.057


100%|██████████| 99/99 [12:57<00:00,  7.86s/it]


the  3250  has p value:  0.304


100%|██████████| 99/99 [13:34<00:00,  8.23s/it]


the  3500  has p value:  0.169


100%|██████████| 99/99 [13:18<00:00,  8.07s/it]


the  3750  has p value:  0.068


100%|██████████| 99/99 [13:14<00:00,  8.03s/it]


the  4000  has p value:  0.008


100%|██████████| 99/99 [13:15<00:00,  8.04s/it]


the  4250  has p value:  0.053


100%|██████████| 99/99 [13:13<00:00,  8.02s/it]


the  4500  has p value:  0.015


100%|██████████| 99/99 [13:29<00:00,  8.18s/it]


the  4750  has p value:  0.334


100%|██████████| 99/99 [13:21<00:00,  8.10s/it]


the  5000  has p value:  0.193


100%|██████████| 99/99 [13:14<00:00,  8.02s/it]


the  5250  has p value:  0.098


100%|██████████| 99/99 [13:20<00:00,  8.08s/it]


the  5500  has p value:  0.037


100%|██████████| 99/99 [13:21<00:00,  8.09s/it]


the  5750  has p value:  0.058


100%|██████████| 99/99 [13:07<00:00,  7.95s/it]


the  6000  has p value:  0.861


100%|██████████| 99/99 [12:59<00:00,  7.88s/it]


the  6250  has p value:  0.08


100%|██████████| 99/99 [12:57<00:00,  7.86s/it]


the  6500  has p value:  0.037


100%|██████████| 99/99 [12:57<00:00,  7.86s/it]


the  6750  has p value:  0.078


100%|██████████| 99/99 [12:58<00:00,  7.86s/it]


the  7000  has p value:  0.843


100%|██████████| 99/99 [12:59<00:00,  7.87s/it]


the  7250  has p value:  0.122


100%|██████████| 99/99 [12:56<00:00,  7.85s/it]


the  7500  has p value:  0.157


100%|██████████| 99/99 [13:13<00:00,  8.02s/it]


the  7750  has p value:  0.917


100%|██████████| 99/99 [13:12<00:00,  8.01s/it]


the  8000  has p value:  0.097


100%|██████████| 99/99 [13:14<00:00,  8.03s/it]


the  8250  has p value:  0.03


100%|██████████| 99/99 [13:23<00:00,  8.11s/it]


the  8500  has p value:  0.857


100%|██████████| 99/99 [13:11<00:00,  8.00s/it]


the  8750  has p value:  0.233


100%|██████████| 99/99 [13:19<00:00,  8.07s/it]


the  9000  has p value:  0.23


100%|██████████| 99/99 [13:11<00:00,  8.00s/it]


the  9250  has p value:  0.156


100%|██████████| 99/99 [13:01<00:00,  7.90s/it]


the  9500  has p value:  0.038


100%|██████████| 99/99 [13:08<00:00,  7.96s/it]

the  9750  has p value:  0.026





In [None]:
p_val_list

[0.032,
 0.074,
 0.597,
 0.367,
 0.018,
 0.027,
 0.225,
 0.005,
 0.016,
 0.051,
 0.864,
 0.279,
 0.057,
 0.304,
 0.169,
 0.068,
 0.008,
 0.053,
 0.015,
 0.334,
 0.193,
 0.098,
 0.037,
 0.058,
 0.861,
 0.08,
 0.037,
 0.078,
 0.843,
 0.122,
 0.157,
 0.917,
 0.097,
 0.03,
 0.857,
 0.233,
 0.23,
 0.156,
 0.038,
 0.026]

In [None]:
np.quantile(p_val_list, 0.25), np.median(p_val_list), np.quantile(p_val_list, 0.75)

(0.037, 0.0885, 0.24450000000000002)

In [None]:
np.mean([p_val < 0.05 for p_val in p_val_list])

0.3