In [None]:
import torch
import torch.nn as nn
from timm.models.vision_transformer import Block, PatchEmbed
from functools import partial
import numpy as np

SinCos编码  
$$PE(pos, 2i) = \sin{\frac{pos}{10000^{2i/d_{model}}}}$$
$$PE(pos, 2i + 1) = \cos{\frac{pos}{10000^{2i/d_{model}}}}$$
引入位置信息：相隔 k 个词的两个位置 pos 和 pos+k 的位置编码是由 k 和pos的位置编码定义的一个线性变换  
$$PE(pos + k, 2i) = PE(pos, 2i)PE(k, 2i + 1) + PE(pos, 2i + 1)PE(k, 2i)$$
$$PE(pos + k, 2i + 1) = PE(pos, 2i + 1)PE(k, 2i + 1) - PE(pos, 2i)PE(k, 2i)$$

In [None]:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)
    grid = np.stack(grid, axis=0)
    grid = grid.reshape([2, 1, grid_size, grid_size])
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
    pos_embed = np.concatenate([emb_h, emb_w], axis=1)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    omega = np.arange(embed_dim // 2, dtype=np.float)
    omega /= embed_dim / 2.
    omega = 1. / 10000 ** omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

In [1]:
from vit_model import MaskedAutoEncoderViT, VisionTransformer


MAE_model = MaskedAutoEncoderViT()
ViT_model = VisionTransformer()

mae_name_param_size = dict()
vit_name_param_size = dict()

for name, param in MAE_model.named_parameters():
    mae_name_param_size[name] = param.shape
for name, param in ViT_model.named_parameters():
    vit_name_param_size[name] = param.shape


In [2]:
for mae, vit in zip()

dict_keys(['cls_token', 'pos_embed', 'mask_token', 'decoder_pos_embed', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'enc_blocks.0.norm1.weight', 'enc_blocks.0.norm1.bias', 'enc_blocks.0.attn.qkv.weight', 'enc_blocks.0.attn.qkv.bias', 'enc_blocks.0.attn.proj.weight', 'enc_blocks.0.attn.proj.bias', 'enc_blocks.0.norm2.weight', 'enc_blocks.0.norm2.bias', 'enc_blocks.0.mlp.fc1.weight', 'enc_blocks.0.mlp.fc1.bias', 'enc_blocks.0.mlp.fc2.weight', 'enc_blocks.0.mlp.fc2.bias', 'enc_blocks.1.norm1.weight', 'enc_blocks.1.norm1.bias', 'enc_blocks.1.attn.qkv.weight', 'enc_blocks.1.attn.qkv.bias', 'enc_blocks.1.attn.proj.weight', 'enc_blocks.1.attn.proj.bias', 'enc_blocks.1.norm2.weight', 'enc_blocks.1.norm2.bias', 'enc_blocks.1.mlp.fc1.weight', 'enc_blocks.1.mlp.fc1.bias', 'enc_blocks.1.mlp.fc2.weight', 'enc_blocks.1.mlp.fc2.bias', 'enc_blocks.2.norm1.weight', 'enc_blocks.2.norm1.bias', 'enc_blocks.2.attn.qkv.weight', 'enc_blocks.2.attn.qkv.bias', 'enc_blocks.2.attn.proj.weight', 'enc_bloc