VAE

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import os

# Dependency imports
from absl import flags
import numpy as np
from six.moves import urllib
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

IMAGE_SHAPE = [28, 28, 1]

In [None]:

from scipy.io import loadmat


def load_omniglot(batch_size):
    n_validation=1345
    # set args
    input_size = [1, 28, 28]
    input_type = 'binary'
    dynamic_binarization = True

    # start processing
    def reshape_data(data):
        return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='F')
    omni_raw = loadmat(os.path.join('/content', 'chardata.mat'))

    # train and test data
    train_data = reshape_data(omni_raw['data'].T.astype('float32'))
    x_test = reshape_data(omni_raw['testdata'].T.astype('float32'))

    # shuffle train data
    np.random.shuffle(train_data)

    # set train and validation data
    x_train = train_data[:-n_validation]
    x_val = train_data[-n_validation:]

    # binarize
    np.random.seed(777)
    x_val = np.random.binomial(1, x_val)
    x_test = np.random.binomial(1, x_test)

    # idle y's
    y_train = np.zeros( (x_train.shape[0], 1) )
    y_val = np.zeros( (x_val.shape[0], 1) )
    y_test = np.zeros( (x_test.shape[0], 1) )
    
    


    def train():
      train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
      train_dataset.shuffle(50000).repeat().batch(batch_size)
      return tf.compat.v1.data.make_one_shot_iterator(train_dataset).get_next()

    def eval():
      validation_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
      validation_dataset.batch(batch_size)
      return tf.compat.v1.data.make_one_shot_iterator(validation_dataset).get_next()

    #validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val))
    #val_loader = data_utils.DataLoader(validation, batch_size=args.test_batch_size, shuffle=False, **kwargs)

    return train, eval

In [None]:
def _softplus_inverse(x):
  """Helper which computes the function inverse of `tf.nn.softplus`."""
  return tf.math.log(tf.math.expm1(x))


def make_encoder(activation, latent_size, base_depth):
  """Creates the encoder function.

  Args:
    activation: Activation function in hidden layers.
    latent_size: The dimensionality of the encoding.
    base_depth: The lowest depth for a layer.

  Returns:
    encoder: A `callable` mapping a `Tensor` of images to a
      `tfd.Distribution` instance over encodings.
  """
  # conv = functools.partial(
  #     tf.keras.layers.Conv2D, padding="SAME", activation=activation)

  # encoder_net = tf.keras.Sequential([
  #     conv(base_depth, 5, 1),
  #     conv(base_depth, 5, 2),
  #     conv(2 * base_depth, 5, 1),
  #     conv(2 * base_depth, 5, 2),
  #     conv(4 * latent_size, 7, padding="VALID"),
  #     tf.keras.layers.Flatten(),
  #     tf.keras.layers.Dense(2 * latent_size, activation=None),
  # ])
  encoder_net = tf.keras.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(300, input_dim= 28*28),
        tf.keras.layers.Dense(300, activation = tf.nn.silu),
        tf.keras.layers.Dense(300, activation = tf.nn.silu),
        tf.keras.layers.Dense(2 * latent_size, activation=None),
    ])
  def encoder(images):
    images = 2 * tf.cast(images, dtype=tf.float32) - 1
    net = encoder_net(images)
    return tfd.MultivariateNormalDiag(
        loc= net[..., :latent_size],
        scale_diag=tf.nn.softplus(net[..., latent_size:] +
                                  _softplus_inverse(1.0)),
        name="code")

  return encoder


def make_decoder(activation, latent_size, output_shape, base_depth):
  """Creates the decoder function.

  Args:
    activation: Activation function in hidden layers.
    latent_size: Dimensionality of the encoding.
    output_shape: The output image shape.
    base_depth: Smallest depth for a layer.

  Returns:
    decoder: A `callable` mapping a `Tensor` of encodings to a
      `tfd.Distribution` instance over images.
  """
  # deconv = functools.partial(
  #     tf.keras.layers.Conv2DTranspose, padding="SAME", activation=activation)
  # conv = functools.partial(
  #     tf.keras.layers.Conv2D, padding="SAME", activation=activation)

  # decoder_net = tf.keras.Sequential([
  #     deconv(2 * base_depth, 7, padding="VALID"),
  #     deconv(2 * base_depth, 5),
  #     deconv(2 * base_depth, 5, 2),
  #     deconv(base_depth, 5),
  #     deconv(base_depth, 5, 2),
  #     deconv(base_depth, 5),
  #     conv(output_shape[-1], 5, activation=None),
  # ])

  decoder_net = tf.keras.Sequential([
        tf.keras.layers.Dense(300, input_dim= latent_size),
        tf.keras.layers.Dense(300,  activation = tf.nn.silu),
        tf.keras.layers.Dense(300,  activation = tf.nn.silu),
        tf.keras.layers.Dense(28 * 28 , activation=None),
    ])

  def decoder(codes):
    original_shape = tf.shape(input=codes)
    # Collapse the sample and batch dimension and convert to rank-4 tensor for
    # use with a convolutional decoder network.
    codes = tf.reshape(codes, (-1, 1, 1, latent_size))
    logits = decoder_net(codes)
    logits = tf.reshape(
        logits, shape=tf.concat([original_shape[:-1], output_shape], axis=0))
    return tfd.Independent(tfd.Bernoulli(logits=logits),
                           reinterpreted_batch_ndims=len(output_shape),
                           name="image")

  return decoder

In [None]:
def make_mixture_prior(latent_size, mixture_components):
  """Creates the mixture of Gaussians prior distribution.

  Args:
    latent_size: The dimensionality of the latent representation.
    mixture_components: Number of elements of the mixture.

  Returns:
    random_prior: A `tfd.Distribution` instance representing the distribution
      over encodings in the absence of any evidence.
  """
  if mixture_components == 1:
    # See the module docstring for why we don't learn the parameters here.
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros([latent_size]),
        scale_identity_multiplier=1.0)

  loc = tf.compat.v1.get_variable(
      name="loc", shape=[mixture_components, latent_size])
  raw_scale_diag = tf.compat.v1.get_variable(
      name="raw_scale_diag", shape=[mixture_components, latent_size])
  mixture_logits = tf.compat.v1.get_variable(
      name="mixture_logits", shape=[mixture_components])

  return tfd.MixtureSameFamily(
      components_distribution=tfd.MultivariateNormalDiag(
          loc=loc,
          scale_diag=tf.nn.softplus(raw_scale_diag)),
      mixture_distribution=tfd.Categorical(logits=mixture_logits),
      name="prior")


def pack_images(images, rows, cols):
  """Helper utility to make a field of images."""
  shape = tf.shape(input=images)
  width = shape[-3]
  height = shape[-2]
  depth = shape[-1]
  images = tf.reshape(images, (-1, width, height, depth))
  batch = tf.shape(input=images)[0]
  rows = tf.minimum(rows, batch)
  cols = tf.minimum(batch // rows, cols)
  images = images[:rows * cols]
  images = tf.reshape(images, (rows, cols, width, height, depth))
  images = tf.transpose(a=images, perm=[0, 2, 1, 3, 4])
  images = tf.reshape(images, [1, rows * width, cols * height, depth])
  return images


def image_tile_summary(name, tensor, rows=8, cols=8):
  tf.compat.v1.summary.image(
      name, pack_images(tensor, rows, cols), max_outputs=1)


def model_fn(features, labels, mode, params, config):
  """Builds the model function for use in an estimator.

  Args:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, config

  encoder = make_encoder(params["activation"],
                         params["latent_size"],
                         params["base_depth"])
  decoder = make_decoder(params["activation"],
                         params["latent_size"],
                         IMAGE_SHAPE,
                         params["base_depth"])
  latent_prior = make_mixture_prior(params["latent_size"],
                                    params["mixture_components"])

  image_tile_summary(
      "input", tf.cast(features, dtype=tf.float32), rows=1, cols=16)

  approx_posterior = encoder(features)
  approx_posterior_sample = approx_posterior.sample(params["n_samples"])
  decoder_likelihood = decoder(approx_posterior_sample)
  image_tile_summary(
      "recon/sample",
      tf.cast(decoder_likelihood.sample()[:3, :16], dtype=tf.float32),
      rows=3,
      cols=16)
  image_tile_summary(
      "recon/mean",
      decoder_likelihood.mean()[:3, :16],
      rows=3,
      cols=16)

  # `distortion` is just the negative log likelihood.
  distortion = -decoder_likelihood.log_prob(features)
  avg_distortion = tf.reduce_mean(input_tensor=distortion)
  tf.compat.v1.summary.scalar("distortion", avg_distortion)
  print("LOGLIKELIHOOD : ", -avg_distortion)

  rate = (approx_posterior.log_prob(approx_posterior_sample)
            - latent_prior.log_prob(approx_posterior_sample))
  
  avg_rate = tf.reduce_mean(input_tensor=rate)
  tf.compat.v1.summary.scalar("rate", avg_rate)

  elbo_local = -(rate + distortion)

  elbo = tf.reduce_mean(input_tensor=elbo_local)
  loss = -elbo
  tf.compat.v1.summary.scalar("elbo", elbo)

  importance_weighted_elbo = tf.reduce_mean(
      input_tensor=tf.reduce_logsumexp(input_tensor=elbo_local, axis=0) -
      tf.math.log(tf.cast(params["n_samples"], dtype=tf.float32)))
  tf.compat.v1.summary.scalar("elbo/importance_weighted",
                              importance_weighted_elbo)

  # Decode samples from the prior for visualization.
  random_image = decoder(latent_prior.sample(16))
  image_tile_summary(
      "random/sample",
      tf.cast(random_image.sample(), dtype=tf.float32),
      rows=4,
      cols=4)
  image_tile_summary("random/mean", random_image.mean(), rows=4, cols=4)

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.compat.v1.train.get_or_create_global_step()
  learning_rate = tf.compat.v1.train.cosine_decay(
      params["learning_rate"], global_step, params["max_steps"])
  tf.compat.v1.summary.scalar("learning_rate", learning_rate)
  optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
  train_op = optimizer.minimize(loss, global_step=global_step)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          "elbo":
              tf.compat.v1.metrics.mean(elbo),
          "elbo/importance_weighted":
              tf.compat.v1.metrics.mean(importance_weighted_elbo),
          "rate":
              tf.compat.v1.metrics.mean(avg_rate),
          "distortion":
              tf.compat.v1.metrics.mean(avg_distortion),
      },
  )

In [None]:
ROOT_PATH = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
FILE_TEMPLATE = "binarized_mnist_{split}.amat"

# ROOT_PATH = "https://people.cs.umass.edu/~marlin/data/"
# FILE_TEMPLATE = "caltech101_silhouettes_28_split1.mat"


def download(directory, filename):
  """Downloads a file."""
  filepath = os.path.join(directory, filename)
  if tf.io.gfile.exists(filepath):
    return filepath
  if not tf.io.gfile.exists(directory):
    tf.io.gfile.makedirs(directory)
  url = os.path.join(ROOT_PATH, filename)
  print("Downloading %s to %s" % (url, filepath))
  urllib.request.urlretrieve(url, filepath)
  return filepath


def static_mnist_dataset(directory, split_name):
  """Returns binary static MNIST tf.data.Dataset."""
  amat_file = download(directory, FILE_TEMPLATE.format(split=split_name))
  dataset = tf.data.TextLineDataset(amat_file)
  str_to_arr = lambda string: np.array([c == b"1" for c in string.split()])

  def _parser(s):
    booltensor = tf.compat.v1.py_func(str_to_arr, [s], tf.bool)
    reshaped = tf.reshape(booltensor, [28, 28, 1])
    return tf.cast(reshaped, dtype=tf.float32), tf.constant(0, tf.int32)

  return dataset.map(_parser)



def build_input_fns(data_dir, batch_size):
  """Builds an Iterator switching between train and heldout data."""

  # Build an iterator over training batches.
  def train_input_fn():
    dataset = static_mnist_dataset(data_dir, "train")
    dataset = dataset.shuffle(50000).repeat().batch(batch_size)
    return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()

  # Build an iterator over the heldout set.
  def eval_input_fn():
    eval_dataset = static_mnist_dataset(data_dir, "valid")
    eval_dataset = eval_dataset.batch(batch_size)
    return tf.compat.v1.data.make_one_shot_iterator(eval_dataset).get_next()

  return train_input_fn, eval_input_fn


def main(argv):
  del argv  # unused

  params = {}
  params["learning_rate"] = 0.001
  params["max_steps"] = 500
  params["latent_size"] = 16
  params["base_depth"] = 32
  params["activation"] = getattr(tf.nn, "leaky_relu")
  params["batch_size"] = 32
  params["n_samples"] = 16
  params["mixture_components"] = 10 #standard
  params["viz_steps"] = 4000

  # model_dir = "\content\Results"
  data_dir =os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "vae/data")

  train_input_fn, eval_input_fn = build_input_fns(data_dir, params["batch_size"])
  #train_1, eval_1 = load_omniglot(16)

  estimator = tf.estimator.Estimator(
      model_fn,
      params=params,
      config=None
  )

  estimator = tf.estimator.Estimator(
      model_fn,
      params=params,
      config=tf.estimator.RunConfig(
          model_dir="/content/Results",
          save_checkpoints_steps=params["viz_steps"],
      )
  )



  # for _ in range(params["max_steps"] // params["viz_steps"]):
  #   estimator.train(train_input_fn, steps=params["viz_steps"])
  #   eval_results = estimator.evaluate(eval_input_fn)
  #   print("Evaluation_results:\n\t%s\n " % eval_results)

  estimator.train(train_input_fn, steps=params["viz_steps"])
  print("##################### TRAINING FINISHED ############################")
  eval_results = estimator.evaluate(eval_input_fn)
  print("Evaluation_results:\n\t%s\n " % eval_results)

if __name__ == "__main__":
  tf.compat.v1.app.run()

In [None]:
logdir = "/content/Results/"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
%load_ext tensorboard
%tensorboard --logdir /content/Results/