<a href="https://colab.research.google.com/github/SamarthAdat/PixelCNN/blob/main/PixelCNN_for_Single_Image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from PIL import Image
from google.colab import files
from tensorflow.keras import layers


In [None]:
# Load the input image you want to use for generating a new image
input_image = Image.open("img.jpg")  # Replace "input_image.png" with your image file path
input_image = np.array(input_image)
input_shape = input_image.shape

In [None]:
# Model / data parameters
num_classes = 10
n_residual_blocks = 5

In [None]:
class PixelConvLayer(layers.Layer):
    def __init__(self, mask_type, **kwargs):
        super().__init__()
        self.mask_type = mask_type
        self.conv = layers.Conv2D(**kwargs)

    def build(self, input_shape):
        # Build the conv2d layer to initialize kernel variables
        self.conv.build(input_shape)
        # Use the initialized kernel to create the mask
        kernel_shape = self.conv.kernel.get_shape()
        self.mask = np.zeros(shape=kernel_shape)
        self.mask[: kernel_shape[0] // 2, ...] = 1.0
        self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
        if self.mask_type == "B":
            self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0

    def call(self, inputs):
        self.conv.kernel.assign(self.conv.kernel * self.mask)
        return self.conv(inputs)

In [None]:
class ResidualBlock(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = keras.layers.Conv2D(
            filters=filters, kernel_size=1, activation="relu"
        )
        self.pixel_conv = PixelConvLayer(
            mask_type="B",
            filters=filters // 2,
            kernel_size=3,
            activation="relu",
            padding="same",
        )
        self.conv2 = keras.layers.Conv2D(
            filters=filters, kernel_size=1, activation="relu"
        )

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.pixel_conv(x)
        x = self.conv2(x)
        return keras.layers.add([inputs, x])

In [None]:
patch_size = (32, 32)  # Define the patch size
num_patches_x = input_shape[0] // patch_size[0]
num_patches_y = input_shape[1] // patch_size[1]

inputs = keras.Input(shape=(patch_size[0], patch_size[1], input_shape[2]))
x = PixelConvLayer(
    mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
)(inputs)

for _ in range(n_residual_blocks):
    x = ResidualBlock(filters=128)(x)

for _ in range(2):
    x = PixelConvLayer(
        mask_type="B",
        filters=128,
        kernel_size=1,
        strides=1,
        activation="relu",
        padding="valid",
    )(x)

out = keras.layers.Conv2D(
    filters=input_shape[2],  # Match the number of channels in the input image
    kernel_size=1,
    strides=1,
    activation="sigmoid",
    padding="valid"
)(x)


IndexError: ignored

In [None]:
# Compile the model
pixel_cnn = keras.Model(inputs, out)
adam = keras.optimizers.Adam(learning_rate=0.0005)
pixel_cnn.compile(optimizer=adam, loss="binary_crossentropy", metrics=["accuracy"])

# Prepare patches from the input image
patch_size = (32, 32)  # Define the patch size to match input_shape
num_patches_x = input_shape[0] // patch_size[0]
num_patches_y = input_shape[1] // patch_size[1]

patches = []
for i in range(num_patches_x):
    for j in range(num_patches_y):
        patch = input_image[i * patch_size[0]: (i + 1) * patch_size[0], j * patch_size[1]: (j + 1) * patch_size[1]]
        patches.append(patch)

patches = np.array(patches)

pixel_cnn.summary()
pixel_cnn.fit(
    x=patches, y=patches, batch_size=128, epochs=50, verbose=2
)


Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 pixel_conv_layer_8 (PixelC  (None, 32, 32, 128)       18944     
 onvLayer)                                                       
                                                                 
 residual_block_5 (Residual  (None, 32, 32, 128)       98624     
 Block)                                                          
                                                                 
 residual_block_6 (Residual  (None, 32, 32, 128)       98624     
 Block)                                                          
                                                                 
 residual_block_7 (Residual  (None, 32, 32, 128)       98624     
 Block)kkk                                                 

KeyboardInterrupt: ignored