In [None]:
# ignore this cell when running on colab
!nvidia-smi

In [None]:
# this as well
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import numpy as np
from matplotlib import pyplot as plt

In [None]:
"""Some utilities"""

def repeated_gibbs(initial_sample, n_iterations, gibbs_update_fn,
                   return_all=False, **guf_kwargs):
    """Repeatedly apply Gibbs updates for a given number of iterations.
    
    Alternative wording: Run a Markov chain.

    Parameters:
        initial_sample: Batch of input samples to start with. Needs to be in the
                     appropriate format for the update function (e.g. tuple of
                     visible/hidden for RBMs).
        n_iterations: How many Gibbs updates to do.
        gibbs_update_fn: Function that takes a batch of input samples and
                         computes a new one.
        return_all: If true, return all samples, not just the last one. Can be used
                    if you want to plot the entire chain, for example.
        guf_kwargs: Keyword arguments passed to gibbs_update_fn.

    Returns:
        New batch of input samples.

    """
    iteration_dummy = tf.range(n_iterations)

    def loop_body(sample, _): return gibbs_update_fn(sample, **guf_kwargs)

    loop_fn = tf.scan if return_all else tf.foldl
    return loop_fn(loop_body, iteration_dummy, initializer=initial_sample, back_prop=False)


def repeated_gibbs_python(initial_sample, n_iterations, gibbs_update_fn,
                          return_all=False, **guf_kwargs):
    """Included for pedagogical reasons. ;)
    
    It does the same thing as the above function, but perhaps easier to understand using Python loops.
    
    The only annoying this is that, for this to work properly in a tf.function, we need to use
    TensorArrays for iteration. But for the case return_all=False, you can ignore them completely.
    """
    if return_all:
        visible_samples = tf.TensorArray(tf.float32, size=n_iterations+1)
        hidden_samples = tf.TensorArray(tf.float32, size=n_iterations+1)
        visible_samples = visible_samples.write(0, initial_sample[0])
        hidden_samples = hidden_samples.write(0, initial_sample[1])
        
    # "core logic" starts here
    sample = initial_sample
    for index in tf.range(n_iterations):
        sample = gibbs_update_fn(sample, **guf_kwargs)
    # "core logic" ends here
        
        if return_all:
            visible_samples = visible_samples.write(index + 1, sample[0])
            hidden_samples = hidden_samples.write(index + 1, sample[1])

    return (visible_samples.stack(), hidden_samples.stack()) if return_all else sample


def gibbs_update_brbm(previous_sample, w_vh, b_v, b_h):
    """Gibbs update step for binary RBMs.

    Given an input sample, take a hidden sample and then a new input sample.

    Parameters:
        prev_sample: Tuple of b x d_v tensor and b x d_h tensor: Both batches
                     of input/hidden samples.
        w_vh: Connection matrix of RBM, d_v x d_h.
        b_v: Bias vector for inputs, d_v-dimensional.
        b_h: Bias vector for hidden variables, d_h-dimensional.

    Returns:
        New batch of input/hidden samples as tuple.

    """
    v, _ = previous_sample

    p_h_given_v = tf.nn.sigmoid(tf.matmul(v, w_vh) + b_h)
    sample_h = tfp.distributions.Bernoulli(
        probs=p_h_given_v, dtype=tf.float32).sample()
    
    p_v_given_h = tf.nn.sigmoid(tf.matmul(sample_h, tf.transpose(w_vh)) + b_v)
    sample_v = tfp.distributions.Bernoulli(
        probs=p_v_given_h, dtype=tf.float32).sample()

    return sample_v, sample_h


def energy_rbm(v, h, w_vh, b_v, b_h):
    """Compute energy for an RBM.

    Parameters:
        v: Batch of inputs, b x d_v.
        h: Batch of hidden units, b x d_h.
        w_vh: Connection matrix of RBM, d_v x d_h.
        b_v: Bias vector for inputs, d_v-dimensional.
        b_h: Bias vector for hidden variables, d_h-dimensional.

    Returns:
        b-dimensional vector, energy for each batch element.

    """
    return (-tf.linalg.matvec(v, b_v) - tf.linalg.matvec(h, b_h) 
            - tf.reduce_sum(tf.matmul(v, w_vh) * h, axis=-1))  # alternative: batched matmul


In [None]:
"""data"""

# data
batch_size = 1024
n_v = 28*28

(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape((-1, n_v)).astype(np.float32) / 255.
test_images = test_images.reshape((-1, n_v)).astype(np.float32) / 255.

# binarize data
train_images = (train_images >= 0.5).astype(np.float32)
test_images = (test_images >= 0.5).astype(np.float32)

data = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(batch_size).repeat()

# compute marginal for better sampling
marginals = tf.reduce_mean(train_images, axis=0)

start_sampler = tfd.Bernoulli(probs=0.5, dtype=tf.float32)
marginal_sampler = tfd.Bernoulli(probs=marginals, dtype=tf.float32)


In [None]:
# this shows, for each pixel position, the probability to sample it as 1 if using marginal_sampler
plt.imshow(marginals.numpy().reshape(28, 28), cmap="Greys", vmin=0, vmax=1)
plt.colorbar()
plt.show()

In [None]:
# model
mode = "simple"  # simple is the "naive" algorithm 18.1
if mode not in ["pcd", "cd", "simple"]:
    raise ValueError("Invalid mode!")

n_h = 512  # could be tuned
w_vh = tf.Variable(tf.random.uniform([n_v, n_h], -0.1, 0.1))
b_v = tf.Variable(tf.zeros([n_v]))
b_h = tf.Variable(tf.zeros([n_h]))
weights = [w_vh, b_v, b_h]

train_steps = 5000
chain_length = 500  # how long to run Markov chains

# SGD with decay worked much better for me than adam
optimizer = tf.optimizers.SGD(tf.keras.optimizers.schedules.PolynomialDecay(0.1, train_steps, 1e-3))
#optimizer = tf.optimizers.Adam() 

# if true, initial samples for v are taken from the marginal distribution above.
# if false, we just sample each pixel randomly with p=0.5.
sample_from_marginal = True

In [None]:
"""Train"""

@tf.function(jit_compile=True)
def train(batch, v_sampled=None, h_sampled=None):
    """v_sampled and h_sampled are used only for PCD.
    
    It's always passed because I'm lazy.
    """
    v_data = batch
    h_data = tf.nn.sigmoid(tf.matmul(v_data, w_vh) + b_h)
    h_data = tfd.Bernoulli(probs=h_data, dtype=tf.float32).sample()
    
    # gibbs sampling -- naive/cd/pcd only differ in how the chains are initialized
    if mode == "cd":
        v_sampled, h_sampled = repeated_gibbs(
            (v_data, h_data), chain_length//10, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
        
    elif mode == "pcd":
        v_sampled, h_sampled = repeated_gibbs(
            (v_sampled, h_sampled), chain_length//10, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
        
    else:
        if sample_from_marginal:
            v_random = marginal_sampler.sample(tf.shape(batch)[0])
        else:
            v_random = start_sampler.sample([tf.shape(batch)[0], n_v])
        # h_random is just a dummy, it will immediately be overwritten by a sample conditioned 
        # on v, but the code is set up such that it needs an initial value for h...
        h_random = start_sampler.sample([tf.shape(batch)[0], n_h])
        v_sampled, h_sampled = repeated_gibbs(
            (v_random, h_random), chain_length, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)

    # compute "loss" and take the gradient
    # altough this loss can take any value (real number), it should converge to around 0
    # if it goes to -infinity there is something wrong!
    with tf.GradientTape() as tape:
        logits_positive = tf.reduce_mean(-energy_rbm(v_data, h_data, w_vh, b_v, b_h))
        logits_negative = tf.reduce_mean(
            -energy_rbm(v_sampled, h_sampled, w_vh, b_v, b_h))
        loss = -(logits_positive - logits_negative)
    gradients = tape.gradient(loss, weights)
    optimizer.apply_gradients(zip(gradients, weights))
    
    # returning the sampled values is once again for PCD only
    return loss, v_sampled, h_sampled


# training loop
if sample_from_marginal:
    v_sampled = marginal_sampler.sample(batch_size)
else:
    v_sampled = start_sampler.sample([batch_size, n_v])
h_sampled = start_sampler.sample([batch_size, n_h])
for step, img_batch in enumerate(data):
    if step > train_steps:
        break

    loss, v_sampled, h_sampled = train(img_batch, v_sampled, h_sampled)
    if not step % 250:
        print("Step", step)
        print("Loss:", loss) 
        
        # look at some samples
        if sample_from_marginal:
            v_random = marginal_sampler.sample([64])
        else:
            v_random = start_sampler.sample([64, n_v])
        h_random = start_sampler.sample([64, n_h])
        image_sample, final_h = repeated_gibbs(
            (v_random, h_random), chain_length, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)

        # these are the "true" binary samples
        image_sample = [image.numpy().reshape((28, 28)) for image in image_sample]

        f = plt.figure(figsize=(15, 15))
        f.suptitle("binary samples")
        for ind, image in enumerate(image_sample):
            plt.subplot(8, 8, ind+1)
            plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
            plt.axis("off")
        plt.show()

        # here, we plot the _probabilities_ of v in the last step (at the end of the markov chain)
        # much smoother, less noisy due to lack of randomness
        images_prob = tf.nn.sigmoid(tf.matmul(final_h, tf.transpose(w_vh)) + b_v)
        image_sample = [image.numpy().reshape((28, 28)) for image in images_prob]

        f = plt.figure(figsize=(15, 15))
        f.suptitle("probabilities")
        for ind, image in enumerate(image_sample):
            plt.subplot(8, 8, ind+1)
            plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
            plt.axis("off")
        plt.show()

In [None]:
"""random samples"""
if sample_from_marginal:
    v_random = marginal_sampler.sample([64])
else:
    v_random = start_sampler.sample([64, n_v])
h_random = start_sampler.sample([64, n_h])
image_sample, final_h = repeated_gibbs(
    (v_random, h_random), chain_length, gibbs_update_brbm,
    w_vh=w_vh, b_v=b_v, b_h=b_h)

# these are the "true" binary samples
image_sample = [image.numpy().reshape((28, 28)) for image in image_sample]

f = plt.figure(figsize=(15, 15))
f.suptitle("binary samples")
for ind, image in enumerate(image_sample):
    plt.subplot(8, 8, ind+1)
    plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
    plt.axis("off")
plt.show()


# here, we plot the _probabilities_ of v in the last step (at the end of the markov chain)
# much smoother, less noisy due to lack of randomness
images_prob = tf.nn.sigmoid(tf.matmul(final_h, tf.transpose(w_vh)) + b_v)
image_sample = [image.numpy().reshape((28, 28)) for image in images_prob]

f = plt.figure(figsize=(15, 15))
f.suptitle("probabilities")
for ind, image in enumerate(image_sample):
    plt.subplot(8, 8, ind+1)
    plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
    plt.axis("off")
plt.show()

In [None]:
# something we can do for a sort of "introspection":
# there is one bias per pixel, so we can shape it into an image and plot.
# this shows the general "preference" of the model to activate certain pixels
bias_image = b_v.numpy().reshape((28, 28))
plt.imshow(bias_image, cmap="Greys_r")
plt.colorbar()
plt.show()

In [None]:
# similarly, each hidden unit is connected to each visible unit, so we can do the same thing.
# each unit encodes an input pattern that would maximally activate it.
weight_images = w_vh.numpy().T.reshape((-1, 28, 28))


absmax = abs(weight_images).max()
plt.figure(figsize=(15, 15))
# here, only for a selection of units (64 of them)
for ind in range(64):
    plt.subplot(8, 8, ind+1)
    plt.imshow(weight_images[ind], vmin=-absmax, vmax=absmax, cmap="coolwarm")
    plt.axis("off")
# I dunno how to fix the colorbar xd
plt.colorbar()
plt.show()

In [None]:
# this cell takes quite long!!

import os

# another thing we can do: plot the development of the chain over time.
plot_frequency = 25
# optionally we can create a GIF of the chains. to do that, set a folder name here.
gif_folder = None
if gif_folder and not os.path.exists(gif_folder):
    os.mkdir(gif_folder)

if sample_from_marginal:
    v_random = marginal_sampler.sample([64])
else:
    v_random = start_sampler.sample([64, n_v])
h_random = start_sampler.sample([64, n_h])
image_sample, _ = repeated_gibbs(
    (v_random, h_random), chain_length, gibbs_update_brbm,
    return_all=True, w_vh=w_vh, b_v=b_v, b_h=b_h)

f = plt.figure(figsize=(15, 15))
for ind, image in enumerate(v_random):
    plt.subplot(8, 8, ind+1)
    plt.imshow(image.numpy().reshape((28, 28)), cmap="Greys", vmin=0, vmax=1)
    plt.axis("off")
if gif_folder:
    plt.savefig(os.path.join(gif_folder, "step0.png"))
plt.show()

for ii, chain_states in enumerate(image_sample):
    chain_states = [image.numpy().reshape((28, 28)) for image in chain_states]
    f = plt.figure(figsize=(15, 15))
    if gif_folder is None:  # don't want title in the images if making a gif
        f.suptitle("Step {}".format(ii+1))
        
    for ind, image in enumerate(chain_states):
        plt.subplot(8, 8, ind+1)
        plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
        plt.axis("off")
        
    if gif_folder:
        plt.savefig(os.path.join(gif_folder, "step{}.png".format(ii+1)))
    if not ii % plot_frequency:
        plt.show()
    plt.close("all")

In [None]:
# this is for creating a gif from the images saved above.
# you may need to !pip install imageio
import imageio
filenames = [os.path.join(gif_folder, "step{}.png".format(step)) for step in range(chain_length+1)]
with imageio.get_writer(os.path.join(gif_folder, "markov_chains.gif"), mode='I') as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

In [None]:
# bonus 1
# This is an alternative formulation with hand-derived gradients rather than using gradient tape.
# it seems to be slower than the one above... it's probably possible to do the computations for W
# more efficiently.

# this is basically a drop-in replacement for cell #8 (the one with the other train function).
# so run this cell INSTEAD of the above one.
# this could also be integrated into the other cell like
#
# if use_tape:
#     train_with_GradientTape
# else:
#     train_with_manual_gradients
#
# ...since nothing else changes


@tf.function
def train_manual(batch, lr, v_sampled=None, h_sampled=None):
    """v_sampled and h_sampled are used only for PCD.
    
    It's always passed because I'm lazy.
    """
    v_data = batch
    h_data = tf.nn.sigmoid(tf.matmul(v_data, w_vh) + b_h)
    h_data = tfd.Bernoulli(probs=h_data, dtype=tf.float32).sample()
    
    if mode == "cd":
        v_sampled, h_sampled = repeated_gibbs(
            (v_data, h_data), chain_length//10, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
    elif mode == "pcd":
        v_sampled, h_sampled = repeated_gibbs(
            (v_sampled, h_sampled), chain_length//10, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
    else:
        v_random = marginal_sampler.sample(tf.shape(batch)[0])
        #v_random = start_sampler.sample(tf.shape(batch))
        # this is just a dummy
        h_random = start_sampler.sample([tf.shape(batch)[0], n_h])
        v_sampled, h_sampled = repeated_gibbs(
            (v_random, h_random), chain_length, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
    
    # next lines are not necessary -- only computing this term to get an
    # idea of how training is going
    logits_pos = tf.reduce_mean(-energy_rbm(v_data, h_data, w_vh, b_v, b_h))
    logits_neg = tf.reduce_mean(
        -energy_rbm(v_sampled, h_sampled, w_vh, b_v, b_h))
    loss = -(logits_pos - logits_neg)
        
    # compute gradients of (negative!) energy for data and samples
    # each term is gradient for data minus gradient for model samples
    # note that even computing two means per line is unnecessary, terms could be subtracted
    # inside the mean. but that might get confusing.
    b_v_grads = tf.reduce_mean(v_data - v_sampled, axis=0)
    b_h_grads = tf.reduce_mean(h_data - h_sampled, axis=0)
    w_grads = tf.reduce_mean(v_data[:, :, None] * h_data[:, None, :]
                             - v_sampled[:, :, None] * h_sampled[:, None, :], axis=0)
    
    b_v.assign_add(lr * b_v_grads)
    b_h.assign_add(lr * b_h_grads)
    w_vh.assign_add(lr * w_grads)
    
    return loss, v_sampled, h_sampled


v_samp = marginal_sampler.sample(batch_size)
#v_samp = start_sampler.sample([batch_size, 1024])
h_samp = start_sampler.sample([batch_size, n_h])
for step, img_batch in enumerate(data):
    if step > train_steps:
        break
        
    # note we compute the learning rate decay by hand and pass it to the train function
    lr = (0.1 - 1e-3) * (1-step/train_steps) + 1e-3

    loss, v_samp, h_samp = train_manual(img_batch, tf.convert_to_tensor(lr), v_samp, h_samp)
    if not step % 50:
        print("Step", step)
        print("Loss:", loss)

In [None]:
# bonus #2: pseudolikelihood? this was optional reading.
# another training-replacement-cell.

@tf.function(jit_compile=True)
def train(batch):
    # get a set of hidden values
    h_sample = tf.nn.sigmoid(tf.matmul(batch, w_vh) + b_h)
    h_sample = tfp.distributions.Bernoulli(probs=h_sample, dtype=tf.float32).sample()
    with tf.GradientTape() as tape:
        # in pseudolikelihood we should compute the probability of each variable, conditioned
        # on all the other variables. the bipartite RBM structure helps once again:
        # we can compute all v conditionals and all h conditionals at once, respectively
        # -> only 2 steps instead of n_hidden + n_visible steps
        v_probs = tf.nn.sigmoid((2*batch-1) * (tf.matmul(h_sample, tf.transpose(w_vh)) + b_v))
        h_probs = tf.nn.sigmoid((2*h_sample-1) * (tf.matmul(batch, w_vh) + b_h))
        
        pll = -1*tf.reduce_mean(tf.reduce_sum(tf.math.log(v_probs), axis=-1) + 
                                tf.reduce_sum(tf.math.log(h_probs), axis=-1))
    grads = tape.gradient(pll, weights)
    optimizer.apply_gradients(zip(grads, weights))
    
    return pll

# you probably want to change the decay schedule of the optimizer to work over more steps!!
# pseudolikelihood seems to take many steps to converge. maybe use 10 times as many.
# but because we don't run markov chains in training, each step is MUCH faster
for step, img_batch in enumerate(data):
    if step > train_steps:
        break

    loss = train(img_batch)
    if not step % 5000:
        # for sampling, we still use a markov chain.
        # at least I don't know any other way xd
        print("Step", step)
        print("Loss:", loss)
        #v_random = marginal_sampler.sample([64])
        v_random = start_sampler.sample([64, n_v])
        h_random = start_sampler.sample([64, n_h])
        img_sample, final_h = repeated_gibbs(
                (v_random, h_random), chain_length, gibbs_update_brbm,
                w_vh=w_vh, b_v=b_v, b_h=b_h)

        # these are the "true" binary samples
        img_sample = [img.numpy().reshape((28, 28)) for img in img_sample]

        f = plt.figure(figsize=(15, 15))
        f.suptitle("binary samples")
        for ind, img in enumerate(img_sample):
            plt.subplot(8, 8, ind+1)
            plt.imshow(img, cmap="Greys", vmin=0, vmax=1)
            plt.axis("off")
        plt.show()


        # here, we plot the _probabilities_ of v in the last step (at the end of the markov chain)
        # much smoother, less noisy due to lack of randomness
        imgs_p = tf.nn.sigmoid(tf.matmul(final_h, tf.transpose(w_vh)) + b_v)
        img_sample = [img.numpy().reshape((28, 28)) for img in imgs_p]

        f = plt.figure(figsize=(15, 15))
        f.suptitle("probabilities")
        for ind, img in enumerate(img_sample):
            plt.subplot(8, 8, ind+1)
            plt.imshow(img, cmap="Greys", vmin=0, vmax=1)
            plt.axis("off")
        plt.show()
        
# once this is trained, you "should" be able to draw samples via gibbs sampling just like
# before (cells above). but for me it didn't work well :( I only got 0s as samples...