In [None]:
from mmd import *
import csv
import numpy as np
import torch
import random
import matplotlib.pyplot as plt
import matplotlib
import os
import pywt
from PIL import Image
import tqdm
import functools
import itertools

In [None]:
def determine_bandwidth(data1,data2,metric):
    distances = np.zeros((len(data1),len(data2)))
    for x,point1 in enumerate(data1):
        for y,point2 in enumerate(data2):
            distances[x,y] = metric(point1,point2)
            
    print("Best single KERNEL has bandwidth %.5f" % np.median(distances))
    return np.median(distances)

def gray2haar(img):
    datapoint = np.reshape(img,(28,28))
    haar1 = pywt.dwt2(datapoint,'haar')
    haar1_flat = np.reshape([haar1[1][0],haar1[1][1],haar1[1][2]],(-1))
    haar2 = pywt.dwt2(haar1[0],'haar')
    haar2_flat = np.reshape([haar2[1][0],haar2[1][1],haar2[1][2]],(-1))
    return np.hstack([haar1_flat,haar2_flat])

def rgb2haar(datapoint):
    #reshaped = np.reshape(datapoint,(int((len(datapoint)/3)**0.5),int((len(datapoint)/3)**0.5),3))
    reshaped = datapoint
    r1 = pywt.dwt2(reshaped[:,:,0],'haar')
    r1_flat = np.reshape([r1[1][0],r1[1][1],r1[1][2]],(-1))
    g1 = pywt.dwt2(reshaped[:,:,1],'haar')
    g1_flat = np.reshape([g1[1][0],g1[1][1],g1[1][2]],(-1))
    b1 = pywt.dwt2(reshaped[:,:,2],'haar')
    b1_flat = np.reshape([b1[1][0],b1[1][1],b1[1][2]],(-1))

    r2 = pywt.dwt2(r1[0],'haar')
    r2_flat = np.reshape([r2[1][0],r2[1][1],r2[1][2]],(-1))
    g2 = pywt.dwt2(g1[0],'haar')
    g2_flat = np.reshape([g2[1][0],g2[1][1],g2[1][2]],(-1))
    b2 = pywt.dwt2(b1[0],'haar')
    b2_flat = np.reshape([b2[1][0],b2[1][1],b2[1][2]],(-1))
    
    h1 = np.vstack((r1_flat,g1_flat,b1_flat)).flatten()
    h2 = np.vstack((r2_flat,g2_flat,b2_flat)).flatten()
    ret = np.hstack((h1,h2))
    return ret

In [None]:
ground_images = np.fromfile('lsun_rgb_samples.bin').reshape((1000,-1))
wgan_images = np.fromfile('wgan_samples.bin').reshape((1000,-1))
pgan_images = np.fromfile('pgan_samples.bin').reshape((1000,-1))

ground_images_haar = np.array([rgb2haar(x) for x in ground_images.reshape(1000,256,256,3)])
wgan_images_haar = np.array([rgb2haar(x) for x in wgan_images.reshape(1000,256,256,3)])
pgan_images_haar = np.array([rgb2haar(x) for x in pgan_images.reshape(1000,256,256,3)])

In [None]:
fig,ax = plt.subplots(1,3)
ax[0].imshow(ground_images.reshape(1000,256,256,3)[9])
ax[1].imshow(pgan_images.reshape(1000,256,256,3)[9])
ax[2].imshow(wgan_images.reshape(1000,256,256,3)[9])
plt.show()

In [None]:
#MMD function !Currently, must be Torch compatible
mmd_func = mix_rbf_mmd2
distance_metric = lambda x,y: np.linalg.norm(x-y,2)

bandwidth = None
bandwidth_euc_pgan = determine_bandwidth(pgan_images[:100],ground_images[:100],distance_metric)
bandwidth_haar_pgan = determine_bandwidth(pgan_images_haar[:100],ground_images_haar[:100],distance_metric)

bandwidth_euc_wgan = determine_bandwidth(wgan_images[:100],ground_images[:100],distance_metric)
bandwidth_haar_wgan = determine_bandwidth(wgan_images_haar[:100],ground_images_haar[:100],distance_metric)
#Permutations to run in PTest !250 is considered a relatively sufficient number.
permutation_count = 250

#Run all cells.

In [None]:
def run_permutation_test(pooled):
    np.random.shuffle(pooled)
    x_star = pooled[0:len(pooled)//2]
    y_star = pooled[len(pooled)//2:]
    
    X = torch.from_numpy(x_star)
    Y = torch.from_numpy(y_star)

    biased_statistic = mmd_func(X,Y,[bandwidth],biased=True).item()
    return(biased_statistic)

def determine_delta(element,real_element):
    X = torch.from_numpy(element)
    Y = torch.from_numpy(real_element)

    biased_statistic = mmd_func(X,Y,[bandwidth],biased=True).item()
    delta = biased_statistic
    print("Delta for Synthetic Samples against Test Samples"," is:",delta)
    return delta

def determine_probs(pooled,delta):
    numSamples = permutation_count
    #estimates = np.array(list())
    #tester = functools.partial(run_permutation_test,pooled)
    estimates = []
    for result in tqdm.tqdm(map(lambda x: run_permutation_test(pooled),range(numSamples)), total=numSamples):
        estimates.append(result)
    estimates = np.array(estimates)
    
    diffCount = len(np.where(estimates <= delta)[0])
    hat_asl_perm = 1.0 - (float(diffCount)/float(numSamples))
    print("PValue = ",hat_asl_perm)
    return estimates,hat_asl_perm

    

In [None]:
euc_results_pgan = []
haar_results_pgan = []
i=0

print("EUCLIDEAN TEST PGAN ", i)
bandwidth = bandwidth_euc_pgan
delta = determine_delta(ground_images.astype(np.float32),pgan_images.astype(np.float32))
pooled = np.vstack([pgan_images.astype(np.float32),ground_images.astype(np.float32)])
estimates,pval = determine_probs(pooled,delta)
print('Mean and Stdev:',np.mean(estimates),np.std(estimates))
print("Stdev from 50th percentile:",(delta-np.mean(estimates))/np.std(estimates))
euc_results_pgan.append((delta,estimates,pval))

print("HAAR TEST PGAN ",i)
bandwidth = bandwidth_haar_pgan
delta = determine_delta(ground_images_haar.astype(np.float32),pgan_images_haar.astype(np.float32))
pooled = np.vstack([pgan_images_haar.astype(np.float32),ground_images_haar.astype(np.float32)])
estimates,pval = determine_probs(pooled,delta)
print('Mean and Stdev:',np.mean(estimates),np.std(estimates))
print("Stdev from 50th percentile:",(delta-np.mean(estimates))/np.std(estimates))
haar_results_pgan.append((delta,estimates,pval))


In [None]:
euc_results_wgan = []
haar_results_wgan = []
i=0

print("EUCLIDEAN TEST WGAN ", i)
bandwidth = bandwidth_euc_wgan
delta = determine_delta(ground_images.astype(np.float32),wgan_images.astype(np.float32))
pooled = np.vstack([wgan_images.astype(np.float32),ground_images.astype(np.float32)])
estimates,pval = determine_probs(pooled,delta)
print('Mean and Stdev:',np.mean(estimates),np.std(estimates))
print("Stdev from 50th percentile:",(delta-np.mean(estimates))/np.std(estimates))
euc_results_wgan.append((delta,estimates,pval))

print("HAAR TEST WGAN ",i)
bandwidth = bandwidth_haar_wgan
delta = determine_delta(ground_images_haar.astype(np.float32),wgan_images_haar.astype(np.float32))
pooled = np.vstack([wgan_images_haar.astype(np.float32),ground_images_haar.astype(np.float32)])
estimates,pval = determine_probs(pooled,delta)
print('Mean and Stdev:',np.mean(estimates),np.std(estimates))
print("Stdev from 50th percentile:",(delta-np.mean(estimates))/np.std(estimates))
haar_results_wgan.append((delta,estimates,pval))
