In [1]:
# # https://github.com/maxw1489/Mask_RCNN (tensorflow 2.9.1)

In [2]:
import torch
import torch.nn as nn

from timm.layers import SqueezeExcite

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
img_size = 256
patch_size=16
frozen_stages = 0
in_chans = 3
embed_dim = [224, 336, 448]
partial_dim = [48, 72, 96] # partial_dim = r*embed_dim with r=1/4.67
qk_dim = [16, 16, 16]
depth = [4, 7, 6]
types = ["i", "s", "s"]
down_ops = [['subsample', 2], ['subsample', 2], ['']]
pretrained = None
distillation = False

In [4]:
class GroupNorm(torch.nn.GroupNorm):
    """
    Group Normalization with 1 group.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, **kwargs):
        super().__init__(1, num_channels, **kwargs)

In [5]:
class Conv2d_BN(torch.nn.Sequential):  
    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
                 groups=1, bn_weight_init=1):
        super().__init__()
        self.add_module('c', torch.nn.Conv2d(
            a, b, ks, stride, pad, dilation, groups, bias=False))
        self.add_module('bn', torch.nn.BatchNorm2d(b))
        torch.nn.init.constant_(self.bn.weight, bn_weight_init)
        torch.nn.init.constant_(self.bn.bias, 0)

    @torch.no_grad()
    def fuse(self):
        c, bn = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = c.weight * w[:, None, None, None]
        b = bn.bias - bn.running_mean * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

In [6]:
class BN_Linear(torch.nn.Sequential):
    def __init__(self, a, b, bias=True, std=0.02):
        super().__init__()
        self.add_module('bn', torch.nn.BatchNorm1d(a))
        self.add_module('l', torch.nn.Linear(a, b, bias=bias))
        trunc_normal_(self.l.weight, std=std)
        if bias:
            torch.nn.init.constant_(self.l.bias, 0)

    @torch.no_grad()
    def fuse(self):
        bn, l = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        b = bn.bias - self.bn.running_mean * \
            self.bn.weight / (bn.running_var + bn.eps)**0.5
        w = l.weight * w[None, :]
        if l.bias is None:
            b = b @ self.l.weight.T
        else:
            b = (l.weight @ b[:, None]).view(-1) + self.l.bias
        m = torch.nn.Linear(w.size(1), w.size(0))
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

In [7]:
class PatchMerging(torch.nn.Module):
    def __init__(self, dim, out_dim):
        super().__init__()
        hid_dim = int(dim * 4)
        self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0)
        self.act = torch.nn.ReLU()
        self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim)
        self.se = SqueezeExcite(hid_dim, .25)
        self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0)

    def forward(self, x):
        x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
        return x

In [8]:
class Residual(torch.nn.Module):
    def __init__(self, m, drop=0.):
        super().__init__()
        self.m = m
        self.drop = drop

    def forward(self, x):
        if self.training and self.drop > 0:
            return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
                                              device=x.device).ge_(self.drop).div(1 - self.drop).detach()
        else:
            return x + self.m(x)
    
    @torch.no_grad()
    def fuse(self):
        if isinstance(self.m, Conv2d_BN):
            m = self.m.fuse()
            assert(m.groups == m.in_channels)
            identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
            identity = torch.nn.functional.pad(identity, [1,1,1,1])
            m.weight += identity.to(m.weight.device)
            return m
        else:
            return self

In [9]:
class FFN(torch.nn.Module):
    def __init__(self, ed, h):
        super().__init__()
        self.pw1 = Conv2d_BN(ed, h)
        self.act = torch.nn.ReLU()
        self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0)

    def forward(self, x):
        x = self.pw2(self.act(self.pw1(x)))
        return x

In [10]:
class SHSA(torch.nn.Module):
    """Single-Head Self-Attention"""
    def __init__(self, dim, qk_dim, pdim):
        super().__init__()
        self.scale = qk_dim ** -0.5
        self.qk_dim = qk_dim
        self.dim = dim
        self.pdim = pdim

        self.pre_norm = GroupNorm(pdim)

        self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)
        self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
            dim, dim, bn_weight_init = 0))
        

    def forward(self, x):
        B, C, H, W = x.shape
        x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim = 1)
        x1 = self.pre_norm(x1)
        qkv = self.qkv(x1)
        q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim = 1)
        q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
        
        attn = (q.transpose(-2, -1) @ k) * self.scale
        attn = attn.softmax(dim = -1)
        x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)
        x = self.proj(torch.cat([x1, x2], dim = 1))

        return x

In [11]:
class BasicBlock(torch.nn.Module):
    def __init__(self, dim, qk_dim, pdim, type):
        super().__init__()
        if type == "s":    # for later stages
            self.conv = Residual(Conv2d_BN(dim, dim, 3, 1, 1, groups = dim, bn_weight_init = 0))
            self.mixer = Residual(SHSA(dim, qk_dim, pdim))
            self.ffn = Residual(FFN(dim, int(dim * 2)))
        elif type == "i":   # for early stages
            self.conv = Residual(Conv2d_BN(dim, dim, 3, 1, 1, groups = dim, bn_weight_init = 0))
            self.mixer = torch.nn.Identity()
            self.ffn = Residual(FFN(dim, int(dim * 2)))
    
    def forward(self, x):
        return self.ffn(self.mixer(self.conv(x)))

In [12]:
input_image = torch.randn(1, 3, 512, 512)

patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1)
                           )

x = patch_embed(input_image)
print("patch_embed: ", x.shape)

patch_embed:  torch.Size([1, 224, 32, 32])


In [13]:
blocks1 = []
blocks2 = []
blocks3 = []
outs = []


for i, (ed, kd, pd, dpth, do, t) in enumerate(zip(embed_dim, qk_dim, partial_dim, depth, down_ops, types)):
    print (i, ed, kd, pd, dpth, do, t)
    for d in range(dpth):
        eval('blocks' + str(i+1)).append(BasicBlock(ed, kd, pd, t))
    if do[0] == 'subsample':
                # Build SHViT downsample block
                #('Subsample' stride)
                blk = eval('blocks' + str(i+2))
                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
                                    Residual(FFN(embed_dim[i], int(embed_dim[i] * 2))),))
                blk.append(PatchMerging(*embed_dim[i:i + 2]))
                
                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])),
                                    Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2))),))

# print(eval('blocks1'))

print(len(blocks1),len(blocks2),len(blocks3)) # 4, 7+3, 6+3

blocks1 = torch.nn.Sequential(*blocks1)
blocks2 = torch.nn.Sequential(*blocks2)
blocks3 = torch.nn.Sequential(*blocks3)

print("block1 in : ", x.shape)
x = blocks1(x)
outs.append(x)
print("block1 out: ", x.shape)
print("block2 in : ", x.shape)
x = blocks2(x)
outs.append(x)
print("block2 out: ", x.shape)
print("block3 in : ", x.shape)
x = blocks3(x)
outs.append(x)
print("block3 out: ", x.shape)

0 224 16 48 4 ['subsample', 2] i
1 336 16 72 7 ['subsample', 2] s
2 448 16 96 6 [''] s
4 10 9
block1 in :  torch.Size([1, 224, 32, 32])
block1 out:  torch.Size([1, 224, 32, 32])
block2 in :  torch.Size([1, 224, 32, 32])
block2 out:  torch.Size([1, 336, 16, 16])
block3 in :  torch.Size([1, 336, 16, 16])
block3 out:  torch.Size([1, 448, 8, 8])


In [14]:
512//16, 512//32, 512//64

(32, 16, 8)

In [15]:
shapes = [tensor.shape for tensor in outs]
print(shapes)

[torch.Size([1, 224, 32, 32]), torch.Size([1, 336, 16, 16]), torch.Size([1, 448, 8, 8])]
