In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from swin_decoder import SwinTransDecoder
from _base import BaseModel2D
from timm.models.layers import trunc_normal_
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.unet.decoder import UnetDecoder
from segmentation_models_pytorch.base import SegmentationHead, ClassificationHead


def kaiming_normal_init_weight(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight)
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
encoder_name="resnet50"
encoder_depth=5
encoder_weights="imagenet"
decoder_use_batchnorm=True
decoder_channels=(256, 128, 64, 32, 16)
decoder_attention_type=None
in_channels=3
classes=4
activation=None
embed_dim=96
norm_layer=nn.LayerNorm
img_size=224
patch_size=4
depths=[2, 2, 2, 2]
num_heads=[3, 6, 12, 24]
window_size=7
qkv_bias=True
qk_scale=None
drop_rate=0.
attn_drop_rate=0.
use_checkpoint=False
ape=True
cls=True
contrast_embed=False
contrast_embed_dim=256
contrast_embed_index=-3
mlp_ratio=4.
drop_path_rate=0.1
final_upsample="expand_first"
patches_resolution=[56, 56]

device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
x =  torch.randn(1, 3, 224, 224).to(device)

In [3]:
encoder = get_encoder(
    encoder_name,
    in_channels=in_channels,
    depth=encoder_depth,
    weights=encoder_weights,
).to(device)

encoder_channels = encoder.out_channels

cnn_decoder = UnetDecoder(
    encoder_channels=encoder_channels,
    decoder_channels=decoder_channels,
    n_blocks=encoder_depth,
    use_batchnorm=decoder_use_batchnorm,
    center=True if encoder_name.startswith("vgg") else False,
    attention_type=decoder_attention_type,
).to(device)

seg_head = SegmentationHead(
    in_channels=decoder_channels[-1],
    out_channels=classes,
    activation=activation,
    kernel_size=3,
).to(device)

swin_decoder = SwinTransDecoder(
    classes, embed_dim, norm_layer, img_size, patch_size, depths, num_heads,
    window_size, qkv_bias, qk_scale, drop_rate, attn_drop_rate, use_checkpoint,
    ape, mlp_ratio, drop_path_rate, final_upsample, patches_resolution,
    encoder_channels
).to(device)

cls_head = ClassificationHead(in_channels=encoder_channels[-1], classes=4).to(device) if cls else None

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:05<00:00, 20.4MB/s]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
features = encoder(x)
f_each_layer = [f'{i.shape}' for i in features]
for i,f in enumerate(f_each_layer):
    print(i, f)
print('')
# seg = seg_head(cnn_decoder(*features))
# print(seg.shape)
seg_tf = swin_decoder(features, device)
print(seg_tf.shape)
# cls = cls_head(features[-1]) if cls_head else None
# print(cls)

0 torch.Size([1, 3, 224, 224])
1 torch.Size([1, 64, 112, 112])
2 torch.Size([1, 256, 56, 56])
3 torch.Size([1, 512, 28, 28])
4 torch.Size([1, 1024, 14, 14])
5 torch.Size([1, 2048, 7, 7])

torch.Size([1, 4, 224, 224])
