# Building a PixelCNN model from scratch

There are **three main categories** of deep neural network **generative** algorithms:
- Generative Adversarial Networks
- Variational Autoencoders
- Autoregressive models

## Autoregressive Models

In machine learning terminology _regress_ means _predict new values_. I.e. "autoregressive model" means we use a model to predict new data points based on the model's past data points.

An assumption we make here is that the value of a pixel depends only on that of the pixel before it. In other words, a pixel is conditioned only on the pixel before it, that is: $ p(x_i) = p(x_i | x_{i-1})p(x_{i-1}) $
So, the joint probability will be the product of conditional probabilities:<br>
$p(x)=p(x_n, x_{n-1}, ..., x_2, x_1)$
<br>
$p(x)=p(x_n|x_{n-1})...p(x_3|x_2)p(x_2|x_1)p(x_1)$

_in simple words, say we have images with red apples surrounded by gree leaves. If a color of a pixel at (0,0) is green, the the probability that the next pixel's color at (0, 1) is green is high._

**PixelRNN** - recurrent NN based on LSTM. Reads the image one row at a time in a step in the LSTM and processes it with a 1D convo layer, then feeds the activations into subsequent layers. Slow. Fell out of fasion.

**PixelCNN** - made up of convolutional layers, making it a lot faster than PixelRNN.

## PixelCNN

### Masking
When performing a convolution to predict the current pixel, a conventional convolution kernel is able to see the current input pixel together with the surrounding pixels, including future pixels - this breaks our conditional probability assumption. 
To avoid that we use Masking - a mask is applied to the convolutional kernel weights before performing convolution (_see the pages 18-19 of the book for more details_)

In [3]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential

import tensorflow_datasets as tfds

import numpy as np
import matplotlib as plt

from tqdm.keras import TqdmCallback

import warnings
warnings.filterwarnings('ignore')

from packaging.version import parse as parse_version
assert parse_version(tf.__version__) < parse_version("2.5.1"), \
    f"Please install TensorFlow version 2.3.1 or older. Your current version is {tf.__version__}."

### Load MNIST dataset

In [2]:
(ds_train, ds_test), ds_info = tfds.load('mnist', split=['test', 'test'], shuffle_files=True, as_supervised=True, with_info=True)

def binarize(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.math.round(image/255.)
    return image, tf.cast(image, tf.int32)

ds_train = ds_train.map(binarize)
ds_train = ds_train.cache() # put dataset into memory
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(64)
ds_test = ds_test.map(binarize).batch(64).cache().prefetch(64)

### Create Custom Layers and PixelCNN

In [3]:
class MaskedConv2D(layers.Layer):
    def __init__(self, mask_type, kernel = 5, filters = 1):
        super(MaskedConv2D, self).__init__()
        self.kernel = kernel
        self.filters = filters
        self.mask_type = mask_type
        
    def build(self, input_shape):
        
        self.w = self.add_weight(shape=[self.kernel, 
                                        self.kernel, 
                                        input_shape[-1], 
                                        self.filters],
                                 initializer='glorot_normal', 
                                 trainable=True)
        
        self.b = self.add_weight(shape=(self.filters,),
                                initializer = 'zeros',
                                trainable = True)
        # Creating the Mask
        mask = np.ones(self.kernel**2, dtype=np.float32)
        center = len(mask)//2
        mask[center+1:] = 0
        if self.mask_type == 'A':
            mask[center] = 0
            
        mask = mask.reshape((self.kernel, self.kernel, 1, 1))
        
        self.mask = tf.constant(mask, dtype='float32')
        
    def call(self, inputs):
        # mask the convolution
        masked_w = tf.math.multiply(self.w, self.mask)
        
        # perform conv2d using low level API
        output = tf.nn.conv2d(inputs, masked_w, 1, "SAME") + self.b
        
        return tf.nn.relu(output)
    
class ResidualBlock(layers.Layer):
    def __init__(self, h=32):
        super(ResidualBlock, self).__init__()
        
        self.forward = Sequential([MaskedConv2D('B', kernel=1, filters=h),
                                   MaskedConv2D('B', kernel=3, filters=h),
                                   MaskedConv2D('B', kernel=1, filters=2*h)])

    def call(self, inputs):
        x = self.forward(inputs)
        return x + inputs
        
def SimplePixelCnn(hidden_features = 64,
                   output_features = 64,
                   resblocks_num = 7):
    
    inputs = layers.Input(shape=[28, 28, 1])
    x = inputs
    
    x = MaskedConv2D('A', kernel=7, filters=2*hidden_features)(x)
    
    for _ in range(resblocks_num):
        x = ResidualBlock(hidden_features)(x)
        
    x = layers.Conv2D(output_features, (1, 1), padding='same', activation='relu')(x)
    x = layers.Conv2D(1, (1, 1), padding='same', activation='sigmoid')(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x, name='PixelCnn')

pixel_cnn = SimplePixelCnn()
pixel_cnn.summary()

Model: "PixelCnn"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
masked_conv2d (MaskedConv2D) (None, 28, 28, 128)       6400      
_________________________________________________________________
residual_block (ResidualBloc (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_1 (ResidualBl (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_2 (ResidualBl (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_3 (ResidualBl (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_4 (ResidualBl (None, 28, 28, 128)       535

In [None]:
pixel_cnn.compile(
    loss = tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
    metrics=[tf.keras.losses.BinaryCrossentropy()])

pixel_cnn.fit(ds_train, epochs = 50, validation_data=ds_test, callbacks=[TqdmCallback(verbose=1)])

In [1]:
grid_row = 5
grid_col = 5
batch = grid_row * grid_col
h = w = 28
images = np.ones((batch,h,w,1), dtype=np.float32)

for row in range(h):

    for col in range(w):

        prob = pixel_cnn.predict(images)[:,row,col,0]

        pixel_samples = tf.random.categorical(tf.math.log(np.stack([1-prob, prob],1)), 1)

        images[:,row,col,0] = tf.reshape(pixel_samples,[batch])

NameError: name 'np' is not defined

In [None]:
# Display
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*1.1,grid_row))

i = 0
for row in range(grid_row):
    for col in range(grid_col):
        axarr[row,col].imshow(images[i,:,:,0], cmap='gray')
        axarr[row,col].axis('off')
        i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)        
plt.show()

## Sample and Display Images

In [None]:
grid_row = 5
grid_col = 5
batch = grid_row * grid_col
h = w = 28
images = np.ones((batch,h,w,1), dtype=np.float32)

for row in range(h):

    for col in range(w):

        prob = pixel_cnn.predict(images)[:,row,col,0]

        pixel_samples = tf.random.categorical(tf.math.log(np.stack([1-prob, prob],1)), 1)

        images[:,row,col,0] = tf.reshape(pixel_samples,[batch])

In [None]:
# Display
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*1.1,grid_row))

i = 0
for row in range(grid_row):
    for col in range(grid_col):
        axarr[row,col].imshow(images[i,:,:,0], cmap='gray')
        axarr[row,col].axis('off')
        i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)        
plt.show()