In [1]:
import torch
from torch import nn
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
import os
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchvision import models
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from utils import *
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.nn import functional as F
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [2]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [3]:
class Attention(nn.Module):
    def __init__(self, dim, factor, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim*factor),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [57]:
class SwinTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3,
                 embed_dim=96, depths=[2, 2, 6], num_heads=[3, 6, 12],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        
        super().__init__()
        
        patches_resolution = [img_size // patch_size, img_size // patch_size]
        num_patches = patches_resolution[0] * patches_resolution[1]
        
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample= None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}


In [58]:
SwinTransformer()

SwinTransformer(
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=96, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=96, input_resolution=(56, 56), num_heads=3, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=96, window_size=(7, 7), num_heads=3
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): G

In [35]:
class SkipCLS(nn.Module):
    def __init__(self, dim, factor):
        super().__init__()
        self.dim = dim
        self.factor = factor
        
        self.multi_head_attention = Attention(dim=self.dim, factor=self.factor)
        
    def forward(self,embeddings, CLS):
        in_MSA = torch.cat((CLS, embeddings), dim = 1)
        MSA_out = self.multi_head_attention(in_MSA)
        return MSA_out[:,0].unsqueeze(1)

In [36]:
from ml_collections import config_dict

# # # Resnet 18
# cfg = config_dict.ConfigDict()
# cfg.cnn_backbone = "resnet18"
# cfg.cnn_pyramid_fm  = [64,128,256,512]
# cfg.swin_pyramid_fm = [96, 192, 384, 768]
# cfg.image_size = 224
# cfg.patch_size = 4
# cfg.num_classes = 1000

# Resnet 50
# cfg = config_dict.ConfigDict()
# cfg.cnn_backbone = "resnet50"
# cfg.cnn_pyramid_fm  = [256,512,1024,2048]
# cfg.swin_pyramid_fm = [96, 192, 384, 768]
# cfg.image_size = 224
# cfg.patch_size = 4
# cfg.num_classes = 1000

# # # Resnet 34
cfg = config_dict.ConfigDict()
cfg.cnn_backbone = "resnet34"
cfg.cnn_pyramid_fm  = [64,128,256,512]
cfg.swin_pyramid_fm = [96, 192, 384, 768]
cfg.image_size = 224
cfg.patch_size = 4
cfg.num_classes = 1000

In [60]:
model_path = './weights/swin_tiny_patch4_window7_224.pth'
swin_transformer = SwinTransformer(224,in_chans = 3)
checkpoint = torch.load(model_path, map_location=torch.device(device))['model']
unexpected = ["patch_embed.proj.weight", "patch_embed.proj.bias", "patch_embed.norm.weight", "patch_embed.norm.bias",
             "head.weight", "head.bias", "layers.0.downsample.norm.weight", "layers.0.downsample.norm.bias",
             "layers.0.downsample.reduction.weight", "layers.1.downsample.norm.weight", "layers.1.downsample.norm.bias",
             "layers.1.downsample.reduction.weight", "layers.2.downsample.norm.weight", "layers.2.downsample.norm.bias",
             "layers.2.downsample.reduction.weight", "norm.weight", "norm.bias"]

for key in list(checkpoint.keys()):
    if key in unexpected or 'layers.3' in key:
        del checkpoint[key]
swin_transformer.load_state_dict(checkpoint)

<All keys matched successfully>

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PyramidFeatures(nn.Module):
    def __init__(self, img_size = 224, in_channels = 3):
        super().__init__()
        
        self.swin_transformer = SwinTransformer(img_size)
        
        model_path = './weights/swin_tiny_patch4_window7_224.pth'
        self.swin_transformer = SwinTransformer(img_size,in_chans = 3)
        checkpoint = torch.load(model_path, map_location=torch.device(device))['model']
        unexpected = ["patch_embed.proj.weight", "patch_embed.proj.bias", "patch_embed.norm.weight", "patch_embed.norm.bias",
                     "head.weight", "head.bias", "layers.0.downsample.norm.weight", "layers.0.downsample.norm.bias",
                     "layers.0.downsample.reduction.weight", "layers.1.downsample.norm.weight", "layers.1.downsample.norm.bias",
                     "layers.1.downsample.reduction.weight", "layers.2.downsample.norm.weight", "layers.2.downsample.norm.bias",
                     "layers.2.downsample.reduction.weight"]
        for key in list(checkpoint.keys()):
            if key in unexpected or 'layers.3' in key:
                del checkpoint[key]
        self.swin_transformer.load_state_dict(checkpoint)

        
        resnet = eval(f"models.{cfg.cnn_backbone}()")
        self.resnet_layers = nn.ModuleList(resnet.children())[:7]
        
        self.cls_token = nn.Parameter(torch.zeros(1,1,cfg.swin_pyramid_fm[0])) # Class Token

        self.p1_ch = nn.Conv2d(cfg.cnn_pyramid_fm[0], cfg.swin_pyramid_fm[0] , kernel_size = 1)
        self.p1_pm = PatchMerging((cfg.image_size // cfg.patch_size, cfg.image_size // cfg.patch_size), cfg.swin_pyramid_fm[0])
        self.skip_CLS_1 = SkipCLS(dim=cfg.swin_pyramid_fm[0], factor=2)

        self.p2 = self.resnet_layers[5]
        self.p2_ch = nn.Conv2d(cfg.cnn_pyramid_fm[1], cfg.swin_pyramid_fm[1] , kernel_size = 1)
        self.p2_pm = PatchMerging((cfg.image_size // cfg.patch_size // 2, cfg.image_size // cfg.patch_size // 2), cfg.swin_pyramid_fm[1])
        self.skip_CLS_2 = SkipCLS(dim=cfg.swin_pyramid_fm[1], factor=2)
        
        self.p3 = self.resnet_layers[6]
        self.p3_ch = nn.Conv2d(cfg.cnn_pyramid_fm[2] , cfg.swin_pyramid_fm[2] , kernel_size =  1)
#         self.p3_pm = PatchMerging((cfg.image_size // cfg.patch_size // 4,cfg.image_size // cfg.patch_size // 4), cfg.swin_pyramid_fm[2])
#         self.skip_CLS_3 = SkipCLS(dim=cfg.swin_pyramid_fm[2], factor=2)
        
#         self.p4 = self.resnet_layers[7]
#         self.p4_ch = nn.Conv2d(cfg.cnn_pyramid_fm[3] , cfg.swin_pyramid_fm[3] , kernel_size = 1)
        
        trunc_normal_(self.cls_token, std=.02)
        

    def forward(self, x):
        
        for i in range(5):
            x = self.resnet_layers[i](x) 
        
        B, C, H, W = x.shape
        CLS = self.cls_token.expand(B, -1, -1)
        sw1_CLS = CLS

        # 1
        fm1 = x
        fm1_ch = self.p1_ch(x)
        fm1_reshaped = Rearrange('b c h w -> b (h w) c')(fm1_ch)               
        sw1 = self.swin_transformer.layers[0](fm1_reshaped)
        sw1_skipped = fm1_reshaped  + sw1
        CLS = self.skip_CLS_1(embeddings=sw1_skipped, CLS=CLS)
        fm1_sw1 = self.p1_pm(sw1_skipped)
        
        #2
        fm1_sw2 = self.swin_transformer.layers[1](fm1_sw1)
        fm2 = self.p2(fm1)
        fm2_ch = self.p2_ch(fm2)
        fm2_reshaped = Rearrange('b c h w -> b (h w) c')(fm2_ch) 
        fm2_sw2_skipped = fm2_reshaped  + fm1_sw2
        CLS = self.skip_CLS_2(embeddings=fm2_sw2_skipped, CLS=CLS)
        fm2_sw2 = self.p2_pm(fm2_sw2_skipped)
    
        #3
        fm2_sw3 = self.swin_transformer.layers[2](fm2_sw2)
        fm3 = self.p3(fm2)
        fm3_ch = self.p3_ch(fm3)
        fm3_reshaped = Rearrange('b c h w -> b (h w) c')(fm3_ch) 
        fm3_sw3_skipped = fm3_reshaped  + fm2_sw3
#         CLS = self.skip_CLS_3(embeddings=fm3_sw3_skipped, CLS=CLS)
#         fm3_sw3 = self.p3_pm(fm3_sw3_skipped)
        
        #4
#         fm3_sw4 = self.swin_transformer.layers[3](fm3_sw3)
#         fm4 = self.p4(fm3)
#         fm4_ch = self.p4_ch(fm4)
#         fm4_reshaped = Rearrange('b c h w -> b (h w) c')(fm4_ch) 
#         fm4_sw4_skipped = fm4_reshaped  + fm3_sw4

                
        return [torch.cat((sw1_CLS, sw1_skipped), dim=1), torch.cat((CLS, fm3_sw3_skipped), dim=1)]

In [40]:
pyramid = PyramidFeatures()

RuntimeError: Error(s) in loading state_dict for SwinTransformer:
	size mismatch for norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for norm.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).

In [None]:
out = pyramid(torch.rand(2,3,224,224))

In [42]:
class All2Cross(nn.Module):
    def __init__(self, img_size = 224, in_chans=3, embed_dim=(96, 384),
                 depth=[[1, 3, 1]],
                 num_heads=(6, 12), mlp_ratio=(2., 2., 4.), qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm, multi_conv=False):
        super().__init__()
        self.pyramid = PyramidFeatures(img_size= img_size, in_channels=in_chans)
        
        n_p1 = (cfg.image_size // cfg.patch_size) ** 2       # default: 3136 
        n_p2 = (cfg.image_size // cfg.patch_size // 4) ** 2  # default: 49 
        num_patches = (n_p1, n_p2)
        
        self.num_branches = 2
        
        
        total_depth = sum([sum(x[-2:]) for x in depth])
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)]  # stochastic depth decay rule
        dpr_ptr = 0
        self.blocks = nn.ModuleList()
        for idx, block_cfg in enumerate(depth):
            curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
            dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
            blk = MultiScaleBlock(embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
                                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_,
                                  norm_layer=norm_layer)
            dpr_ptr += curr_depth
            self.blocks.append(blk)

        self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])

        self.apply(self._init_weights)
        
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x):
        xs = self.pyramid(x)

        for blk in self.blocks:
            xs = blk(xs)
        xs = [self.norm[i](x) for i, x in enumerate(xs)]
        out = [x[:, 0] for x in xs]
        
        return xs

In [43]:
# test = All2Cross()

In [44]:
# test(torch.rand(1,3,224,224))[0].shape, test(torch.rand(1,3,224,224))[1].shape

In [45]:
# (1,49,768)  >>> (1,7*7,768)  >Permute> (1,768,7,7)     
# (1,3136,96) >>> (1,56*56,96) >Permute> (1,96,56,56)

# C1 = [96,56,56]  >conv -> [128,56,56] 
# C2 = [768,7,7]   >conv-2x -> [128,14,14] >conv-2x -> [128,28,28] >conv-2x -> [128,56,56]

# C = C1 + C2 >>> [128,56,56]
# [128,56,56] > Conv-4x -> [classes, 224,224]

In [46]:
class ConvUpsample(nn.Module):
    def __init__(self, in_chans=384, out_chans=[128], upsample=True):
        super().__init__()
        self.in_chans = in_chans
        self.out_chans = out_chans
        
        self.conv_tower = nn.ModuleList()
        for i, out_ch in enumerate(self.out_chans):
            if i>0: self.in_chans = out_ch
            self.conv_tower.append(nn.Conv2d(
                self.in_chans, out_ch,
                kernel_size=3, stride=1,
                padding=1, bias=False
            ))
            self.conv_tower.append(nn.GroupNorm(32, out_ch))
            self.conv_tower.append(nn.ReLU(inplace=False))
            if upsample:
                self.conv_tower.append(nn.Upsample(
                        scale_factor=2, mode='bilinear', align_corners=False))
            
        self.convs_level = nn.Sequential(*self.conv_tower)
        
    def forward(self, x):
        return self.convs_level(x)

In [47]:
class SegmentationHead(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        super().__init__(conv2d)


class SwinRetina_V2(nn.Module):
    def __init__(self, img_size=224, in_chans=3, n_classes=9):
        super().__init__()
        self.img_size = img_size
        self.patch_size = [4, 16]
        self.n_classes = n_classes
        self.All2Cross = All2Cross( img_size= img_size, in_chans=in_chans)
        
        self.ConvUp_s = ConvUpsample(in_chans=384, out_chans=[128,128])
        self.ConvUp_l = ConvUpsample(in_chans=96, out_chans=[128], upsample=False)
        
        self.segmentation_head = SegmentationHead(
            in_channels=16,
            out_channels=n_classes,
            kernel_size=3,
        )    

        self.conv_pred = nn.Sequential(
            nn.Conv2d(
                128, 16,
                kernel_size=1, stride=1,
                padding=0, bias=True),
#             nn.GroupNorm(8, self.n_classes), 
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        )
    
    def forward(self, x):
        xs = self.All2Cross(x)
        embeddings = [x[:, 1:] for x in xs]
        reshaped_embed = []
        for i, embed in enumerate(embeddings):

            embed = Rearrange('b (h w) d -> b d h w', h=(self.img_size//self.patch_size[i]), w=(self.img_size//self.patch_size[i]))(embed)
            embed = self.ConvUp_l(embed) if i == 0 else self.ConvUp_s(embed)
            
            reshaped_embed.append(embed)
        C = reshaped_embed[0] + reshaped_embed[1]
        C = self.conv_pred(C)

        out = self.segmentation_head(C)
        
        return out  

In [48]:
test_swin_retina = SwinRetina_V2(n_classes=9)

In [49]:
out_test = test_swin_retina(torch.rand(1,3,224,224))

In [50]:
out_test.shape

torch.Size([1, 9, 224, 224])

In [32]:
swin_par = get_n_params(PyramidFeatures().swin_transformer)
print('# Swin:', swin_par)

res_par = get_n_params(PyramidFeatures().resnet_layers)
remained_par = get_n_params(PyramidFeatures()) - swin_par - res_par
print('# Resnet:', res_par)
print('# SkipCLS + PM:', remained_par)


crossvit_par = get_n_params(All2Cross()) - res_par - swin_par
print('# CrossViT:', crossvit_par)

decoder_par = get_n_params(SwinRetina_V2()) - crossvit_par - res_par - swin_par
print('# Decoder:', decoder_par)

print('Total:', get_n_params(SwinRetina_V2(n_classes=9)))

# Swin: 11776170
# Resnet: 8170304
# SkipCLS + PM: 1238592
# CrossViT: 5640672
# Decoder: 704553
Total: 26291699


In [80]:
swin_par = get_n_params(PyramidFeatures().swin_transformer)
print('# Swin:', swin_par)

res_par = get_n_params(PyramidFeatures().resnet_layers)
remained_par = get_n_params(PyramidFeatures()) - swin_par - res_par
print('# Resnet:', res_par)
print('# SkipCLS + PM:', remained_par)


crossvit_par = get_n_params(All2Cross()) - res_par - swin_par
print('# CrossViT:', crossvit_par)

decoder_par = get_n_params(SwinRetina_V2()) - crossvit_par - res_par - swin_par
print('# Decoder:', decoder_par)

print('Total:', swin_par + res_par + crossvit_par + decoder_par)

# Swin: 18854826
# Resnet: 21284672
# SkipCLS + PM: 3799104
# CrossViT: 5024160
# Decoder: 631081
Total: 45794739
