In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information

import ast
import math
import numpy as np

import tensorflow as tf
import tensorflow.keras as keras

import keras.layers as KL

In [2]:
img_size = 512
patch_size=16
frozen_stages = 0
in_chans = 3
embed_dim = [224, 336, 448]
partial_dim = [48, 72, 96] # partial_dim = r*embed_dim with r=1/4.67
qk_dim = [16, 16, 16]
depth = [4, 7, 6]
types = ["i", "s", "s"]
down_ops = [['subsample', 2], ['subsample', 2], ['']]
pretrained = None
distillation = False

train_bn = False

In [3]:
class GroupNorm(KL.Layer):
    """
    This implementation assumes the input tensor shape is [B, H, W, C], which is typical in TensorFlow/Keras, as opposed to [B, C, H, W] in PyTorch.
    The GroupNorm here normalizes over spatial dimensions (height and width) while keeping the channel dimension intact.
    mean and variance are computed across the spatial dimensions.
    """
    def __init__(self, num_channels, num_groups=1, **kwargs):
        super(GroupNorm, self).__init__(**kwargs)
        self.num_channels = num_channels
        self.num_groups = num_groups
        self.epsilon = 1e-5

    def call(self, inputs):
        # Reshape input to (B, H, W, C)
        inputs = tf.convert_to_tensor(inputs)
        B, H, W, C = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2], self.num_channels
        
        # Reshape for group normalization
        inputs = tf.reshape(inputs, (B, H, W, self.num_groups, C // self.num_groups))
        
        # Calculate mean and variance for each group
        mean, variance = tf.nn.moments(inputs, axes=[1, 2, 4], keepdims=True)

        # Normalize
        normalized = (inputs - mean) / tf.sqrt(variance + self.epsilon)

        # Reshape back to original dimensions
        normalized = tf.reshape(normalized, (B, H, W, C))

        return normalized

In [4]:
class Conv2d_BN(KL.Layer):
    def __init__(self, a=0, filters=16, kernel_size=1, strides=1, padding='same', dilation_rate=1, groups=1, use_bn=True, activation="relu"):
        super(Conv2d_BN, self).__init__()
        self.conv = KL.Conv2D(filters, kernel_size, strides=strides, padding=padding, dilation_rate=1, groups=1, use_bias=not use_bn)
        self.bn = KL.BatchNormalization() if use_bn else None
        self.activation = KL.Activation(activation) if activation else None

    def call(self, x, training=False):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x, training=training)
        if self.activation:
            x = self.activation(x)
        return x

In [5]:
# copy from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/helpers.py
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
    min_value = min_value or divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < round_limit * v:
        new_v += divisor
    return new_v

# translated with ChatGPT (original from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/squeeze_excite.py)
class SqueezeExcite(KL.Layer):
    def __init__(self, channels, rd_ratio=1/16, rd_channels=None, rd_divisor=8, add_maxpool=False,
                 bias=True, act_layer="relu", norm_layer=None, gate_layer="sigmoid"):
        super(SqueezeExcite, self).__init__()
        self.add_maxpool = add_maxpool
        if not rd_channels:
            rd_channels = make_divisible(int(channels * rd_ratio), rd_divisor)
        
        self.fc1 = KL.Conv2D(filters=rd_channels, kernel_size=1, use_bias=bias)
        self.bn = norm_layer() if norm_layer else KL.Lambda(lambda x: x)
        self.act = KL.Activation(act_layer)
        self.fc2 = KL.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
        self.gate = KL.Activation(gate_layer)

    def call(self, x):
        x_se = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        if self.add_maxpool:
            x_se = 0.5 * x_se + 0.5 * tf.reduce_max(x, axis=[1, 2], keepdims=True)
        x_se = self.fc1(x_se)
        x_se = self.act(self.bn(x_se))
        x_se = self.fc2(x_se)
        return x * self.gate(x_se)

In [6]:
class PatchMerging(KL.Layer):
    """
    Initializes three Conv2d_BN layers and an activation function.
    The call method processes the input through these layers sequentially, applying ReLU activations and squeeze-and-excitation before the final convolution.
    """
    def __init__(self, dim, out_dim, train_bn=True, **kwargs):
        super(PatchMerging, self).__init__(**kwargs)
        hid_dim = int(dim * 4)
        self.conv1 = Conv2d_BN(a=dim, filters=hid_dim, kernel_size=1, use_bn=train_bn)
        self.act = keras.activations.relu
        self.conv2 = Conv2d_BN(a=hid_dim, filters=hid_dim, kernel_size=3, strides=2, padding="same", groups=hid_dim, use_bn=train_bn)
        self.se = SqueezeExcite(channels=hid_dim, rd_ratio=0.25)
        self.conv3 = Conv2d_BN(a=hid_dim, filters=out_dim, kernel_size=1, use_bn=train_bn)

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.act(x)
        x = self.conv2(x)
        x = self.act(x)
        x = self.se(x)
        x = self.conv3(x)
        return x

In [7]:
class Residual(KL.Layer):
    """
    Initialization: Takes a layer m and a dropout probability drop.

    Forward Pass:
        During training, if dropout is active, a random mask is applied to the output of m.
        If not in training mode, it simply adds the output of m to the input.

    Fusion:
        If m is a Conv2d_BN instance, it fuses the convolution and batch normalization layers.
        An identity tensor is created, padded, and added to the convolution weights.
    """
    def __init__(self, m, drop=0.0, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.m = m
        self.drop = drop

    def call(self, inputs, use_bn=None):
        if use_bn and self.drop > 0:
            # Generate a random mask for dropout
            rand_tensor = tf.random.uniform((tf.shape(inputs)[0], 1, 1, 1), 0, 1)
            mask = tf.cast(rand_tensor >= self.drop, tf.float32) / (1 - self.drop)
            return inputs + self.m(inputs) * mask
        else:
            return inputs + self.m(inputs)

In [8]:
class FFN(tf.keras.layers.Layer):
    """
    Initialization: Initializes two Conv2d_BN layers (pointwise convolutions) and a ReLU activation.
    Forward Pass: Applies the first convolution, then the ReLU activation, and finally the second convolution, returning the output.
    """
    def __init__(self, ed, h, train_bn=True, **kwargs):
        super(FFN, self).__init__(**kwargs)
        self.pw1 = Conv2d_BN(a=ed, filters=h, use_bn=train_bn)  # First pointwise convolution with BN
        self.act = keras.activations.relu  # ReLU activation
        self.pw2 = Conv2d_BN(a=h, filters=ed, use_bn=train_bn)  # Second pointwise convolution with BN

    def call(self, inputs):
        x = self.pw1(inputs)
        x = self.act(x)
        x = self.pw2(x)
        return x

In [9]:
class SHSA(KL.Layer):
    """Single-Head Self-Attention"""
    """
    Initialization: 
        Iitializes scaling factor, dimensions, normalization layer, and the query-key-value (QKV) convolutional layer, along with a projection layer.

    Forward Pass:
        Splits the input into two parts, applies normalization, and computes QKV.
        Flattens Q, K, and V, calculates the attention scores, applies softmax, and reshapes the result.
        Concatenates the processed part with the other part and passes it through the projection layer.
    """
    def __init__(self, dim, qk_dim, pdim, train_bn=True, **kwargs):
        super(SHSA, self).__init__(**kwargs)

        self.scale = qk_dim ** -0.5
        self.qk_dim = qk_dim
        self.dim = dim
        self.pdim = pdim

        self.pre_norm = GroupNorm(num_channels=pdim)  # Assuming GroupNorm is defined
        self.qkv = Conv2d_BN(a=pdim, filters=qk_dim * 2 + pdim, use_bn=train_bn)  # Conv2d_BN layer
        self.proj = tf.keras.Sequential([
            keras.layers.ReLU(),
            Conv2d_BN(a=dim, filters=dim, use_bn=train_bn)  # Another Conv2d_BN layer
        ])

    def call(self, x):
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        x1, x2 = tf.split(x, [self.pdim, self.dim - self.pdim], axis=-1)
        x1 = self.pre_norm(x1)
        
        qkv = self.qkv(x1)
        q, k, v = tf.split(qkv, [self.qk_dim, self.qk_dim, self.pdim], axis=-1)
        
        q = tf.reshape(q, (B, -1, self.qk_dim))
        k = tf.reshape(k, (B, -1, self.qk_dim))
        v = tf.reshape(v, (B, -1, self.pdim))

        attn = tf.matmul(q, k, transpose_b=True) * self.scale
        attn = tf.nn.softmax(attn, axis=-1)

        x1 = tf.reshape(tf.matmul(attn, v), (B, H, W, self.pdim))
        x = self.proj(tf.concat([x1, x2], axis=-1))

        return x

In [10]:
class BasicBlock(KL.Layer):
    """
    Initialization:
        For "s" (later stages): Initializes convolution, self-attention mixer, and feed-forward network (FFN) wrapped in residuals.
        For "i" (early stages): Initializes convolution and FFN as before but uses an identity layer for the mixer.

    Forward Pass:
        Calls the convolution layer, the mixer, and the feed-forward network sequentially, returning the output.
    """
    def __init__(self, dim, qk_dim, pdim, block_type, train_bn=True, **kwargs):
        super(BasicBlock, self).__init__(**kwargs)
        if block_type == "s":  # for later stages
            self.conv = Residual(Conv2d_BN(a=dim, filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, use_bn=train_bn))
            self.mixer = Residual(SHSA(dim=dim, qk_dim=qk_dim, pdim=pdim, train_bn=train_bn))
            self.ffn = Residual(FFN(ed=dim, h=int(dim * 2), train_bn=train_bn))
        elif block_type == "i":  # for early stages
            self.conv = Residual(Conv2d_BN(a=dim, filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, use_bn=train_bn))
            self.mixer = KL.Layer()  # Identity layer
            self.ffn = Residual(FFN(ed=dim, h=int(dim * 2), train_bn=train_bn))

    def call(self, x):
        return self.ffn(self.mixer(self.conv(x)))

In [11]:
# input image [batch, height, width, channels]
input_image = tf.random.normal([1, 512, 512, 3])

blocks1 = []
blocks2 = []
blocks3 = []
outs = []


# 16x16 Overlap PatchEmbed (Fig 2 and Fig 5), in SHViT-pytorch version of the code, each Conv2D is followed with BatchNormalization 
x = Conv2d_BN(filters=embed_dim[0]//8, kernel_size=3, strides=2, padding="same", use_bn=train_bn, activation="ReLU")(input_image)
x = Conv2d_BN(filters=embed_dim[0]//4, kernel_size=3, strides=2, padding="same", use_bn=train_bn, activation="ReLU")(x)
x = Conv2d_BN(filters=embed_dim[0]//2, kernel_size=3, strides=2, padding="same", use_bn=train_bn, activation="ReLU")(x)
outs.append(x)
outs.append(x)
x = Conv2d_BN(filters=embed_dim[0]//1, kernel_size=3, strides=2, padding="same", use_bn=train_bn, activation="ReLU")(x)

x.shape # [1, 32, 32, 224] represents the spatial grid of 16x16 patches (32 patches along each dimension) with each patch transformed into a 224-dimensional vector)

TensorShape([1, 32, 32, 224])

In [12]:
for i, (ed, kd, pd, dpth, do, t) in enumerate(zip(embed_dim, qk_dim, partial_dim, depth, down_ops, types)):
    print (i, ed, kd, pd, dpth, do, t)
    for d in range(dpth):
        eval("blocks" + str(i+1)).append(BasicBlock(ed, kd, pd, t))
    if do[0] == "subsample":
        # Build SHViT downsample block  
        blk = eval("blocks" + str(i+2))
        blk.append(keras.Sequential([
                    Residual(Conv2d_BN(a=embed_dim[i], filters=embed_dim[i], kernel_size=3, strides=1, padding="same", groups=embed_dim[i], use_bn=train_bn)),
                    Residual(FFN(ed=embed_dim[i], h=int(embed_dim[i] * 2), train_bn=train_bn)),
                ]))
        blk.append(PatchMerging(dim=embed_dim[i], out_dim=embed_dim[i + 1], train_bn=train_bn))
        blk.append(keras.Sequential([
                    Residual(Conv2d_BN(a=embed_dim[i + 1], filters=embed_dim[i + 1], kernel_size=3, strides=1, padding="same", groups=embed_dim[i + 1], use_bn=train_bn)),
                    Residual(FFN(ed=embed_dim[i + 1], h=int(embed_dim[i + 1] * 2), train_bn=train_bn)),
                ]))

print(eval('blocks1'))

print(len(blocks1),len(blocks2),len(blocks3)) # 4, 7+3, 6+3

blocks1 = tf.keras.Sequential(blocks1)
blocks2 = tf.keras.Sequential(blocks2)
blocks3 = tf.keras.Sequential(blocks3)

print("block1 in : ", x.shape)
x = blocks1(x)
outs.append(x)
print("block1 out: ", x.shape)
x = blocks2(x)
outs.append(x)
print("block2 out: ", x.shape)
x = blocks3(x)
outs.append(x)
print("block3 out: ", x.shape)

0 224 16 48 4 ['subsample', 2] i
1 336 16 72 7 ['subsample', 2] s
2 448 16 96 6 [''] s
[<__main__.BasicBlock object at 0x7fe5aa99c220>, <__main__.BasicBlock object at 0x7fe71c33e830>, <__main__.BasicBlock object at 0x7fe5aa9e31c0>, <__main__.BasicBlock object at 0x7fe5aa9f28c0>]
4 10 9
block1 in :  (1, 32, 32, 224)
block1 out:  (1, 32, 32, 224)
block2 out:  (1, 16, 16, 336)
block3 out:  (1, 8, 8, 448)


In [13]:
shapes = [tensor.shape for tensor in outs]
print(shapes)

_, C2, C3, C4, C5 = outs

[TensorShape([1, 64, 64, 112]), TensorShape([1, 64, 64, 112]), TensorShape([1, 32, 32, 224]), TensorShape([1, 16, 16, 336]), TensorShape([1, 8, 8, 448])]


### FPN

In [14]:
class Config(object):
    # Size of the top-down layers used to build the feature pyramid
    TOP_DOWN_PYRAMID_SIZE = 256
    
    # The strides of each layer of the FPN Pyramid. These values
    # are based on a SHViT  backbone.
    BACKBONE_STRIDES = [16, 32, 64, 128] # 128 added as they add 64 in the paper (original strides were [4,8,16,32] -> [4,8,16,32,64], maybe because P5 was upsampled by factor 2??)
    
    # Length of square anchor side in pixels
    RPN_ANCHOR_SCALES = (32, 64, 128, 256) 

    # Ratios of anchors at each cell (width/height)
    # A value of 1 represents a square anchor, and 0.5 is a wide anchor
    RPN_ANCHOR_RATIOS = [0.5, 1, 2]

    # Anchor stride
    # If 1 then anchors are created for each cell in the backbone feature map.
    # If 2, then anchors are created for every other cell, and so on.
    RPN_ANCHOR_STRIDE = 1    
    


config = Config()

In [15]:
P4 = KL.Conv2D(filters=config.TOP_DOWN_PYRAMID_SIZE, kernel_size=(1, 1), name='fpn_c4p4')(C4)
P3 = KL.Add(name="fpn_p3add")([
            KL.UpSampling2D(size=(2, 2), name="fpn_p4upsampled")(P4),
            KL.Conv2D(filters=config.TOP_DOWN_PYRAMID_SIZE, kernel_size=(1, 1), name='fpn_c3p3')(C3)])
P2 = KL.Add(name="fpn_p2add")([
            KL.UpSampling2D(size=(2, 2), name="fpn_p3upsampled")(P3),
            KL.Conv2D(filters=config.TOP_DOWN_PYRAMID_SIZE, kernel_size=(1, 1), name='fpn_c2p2')(C2)])

print(P4.shape, P3.shape, P2.shape)

(1, 16, 16, 256) (1, 32, 32, 256) (1, 64, 64, 256)


In [16]:
# Attach 3x3 conv to all P layers to get the final feature maps.

P2 = KL.Conv2D(filters=config.TOP_DOWN_PYRAMID_SIZE, kernel_size=(3, 3), padding="SAME", name="fpn_p2")(P2)
P3 = KL.Conv2D(filters=config.TOP_DOWN_PYRAMID_SIZE, kernel_size=(3, 3), padding="SAME", name="fpn_p3")(P3)
P4 = KL.Conv2D(filters=config.TOP_DOWN_PYRAMID_SIZE, kernel_size=(3, 3), padding="SAME", name="fpn_p4")(P4)
P5 = KL.MaxPooling2D(pool_size=(1, 1), strides=2, name="fpn_p5")(P4) # Here we introduce P5 only for covering a larger anchor scale of 256^2. P5 is simply a stride two subsampling of P4. (footnote page 4 of https://arxiv.org/pdf/1612.03144)

rpn_feature_maps = [P2, P3, P4, P5]
mrcnn_feature_maps = [P2, P3, P4]

print(P5.shape, P4.shape, P3.shape, P2.shape)

(1, 8, 8, 256) (1, 16, 16, 256) (1, 32, 32, 256) (1, 64, 64, 256)


In [17]:
image_shape = input_image[0, :, :, :].shape

In [18]:
def compute_backbone_shapes(config, image_shape):
    """Computes the width and height of each stage of the backbone network.

    Returns:
        [N, (height, width)]. Where N is the number of stages
    """

    x = np.array(
        [[int(math.ceil(image_shape[0] / stride)),
            int(math.ceil(image_shape[1] / stride))]
            for stride in config.BACKBONE_STRIDES])

    return x

backbone_shapes = compute_backbone_shapes(config, image_shape)
print("backbone_shapes: ", backbone_shapes)

backbone_shapes:  [[32 32]
 [16 16]
 [ 8  8]
 [ 4  4]]


In [19]:
############################################################
#  Anchors
############################################################

def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
    """
    scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
    ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
    shape: [height, width] spatial shape of the feature map over which
            to generate anchors.
    feature_stride: Stride of the feature map relative to the image in pixels.
    anchor_stride: Stride of anchors on the feature map. For example, if the
        value is 2 then generate anchors for every other feature map pixel.
    """
    # Get all combinations of scales and ratios
    scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
    scales = scales.flatten()
    ratios = ratios.flatten()

    # Enumerate heights and widths from scales and ratios
    heights = scales / np.sqrt(ratios)
    widths = scales * np.sqrt(ratios)

    # Enumerate shifts in feature space
    shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
    shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
    shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)

    # Enumerate combinations of shifts, widths, and heights
    box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
    box_heights, box_centers_y = np.meshgrid(heights, shifts_y)

    # Reshape to get a list of (y, x) and a list of (h, w)
    box_centers = np.stack(
        [box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
    box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])

    # Convert to corner coordinates (y1, x1, y2, x2)
    boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                            box_centers + 0.5 * box_sizes], axis=1)
    return boxes


def generate_pyramid_anchors(scales, ratios, feature_shapes, feature_strides,
                             anchor_stride):
    """Generate anchors at different levels of a feature pyramid. Each scale
    is associated with a level of the pyramid, but each ratio is used in
    all levels of the pyramid.

    Returns:
    anchors: [N, (y1, x1, y2, x2)]. All generated anchors in one array. Sorted
        with the same order of the given scales. So, anchors of scale[0] come
        first, then anchors of scale[1], and so on.
    """
    # Anchors
    # [anchor_count, (y1, x1, y2, x2)]
    print(len(scales))
    anchors = []
    for i in range(len(scales)):
        #print(i, anchors)
        anchors.append(generate_anchors(scales[i], ratios, feature_shapes[i],
                                        feature_strides[i], anchor_stride))
    return np.concatenate(anchors, axis=0)


a = generate_pyramid_anchors(
                config.RPN_ANCHOR_SCALES,
                config.RPN_ANCHOR_RATIOS,
                backbone_shapes,
                config.BACKBONE_STRIDES,
                config.RPN_ANCHOR_STRIDE)

4


In [20]:
print("RPN_ANCHOR_SCALES", config.RPN_ANCHOR_SCALES)
print("RPN_ANCHOR_RATIOS", config.RPN_ANCHOR_RATIOS)
print("backbone_shapes", backbone_shapes)
print("BACKBONE_STRIDES", config.BACKBONE_STRIDES)
print("RPN_ANCHOR_STRIDE", config.RPN_ANCHOR_STRIDE)

RPN_ANCHOR_SCALES (32, 64, 128, 256)
RPN_ANCHOR_RATIOS [0.5, 1, 2]
backbone_shapes [[32 32]
 [16 16]
 [ 8  8]
 [ 4  4]]
BACKBONE_STRIDES [16, 32, 64, 128]
RPN_ANCHOR_STRIDE 1


In [21]:
len(a)

4080