In [None]:
import keras
from keras import ops
import math

class Attention(keras.layers.Layer):
    def __init__(
        self,
        dim,
        num_heads,
        sr_ratio,
        qkv_bias=False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = self.dim // self.num_heads

        self.units = self.num_heads * self.head_dim
        self.sqrt_of_units = math.sqrt(self.head_dim)

        self.q = keras.layers.Dense(self.units)
        self.k = keras.layers.Dense(self.units)
        self.v = keras.layers.Dense(self.units)

        self.attn_drop = keras.layers.Dropout(attn_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = keras.layers.Conv2D(
                filters=dim, kernel_size=sr_ratio, strides=sr_ratio, name='sr',
            )
            self.norm = keras.layers.LayerNormalization(epsilon=1e-05)
           
        self.proj = keras.layers.Dense(dim)
        self.proj_drop = keras.layers.Dropout(proj_drop)

    def call(
        self,
        x,
        H,
        W,
    ):
        get_shape = ops.shape(x)
        B = get_shape[0]
        C = get_shape[2]

        q = self.q(x)
        q = ops.reshape(
            q, (ops.shape(q)[0], -1, self.num_heads, self.head_dim)
        )
        q = ops.transpose(q, axes=[0, 2, 1, 3])

        if self.sr_ratio > 1:
            x = ops.reshape(x, (B, H, W, C))
            x = self.sr(x)
            x = ops.reshape(x, (B, -1, C))
            x = self.norm(x)

        k = self.k(x)
        k = ops.reshape(
            k, (ops.shape(k)[0], -1, self.num_heads, self.head_dim)
        )
        k = ops.transpose(k, axes=[0, 2, 1, 3])

        v = self.v(x)
        v = ops.reshape(
            v, (ops.shape(v)[0], -1, self.num_heads, self.head_dim)
        )
        v = ops.transpose(v, axes=[0, 2, 1, 3])

        attn = ops.matmul(q, ops.transpose(k, axes=[0, 1, 3, 2]))
        scale = ops.cast(self.sqrt_of_units, dtype=attn.dtype)
        attn = ops.divide(attn, scale)

        attn = ops.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)
        x = ops.matmul(attn, v)
        x = ops.transpose(x, axes=[0, 2, 1, 3])
        x = ops.reshape(x, (B, -1, self.units))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [None]:
import keras
from keras import ops

class MLP(keras.layers.Layer):
    def __init__(self, decode_dim):
        super().__init__()
        self.proj = keras.layers.Dense(decode_dim)

    def call(self, x):
        x = self.proj(x)
        return x


class ConvModule(keras.layers.Layer):
    def __init__(self, decode_dim):
        super().__init__()
        self.conv = keras.layers.Conv2D(
            filters=decode_dim, kernel_size=1, use_bias=False
        )
        self.bn = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9)
        self.activate = keras.layers.ReLU()

    def call(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activate(x)
        return x


class SegFormerHead(keras.layers.Layer):
    def __init__(self, num_mlp_layers=4, decode_dim=768, num_classes=19):
        super().__init__()

        self.linear_layers = []
        for _ in range(num_mlp_layers):
            self.linear_layers.append(MLP(decode_dim))

        self.linear_fuse = ConvModule(decode_dim)
        self.dropout = keras.layers.Dropout(0.1)
        self.linear_pred = keras.layers.Conv2D(num_classes, kernel_size=1)

    def call(self, inputs):
        H = ops.shape(inputs[0])[1]
        W = ops.shape(inputs[0])[2]
        outputs = []

        for x, mlps in zip(inputs, self.linear_layers):
            x = mlps(x)
            x = ops.image.resize(x, size=(H, W), interpolation="bilinear")
            outputs.append(x)

        x = self.linear_fuse(ops.concatenate(outputs[::-1], axis=3))
        x = self.dropout(x)
        x = self.linear_pred(x)

        return x

In [None]:
import keras
from keras import ops


class DWConv(keras.layers.Layer):
    def __init__(self, filters=768, **kwargs):
        super().__init__(**kwargs)
        self.dwconv = keras.layers.Conv2D(
            filters=filters,
            kernel_size=3,
            strides=1,
            padding="same",
            groups=filters,
        )

    def call(self, x, H, W):
        get_shape_1 = ops.shape(x)
        x = ops.reshape(x, (get_shape_1[0], H, W, get_shape_1[-1]))
        x = self.dwconv(x)
        get_shape_2 = ops.shape(x)
        x = ops.reshape(
            x, (get_shape_2[0], get_shape_2[1] * get_shape_2[2], get_shape_2[3])
        )
        return x


class Mlp(keras.layers.Layer):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = keras.layers.Dense(hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = keras.layers.Activation("gelu")
        self.fc2 = keras.layers.Dense(out_features)
        self.drop = keras.layers.Dropout(drop)

    def call(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H=H, W=W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(keras.layers.Layer):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        sr_ratio=1,
    ):
        super().__init__()
        self.norm1 = keras.layers.LayerNormalization(epsilon=1e-05)
        self.attn = Attention(
            dim,
            num_heads,
            sr_ratio,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path)
        self.norm2 = keras.layers.LayerNormalization(epsilon=1e-05)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            drop=drop,
        )

    def call(self, x, H, W):
        # Apply LayerNormalization and Attention layer
        attn_output_norm = self.norm1(x)
        attn_output = self.attn(attn_output_norm, H=H, W=W)
        attn_output_with_drop = self.drop_path(attn_output)
        x = x + attn_output_with_drop

        # Apply LayerNormalization and MLP layer
        mlp_output_norm = self.norm2(x)
        mlp_output = self.mlp(mlp_output_norm, H=H, W=W)
        mlp_output_with_drop = self.drop_path(mlp_output)
        x = x + mlp_output_with_drop

        return x



class OverlapPatchEmbed(keras.layers.Layer):
    def __init__(
        self, img_size=224, patch_size=7, stride=4, filters=768, **kwargs
    ):
        super().__init__(**kwargs)
        self.pad = keras.layers.ZeroPadding2D(padding=patch_size // 2)
        self.conv = keras.layers.Conv2D(
            filters=filters,
            kernel_size=patch_size,
            strides=stride,
            padding="VALID",
            name='proj',
        )
        self.norm = keras.layers.LayerNormalization(epsilon=1e-05)

    def call(self, x):
        x = self.conv(self.pad(x))
        get_shapes = ops.shape(x)
        H = get_shapes[1]
        W = get_shapes[2]
        C = get_shapes[3]
        x = ops.reshape(x, (-1, H * W, C))
        x = self.norm(x)
        return x, H, W


class MixVisionTransformer(keras.layers.Layer):
    def __init__(
        self,
        img_size=224,
        embed_dims=[64, 128, 256, 512],
        num_heads=[1, 2, 4, 8],
        mlp_ratios=[4, 4, 4, 4],
        qkv_bias=False,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        depths=[3, 4, 6, 3],
        sr_ratios=[8, 4, 2, 1],
    ):
        super().__init__()
        self.depths = depths
        # patch_embed
        self.patch_embed1 = OverlapPatchEmbed(
            img_size=img_size,
            patch_size=7,
            stride=4,
            filters=embed_dims[0],
        )
        self.patch_embed2 = OverlapPatchEmbed(
            img_size=img_size // 4,
            patch_size=3,
            stride=2,
            filters=embed_dims[1],
        )
        self.patch_embed3 = OverlapPatchEmbed(
            img_size=img_size // 8,
            patch_size=3,
            stride=2,
            filters=embed_dims[2],
        )
        self.patch_embed4 = OverlapPatchEmbed(
            img_size=img_size // 16,
            patch_size=3,
            stride=2,
            filters=embed_dims[3],
        )

        dpr = [x for x in ops.linspace(0.0, drop_path_rate, sum(depths))]
        cur = 0
        self.block1 = [
            Block(
                dim=embed_dims[0],
                num_heads=num_heads[0],
                mlp_ratio=mlp_ratios[0],
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[cur + i],
                sr_ratio=sr_ratios[0],
            )
            for i in range(depths[0])
        ]
        self.norm1 = keras.layers.LayerNormalization(epsilon=1e-05)

        cur += depths[0]
        self.block2 = [
            Block(
                dim=embed_dims[1],
                num_heads=num_heads[1],
                mlp_ratio=mlp_ratios[1],
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[cur + i],
                sr_ratio=sr_ratios[1],
            )
            for i in range(depths[1])
        ]
        self.norm2 = keras.layers.LayerNormalization(epsilon=1e-05)

        cur += depths[1]
        self.block3 = [
            Block(
                dim=embed_dims[2],
                num_heads=num_heads[2],
                mlp_ratio=mlp_ratios[2],
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[cur + i],
                sr_ratio=sr_ratios[2],
            )
            for i in range(depths[2])
        ]
        self.norm3 = keras.layers.LayerNormalization(epsilon=1e-05)

        cur += depths[2]
        self.block4 = [
            Block(
                dim=embed_dims[3],
                num_heads=num_heads[3],
                mlp_ratio=mlp_ratios[3],
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[cur + i],
                sr_ratio=sr_ratios[3],
            )
            for i in range(depths[3])
        ]
        self.norm4 = keras.layers.LayerNormalization(epsilon=1e-05)

    def call_features(self, x):
        B = ops.shape(x)[0]
        outs = []

        # stage 1
        x, H, W = self.patch_embed1(x)
        for i, blk in enumerate(self.block1):
            x = blk(x, H=H, W=W)
        x = self.norm1(x)
        x = ops.reshape(x, (B, H, W, ops.shape(x)[-1]))
        outs.append(x)

        # stage 2
        x, H, W = self.patch_embed2(x)
        for i, blk in enumerate(self.block2):
            x = blk(x, H=H, W=W)
        x = self.norm2(x)
        x = ops.reshape(x, (B, H, W, ops.shape(x)[-1]))
        outs.append(x)

        # stage 3
        x, H, W = self.patch_embed3(x)
        for i, blk in enumerate(self.block3):
            x = blk(x, H=H, W=W)
        x = self.norm3(x)
        x = ops.reshape(x, (B, H, W, ops.shape(x)[-1]))
        outs.append(x)

        # stage 4
        x, H, W = self.patch_embed4(x)
        for i, blk in enumerate(self.block4):
            x = blk(x, H=H, W=W)
        x = self.norm4(x)
        x = ops.reshape(x, (B, H, W, ops.shape(x)[-1]))
        outs.append(x)

        return outs

    def call(self, x):
        x = self.call_features(x)
        return x

In [None]:
import keras
from keras import ops


MODEL_CONFIGS = {
    "mit_b0": {
        "embed_dims": [32, 64, 160, 256],
        "depths": [2, 2, 2, 2],
        "decode_dim": 256,
    },
    "mit_b1": {
        "embed_dims": [64, 128, 320, 512],
        "depths": [2, 2, 2, 2],
        "decode_dim": 256,
    },
    "mit_b2": {
        "embed_dims": [64, 128, 320, 512],
        "depths": [3, 4, 6, 3],
        "decode_dim": 768,
    },
    "mit_b3": {
        "embed_dims": [64, 128, 320, 512],
        "depths": [3, 4, 18, 3],
        "decode_dim": 768,
    },
    "mit_b4": {
        "embed_dims": [64, 128, 320, 512],
        "depths": [3, 8, 27, 3],
        "decode_dim": 768,
    },
    "mit_b5": {
        "embed_dims": [64, 128, 320, 512],
        "depths": [3, 6, 40, 3],
        "decode_dim": 768,
    },
}


def SegFormer_B0(input_shape, num_classes):
    input_layer = keras.layers.Input(shape=input_shape)
    x = MixVisionTransformer(
        img_size=input_shape[1],
        embed_dims=MODEL_CONFIGS["mit_b0"]["embed_dims"],
        depths=MODEL_CONFIGS["mit_b0"]["depths"],
    )(input_layer)
    x = SegFormerHead(
        num_classes=num_classes,
        decode_dim=MODEL_CONFIGS["mit_b0"]["decode_dim"],
    )(x)

    x = ResizeLayer(input_shape[0], input_shape[1])(x)
    x = ops.softmax(x)
    return keras.Model(inputs=input_layer, outputs=x)


def SegFormer_B1(input_shape, num_classes):
    input_layer = keras.layers.Input(shape=input_shape)
    x = MixVisionTransformer(
        img_size=input_shape[1],
        embed_dims=MODEL_CONFIGS["mit_b1"]["embed_dims"],
        depths=MODEL_CONFIGS["mit_b1"]["depths"],
    )(input_layer)
    x = SegFormerHead(
        num_classes=num_classes,
        decode_dim=MODEL_CONFIGS["mit_b1"]["decode_dim"],
    )(x)

    x = ResizeLayer(input_shape[0], input_shape[1])(x)
    x = ops.softmax(x)
    return keras.Model(inputs=input_layer, outputs=x)


def SegFormer_B2(input_shape, num_classes):
    input_layer = keras.layers.Input(shape=input_shape)
    x = MixVisionTransformer(
        img_size=input_shape[1],
        embed_dims=MODEL_CONFIGS["mit_b2"]["embed_dims"],
        depths=MODEL_CONFIGS["mit_b2"]["depths"],
    )(input_layer)
    x = SegFormerHead(
        num_classes=num_classes,
        decode_dim=MODEL_CONFIGS["mit_b2"]["decode_dim"],
    )(x)

    x = ResizeLayer(input_shape[0], input_shape[1])(x)
    x = ops.softmax(x)
    return keras.Model(inputs=input_layer, outputs=x)


def SegFormer_B3(input_shape, num_classes):
    input_layer = keras.layers.Input(shape=input_shape)
    x = MixVisionTransformer(
        img_size=input_shape[1],
        embed_dims=MODEL_CONFIGS["mit_b3"]["embed_dims"],
        depths=MODEL_CONFIGS["mit_b3"]["depths"],
    )(input_layer)
    x = SegFormerHead(
        num_classes=num_classes,
        decode_dim=MODEL_CONFIGS["mit_b3"]["decode_dim"],
    )(x)

    x = ResizeLayer(input_shape[0], input_shape[1])(x)
    x = ops.softmax(x)
    return keras.Model(inputs=input_layer, outputs=x)


def SegFormer_B4(input_shape, num_classes):
    input_layer = keras.layers.Input(shape=input_shape)
    x = MixVisionTransformer(
        img_size=input_shape[1],
        embed_dims=MODEL_CONFIGS["mit_b4"]["embed_dims"],
        depths=MODEL_CONFIGS["mit_b4"]["depths"],
    )(input_layer)
    x = SegFormerHead(
        num_classes=num_classes,
        decode_dim=MODEL_CONFIGS["mit_b4"]["decode_dim"],
    )(x)

    x = ResizeLayer(input_shape[0], input_shape[1])(x)
    x = ops.softmax(x)
    return keras.Model(inputs=input_layer, outputs=x)


def SegFormer_B5(input_shape, num_classes):
    input_layer = keras.layers.Input(shape=input_shape)
    x = MixVisionTransformer(
        img_size=input_shape[1],
        embed_dims=MODEL_CONFIGS["mit_b5"]["embed_dims"],
        depths=MODEL_CONFIGS["mit_b5"]["depths"],
    )(input_layer)
    x = SegFormerHead(
        num_classes=num_classes,
        decode_dim=MODEL_CONFIGS["mit_b5"]["decode_dim"],
    )(x)

    x = ResizeLayer(input_shape[0], input_shape[1])(x)
    x = ops.softmax(x)
    return keras.Model(inputs=input_layer, outputs=x)

In [None]:
import keras
from keras import ops
import tensorflow as tf

class ResizeLayer(keras.layers.Layer):
    def __init__(self, height, width, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.height = height
        self.width = width

    def call(self, inputs):
        resized = ops.image.resize(
            inputs,
            size=(self.height, self.width),
            interpolation="bilinear",
        )
        return resized


class DropPath(keras.layers.Layer):
    def __init__(self, drop_path, **kwargs):
        super().__init__(**kwargs)
        self.drop_path = drop_path

    def call(self, x, training=None):
        if training:
            keep_prob = tf.cast(1.0 - self.drop_path, dtype=x.dtype)
            shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=x.dtype)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x
    
    def call(self, x, training=None):
        if training:
            keep_prob = tf.cast(1.0 - self.drop_path, dtype=x.dtype)
            shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=x.dtype)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x
        
