<a href="https://colab.research.google.com/github/Ahtesham519/Genrative_Deep_learning_v2_2023/blob/main/Energy_Based_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np

import tensorflow as tf
from tensorflow.keras import (
    datasets,
    layers,
    models,
    optimizers,
    activations,
    metrics,
    callbacks,

)
import random

#0.  Parameters

In [2]:
IMAGE_SIZE = 32
CHANNELS = 1
STEP_SIZE = 10
STEPS = 60
NOISE = 0.005
ALPHA = 0.1
GRADIENT_CLIP = 0.03
BATCH_SIZE = 8192
LEARNING_RATE = 0.0001
EPOCHS = 60
LOAD_MODEL = False


In [3]:
#Load the data
(x_train, _), (x_test, _) = datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [4]:
#Preprocess the data

def preprocess(imgs):
  """
  Normalize and reshape the images
  """
  imgs = (imgs.astype("float32") - 127.5)/ 127.5
  imgs = np.pad(imgs, ((0,0) ,( 2,2) , (2,2)) , constant_values = -1.0)
  imgs = np.expand_dims(imgs, -1)
  return imgs

  x_train = preprocess(x_train)
  x_test = preprocess(x_test)



In [5]:
x_train = tf.data.Dataset.from_tensor_slices(x_train).batch(BATCH_SIZE)
x_test = tf.data.Dataset.from_tensor_slices(x_test).batch(BATCH_SIZE)

In [8]:
#Show some items of the clothing from the training set
x_train

<_BatchDataset element_spec=TensorSpec(shape=(None, 28, 28), dtype=tf.uint8, name=None)>

#2. Build the EBM network

In [9]:
ebm_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE , CHANNELS))
x = layers.Conv2D(
    16,kernel_size = 5, strides=2, padding= "same", activation= activations.swish
)(ebm_input)
x = layers.Conv2D(
    32, kernel_size = 3, strides= 2 , padding = "same", activation = activations.swish
)(x)
x = layers.Conv2D(
    64, kernel_size = 3, strides=2, padding = "same" , activation = activations.swish
)(x)
x = layers.Conv2D(
    64, kernel_size = 3, strides = 2, padding = "same", activation = activations.swish
)(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation = activations.swish)(x)
ebm_output = layers.Dense(1)(x)
model = models.Model(ebm_input,ebm_output)
model.summary()


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 16, 16, 16)        416       
                                                                 
 conv2d_1 (Conv2D)           (None, 8, 8, 32)          4640      
                                                                 
 conv2d_2 (Conv2D)           (None, 4, 4, 64)          18496     
                                                                 
 conv2d_3 (Conv2D)           (None, 2, 2, 64)          36928     
                                                                 
 flatten (Flatten)           (None, 256)               0         
                                                                 
 dense (Dense)               (None, 64)                16448 

In [10]:
if LOAD_MODEL:
  model.load_weights("./models/model.h5")

#2. Set up a Langevin sampler function

In [11]:
#Function to generate samples using Langevin Dynamics
def generate_samples(
    model, inps_imgs, steps, step_size, noise, return_img_per_step = False
):
    imgs_per_step = []
    for _ in range(steps):
      inp_imgs += tf.random.normal(inp_imgs.shape, mean = 0 , stddev = noise)
      inp_imgs = tf.clip_by_value(inp_imgs, -1.0 , 1.0)
      with tf.GradientTape() as tape:
        tape.watch(inp_imgs)
        out_score = model(inp_imgs)
      grads = tape.gradient(out_score, inp_imgs)
      grads = tf.clip_by_value(grads, -GRADIENT_CLIP , GRADIENT_CLIP)
      inp_imgs += step_size * grads
      inp_imgs = tf.clip_by_value(inp_imgs, -1.0 , 1.0)
      if return_img_per_step:
        imgs_per_step.append(inp_imgs)
    if return_img_per_step:
      return tf.stack(imgs_per_step, axis = 0)
    else:
      return inp_imgs