In [8]:
import tensorflow as tf
from tensorflow.keras import layers

# 1. Dataset preparation (CIFAR-10 as an example)
def load_data():
    (train_images, _), (_, _) = tf.keras.datasets.cifar10.load_data()  # Load the CIFAR-10 dataset
    train_images = train_images.astype("float32") / 255.0  # Normalize images to [0, 1]
    return train_images

train_images = load_data()

# 2. Vision Transformer (ViT) model
class ViTImageGenerator(tf.keras.Model):
    def __init__(self, img_size=32, patch_size=4, embed_dim=64, num_heads=4, num_layers=6):
        super(ViTImageGenerator, self).__init__()

        # Parameters for the patch embedding
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2  # Total number of patches

        # Dense layer to embed the patches
        self.patch_embed = layers.Dense(embed_dim)

        # Positional embedding to retain spatial information
        self.position_embed = tf.Variable(tf.random.normal([num_patches, embed_dim]))

        # Transformer layers for processing the patches
        self.encoder_layers = [
            layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) for _ in range(num_layers)
        ]
        self.norm_layers = [layers.LayerNormalization() for _ in range(num_layers)]

        # Output layer to generate the image (scaled to [0,1] using sigmoid)
        self.fc = layers.Dense(img_size * img_size * 3, activation='sigmoid')

    def call(self, x):
        # Step 1: Extract patches from the input image
        patches = tf.image.extract_patches(
            images=x,
            sizes=[1, self.patch_size, self.patch_size, 1],  # Patch size
            strides=[1, self.patch_size, self.patch_size, 1],  # Move by patch size
            rates=[1, 1, 1, 1],
            padding='VALID'
        )

        # Get the batch size dynamically using tf.shape
        batch_size = tf.shape(x)[0]

        # Reshape patches to [batch, patches, channels], using batch_size
        patches = tf.reshape(patches, (batch_size, -1, patches.shape[-1]))

        # Step 2: Embed patches and add positional information
        x = self.patch_embed(patches)
        x += self.position_embed  # Adding position embedding

        # Step 3: Pass the embedded patches through transformer layers
        for enc_layer, norm_layer in zip(self.encoder_layers, self.norm_layers):
            x = norm_layer(x + enc_layer(x, x))  # Apply attention + residual connection

        # Step 4: Reconstruct the image from the processed patches
        x = self.fc(x)
        return tf.reshape(x, (-1, 32, 32, 3))  # Reshape the output back to image dimensions

# 3. Model training
def train_vit():
    # Instantiate the ViT model
    vit_model = ViTImageGenerator()
    vit_model.compile(optimizer='adam', loss='mse')  # Compile the model with Adam optimizer and MSE loss

    # Train the model (autoencoder style) to reconstruct the input image
    vit_model.fit(train_images, train_images, epochs=10, batch_size=64)

train_vit()  # Start training


Epoch 1/10


InvalidArgumentError: Graph execution error:

Detected at node compile_loss/mse/sub defined at (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code

  File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>

  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start

  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start

  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>

  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "<ipython-input-8-245532aafb4d>", line 73, in <cell line: 73>

  File "<ipython-input-8-245532aafb4d>", line 71, in train_vit

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 318, in fit

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 121, in one_step_on_iterator

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 108, in one_step_on_data

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 54, in train_step

  File "/usr/local/lib/python3.10/dist-packages/keras/src/trainers/trainer.py", line 357, in _compute_loss

  File "/usr/local/lib/python3.10/dist-packages/keras/src/trainers/trainer.py", line 325, in compute_loss

  File "/usr/local/lib/python3.10/dist-packages/keras/src/trainers/compile_utils.py", line 609, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/trainers/compile_utils.py", line 645, in call

  File "/usr/local/lib/python3.10/dist-packages/keras/src/losses/loss.py", line 43, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/losses/losses.py", line 27, in call

  File "/usr/local/lib/python3.10/dist-packages/keras/src/losses/losses.py", line 1286, in mean_squared_error

Incompatible shapes: [64,32,32,3] vs. [4096,32,32,3]
	 [[{{node compile_loss/mse/sub}}]] [Op:__inference_one_step_on_iterator_63275]

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

# 1. Dataset preparation (CIFAR-10 as an example)
def load_data():
    (train_images, _), (_, _) = tf.keras.datasets.cifar10.load_data()  # Load the CIFAR-10 dataset
    train_images = train_images.astype("float32") / 255.0  # Normalize images to [0, 1]
    return train_images

train_images = load_data()

# 2. Vision Transformer (ViT) model
class ViTImageGenerator(tf.keras.Model):
    def __init__(self, img_size=32, patch_size=4, embed_dim=64, num_heads=4, num_layers=6):
        super(ViTImageGenerator, self).__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2  # Total number of patches

        # Dense layer to embed the patches
        self.patch_embed = layers.Dense(embed_dim)

        # Positional embedding to retain spatial information
        self.position_embed = tf.Variable(tf.random.normal([num_patches, embed_dim]))

        # Transformer layers for processing the patches
        self.encoder_layers = [
            layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) for _ in range(num_layers)
        ]
        self.norm_layers = [layers.LayerNormalization() for _ in range(num_layers)]

        # Output layer to generate the image (scaled to [0,1] using sigmoid)
        self.fc = layers.Dense(patch_size * patch_size * 3, activation='sigmoid') # Output per patch

    def call(self, x):
        # Step 1: Extract patches from the input image
        patches = tf.image.extract_patches(
            images=x,
            sizes=[1, self.patch_size, self.patch_size, 1],  # Patch size
            strides=[1, self.patch_size, self.patch_size, 1],  # Move by patch size
            rates=[1, 1, 1, 1],
            padding='VALID'
        )

        # Get the batch size dynamically using tf.shape
        batch_size = tf.shape(x)[0]

        # Reshape patches to [batch, patches, channels], using batch_size
        patches = tf.reshape(patches, (batch_size, -1, patches.shape[-1]))

        # Step 2: Embed patches and add positional information
        x = self.patch_embed(patches)
        x += self.position_embed  # Adding position embedding

        # Step 3: Pass the embedded patches through transformer layers
        for enc_layer, norm_layer in zip(self.encoder_layers, self.norm_layers):
            x = norm_layer(x + enc_layer(x, x))  # Apply attention + residual connection

        # Step 4: Reconstruct the image from the processed patches
        x = self.fc(x) # Apply fc layer to each patch

        # Reshape the output to image dimensions
        x = tf.reshape(x, (batch_size, self.img_size // self.patch_size, self.img_size // self.patch_size, self.patch_size * self.patch_size * 3))

        # Reshape to the original image size
        reconstructed = tf.nn.depth_to_space(x, self.patch_size)

        return reconstructed # Reshape the output back to image dimensions

# 3. Model training
def train_vit():
    # Instantiate the ViT model
    vit_model = ViTImageGenerator()
    vit_model.compile(optimizer='adam', loss='mse')  # Compile the model with Adam optimizer and MSE loss

    # Train the model (autoencoder style) to reconstruct the input image
    vit_model.fit(train_images, train_images, epochs=10, batch_size=32)

    return vit_model

# Train the model
vit_model = train_vit()

Epoch 1/10
[1m 129/1563[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m7:49[0m 327ms/step - loss: 0.0682