<a href="https://colab.research.google.com/github/Ahtesham519/Genrative_Deep_learning_v2_2023/blob/main/PixelCNN_md.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np

import tensorflow as tf
from tensorflow.keras import datasets, layers, models, optimizers, callbacks
import tensorflow_probability as tfp


#0. Parameters

In [None]:
IMAGE_SIZE = 32
N_COMPONENTS = 5
EPOCHS = 10
BATCH_SIZE = 128

# 1. Prepare the data

In [None]:
#Load the data
(x_train , _) , (_,_) = datasets.fashion_mnist.load_data()

In [None]:
#preprocess the data

def preprocess(imgs):
  imgs = np.expand_dims(imgs , -1)
  imgs = tf.image.resize(imgs , (IMAGE_SIZE, IMAGE_SIZE)).numpy()
  return imgs


input_data = preprocess(x_train)

#2. Bulid the PixelCNN

In [None]:
#Define a Pixel CNN network

dist = tfp.distributations.PixelCNN(
    image_shape = (IMAGE_SIZE , IMAGE_SIZE, 1),
    num_resnet = 1,
    num_hierarchies = 2,
    num_filters = 32,
    num_logistic_mix = N_COMPONENTS,
    dropout_p = 0.3,
)

#Define the model input =
image_input = layers.Input(shape = (IMAGE_SIZE , IMAGE_SIZE , 1))

#Define the log likelihood for the loss fn
log_prob = dist.log_prob(image_input)

#define the model
pixelcnn = models.Model(inputs = image_input , outputs = log_prob)
pixelcnn.add_loss(-tf.reduce_mean(log_prob))

##.Train the PixelCNN

In [None]:
#Compile and train the model
pixelcnn.compile(
    optimizer = optimizers.Adam(0.001),

)

In [None]:
tensorboard_callback = callbacks.TensorBoard(log_dir = "./logs")


class ImageGenerator(callbacks.Callback):
  def __init__(self, num_img):
    self.num_img = num_img

  def generate(self):
    return dist.sample(self.num_img).numpy()

  def on_epoch_end(self, epoch , logs = None):
    generated_images = self.generate()
    display(
        generated_images ,
        n= self.num_img,
        save_to = "./output.generated_img_%03d.png" % (epoch),
    )

img_generator_callback = ImageGenerator(num_img = 2)

In [None]:
#4. Generate images
generated_images = img_generator_callback.generate()