In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np


In [2]:
class PixelRNN(tf.keras.Model):
  def __init__(self, hidden_dim=128, num_pixels=256):
    super(PixelRNN, self).__init__()
    # Define the hidden dimension and number of pixels
    self.hidden_dim = hidden_dim
    self.num_pixels = num_pixels
    # Define the LSTM layer with the specified hidden dimension
    self.lstm = layers.LSTM(hidden_dim, return_sequences=True)
    # Define the Dense layer with softmax activation function
    self.dense = layers.Dense(num_pixels)
  def call(self, x):
    # Pass the input through the LSTM layer
    x = self.lstm(x)
    # Pass the output of LSTM layer through the Dense layer
    return self.dense(x)

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Normalize the pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
x_train = x_train.reshape((-1, 3072, 1))
x_test = x_test.reshape((-1, 3072, 1))

In [5]:
x_train_inputs = x_train[:, :-1]
x_train_targets = x_train[:, 1:]
x_test_inputs = x_test[:, :-1]
x_test_targets = x_test[:, 1:]


In [6]:
model = PixelRNN()


In [7]:
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

In [8]:
model.fit(x_train_inputs, x_train_targets, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x79d58ed55600>

In [9]:
model.evaluate(x_test_inputs, x_test_targets)



[0.021195199340581894, 0.008212113752961159]

In [10]:
predictions = model.predict(x_test_inputs[:10])



In [23]:
predictions.shape

TensorShape([10, 786176])

The output shape (10, 786176) means that we are getting a sequence of 2 timesteps, where each timestep has a height of 32 pixels, a width of 32 pixels, and 2 feature maps (or "channels"), and each feature map has 256 values.

Specifically, the shape of the output is (batch_size, timesteps, height, width, channels), where batch_size is the number of samples in the batch (in our case, 10), timesteps is the number of timesteps in the sequence (in our case, 2), height is the height of each timestep (in our case, 32 pixels), width is the width of each timestep (in our case, 32 pixels), and channels is the number of feature maps (in our case, 2).

Therefore, the total number of elements in the output array is timesteps * height * width * channels, which is equal to 2 * 32 * 32 * 2 * 256 = 786176 in our case.

##But there should be 3 channels !!!
So there's some error in model architecture.

In [24]:
class PixelRNN(tf.keras.Model):
  def __init__(self, hidden_dim=128, num_pixels=256):
    super(PixelRNN, self).__init__()
    # Define the hidden dimension and number of pixels
    self.hidden_dim = hidden_dim
    self.num_pixels = num_pixels
    # Define the LSTM layer with the specified hidden dimension
    self.lstm = layers.LSTM(hidden_dim, return_sequences=True)
    # Define the Dense layer with softmax activation function
    self.dense = layers.Dense(num_pixels * 3)

  def call(self, x):
    # Pass the input through the LSTM layer
    x = self.lstm(x)
    # Pass the output of LSTM layer through the Dense layer
    return self.dense(x)

In [25]:
model = PixelRNN()

In [26]:
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

In [27]:
model.fit(x_train_inputs, x_train_targets, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x79d5805731f0>

In [28]:
model.evaluate(x_test_inputs, x_test_targets)



[0.02342275157570839, 0.00800029281526804]

In [29]:
predictions = model.predict(x_test_inputs[:10])



In [31]:
predicted_images = predictions.reshape((-1, 32, 32, 3))

ValueError: cannot reshape array of size 23585280 into shape (32,32,2)

The total number of elements in the output array is timesteps * height * width * channels, which is equal to 3 * 32 * 32 * 3 * 256 = 786176 in our case.

In this case idk why it is taking timestamp = 3.