In [None]:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Flatten, Reshape, Layer
from tensorflow.keras.models import Model

# Load CIFAR-10 dataset
(x_train, _), (x_test, _) = cifar10.load_data()

# Normalize the data to the range [0, 1]
x_train = x_train.astype(np.float32) / 255.0
x_test = x_test.astype(np.float32) / 255.0

# Reshape data for consistency
x_train = x_train.reshape(-1, 32, 32, 3)
x_test = x_test.reshape(-1, 32, 32, 3)

class VectorQuantizer(Layer):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.embedding = self.add_weight(
            name='embedding', shape=(self.num_embeddings, self.embedding_dim), initializer='uniform', trainable=True
        )

    def call(self, inputs):
        # Calculate distances between input vectors and embedding vectors
        flattened_inputs = tf.reshape(inputs, [-1, self.embedding_dim])
        distances = tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True) - 2 * tf.matmul(flattened_inputs, self.embedding, transpose_b=True)
        encoding_indices = tf.argmin(distances, axis=1)
        encoding_indices = tf.reshape(encoding_indices, tf.shape(inputs)[:-1])

        # Quantize
        quantized = tf.gather(self.embedding, encoding_indices)

        # Commitment loss
        commitment_loss = self.commitment_cost * tf.reduce_mean((quantized - inputs) ** 2)

        # Quantization error
        quantized = inputs + tf.stop_gradient(quantized - inputs)

        return quantized, commitment_loss

class Encoder(Model):
    def __init__(self, latent_dim=64):
        super(Encoder, self).__init__()
        self.conv1 = Conv2D(32, (4, 4), strides=2, padding="same", activation="relu")
        self.conv2 = Conv2D(64, (4, 4), strides=2, padding="same", activation="relu")
        self.conv3 = Conv2D(128, (4, 4), strides=2, padding="same", activation="relu")
        self.flatten = Flatten()
        self.dense = Dense(latent_dim)

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

class Decoder(Model):
    def __init__(self, output_shape=(32, 32, 3)):
        super(Decoder, self).__init__()
        self.dense = Dense(8 * 8 * 128, activation="relu")
        self.reshape = Reshape((8, 8, 128))
        self.convT1 = Conv2DTranspose(128, (4, 4), strides=2, padding="same", activation="relu")
        self.convT2 = Conv2DTranspose(64, (4, 4), strides=2, padding="same", activation="relu")
        self.convT3 = Conv2DTranspose(32, (4, 4), strides=2, padding="same", activation="relu")
        self.convT4 = Conv2DTranspose(3, (3, 3), strides=1, padding="same", activation="sigmoid")

    def call(self, inputs):
        x = self.dense(inputs)
        x = self.reshape(x)
        x = self.convT1(x)
        x = self.convT2(x)
        x = self.convT3(x)
        x = self.convT4(x)
        return x

class VQVAE(Model):
    def __init__(self, encoder, decoder, quantizer):
        super(VQVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.quantizer = quantizer

    def call(self, inputs):
        z = self.encoder(inputs)
        quantized, commitment_loss = self.quantizer(z)
        reconstructed = self.decoder(quantized)
        return reconstructed, commitment_loss

# Hyperparameters
latent_dim = 64
num_embeddings = 512
embedding_dim = 64
commitment_cost = 0.25
batch_size = 64
epochs = 10

# Instantiate the components
encoder = Encoder(latent_dim)
decoder = Decoder()
quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
vqvae = VQVAE(encoder, decoder, quantizer)

# Optimizer and loss function
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss = tf.keras.losses.MeanSquaredError()

# Training loop
for epoch in range(epochs):
    for batch in range(0, len(x_train), batch_size):
        x_batch = x_train[batch:batch+batch_size]

        with tf.GradientTape() as tape:
            reconstructed, commitment_loss = vqvae(x_batch)
            reconstruction_loss = mse_loss(x_batch, reconstructed)
            total_loss = reconstruction_loss + commitment_loss

        grads = tape.gradient(total_loss, vqvae.trainable_variables)
        optimizer.apply_gradients(zip(grads, vqvae.trainable_variables))

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss.numpy()}")

# Generate new images after training
def generate_images(model, num_images=10):
    noise = np.random.randn(num_images, 32, 32, 3)
    reconstructed_images = model.decoder(noise)
    return reconstructed_images

generated_images = generate_images(vqvae)
plt.imshow(generated_images[0])
plt.show()
