In [1]:
import tensorflow as tf
from dataset_preprocess import train_ds, val_ds
from transformers import TFSegformerForSemanticSegmentation
from transformers import create_optimizer
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from transformers.keras_callbacks import KerasMetricCallback
import pandas as pd
import evaluate
# from tensorflow.keras.models import load_model

img_0001.png
(<tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/png_images/IMAGES/img_0001.png'>, <tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/png_masks/MASKS/seg_0001.png'>)
(<tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/png_images/IMAGES/img_0002.png'>, <tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/png_masks/MASKS/seg_0002.png'>)
(<tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/png_images/IMAGES/img_0004.png'>, <tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/png_masks/MASKS/seg_0004.png'>)
(<tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/val_dataset/png_images/IMAGES/img_0003.png'>, <tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/val_dataset/png_masks/MASKS/seg_0003.png'>)
(<tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/val_dataset/png_images/IMAGES/img_0010.png'>, <tf.Tensor: shape=(), dtype=string, numpy=b'../dataset/val_dataset/png_masks/MASKS/seg_0010.png'>)
(<tf.Tensor: shape=(), dtype=string, numpy=b'../datase

## Model

In [2]:
df = pd.read_csv('../dataset/labels.csv')

In [3]:
id2label = {}
for i,j in df.iterrows():
    if i==0:
        id2label[i] = 'nan'
    else:
        id2label[i] = j['label_list']

In [4]:
label2id = {label : id for id, label in id2label.items()}

In [5]:
print(label2id)
print(len(label2id))

{'nan': 0, 'accessories': 1, 'bag': 2, 'belt': 3, 'blazer': 4, 'blouse': 5, 'bodysuit': 6, 'boots': 7, 'bra': 8, 'bracelet': 9, 'cape': 10, 'cardigan': 11, 'clogs': 12, 'coat': 13, 'dress': 14, 'earrings': 15, 'flats': 16, 'glasses': 17, 'gloves': 18, 'hair': 19, 'hat': 20, 'heels': 21, 'hoodie': 22, 'intimate': 23, 'jacket': 24, 'jeans': 25, 'jumper': 26, 'leggings': 27, 'loafers': 28, 'necklace': 29, 'panties': 30, 'pants': 31, 'pumps': 32, 'purse': 33, 'ring': 34, 'romper': 35, 'sandals': 36, 'scarf': 37, 'shirt': 38, 'shoes': 39, 'shorts': 40, 'skin': 41, 'skirt': 42, 'sneakers': 43, 'socks': 44, 'stockings': 45, 'suit': 46, 'sunglasses': 47, 'sweater': 48, 'sweatshirt': 49, 'swimwear': 50, 't-shirt': 51, 'tie': 52, 'tights': 53, 'top': 54, 'vest': 55, 'wallet': 56, 'watch': 57, 'wedges': 58}
59


In [6]:
model_id = "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
model = TFSegformerForSemanticSegmentation.from_pretrained(
    model_id,
    num_labels = len(label2id),
    id2label = id2label,
    label2id = label2id, 
    ignore_mismatched_sizes = True
)

All model checkpoint layers were used when initializing TFSegformerForSemanticSegmentation.

Some weights of TFSegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b5-finetuned-cityscapes-1024-1024 and are newly initialized because the shapes did not match:
- decode_head/classifier/kernel:0: found shape (1, 1, 768, 19) in the checkpoint and (1, 1, 768, 59) in the model instantiated
- decode_head/classifier/bias:0: found shape (19,) in the checkpoint and (59,) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
model.summary()

Model: "tf_segformer_for_semantic_segmentation"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 segformer (TFSegformerMainL  multiple                 81443008  
 ayer)                                                           
                                                                 
 decode_head (TFSegformerDec  multiple                 3197243   
 odeHead)                                                        
                                                                 
Total params: 84,640,251
Trainable params: 84,638,715
Non-trainable params: 1,536
_________________________________________________________________


In [8]:
H,W = 512, 512
model(tf.zeros([1,3,H,W])).logits.shape

TensorShape([1, 59, 128, 128])

## Training & Evaluation

In [9]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    logits = tf.transpose(logits, perm[0,2,3,1])
    logits_resized = tf.image.resize(
        logits,
        size=tf.shape(labels)[1:],
        method="bilinear"
    )
    pred_labels = tf.argmax(logits_resized, axis=-1)

    metrics = metrics.compute(
        predictions = pred_labels,
        references = labels,
        num_labels = len(label2id),
        ignore_index = 0
    )

    per_category_acuuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()

    metrics.update({f"accuracy_{id2label[i]}" : v for i, v in enumerate(per_category_accuracy)})
    metrics.update({f"iou_{id2label[i]}" : v for i, v in enumerate(per_category_iou)})

    return {"val_" + k: v for k, v in metrics.items()}

In [10]:
batch_size = 1
num_epochs = 15
num_train_steps = len(train_ds)*num_epochs
learning_rate = 6e-5
weight_decay_rate = 0.01

optimizer, lr_schedule = create_optimizer(
    init_lr = learning_rate,
    num_train_steps = num_train_steps,
    weight_decay_rate = weight_decay_rate,
    num_warmup_steps = 0
)


In [11]:
metric_callback = KerasMetricCallback(
    metric_fn = compute_metrics,
    eval_dataset = val_ds,
    batch_size = batch_size,
    label_cols = ["labels"]
)
callbacks = [metric_callback]

In [13]:
model.compile(optimizer=optimizer)

No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


In [14]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=num_epochs,
    callbacks=callbacks,
)

Epoch 1/15


ResourceExhaustedError: 2 root error(s) found.
  (0) RESOURCE_EXHAUSTED:  OOM when allocating tensor with shape[1,2,4096,256] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[node tf_segformer_for_semantic_segmentation/segformer/encoder/block.1.4/attention/self/Softmax
 (defined at c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\tf_utils.py:70)
]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

	 [[tf_segformer_for_semantic_segmentation/sparse_categorical_crossentropy/cond/then/_0/tf_segformer_for_semantic_segmentation/sparse_categorical_crossentropy/cond/cond/then/_59/tf_segformer_for_semantic_segmentation/sparse_categorical_crossentropy/cond/cond/remove_squeezable_dimensions/Equal/_2792]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

  (1) RESOURCE_EXHAUSTED:  OOM when allocating tensor with shape[1,2,4096,256] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[node tf_segformer_for_semantic_segmentation/segformer/encoder/block.1.4/attention/self/Softmax
 (defined at c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\tf_utils.py:70)
]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_161663]

Errors may have originated from an input operation.
Input Source operations connected to node tf_segformer_for_semantic_segmentation/segformer/encoder/block.1.4/attention/self/Softmax:
In[0] tf_segformer_for_semantic_segmentation/segformer/encoder/block.1.4/attention/self/add:

Operation defined at: (most recent call last)
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\runpy.py", line 194, in _run_module_as_main
>>>     return _run_code(code, main_globals, None,
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\runpy.py", line 87, in _run_code
>>>     exec(code, run_globals)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
>>>     app.launch_new_instance()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
>>>     app.start()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelapp.py", line 739, in start
>>>     self.io_loop.start()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\tornado\platform\asyncio.py", line 205, in start
>>>     self.asyncio_loop.run_forever()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\asyncio\base_events.py", line 570, in run_forever
>>>     self._run_once()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\asyncio\base_events.py", line 1859, in _run_once
>>>     handle._run()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\asyncio\events.py", line 81, in _run
>>>     self._context.run(self._callback, *self._args)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
>>>     await self.process_one()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
>>>     await dispatch(*args)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
>>>     await result
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
>>>     await super().execute_request(stream, ident, parent)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
>>>     reply_content = await reply_content
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
>>>     res = shell.run_cell(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
>>>     return super().run_cell(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3009, in run_cell
>>>     result = self._run_cell(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3064, in _run_cell
>>>     result = runner(coro)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
>>>     coro.send(None)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3269, in run_cell_async
>>>     has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3448, in run_ast_nodes
>>>     if await self.run_code(code, result, async_=asy):
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3508, in run_code
>>>     exec(code_obj, self.user_global_ns, self.user_ns)
>>> 
>>>   File "C:\Users\risha\AppData\Local\Temp\ipykernel_17844\2340480664.py", line 1, in <module>
>>>     history = model.fit(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\training.py", line 1216, in fit
>>>     tmp_logs = self.train_function(iterator)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\training.py", line 878, in train_function
>>>     return step_function(self, iterator)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\training.py", line 867, in step_function
>>>     outputs = model.distribute_strategy.run(run_step, args=(data,))
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\training.py", line 860, in run_step
>>>     outputs = model.train_step(data)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\modeling_tf_utils.py", line 1398, in train_step
>>>     y_pred = self(x, training=True)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\modeling_tf_utils.py", line 833, in run_call_with_unpacked_inputs
>>>     """
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 860, in call
>>>     outputs = self.segformer(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\modeling_tf_utils.py", line 833, in run_call_with_unpacked_inputs
>>>     """
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 482, in call
>>>     encoder_outputs = self.encoder(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 414, in call
>>>     for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)):
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 421, in call
>>>     for i, blk in enumerate(block_layer):
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 422, in call
>>>     layer_outputs = blk(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 324, in call
>>>     self_attention_outputs = self.attention(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 232, in call
>>>     self_outputs = self.self(hidden_states, height, width, output_attentions)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 182, in call
>>>     attention_probs = stable_softmax(logits=attention_scores, axis=-1)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\tf_utils.py", line 70, in stable_softmax
>>>     return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
>>> 

Input Source operations connected to node tf_segformer_for_semantic_segmentation/segformer/encoder/block.1.4/attention/self/Softmax:
In[0] tf_segformer_for_semantic_segmentation/segformer/encoder/block.1.4/attention/self/add:

Operation defined at: (most recent call last)
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\runpy.py", line 194, in _run_module_as_main
>>>     return _run_code(code, main_globals, None,
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\runpy.py", line 87, in _run_code
>>>     exec(code, run_globals)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
>>>     app.launch_new_instance()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
>>>     app.start()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelapp.py", line 739, in start
>>>     self.io_loop.start()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\tornado\platform\asyncio.py", line 205, in start
>>>     self.asyncio_loop.run_forever()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\asyncio\base_events.py", line 570, in run_forever
>>>     self._run_once()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\asyncio\base_events.py", line 1859, in _run_once
>>>     handle._run()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\asyncio\events.py", line 81, in _run
>>>     self._context.run(self._callback, *self._args)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
>>>     await self.process_one()
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
>>>     await dispatch(*args)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
>>>     await result
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
>>>     await super().execute_request(stream, ident, parent)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
>>>     reply_content = await reply_content
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
>>>     res = shell.run_cell(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
>>>     return super().run_cell(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3009, in run_cell
>>>     result = self._run_cell(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3064, in _run_cell
>>>     result = runner(coro)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
>>>     coro.send(None)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3269, in run_cell_async
>>>     has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3448, in run_ast_nodes
>>>     if await self.run_code(code, result, async_=asy):
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\IPython\core\interactiveshell.py", line 3508, in run_code
>>>     exec(code_obj, self.user_global_ns, self.user_ns)
>>> 
>>>   File "C:\Users\risha\AppData\Local\Temp\ipykernel_17844\2340480664.py", line 1, in <module>
>>>     history = model.fit(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\training.py", line 1216, in fit
>>>     tmp_logs = self.train_function(iterator)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\training.py", line 878, in train_function
>>>     return step_function(self, iterator)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\training.py", line 867, in step_function
>>>     outputs = model.distribute_strategy.run(run_step, args=(data,))
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\training.py", line 860, in run_step
>>>     outputs = model.train_step(data)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\modeling_tf_utils.py", line 1398, in train_step
>>>     y_pred = self(x, training=True)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\modeling_tf_utils.py", line 833, in run_call_with_unpacked_inputs
>>>     """
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 860, in call
>>>     outputs = self.segformer(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\modeling_tf_utils.py", line 833, in run_call_with_unpacked_inputs
>>>     """
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 482, in call
>>>     encoder_outputs = self.encoder(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 414, in call
>>>     for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)):
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 421, in call
>>>     for i, blk in enumerate(block_layer):
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 422, in call
>>>     layer_outputs = blk(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 324, in call
>>>     self_attention_outputs = self.attention(
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 232, in call
>>>     self_outputs = self.self(hidden_states, height, width, output_attentions)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\engine\base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\keras\utils\traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\models\segformer\modeling_tf_segformer.py", line 182, in call
>>>     attention_probs = stable_softmax(logits=attention_scores, axis=-1)
>>> 
>>>   File "c:\Users\risha\anaconda3\envs\tf_new\lib\site-packages\transformers\tf_utils.py", line 70, in stable_softmax
>>>     return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
>>> 

Function call stack:
train_function -> train_function
