In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import numpy as np
from matplotlib import pyplot as plt
import tqdm
from scipy.io import savemat
import cv2
from scipy.special import factorial
import itertools
from numpy.linalg import norm as np_norm

import sys
sys.path.append('../dependencies/')
import dataset_utils
import network_ec_bm as network
import utils
import zca

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

def derangement(n):
    tmp = np.array([[2,0,3,1],[3,2,1,0],[1,3,0,2]])
    return np.array(tmp), np.array(tmp).shape[0]

def patch_shuffle(imgs, n, orders):
    shap = imgs.shape
    patches = np.zeros((shap[0], int(shap[1]/n), int(shap[2]/n), shap[3], int(n*n)))
    patch_shap = patches.shape
    shuffled = np.zeros((shap[0], shap[1], shap[2], shap[3], orders.shape[0]))
    for i in range (n):
        for j in range (n):
            patches[:,:,:,:,i*n+j] = imgs[:, i*(patch_shap[1]):(i+1)*(patch_shap[1]), j*(patch_shap[2]):(j+1)*(patch_shap[2]), :]
    for k in range (orders.shape[0]):
        for i in range (n):
            for j in range (n):
                shuffled[:, i*(patch_shap[1]):(i+1)*(patch_shap[1]), j*(patch_shap[2]):(j+1)*(patch_shap[2])
                         , :, k] = patches[:,:,:,:,np.squeeze(orders[k,i*n+j])]
    return shuffled
    
def dissimilarity(a,b):
    cos = np.zeros(a.shape[0])
    for i in range (a.shape[0]):
        cos[i] = (1 - (np.dot(a[i,:,:,:].flatten(), b[i,:,:,:].flatten())/(np_norm(a[i,:,:,:].flatten())*np_norm(b[i,:,:,:].flatten()))))/2
    return cos

tf.random.set_seed(42)

tfd = tfp.distributions
tfk = tf.keras
tfkl = tf.keras.layers

train_set = 'cifar10'
norm = None
input_shape = (32,32,3)
n_patches = 2
orders, shuffles = derangement(n_patches)
mode = 'color'
pre = 'ec_bm/'

if mode == 'color':
    input_shape = (32, 32, 3)
    datasets = [
        'svhn_cropped',
        'cifar10',
        'celeb_a',
        'gtsrb',
        'compcars',
        'noise'
    ]
    num_filters = 64
elif mode == 'grayscale':
    input_shape = (32, 32, 1)
    datasets = [
        'mnist',
        'fashion_mnist',
        'emnist/letters',
        'sign_lang',
        'noise'
    ]
    num_filters = 32
    
reg_weight = 0
num_resnet = 2
num_hierarchies = 4
num_logistic_mix = 5
num_filters = num_filters
dropout_p = 0.3
learning_rate = 1e-3
use_weight_norm = True
epochs = 100
optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    
if norm is None:
    dir_str = 'original'
elif norm == 'pctile-5':
    dir_str = 'pctile-5'
elif norm == 'channelwhiten':
    dir_str = 'zca'
elif norm == 'zca_original':
    dir_str = 'zca_original'
elif norm == 'histeq':
    dir_str = 'histeq'
    
model_dir = '../saved_models/' + pre + dir_str + '/' + train_set + '/'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    
if norm == 'zca_original':
    zca_transform = zca.compute_zca(train_set)
else:
    zca_transform = None

dist = network.PixelCNN(
      image_shape=input_shape,
      num_resnet=num_resnet,
      num_hierarchies=num_hierarchies,
      num_filters=num_filters,
      num_logistic_mix=num_logistic_mix,
      dropout_p=dropout_p,
      use_weight_norm=use_weight_norm,
)

image_input = tfkl.Input(shape=input_shape)
log_prob = dist.log_prob(image_input)
model = tfk.Model(inputs=image_input, outputs=log_prob)
model.add_loss(-tf.reduce_mean(log_prob))
model.compile(optimizer=optimizer)

model.build([None] + list(input_shape))
model.load_weights(model_dir+'weights')

probs = {}

for dataset in datasets:
    
    _, _, ds_test = dataset_utils.get_dataset(
          dataset,
          1024,
          mode,
          normalize=norm,
          dequantize=False,
          visible_dist='categorical',
          zca_transform=zca_transform,
          mutation_rate=0
      )
    tmp = []
    for i in range (shuffles):
        globals()['tmp'+str(i)] = []
    for test_batch in tqdm.tqdm(ds_test):
        batch = tf.cast(test_batch, tf.float32).numpy()
        patches = patch_shuffle(batch, n_patches, orders)
        tmp.append(dist.log_prob(batch, training=False).numpy())
        for i in range (shuffles):
            globals()['tmp'+str(i)].append(dist.log_prob(np.round(patches[:,:,:,:,i]), training=False).numpy())
    tmp = np.expand_dims(np.concatenate(tmp, axis=0),axis=-1)
    for i in range (shuffles):
        globals()['tmp'+str(i)] = np.expand_dims(np.concatenate(globals()['tmp'+str(i)], axis=0),axis=-1)
    
    for i in range (shuffles):
        probs[dataset+'_shuffle_'+str(i)] = globals()['tmp'+str(i)]
        if i == 0:
            probs[dataset] = tmp - globals()['tmp'+str(i)]
        else:
            probs[dataset] += tmp - globals()['tmp'+str(i)]
    probs[dataset+'_regular'] = tmp

    
if train_set == 'emnist/letters':
    train_set = 'emnist_letters'
    
dir_str1 = pre + dir_str + '_patch_shuffled_v2'

save_dir = '../probs/' + dir_str1 + '/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    
savemat(save_dir + train_set + '.mat', probs)

2022-05-09 23:54:54.604633: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-09 23:54:55.019712: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10414 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:08:00.0, compute capability: 6.1
  0%|                                                                                                                                                                                     | 0/26 [00:00<?, ?it/s]2022-05-09 23:55:06.034567: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
2022-05-09 23:55:06.615004:







100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [02:42<00:00,  6.25s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:04<00:00,  6.44s/it]
39it [03:54,  6.01s/it]
13it [01:19,  6.11s/it]
14it [01:54,  8.21s/it]
10it [00:56,  5.68s/it]
