In [1]:
import os

# choose a particular GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
from tensorflow.keras.utils import plot_model

import tensorflow_probability as tfp

import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt
import random
import time

seed = 1234
tf.random.set_seed(seed)
os.environ['TF_DETERMINISTIC_OPS'] = 'true'
os.environ['PYTHONHASHSEED'] = f'{seed}'
np.random.seed(seed)
random.seed(seed)

2022-09-06 13:35:09.149772: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
stop = tf.stop_gradient
log1mexp = tfp.math.log1mexp

@tf.function
def log_sigmoid(logits):
    return tf.clip_by_value(tf.math.log_sigmoid(logits), clip_value_max=-1e-7, clip_value_min=-float('inf'))

@tf.function
def logaddexp(x1, x2):
    delta = tf.where(x1 == x2, 0., x1 - x2)
    return tf.math.maximum(x1, x2) + tf.math.softplus(-tf.math.abs(delta))

@tf.function
def log_pr_exactly_k(logp, logq, k):
    
    batch_size = logp.shape[0]
    n = logp.shape[1]
    
    state = np.ones((batch_size, k+2)) * -float('inf')
    state[:, 1] = 0
    state = tf.convert_to_tensor(state, dtype=tf.float32)

    a = tf.TensorArray(tf.float32, size=n+1)
    a = a.write(0, state)
    
    for i in range(1, n+1):
        
        state = tf.concat([
            tf.ones([batch_size, 1]) * -float('inf'), 
            logaddexp(
                state[:, :-1] + logp[:, i-1:i], 
                state[:, 1:] + logq[:, i-1:i]
            )
        ], 1)
        
        a = a.write(i, state)
    a = tf.transpose(a.stack(), perm=[1, 0, 2])
    return a

# @tf.function
def marginals(theta, k):
    log_p = log_sigmoid(theta) 
    log_p_complement = log1mexp(log_p) 
    with tf.GradientTape() as tape:
        tape.watch(log_p)
        a = log_pr_exactly_k(log_p, log_p_complement, 10)
        log_pr = a[:, -1, k+1:k+2]
    return tape.gradient(log_pr, log_p), a

In [3]:
@tf.function
def sample(a, probs):
    
    n = a.shape[-2] - 1
    k = a.shape[-1] - 1
    bsz = a.shape[0]

    j = tf.fill((bsz,), k)
    samples = tf.TensorArray(tf.int32, size=n, clear_after_read=False)
    
    for i in tf.range(n, 0, -1):
        
        # Unnormalized probabilities of Xi and -Xi
        full = tf.fill((bsz,), i-1)
        p_idx = tf.stack([full, j-1], axis=1)
        z_idx = tf.stack([full + 1, j], axis=1)
        
        p = tf.gather_nd(batch_dims=1, indices=p_idx, params=a)
        z = tf.gather_nd(batch_dims=1, indices=z_idx, params=a)
        
        p = (p + probs[:, i-1]) - z
        q = log1mexp(p)

        # Sample according to normalized dist.
        X = tfp.distributions.Bernoulli(logits=(p-q)).sample()

        # Pick next state based on value of sample
        j = tf.where(X>0, j - 1, j)

        # Concatenate to samples
        samples = samples.write(i-1, X)
        
    samples = tf.transpose(samples.stack(), perm=[1, 0])
    
    # Our samples should always satisfy the constraint
    tf.debugging.assert_equal(tf.math.reduce_sum(samples, axis=-1), k-1)
    
    return tf.cast(samples, tf.float32)

In [4]:
@tf.function
def xexpx(x):
    expx = tf.exp(x)
    return tf.where(expx == 0, expx, x*expx)

@tf.function
def xexpy(x,y):
    expy = tf.exp(y)
    return tf.where(expy == 0, expy, x*expy)

@tf.function
def entropy(a, logprobs):
    entropy = tf.zeros((a.shape[0], a.shape[-1]))
    for i in range(10, a.shape[-2]):
        
        p_left = (a[:, i-1, :-1] + logprobs[:, i-1:i]) - a[:, i, 1:]
        p_right = (a[:, i-1, 1:] + log1mexp(logprobs[:, i-1:i])) - a[:, i, 1:]
        
        entropy = tf.concat([tf.zeros((a.shape[0], 1)),
                             xexpx(p_left) + xexpx(p_right) +\
                             xexpy(entropy[:, :-1], p_left) + xexpy(entropy[:, 1:], p_right)
                            ], 1)
    return tf.clip_by_value(-entropy[:, -1], clip_value_max=float('inf'), clip_value_min=0)


In [5]:
class IMLESubsetkLayer(tf.keras.layers.Layer):
  
    def __init__(self, _k=10, _tau=1.0, _lambda=1.0):
        super(IMLESubsetkLayer, self).__init__()
        
        self.k = _k
        self._tau = _tau
        self._lambda = _lambda
        self.samples = None
        self.gumbel_dist = tfp.distributions.Gumbel(loc=0.0, scale=1.0)
        
    @tf.function
    def sample_gumbel(self, shape, eps=1e-20):
        return self.gumbel_dist.sample(shape)
    
    @tf.function
    def sample_gumbel_k(self, shape):
        
        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k,  t/self.k), 
                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))
        # now add the samples
        s = tf.reduce_sum(s, 0)
        # the log(m) term
        s = s - tf.math.log(10.0)
        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)
        s = self._tau * (s / self.k)

        return s
    

    @tf.custom_gradient
    def imle_layer(self, logits, hard=False):
        
        # ZK: Should be exact sampling: we're going to pass it to the decoder on the forward pass
        logp = log_sigmoid(logits)
        logq = log1mexp(logp)
        
        a = log_pr_exactly_k(logp, logq, self.k)
        samples_p = sample(a, logp)

        def custom_grad(dy):
            
            with tf.autodiff.ForwardAccumulator(logits, dy) as accumulate:
                y = marginals(logits, self.k)[0]
            return accumulate.jvp(y), hard

            return grad, hard

        return samples_p, custom_grad

    def call(self, logits, hard=False):
        return self.imle_layer(logits, hard)

In [6]:
PARAMS = {
    "batch_size": 100,
    "data_dim": 784,
    "M": 20,
    "N": 20,
    "nb_epoch": 200, 
    "epsilon_std": 0.01,
    "anneal_rate": 0.0003,
    "init_temperature": 1.0,
    "min_temperature": 0.5,
    "learning_rate": 5e-4,
    "hard": False,
}

class DiscreteVAE(tf.keras.Model):
    
    def __init__(self, params):
        super(DiscreteVAE, self).__init__()
        
        self.params = params
                
        # encoder
        self.enc_dense1 = tf.keras.layers.Dense(512, activation='relu')
        self.enc_dense2 = tf.keras.layers.Dense(256, activation='relu')
        self.enc_dense3 = tf.keras.layers.Dense(params["N"]*params["M"])
        
        # this is our new Gumbel layer
        self.imleLayer = IMLESubsetkLayer(_k=10, _tau=1.0, _lambda=10.0)

        # decoder
        self.flatten = Flatten()
        self.dec_dense1 = tf.keras.layers.Dense(256, activation='relu')
        self.dec_dense2 = tf.keras.layers.Dense(512, activation='relu')
        self.dec_dense3 = tf.keras.layers.Dense(params["data_dim"])


    def sample_gumbel(self, shape, eps=1e-20): 
        """Sample from Gumbel(0, 1)""" 
        U = tf.random.uniform(shape, minval=0, maxval=1)
        return -tf.math.log(-tf.math.log(U + eps) + eps)
    
    def gumbel_softmax_sample(self, logits, temperature): 
        """ Draw a sample from the Gumbel-Softmax distribution"""
        # logits: [batch_size, n_class] unnormalized log-probs
        y = logits + self.sample_gumbel(tf.shape(logits))
        return tf.nn.softmax(y / temperature)  

    def gumbel_softmax(self, logits, temperature, hard=True):
        """
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
        """
        y = self.gumbel_softmax_sample(logits, temperature)
        if hard: 
            # 
            y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keepdims=True)),y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        return y
    
    def decoder(self, x):
        # decoder
        h = self.flatten(x)
        h = self.dec_dense1(h)
        h = self.dec_dense2(h)
        h = self.dec_dense3(h)
        return h

    def call(self, x, tau, hard=False):
        N = self.params["N"]
        M = self.params["M"]

        # encoder
        x = self.enc_dense1(x)
        x = self.enc_dense2(x)
        x = self.enc_dense3(x)   # (batch, N*M)
        logits_y = tf.reshape(x, [-1, M])   # (batch*N, M)

        ###################################################################
        ## here we toggle between methods #################################
        # here we can switch between traditional and our method
        # "traditional" Gumbel Softmax trick
        #y = self.gumbel_softmax(logits=logits_y, temperature=tau, hard=False)
        # IMLE approach -- note: we don't anneal so set temperature once at init
        y = self.imleLayer(logits=logits_y, hard=True)
        ###################################################################
        
        assert y.shape == (self.params["batch_size"]*N, M)
        y = tf.reshape(y, [-1, N, M])
        self.sample_y = y

        # decoder
        logits_x = self.decoder(y)
        return logits_y, logits_x

def gumbel_loss(model, x, tau, hard=True):
    M = 20
    N = 20
    data_dim = PARAMS['data_dim']
    logits_y, logits_x = model(x, tau, hard)
    
    # cross-entropy
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=logits_x)
    cross_ent = tf.math.reduce_sum(cross_ent, 1)
    cross_ent = tf.math.reduce_mean(cross_ent, 0)
    
    # KL loss
    logprobs_q = log_sigmoid(logits_y)
    marginals_q, a_q = marginals(logits_y, 10)
    a_q = tf.where(a_q == -float('inf'), -1000., a_q)
    q_entropy = entropy(a_q, logprobs_q)
    kl = tf.math.log(184756.) - tf.reshape(q_entropy, [-1,N])
    kl = tf.math.reduce_sum(kl, 1)
    kl = tf.math.reduce_mean(kl)

    return cross_ent + kl


def compute_gradients(model, x, tau, hard):
    with tf.GradientTape() as tape:
        loss = gumbel_loss(model, x, tau, hard)
    return tape.gradient(loss, model.trainable_variables), loss


def apply_gradients(optimizer, gradients, variables):
    optimizer.apply_gradients(zip(gradients, variables))


def get_learning_rate(step, init=PARAMS["learning_rate"]):
    return tf.convert_to_tensor(init * pow(0.95, (step / 1000.)), dtype=tf.float32)

In [None]:
# %%time

np.set_printoptions(precision=4,linewidth=200)
model = DiscreteVAE(PARAMS)
learning_rate = tf.Variable(PARAMS["learning_rate"], trainable=False, name="LR")

optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)

# data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

TRAIN_BUF = 60000
BATCH_SIZE = 100
TEST_BUF = 10000

train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(x_test).shuffle(TEST_BUF).batch(BATCH_SIZE)

# temperature
tau = PARAMS["init_temperature"]
anneal_rate = PARAMS["anneal_rate"]
min_temperature = PARAMS["min_temperature"]

results = []

# Train
for epoch in range(1, PARAMS["nb_epoch"] + 1):
    
    # this is only needed for the standard Gumbel softmax trick
    tau = np.maximum(tau * np.exp(-anneal_rate*epoch), min_temperature)

    for train_x in train_dataset:
        gradients, loss = compute_gradients(model, train_x, tau, hard=PARAMS["hard"])
        apply_gradients(optimizer, gradients, model.trainable_variables)

    print("Epoch:", epoch, ", TRAIN loss:", loss.numpy(), ", Temperature:", tau)

    if epoch % 1 == 0:
        losses = []
        for test_x in test_dataset:
            losses.append(gumbel_loss(model, test_x, tau, hard=True))
        eval_loss = np.mean(losses)
        results.append(eval_loss)
        print("Eval Loss:", eval_loss, "\n")

2022-09-06 13:35:11.900861: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-06 13:35:12.437097: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 7025 MB memory:  -> device: 0, name: NVIDIA TITAN RTX, pci bus id: 0000:5e:00.0, compute capability: 7.5


Epoch: 1 , TRAIN loss: 222.86719 , Temperature: 0.9997000449955004
Eval Loss: 219.2978 

Epoch: 2 , TRAIN loss: 204.12341 , Temperature: 0.9991004048785274
Eval Loss: 204.57222 

Epoch: 3 , TRAIN loss: 198.81891 , Temperature: 0.9982016190284373
Eval Loss: 197.22336 

Epoch: 4 , TRAIN loss: 197.72275 , Temperature: 0.997004495503373
Eval Loss: 193.98055 

Epoch: 5 , TRAIN loss: 189.35 , Temperature: 0.9955101098295706
Eval Loss: 191.29666 

Epoch: 6 , TRAIN loss: 190.1796 , Temperature: 0.9937198033910547
Eval Loss: 189.26395 

Epoch: 7 , TRAIN loss: 191.48802 , Temperature: 0.9916351814230984
Eval Loss: 186.94223 

Epoch: 8 , TRAIN loss: 188.79102 , Temperature: 0.9892581106136482
Eval Loss: 185.77408 

Epoch: 9 , TRAIN loss: 184.5834 , Temperature: 0.9865907163177327
Eval Loss: 184.91791 

Epoch: 10 , TRAIN loss: 180.92511 , Temperature: 0.9836353793906725
Eval Loss: 184.20523 

Epoch: 11 , TRAIN loss: 185.68341 , Temperature: 0.9803947326466972
Eval Loss: 183.14648 

Epoch: 12 , TRA

Epoch: 48 , TRAIN loss: 169.47008 , Temperature: 0.7027177228683977
Epoch: 48 , TRAIN loss: 169.47008 , Temperature: 0.7027177228683977
Eval Loss: 174.8649 

Eval Loss: 174.8649 

Eval Loss: 174.8649 

Epoch: 49 , TRAIN loss: 171.39128 , Temperature: 0.6924633268086435
Epoch: 49 , TRAIN loss: 171.39128 , Temperature: 0.6924633268086435
Epoch: 49 , TRAIN loss: 171.39128 , Temperature: 0.6924633268086435
Eval Loss: 174.74059 

Eval Loss: 174.74059 

Eval Loss: 174.74059 

Epoch: 50 , TRAIN loss: 174.15659 , Temperature: 0.6821538909764523
Epoch: 50 , TRAIN loss: 174.15659 , Temperature: 0.6821538909764523
Epoch: 50 , TRAIN loss: 174.15659 , Temperature: 0.6821538909764523
Eval Loss: 174.45892 

Eval Loss: 174.45892 

Eval Loss: 174.45892 

Epoch: 51 , TRAIN loss: 174.13226 , Temperature: 0.6717963735016784
Epoch: 51 , TRAIN loss: 174.13226 , Temperature: 0.6717963735016784
Epoch: 51 , TRAIN loss: 174.13226 , Temperature: 0.6717963735016784
Eval Loss: 174.35344 

Eval Loss: 174.35344 

Ev

Epoch: 80 , TRAIN loss: 172.05139 , Temperature: 0.5
Eval Loss: 173.15005 

Eval Loss: 173.15005 

Eval Loss: 173.15005 

Epoch: 81 , TRAIN loss: 172.23923 , Temperature: 0.5
Epoch: 81 , TRAIN loss: 172.23923 , Temperature: 0.5
Epoch: 81 , TRAIN loss: 172.23923 , Temperature: 0.5
Eval Loss: 172.63911 

Eval Loss: 172.63911 

Eval Loss: 172.63911 

Epoch: 83 , TRAIN loss: 172.92651 , Temperature: 0.5
Eval Loss: 172.52286 

Epoch: 84 , TRAIN loss: 170.39667 , Temperature: 0.5
Eval Loss: 172.58745 

Epoch: 85 , TRAIN loss: 169.82507 , Temperature: 0.5
Eval Loss: 172.43915 

Epoch: 86 , TRAIN loss: 172.76514 , Temperature: 0.5
Eval Loss: 172.3853 

Epoch: 87 , TRAIN loss: 171.57738 , Temperature: 0.5
Eval Loss: 172.30728 

Epoch: 88 , TRAIN loss: 172.87787 , Temperature: 0.5
Eval Loss: 172.42871 

Epoch: 89 , TRAIN loss: 171.5992 , Temperature: 0.5
Eval Loss: 172.26154 

Epoch: 90 , TRAIN loss: 168.21303 , Temperature: 0.5
Eval Loss: 172.19038 

Epoch: 91 , TRAIN loss: 171.00125 , Temperat