In [1]:
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model
from qat.export.utils import replace_module_by_name, fetch_module_by_name
from networks.vision_transformer import VisionTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class AttnScore(nn.Module):
    def __init__(self, scale, attn_drop):
        super().__init__()
        self.scale = scale
        self.attn_drop = nn.Dropout(attn_drop) # one is func. another one is probability.
    def forward(self, q, k):
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        return attn

class Attention2(nn.Module):
    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            qk_norm=False,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
        self.k_linear = nn.Linear(dim, dim, bias=qkv_bias)
        self.v_linear = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_score = AttnScore(scale=self.scale, attn_drop=attn_drop)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv.unbind(0)
        q = self.q_linear(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k_linear(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v_linear(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        q, k = self.q_norm(q), self.k_norm(k)

        # if self.fused_attn:
        #     x = F.scaled_dot_product_attention(
        #         q, k, v,
        #         dropout_p=self.attn_drop.p,
        #     )
        # else:
        #     # q = q * self.scale
        #     # attn = q @ k.transpose(-2, -1)
        #     # attn = attn.softmax(dim=-1)
        #     # attn = self.attn_drop(attn)
        #     attn = self.attn_score(q, k)
        #     x = attn @ v
        attn = self.attn_score(q, k)
        x = attn @ v
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


In [3]:
# from torch.jit import Final
# from timm.layers import use_fused_attn
# class Attention3(nn.Module):
#     fused_attn: Final[bool]
#     def __init__(
#             self,
#             dim,
#             num_heads=8,
#             qkv_bias=False,
#             qk_norm=False,
#             attn_drop=0.,
#             proj_drop=0.,
#             norm_layer=nn.LayerNorm,
#     ):
#         super().__init__()
#         assert dim % num_heads == 0, 'dim should be divisible by num_heads'
#         self.num_heads = num_heads
#         self.head_dim = dim // num_heads
#         self.scale = self.head_dim ** -0.5
#         self.fused_attn = use_fused_attn()

#         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#         self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
#         self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
#         self.attn_drop = nn.Dropout(attn_drop)
#         self.proj = nn.Linear(dim, dim)
#         self.proj_drop = nn.Dropout(proj_drop)

#     def forward(self, x):
#         B, N, C = x.shape
#         qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
#         q, k, v = qkv.unbind(0)
#         q, k = self.q_norm(q), self.k_norm(k)

#         if self.fused_attn:
#             x = F.scaled_dot_product_attention(
#                 q, k, v,
#                 dropout_p=self.attn_drop.p,
#             )
#         else:
#             q = q * self.scale
#             attn = q @ k.transpose(-2, -1)
#             attn = attn.softmax(dim=-1)
#             attn_argmax = attn.argmax(dim=-1)
#             attn = self.attn_drop(attn)
#             output = attn - (attn - attn_argmax).detach() # output = real_output - (real_output - quantized_output).detach()
#             x = output @ v

#         x = x.transpose(1, 2).reshape(B, N, C)
#         x = self.proj(x)
#         x = self.proj_drop(x)
#         return x

In [4]:
# model_name = "deit3_huge_patch14_224.fb_in22k_ft_in1k"
# model_name = "deit3_base_patch16_224.fb_in22k_ft_in1k"
model_name = 'deit3_small_patch16_224.fb_in22k_ft_in1k'
model = create_model(model_name, pretrained=True)

In [6]:
settings = {
    "attn": 384,
}

for i in range(0, 12): 
    for name in settings:
        layer = model.blocks[i]
        module = fetch_module_by_name(layer, name)
        attn = Attention2(
                dim=settings[name], # 為什麼從attn module拿不到 head_dim?
                num_heads=getattr(module, "num_heads"),
                qkv_bias=True,
                qk_norm=False
            )
        
        qkv_weight = module.qkv.weight.data
        qkv_bias = module.qkv.bias.data
        dim = qkv_weight.shape[0] // 3
        q_weight, k_weight, v_weight = qkv_weight[:dim], qkv_weight[dim:2*dim], qkv_weight[2*dim:]
        # print(q_weight.shape)
        q_bias, k_bias, v_bias = qkv_bias[:dim], qkv_bias[dim:2*dim], qkv_bias[2*dim:]
        attn.q_linear.weight.data = q_weight
        attn.q_linear.bias.data = q_bias
        attn.k_linear.weight.data = k_weight
        attn.k_linear.bias.data = k_bias
        attn.v_linear.weight.data = v_weight
        attn.v_linear.bias.data = v_bias
        attn.attn_drop = module.attn_drop
        attn.proj = module.proj
        attn.proj_drop = module.proj_drop
        replace_module_by_name(layer, name, attn)

In [20]:
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention2(
        (q_linear): Linear(in_features=384, out_features=384, bias=True)
        (k_linear): Linear(in_features=384, out_features=384, bias=True)
        (v_linear): Linear(in_features=384, out_features=384, bias=True)
        (attn_score): AttnScore(
          (attn_drop): Dropout(p=0.0, inplace=False)
        )
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): Laye

In [7]:
# output_tensors = []
# def hook(module, input_tensor, output_tensor): # The hook will be called every time after :func:`forward` has computed an output.
#     tensor = output_tensor.cpu()
#     output_tensors.append(tensor)
    
# handles = []
# for name, module in model.named_modules(): 
#     if "attn_score" in name: # double check module_name
#         handles.append(module.register_forward_hook(hook))

In [8]:
from copy import deepcopy
model_0 = deepcopy(model)
for i in range(1, 12):
    model_0.blocks[i] = nn.Identity()


In [9]:
model_0.cuda()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention2(
        (q_linear): Linear(in_features=384, out_features=384, bias=True)
        (k_linear): Linear(in_features=384, out_features=384, bias=True)
        (v_linear): Linear(in_features=384, out_features=384, bias=True)
        (attn_score): AttnScore(
          (attn_drop): Dropout(p=0.0, inplace=False)
        )
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): Laye

In [10]:
x = torch.randn(1, 3, 224, 224).cuda()

In [11]:
model_0.norm = nn.Identity()
model_0.head = nn.Identity()

In [12]:
model_0(x).shape

torch.Size([1, 197, 384])

In [13]:
from torchinfo import summary
summary(model,(1, 3, 224, 224)) 
# for handle in handles:
    # handle.remove()

Layer (type:depth-idx)                   Output Shape              Param #
VisionTransformer                        [1, 197, 1000]            75,648
├─PatchEmbed: 1-1                        [1, 196, 384]             --
│    └─Conv2d: 2-1                       [1, 384, 14, 14]          295,296
│    └─Identity: 2-2                     [1, 196, 384]             --
├─Dropout: 1-2                           [1, 197, 384]             --
├─Identity: 1-3                          [1, 197, 384]             --
├─Identity: 1-4                          [1, 197, 384]             --
├─Sequential: 1-5                        [1, 197, 384]             --
│    └─Block: 2-3                        [1, 197, 384]             --
│    │    └─LayerNorm: 3-1               [1, 197, 384]             768
│    │    └─Attention2: 3-2              [1, 197, 384]             591,360
│    │    └─LayerScale: 3-3              [1, 197, 384]             384
│    │    └─Identity: 3-4                [1, 197, 384]             --

In [14]:
# len(output_tensors)
# output_tensors[0].shape

In [15]:
# settings = {
#     "attn": 1280, # dimension of the model
# }

# for i in range(0, 31): 
#     for name in settings:
#         layer = model.blocks[i]
#         module = fetch_module_by_name(layer, name)
#         attn = Attention3(
#                 dim=settings[name],
#                 num_heads=getattr(module, "num_heads"),
#                 qkv_bias=True,
#                 qk_norm=False
#             )
#         qkv_weight = module.qkv.weight.data
#         qkv_bias = module.qkv.bias.data
#         attn.qkv.weight.data = qkv_weight
#         attn.qkv.bias.data = qkv_bias
#         attn.attn_drop = module.attn_drop
#         attn.proj = module.proj
#         attn.proj_drop = module.proj_drop
#         replace_module_by_name(layer, name, attn)

In [22]:
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention2(
        (q_linear): Linear(in_features=384, out_features=384, bias=True)
        (k_linear): Linear(in_features=384, out_features=384, bias=True)
        (v_linear): Linear(in_features=384, out_features=384, bias=True)
        (attn_score): AttnScore(
          (attn_drop): Dropout(p=0.0, inplace=False)
        )
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): Laye

In [23]:
model_name # deit3_small_patch16_224.fb_in22k_ft_in1k

'deit3_small_patch16_224.fb_in22k_ft_in1k'

In [21]:
torch.save(model.state_dict(), f'{model_name}.pth')

In [17]:
import timm
timm.list_models("deit3*")

['deit3_base_patch16_224',
 'deit3_base_patch16_384',
 'deit3_huge_patch14_224',
 'deit3_large_patch16_224',
 'deit3_large_patch16_384',
 'deit3_medium_patch16_224',
 'deit3_small_patch16_224',
 'deit3_small_patch16_384']

In [18]:
# from torchinfo import summary
# model_name = 'deit3_huge_patch14_224.fb_in22k_ft_in1k'
# model = create_model(model_name, pretrained=True)
# summary(model,(1,3,224,224)) 