In [1]:
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import OrderedDict
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from mmcv.utils import to_2tuple

from mmdet.utils import get_root_logger
from mmdet.models.builder import BACKBONES

import sys
sys.path.append('../')

# from ...utils import get_root_logger
# from ..builder import BACKBONES
from utils.ckpt_convert import swin_converter
from utils.transformer import PatchEmbed, PatchMerging

In [19]:

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'  # noqa
ckpt = _load_checkpoint(
    pretrained, logger=None, map_location='cpu')

load checkpoint from http path: https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth


In [20]:
ckpt.keys()

dict_keys(['model'])

In [21]:
_state_dict = ckpt['model']
print('************************',list(_state_dict.keys())[0])
_state_dict = swin_converter(_state_dict)
print('************************',list(_state_dict.keys())[0])

************************ patch_embed.proj.weight
************************ backbone.patch_embed.projection.weight


In [5]:
_state_dict.keys()

odict_keys(['backbone.patch_embed.projection.weight', 'backbone.patch_embed.projection.bias', 'backbone.patch_embed.norm.weight', 'backbone.patch_embed.norm.bias', 'backbone.stages.0.blocks.0.norm1.weight', 'backbone.stages.0.blocks.0.norm1.bias', 'backbone.stages.0.blocks.0.attn.w_msa.qkv.weight', 'backbone.stages.0.blocks.0.attn.w_msa.qkv.bias', 'backbone.stages.0.blocks.0.attn.w_msa.proj.weight', 'backbone.stages.0.blocks.0.attn.w_msa.proj.bias', 'backbone.stages.0.blocks.0.norm2.weight', 'backbone.stages.0.blocks.0.norm2.bias', 'backbone.stages.0.blocks.0.ffn.layers.0.0.weight', 'backbone.stages.0.blocks.0.ffn.layers.0.0.bias', 'backbone.stages.0.blocks.0.ffn.layers.1.weight', 'backbone.stages.0.blocks.0.ffn.layers.1.bias', 'backbone.stages.0.blocks.1.norm1.weight', 'backbone.stages.0.blocks.1.norm1.bias', 'backbone.stages.0.blocks.1.attn.w_msa.qkv.weight', 'backbone.stages.0.blocks.1.attn.w_msa.qkv.bias', 'backbone.stages.0.blocks.1.attn.w_msa.proj.weight', 'backbone.stages.0.bloc

In [6]:
state_dict = OrderedDict()
for k, v in _state_dict.items():
    if k.startswith('backbone.'):
        state_dict[k[9:]] = v

In [7]:
state_dict.keys()

odict_keys(['patch_embed.projection.weight', 'patch_embed.projection.bias', 'patch_embed.norm.weight', 'patch_embed.norm.bias', 'stages.0.blocks.0.norm1.weight', 'stages.0.blocks.0.norm1.bias', 'stages.0.blocks.0.attn.w_msa.qkv.weight', 'stages.0.blocks.0.attn.w_msa.qkv.bias', 'stages.0.blocks.0.attn.w_msa.proj.weight', 'stages.0.blocks.0.attn.w_msa.proj.bias', 'stages.0.blocks.0.norm2.weight', 'stages.0.blocks.0.norm2.bias', 'stages.0.blocks.0.ffn.layers.0.0.weight', 'stages.0.blocks.0.ffn.layers.0.0.bias', 'stages.0.blocks.0.ffn.layers.1.weight', 'stages.0.blocks.0.ffn.layers.1.bias', 'stages.0.blocks.1.norm1.weight', 'stages.0.blocks.1.norm1.bias', 'stages.0.blocks.1.attn.w_msa.qkv.weight', 'stages.0.blocks.1.attn.w_msa.qkv.bias', 'stages.0.blocks.1.attn.w_msa.proj.weight', 'stages.0.blocks.1.attn.w_msa.proj.bias', 'stages.0.blocks.1.norm2.weight', 'stages.0.blocks.1.norm2.bias', 'stages.0.blocks.1.ffn.layers.0.0.weight', 'stages.0.blocks.1.ffn.layers.0.0.bias', 'stages.0.blocks.1.f

In [8]:
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
    print('NOP')
    state_dict = {k[7:]: v for k, v in state_dict.items()}

In [9]:
# reshape absolute position embedding
if state_dict.get('absolute_pos_embed') is not None:
    print('NOP')
    absolute_pos_embed = state_dict['absolute_pos_embed']
    N1, L, C1 = absolute_pos_embed.size()
    N2, C2, H, W = self.absolute_pos_embed.size()
    if N1 != N2 or C1 != C2 or L != H * W:
        logger.warning('Error in loading absolute_pos_embed, pass')
    else:
        state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
            N2, H, W, C2).permute(0, 3, 1, 2).contiguous()

In [10]:
# interpolate position bias table if needed
relative_position_bias_table_keys = [
    k for k in state_dict.keys()
    if 'relative_position_bias_table' in k
]

In [13]:
relative_position_bias_table_keys

['stages.0.blocks.0.attn.w_msa.relative_position_bias_table',
 'stages.0.blocks.1.attn.w_msa.relative_position_bias_table',
 'stages.1.blocks.0.attn.w_msa.relative_position_bias_table',
 'stages.1.blocks.1.attn.w_msa.relative_position_bias_table',
 'stages.2.blocks.0.attn.w_msa.relative_position_bias_table',
 'stages.2.blocks.1.attn.w_msa.relative_position_bias_table',
 'stages.2.blocks.2.attn.w_msa.relative_position_bias_table',
 'stages.2.blocks.3.attn.w_msa.relative_position_bias_table',
 'stages.2.blocks.4.attn.w_msa.relative_position_bias_table',
 'stages.2.blocks.5.attn.w_msa.relative_position_bias_table',
 'stages.3.blocks.0.attn.w_msa.relative_position_bias_table',
 'stages.3.blocks.1.attn.w_msa.relative_position_bias_table']

In [None]:
state_dict.keys

odict_keys(['patch_embed.projection.weight', 'patch_embed.projection.bias', 'patch_embed.norm.weight', 'patch_embed.norm.bias', 'stages.0.blocks.0.norm1.weight', 'stages.0.blocks.0.norm1.bias', 'stages.0.blocks.0.attn.w_msa.qkv.weight', 'stages.0.blocks.0.attn.w_msa.qkv.bias', 'stages.0.blocks.0.attn.w_msa.proj.weight', 'stages.0.blocks.0.attn.w_msa.proj.bias', 'stages.0.blocks.0.norm2.weight', 'stages.0.blocks.0.norm2.bias', 'stages.0.blocks.0.ffn.layers.0.0.weight', 'stages.0.blocks.0.ffn.layers.0.0.bias', 'stages.0.blocks.0.ffn.layers.1.weight', 'stages.0.blocks.0.ffn.layers.1.bias', 'stages.0.blocks.1.norm1.weight', 'stages.0.blocks.1.norm1.bias', 'stages.0.blocks.1.attn.w_msa.qkv.weight', 'stages.0.blocks.1.attn.w_msa.qkv.bias', 'stages.0.blocks.1.attn.w_msa.proj.weight', 'stages.0.blocks.1.attn.w_msa.proj.bias', 'stages.0.blocks.1.norm2.weight', 'stages.0.blocks.1.norm2.bias', 'stages.0.blocks.1.ffn.layers.0.0.weight', 'stages.0.blocks.1.ffn.layers.0.0.bias', 'stages.0.blocks.1.f

In [18]:
for table_key in relative_position_bias_table_keys:
    print(table_key)
    table_pretrained = state_dict[table_key]
    print(table_pretrained.shape)
    table_current = state_dict()[table_key]
    L1, nH1 = table_pretrained.size()
    L2, nH2 = table_current.size()
    if nH1 != nH2:
        logger.warning(f'Error in loading {table_key}, pass')
    elif L1 != L2:
        S1 = int(L1**0.5)
        S2 = int(L2**0.5)
        table_pretrained_resized = F.interpolate(
            table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
            size=(S2, S2),
            mode='bicubic')
        state_dict[table_key] = table_pretrained_resized.view(
            nH2, L2).permute(1, 0).contiguous()

stages.0.blocks.0.attn.w_msa.relative_position_bias_table
torch.Size([169, 3])


TypeError: 'collections.OrderedDict' object is not callable

In [None]:
model = swin_tiny1(type='swin_tiny1',
        style='pytorch',
        out_indices=(1, 2, 3),
        init_cfg=dict(type='Pretrained', checkpoint=pretrained))
load_state_dict(state_dict, False)

In [16]:
state_dict.keys()

odict_keys(['patch_embed.projection.weight', 'patch_embed.projection.bias', 'patch_embed.norm.weight', 'patch_embed.norm.bias', 'stages.0.blocks.0.norm1.weight', 'stages.0.blocks.0.norm1.bias', 'stages.0.blocks.0.attn.w_msa.qkv.weight', 'stages.0.blocks.0.attn.w_msa.qkv.bias', 'stages.0.blocks.0.attn.w_msa.proj.weight', 'stages.0.blocks.0.attn.w_msa.proj.bias', 'stages.0.blocks.0.norm2.weight', 'stages.0.blocks.0.norm2.bias', 'stages.0.blocks.0.ffn.layers.0.0.weight', 'stages.0.blocks.0.ffn.layers.0.0.bias', 'stages.0.blocks.0.ffn.layers.1.weight', 'stages.0.blocks.0.ffn.layers.1.bias', 'stages.0.blocks.1.norm1.weight', 'stages.0.blocks.1.norm1.bias', 'stages.0.blocks.1.attn.w_msa.qkv.weight', 'stages.0.blocks.1.attn.w_msa.qkv.bias', 'stages.0.blocks.1.attn.w_msa.proj.weight', 'stages.0.blocks.1.attn.w_msa.proj.bias', 'stages.0.blocks.1.norm2.weight', 'stages.0.blocks.1.norm2.bias', 'stages.0.blocks.1.ffn.layers.0.0.weight', 'stages.0.blocks.1.ffn.layers.0.0.bias', 'stages.0.blocks.1.f