In [1]:
from segmentation_models.transformers import Swin_Unet
import tensorflow as tf
import numpy as np
import cv2
from segmentation_models import blocks

In [2]:
model = Swin_Unet(
    (256, 256, 3),
    classes=1,
    patch_size=(2, 2),
    embed_dim=32,
    window_size=8,
    depths=[2, 2, 2, 4, 2],
    num_heads=[3, 6, 12, 12, 24],
)


In [3]:
model.summary()

Model: "swin_tiny_224"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 patch_embed (PatchEmbed)       (None, 16384, 32)    480         ['input_1[0][0]']                
                                                                                                  
 dropout (Dropout)              (None, 16384, 32)    0           ['patch_embed[0][0]']            
                                                                                                  
 stage_0_block_0 (SwinTransform  (None, 16384, 32)   17475       ['dropout[0][0]']    

In [4]:
cloned = tf.keras.models.clone_model(model)

### Verify that metrics are consistent

In [59]:
# IoU

for i in range(50):

    a = tf.random.uniform((4, 256, 256, 8), 0, 1, dtype='float32')
    b = tf.random.uniform((4, 256, 256, 8), 0, 1, dtype='float32', seed=i)

    a = tf.argmax(a, axis=-1)
    b = tf.argmax(b, axis=-1)

    tf_iou = tf.keras.metrics.MeanIoU(num_classes=8)
    tf_iou.update_state(a, b)
    tf_iou = tf_iou.result().numpy()

    a = tf.cast(tf.one_hot(a, depth=8, on_value=1, off_value=0), dtype='float32')
    b = tf.cast(tf.one_hot(b, depth=8, on_value=1, off_value=0), dtype='float32')
    
    my_iou = iou_score(a, b).numpy()

    if np.abs(tf_iou - my_iou) < 1e-6:
        continue
    else:
        raise ValueError('losses are not equal')

print('losses are equal')

0.06696768 0.06696768
0.06632534 0.06632534
0.065832615 0.065832615
0.067066774 0.067066774
0.06708064 0.06708064
0.066732034 0.066732034
0.06698288 0.06698288
0.065923944 0.065923944
0.066766985 0.066766985
0.066738255 0.066738255
0.066242844 0.066242844
0.06594121 0.06594121
0.06622486 0.06622486
0.066171244 0.066171244
0.06657505 0.06657505
0.066771105 0.066771105
0.06678361 0.06678361
0.066918805 0.066918805
0.06667042 0.06667042
0.066364154 0.066364154
0.066614285 0.066614285
0.0668468 0.0668468
0.06674723 0.06674723
0.06718521 0.06718521
0.06607112 0.06607112
0.0665019 0.0665019
0.06615333 0.06615333
0.06625714 0.06625714
0.06668316 0.06668316
0.067141294 0.067141294
0.06671895 0.06671895
0.06707591 0.06707591
0.06604767 0.06604767
0.06654802 0.06654802
0.06652586 0.06652586
0.06661217 0.06661217
0.06730746 0.06730746
0.06735326 0.06735326
0.066957176 0.066957176
0.06665303 0.06665303
0.06709623 0.06709623
0.06653639 0.06653639
0.06638626 0.06638626
0.066337615 0.066337615
0.0670

In [3]:
from segmentation_models.convnext import UNext

In [4]:
model = UNext()

In [5]:
model.summary()

Model: "ConvNextUnet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Input (InputLayer)             [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 PatchEmbedding (Conv2D)        (None, 64, 64, 96)   4704        ['Input[0][0]']                  
                                                                                                  
 first_LayerNorm (LayerNormaliz  (None, 64, 64, 96)  192         ['PatchEmbedding[0][0]']         
 ation)                                                                                           
                                                                                       