In [1]:
import numpy as np
import torch
from sklearn.utils import check_random_state
from utils import *
from matplotlib import pyplot as plt
import torch.nn as nn
dtype = torch.float
device = torch.device("cuda:0")

class ConvNet_CIFAR10(nn.Module):
    def __init__(self):
        super(ConvNet_CIFAR10, self).__init__()
        def discriminator_block(in_filters, out_filters, bn=True):
            block =([nn.Conv2d(in_filters, out_filters, 3, 2, 1), 
                     nn.LeakyReLU(0.2, inplace=True),  
                     nn.Dropout2d(0)])
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        self.model = nn.Sequential(
            nn.Unflatten(1,(3,32,32)),
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        ds_size = 2
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 300))
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)
        return feature
    
def crit(mmd_val, mmd_var, liuetal=True, Sharpe=False):
    if liuetal:
        mmd_std_temp = torch.sqrt(mmd_var+10**(-8)) #this is std
        return -1 * torch.div(mmd_val, mmd_std_temp)
    elif Sharpe:
        return mmd_val - 2.0 * mmd_var

def mmdGT(X, Y, model_u, n, sigma, sigma0, ep):
    S = torch.cat((X, Y), dim=0)
    Fea = model_u(S)
    n = X.shape[0]
    return MMD_General(Fea, n, S, sigma, sigma0, ep, use1sample=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_diffusion_cifar_32():
    diffusion = np.load("../../Diffusion/ddpm_generated_images2.npy").transpose(0,3,1,2)
    cifar10 = np.load('../../data/cifar_data.npy')
    dataset_P = diffusion.reshape(diffusion.shape[0], -1)
    dataset_Q = cifar10.reshape(cifar10.shape[0], -1)
    return dataset_P, dataset_Q[:10000, :], dataset_Q[10000:, :]

DP, DQ_1, DQ_2 = load_diffusion_cifar_32()
mix_rate=2 #For each DP, match with mix_rate*DQ data points

test_DP1=np.concatenate((DP[:2000, :], DQ_1[:4000, :]), axis=0)
test_DQ1=DQ_1[4000: 10000, :]

train_DP1=np.concatenate((DP[2000:7000, :], DQ_2[:10000, :]), axis=0)
train_DQ1=DQ_2[10000: 25000, :]

print(train_DP1.shape, train_DQ1.shape, test_DP1.shape, test_DQ1.shape)
#generate a random shuffle over train_DP1, print the first item
train_DP1 = train_DP1[np.random.choice(train_DP1.shape[0], train_DP1.shape[0], replace=False), :]
train_DQ1 = train_DQ1[np.random.choice(train_DQ1.shape[0], train_DQ1.shape[0], replace=False), :]
test_DP1 = test_DP1[np.random.choice(test_DP1.shape[0], test_DP1.shape[0], replace=False), :]
test_DQ1 = test_DQ1[np.random.choice(test_DQ1.shape[0], test_DQ1.shape[0], replace=False), :]

DP1_t = MatConvert(train_DP1, device, dtype)
DQ1_t = MatConvert(train_DQ1, device, dtype)
DP2_t = MatConvert(test_DP1, device, dtype)
DQ2_t = MatConvert(test_DQ1, device, dtype)

def gen_fun1(n): #n at most 15000
    X = train_DP1[np.random.choice(train_DP1.shape[0], n, replace=False), :]
    Y = train_DQ1[np.random.choice(train_DQ1.shape[0], n, replace=False), :]
    return X, Y
def gen_fun2(n): #n at most 6000
    X = test_DP1[np.random.choice(test_DP1.shape[0], n, replace=False), :]
    Y = test_DQ1[np.random.choice(test_DQ1.shape[0], n, replace=False), :]
    return X, Y

def gen_fun1_t(n):
    X = DP1_t[np.random.choice(DP1_t.shape[0], n, replace=False), :]
    Y = DQ1_t[np.random.choice(DQ1_t.shape[0], n, replace=False), :]
    return X, Y

def gen_fun2_t(n):
    X = DP2_t[np.random.choice(DP2_t.shape[0], n, replace=False), :]
    Y = DQ2_t[np.random.choice(DQ2_t.shape[0], n, replace=False), :]
    return X, Y

(15000, 3072) (15000, 3072) (6000, 3072) (6000, 3072)


In [3]:
def train_d(n, learning_rate=5e-4, N_epoch=50, print_every=20, batch_size=32):  
    batches=n//batch_size
    assert n%batch_size==0
    print("##### Starting N_epoch=%d epochs per data trial #####"%(N_epoch))
    if True:
        X, Y = gen_fun1(n)
        total_S=[(X[i*batch_size:i*batch_size+batch_size], 
                    Y[i*batch_size:i*batch_size+batch_size]) for i in range(batches)]
        total_S=[MatConvert(np.concatenate((X, Y), axis=0), device, dtype) for (X, Y) in total_S]
        model_u = ConvNet_CIFAR10().cuda()
        epsilonOPT = MatConvert(np.array([-1.0]), device, dtype)
        epsilonOPT.requires_grad = True
        sigmaOPT = MatConvert(np.array([10000.0]), device, dtype)
        sigmaOPT.requires_grad = True
        sigma0OPT = MatConvert(np.array([0.1]), device, dtype)
        sigma0OPT.requires_grad = True
        cst=MatConvert(np.ones((1,)), device, dtype) # set to 1 to meet liu etal objective
        optimizer_u = torch.optim.Adam(list(model_u.parameters())+[epsilonOPT]+[sigmaOPT]+[sigma0OPT], lr=learning_rate)
        for t in range(N_epoch):
            for ind in range(batches):
                ep = torch.exp(epsilonOPT)/(1+torch.exp(epsilonOPT))
                sigma = sigmaOPT ** 2
                sigma0_u = sigma0OPT ** 2
                S=total_S[ind]
                modelu_output = model_u(S) 
                TEMP = MMDu(modelu_output, batch_size, S, sigma, sigma0_u, ep, cst)
                mmd_val = TEMP[0]
                mmd_var = TEMP[1]
                STAT_u = crit(mmd_val, mmd_var) 
                optimizer_u.zero_grad()
                STAT_u.backward(retain_graph=True)
                optimizer_u.step()
        return model_u, torch.exp(epsilonOPT)/(1+torch.exp(epsilonOPT)), sigmaOPT ** 2, sigma0OPT ** 2, torch.tensor(X).to(device, dtype), torch.tensor(Y).to(device, dtype)

In [5]:
n=1920
N=300
m_list=[64, 96, 128, 192, 256, 384]
model_u, ep, sigma, sigma0, X_t, Y_t=train_d(n, learning_rate=5e-4, N_epoch=80, print_every=20, batch_size=32)
with torch.no_grad():
        H_u = np.zeros(N) 
        H_v = np.zeros(N)
        R_u = np.zeros(N)
        R_v = np.zeros(N)
        P_u = np.zeros(N)
        P_v = np.zeros(N)
        print("Under this trained kernel, we run N = %d times LFI: "%N)
        for i in range(len(m_list)):
            print("start testing m = %d"%m_list[i])
            m = m_list[i]
            for k in range(N):     
                stat=[]
                for j in range(100):
                        Z_temp, _ = gen_fun2_t(m) ###DEMO PURPOSES ONLY: check Phase 2, Algorithm 1,  in our paper
                        mmd_XZ = mmdGT(X_t, Z_temp, model_u, n, sigma, sigma0, ep)[0] 
                        mmd_YZ = mmdGT(Y_t, Z_temp, model_u, n, sigma, sigma0, ep)[0]
                        stat.append(float(mmd_XZ - mmd_YZ))
                stat = np.sort(stat)
                thres = stat[94]
                Z1, Z2 = gen_fun2_t(m)
                mmd_XZ = mmdGT(X_t, Z1, model_u, n, sigma, sigma0, ep)[0] 
                mmd_YZ = mmdGT(Y_t, Z1, model_u, n, sigma, sigma0, ep)[0] 
                H_u[k] = mmd_XZ - mmd_YZ < 0.0
                R_u[k] = mmd_XZ - mmd_YZ < thres
                P_u[k] = np.searchsorted(stat, float(mmd_XZ - mmd_YZ), side="right")/100.0
                mmd_XZ = mmdGT(X_t, Z2, model_u, n, sigma, sigma0, ep)[0] 
                mmd_YZ = mmdGT(Y_t, Z2, model_u, n, sigma, sigma0, ep)[0] 
                H_v[k] = mmd_XZ - mmd_YZ > 0.0
                R_v[k] = mmd_XZ - mmd_YZ > thres
                P_v[k] = np.searchsorted(stat, float(mmd_XZ - mmd_YZ), side="right")/100.0
            print("n, m=",str(n)+str('  ')+str(m),"--- P(max|Z~X): ", H_u.mean())
            print("n, m=",str(n)+str('  ')+str(m),"--- P(max|Z~Y): ", H_v.mean())
            print("n, m=",str(n)+str('  ')+str(m),"--- P(95|Z~X): ", R_u.mean())
            print("n, m=",str(n)+str('  ')+str(m),"--- P(95|Z~Y): ", R_v.mean())
            print("n, m=",str(n)+str('  ')+str(m),"--- P(p|Z~X): ", P_u.mean())
            print("n, m=",str(n)+str('  ')+str(m),"--- P(p|Z~Y): ", P_v.mean())

##### Starting N_epoch=80 epochs per data trial #####
Under this trained kernel, we run N = 300 times LFI: 
start testing m = 64
n, m= 1920  64 --- P(max|Z~X):  0.7266666666666667
n, m= 1920  64 --- P(max|Z~Y):  0.7266666666666667
n, m= 1920  64 --- P(95|Z~X):  0.9566666666666667
n, m= 1920  64 --- P(95|Z~Y):  0.35
n, m= 1920  64 --- P(p|Z~X):  0.4851666666666667
n, m= 1920  64 --- P(p|Z~Y):  0.7778666666666667
start testing m = 96
n, m= 1920  96 --- P(max|Z~X):  0.7333333333333333
n, m= 1920  96 --- P(max|Z~Y):  0.8133333333333334
n, m= 1920  96 --- P(95|Z~X):  0.9266666666666666
n, m= 1920  96 --- P(95|Z~Y):  0.43666666666666665
n, m= 1920  96 --- P(p|Z~X):  0.4919333333333333
n, m= 1920  96 --- P(p|Z~Y):  0.8507666666666668
start testing m = 128
n, m= 1920  128 --- P(max|Z~X):  0.71
n, m= 1920  128 --- P(max|Z~Y):  0.8666666666666667
n, m= 1920  128 --- P(95|Z~X):  0.9333333333333333
n, m= 1920  128 --- P(95|Z~Y):  0.52
n, m= 1920  128 --- P(p|Z~X):  0.5135666666666666
n, m= 1920  1

In [6]:
from scipy.stats import binom
def find_percentile(m, p, percentile=0.95):
    return binom.ppf(percentile, m, p)/m
def find_p_value(m, p, p_obs):
    return binom.cdf(p_obs*m, m, p)
def EPV(m, p_1, p_2):
    result=0
    for _ in range(50000):
        #p_obs is random observation Binomial(m, p_2)/m
        p_obs=np.random.binomial(m, p_2)/m
        result+=binom.cdf(p_obs*m, m, p_1)
    return result/50000
def MMD_LFI(Fea, n, Fea_org, sigma, sigma0=0.1, epsilon=10 ** (-10), cst = 1.0, is_smooth=True, one_sample=False):
    X = Fea[0:n, :] # fetch the sample 1 (features of deep networks)
    Y = Fea[n:2*n, :] # fetch the sample 2 (features of deep networks)
    Z = Fea[2*n:, :] # fetch the sample 3 (features of deep networks)
    X_org = Fea_org[0:n, :] # fetch the original sample 1
    Y_org = Fea_org[n:2*n, :] # fetch the original sample 2
    Z_org = Fea_org[2*n:, :] # fetch the original sample 3
    Dxx = Pdist2(X, X)
    Dyy = Pdist2(Y, Y)
    Dxz = Pdist2(X, Z)
    Dyz = Pdist2(Y, Z)
    Dxx_org = Pdist2(X_org, X_org)
    Dyy_org = Pdist2(Y_org, Y_org)
    Dxz_org = Pdist2(X_org, Z_org)
    Dyz_org = Pdist2(Y_org, Z_org)
    Kx = cst*((1-epsilon) * torch.exp(-(Dxx / sigma0) - (Dxx_org / sigma)) + epsilon * torch.exp(-Dxx_org / sigma))
    Ky = cst*((1-epsilon) * torch.exp(-(Dyy / sigma0) - (Dyy_org / sigma)) + epsilon * torch.exp(-Dyy_org / sigma))
    Kxz = cst*((1-epsilon) * torch.exp(-(Dxz / sigma0) - (Dxz_org / sigma)) + epsilon * torch.exp(-Dxz_org / sigma))
    Kyx = cst*((1-epsilon) * torch.exp(-(Dyz / sigma0) - (Dyz_org / sigma)) + epsilon * torch.exp(-Dyz_org / sigma))
    return MMD_LFI_SQUARE(Kx, Ky, Kxz, Kyx, n, len(Fea)-2*n, one_sample_U=one_sample)
def MMD_LFI_SQUARE(Kx, Ky, Kyz, Kxz, batch_n, batch_m, one_sample_U=False):
    nx = batch_n
    nz = batch_m
    if one_sample_U:
        xx = torch.div((torch.sum(Kx) - torch.sum(torch.diag(Kx))), (nx * (nx - 1)))
        yy = torch.div((torch.sum(Ky) - torch.sum(torch.diag(Ky))), (nx * (nx - 1)))
        xz = torch.div((torch.sum(Kxz) - torch.sum(torch.diag(Kxz))), (nx * (nz - 1)))
        yz = torch.div((torch.sum(Kyz) - torch.sum(torch.diag(Kyz))), (nx * (nz - 1)))
        mmd2 = xx - yy + 2* xz - 2* yz
    else:
        xx = torch.div((torch.sum(Kx) - torch.sum(torch.diag(Kx))), (nx * (nx - 1)))
        yy = torch.div((torch.sum(Ky) - torch.sum(torch.diag(Ky))), (nx * (nx - 1)))
        xz = torch.div((torch.sum(Kxz)), (nx * nz))
        yz = torch.div((torch.sum(Kyz)), (nx * nz))
        mmd2 = xx - yy - 2* xz + 2* yz
    return mmd2
def K(x, A, t=1, one_sample=False):
    X_t=A[0]
    Y_t=A[1]
    Z_t=x.reshape(t,-1)
    return float(MMD_LFI(torch.cat((X_t, Y_t, Z_t), 0), len(X_t), torch.cat((X_t, Y_t, Z_t), 0), sigma, sigma0, ep, one_sample=one_sample) )
def find_threshold(arr1, arr2):
    n=len(arr1)
    max_val=0
    max_i=0
    for i in range(int(0.2*n), int(0.8*n)):
        j=np.searchsorted(arr2, arr1[i], side='right')
        val=(i-j)**2/(i*(n-i))
        if val>max_val:
            max_val, max_i=val, i
    return arr1[max_i]
def find_threshold_full(cal_data, eva_data):
    stat_X=[]
    stat_Y=[]
    for i in range(len(eva_data[0])):
        X=eva_data[0][i]
        Y=eva_data[1][i]
        stat_X.append(K(X, cal_data))
        stat_Y.append(K(Y, cal_data))
    sort_X=np.sort(stat_X)
    sort_Y=np.sort(stat_Y)
    return find_threshold(sort_X, sort_Y)
def evaluat(Z, cal_data, t):
    x=0
    A=[]
    for ind in range(len(Z)):
        A.append(K(Z[ind], cal_data))
        x+=(A[-1]<t)
    return x/len(Z)

In [8]:
cal_data=(X_t, Y_t)
eva_data=gen_fun2_t(n)
threshold=find_threshold_full(cal_data, eva_data)
Z1, Z2 = gen_fun2_t(6000) #entire set of test data
p_1=evaluat(Z1, cal_data, threshold)
p_2=evaluat(Z2, cal_data, threshold)
print('p_1, p_2=', p_1, p_2) #the thresholded bias for each hypothesis
m=256
print("n, m=",str(n)+str('  ')+str(m),"--- P(95|Z~Y): ", 1-find_p_value(m, p_2, find_percentile(m, p_1)))
print("n, m=",str(n)+str('  ')+str(m),"--- P(Ep|Z~Y): ", EPV(m, p_1, p_2))

p_1, p_2= 0.4201666666666667 0.4806666666666667
n, m= 1920  256 --- P(95|Z~Y):  0.5765752932341279
n, m= 1920  256 --- P(Ep|Z~Y):  0.9225572315829828
