In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LayerNormalization, MultiHeadAttention, Flatten, Resizing, Normalization,Input, add
import tensorflow_datasets as tfds

2024-06-14 19:40:33.655445: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-14 19:40:33.655501: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-14 19:40:33.655546: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-14 19:40:33.663910: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
BATCH_SIZE = 32

'/home/shriyam/Crux-r3/Task-3'

In [None]:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()


In [None]:
# for batch in train_data.take(1):
#     inp,out = batch
# print(inp)

In [None]:
class Patches(tf.keras.layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size
    
    def call(self, inputs):
        (batch_size, h,w,c) = inputs.shape
        patches = tf.image.extract_patches( inputs, sizes = (1,self.patch_size, self.patch_size, 1), strides = (1,self.patch_size, self.patch_size,1), rates = (1,1,1,1), padding="VALID")
        vert_num_patches = h // self.patch_size
        horiz_num_patches = w // self.patch_size
        print(patches)
        if(batch_size == None): 
            patches = tf.reshape(patches, ( BATCH_SIZE, vert_num_patches * horiz_num_patches, self.patch_size**2 * c ) )
        else:
            patches = tf.reshape(patches, ( batch_size, vert_num_patches * horiz_num_patches, self.patch_size**2 * c ) )

        return patches

class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projected_patches = Dense( projection_dim )
        self.pos_embedding = Embedding( input_dim = num_patches, output_dim = projection_dim )

    def call(self, patch):
        positions = tf.expand_dims( tf.range( start = 0, limit = self.num_patches, delta = 1 ), axis = 0 )
        projected_patches = self.projected_patches(patch)
        encoded = projected_patches + self.pos_embedding(positions)
        return encoded
    
class MLP(tf.keras.layers.Layer):
    def __init__(self, input_units, output_units):
        super().__init__()
        self.dense_1 = Dense(input_units)
        self.dense_2 = Dense(output_units)
    
    def call(self, x):
        x = self.dense_1(x)
        x = self.dense_2(x)
        return x

In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__( self, num_heads, projection_dim, mlp_input_units, mlp_output_units ):
        super().__init__()
        self.layerNorm_1 = LayerNormalization()
        self.layerNorm_2 = LayerNormalization()
        self.mha = MultiHeadAttention( num_heads=num_heads, key_dim=projection_dim )
        self.mlp = MLP(mlp_input_units, mlp_output_units)
    def call(self, input):
        x = self.layerNorm_1(input)
        x = self.mha(x,x)
        attention = add([input,x])
        x = self.layerNorm_2(attention)
        x = self.mlp(x)
        x = add([attention, x])

        return x

In [None]:

class ViT(tf.keras.Model):
    def __init__(self, img_size, num_layers, patch_size, num_patches, projection_dim, num_heads, transformer_mlp_input_units, transformer_mlp_output_units, mlp_input_units, mlp_output_units, num_classes):
        super().__init__()
        self.inputs = Input( shape= (32,32,3) )
        self.resizing = Resizing(img_size, img_size)
        self.patches = Patches(patch_size)
        self.patch_encoder = PatchEncoder(num_patches, projection_dim)
        self.encoder_layers = []
        self.layerNorm = LayerNormalization()
        self.flatten = Flatten()
        self.mlp = MLP(mlp_input_units, mlp_output_units)
        self.out = Dense(num_classes)

        for i in range(num_layers):
            self.encoder_layers.append(Encoder( num_heads, projection_dim, transformer_mlp_input_units, transformer_mlp_output_units ))
        
    def call(self, x):
        x = self.resizing(x)
        x = self.patches(x)
        x = self.patch_encoder(x)
        for i in range( len(self.encoder_layers) ):
            x = self.encoder_layers[i](x)
        x = self.layerNorm(x)
        x = self.flatten(x)
        x = self.mlp(x)
        x = self.out(x)
        return x

In [None]:
inp_shape = (32,32,3)
num_layers = 8
patch_size = 6
img_size = 72
num_patches = (img_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_mlp_input_units = projection_dim * 2
transformer_mlp_output_units = projection_dim
mlp_input_units = 2048
mlp_output_units = 1024
num_classes = 100

vision_transformer = ViT(img_size, num_layers, patch_size, num_patches, projection_dim, num_heads, transformer_mlp_input_units, transformer_mlp_output_units, mlp_input_units, mlp_output_units, num_classes)

vision_transformer.compile(
    optimizer=tf.keras.optimizers.AdamW(weight_decay = 0.0001, learning_rate = 0.001),
    loss= tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
    metrics=['accuracy']
)


2024-06-14 19:36:11.393854: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-06-14 19:36:11.708408: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-06-14 19:36:11.708460: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-06-14 19:36:11.717971: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-06-14 19:36:11.718028: I tensorflow/compile

In [None]:

history = vision_transformer.fit(
    x = x_train,
    y = y_train,
    batch_size=32,
    epochs=10,
)

Epoch 1/10
Tensor("vi_t/patches/ExtractImagePatches:0", shape=(None, 12, 12, 108), dtype=float32)
Tensor("vi_t/patches/ExtractImagePatches:0", shape=(None, 12, 12, 108), dtype=float32)


2024-06-14 19:36:23.492831: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:442] Loaded cuDNN version 8907
2024-06-14 19:36:23.745359: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7efca9b98050 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-06-14 19:36:23.745400: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3050 Ti Laptop GPU, Compute Capability 8.6
2024-06-14 19:36:23.773977: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-06-14 19:36:24.015465: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.




InvalidArgumentError: Graph execution error:

Detected at node vi_t/patch_encoder/dense/Tensordot/Reshape defined at (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code

  File "/home/shriyam/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>

  File "/home/shriyam/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1053, in launch_instance

  File "/home/shriyam/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 737, in start

  File "/home/shriyam/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start

  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/home/shriyam/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 524, in dispatch_queue

  File "/home/shriyam/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 513, in process_one

  File "/home/shriyam/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 418, in dispatch_shell

  File "/home/shriyam/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 758, in execute_request

  File "/home/shriyam/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 426, in do_execute

  File "/home/shriyam/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/home/shriyam/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3024, in run_cell

  File "/home/shriyam/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell

  File "/home/shriyam/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/home/shriyam/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async

  File "/home/shriyam/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes

  File "/home/shriyam/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code

  File "/tmp/ipykernel_359318/2273860336.py", line 1, in <module>

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 1783, in fit

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 1377, in train_function

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 1360, in step_function

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 1349, in run_step

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 1126, in train_step

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 589, in __call__

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/tmp/ipykernel_359318/3023161626.py", line 20, in call

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/tmp/ipykernel_359318/1292755962.py", line 25, in call

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/home/shriyam/.local/lib/python3.10/site-packages/keras/src/layers/core/dense.py", line 244, in call

Input to reshape is a tensor with 248832 values, but the requested shape has 497664
	 [[{{node vi_t/patch_encoder/dense/Tensordot/Reshape}}]] [Op:__inference_train_function_16155]