# -----------------------------------搭建MAE的模型架构--------------------------------------------

In [1]:
from functools import partial 
# partial函数是一个非常实用的高阶函数，它用于创建一个新的可调用对象（如函数），这个新对象“部分地”固定了原函数的一些参数。
import numpy as np

import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
# patchEmbed:将二维的图像数据转换为一维的向量序列
# Block：VIT的基本构建单元，包含自注意力层，多层感知机，归一化和激活函数等


In [2]:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
    if 'pos_embed' in checkpoint_model:
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model['pos_embed'] = new_pos_embed

In [3]:
class MaskedAutoencoderViT(nn.Module):
    
    """1、搭建模型需要用到的默认层"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024,
                 depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8,
                 decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()
        
        
        """1.1定义编码器用到的层"""
        
        # 定义图像分割和转换操作
        # img_size: 224*224输入图像大小 patch_size：16*16分割成块的大小，也就是分成了14块
        # in_channs: 输入图像的通道数，embed_dim：嵌入维度，也就是原来矩形两维的图像块，被线性投影成1024维的
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        
        # 一个图像分成patch的总数
        num_patches = self.patch_embed.num_patches
        
        # 特殊起始标记cls，用于记住图片全局信息
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 使用nn.Parameter表示这是一个可以学习的参数
        
        # 固定的位置嵌入参数
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
        
        # 直接调用BLOCK构建vit块
        self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)])
        
        # 定义规范化层
        self.norm = norm_layer(embed_dim)
        
        """1.2定义解码器用到的层"""
        # 全连接层
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        
        # 用于填充或遮蔽位置的占位符
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        
        # 解码器位置嵌入
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)
        
        # 使用vit块搭建解码器 qk_scale=None,这个报错
        self.decoder_block = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, 
                                                  qkv_bias=True, norm_layer=norm_layer)
                                           for i in range(decoder_depth)])
        # 规范化层，用于模型规范输出
        self.decoder_norm = norm_layer(decoder_embed_dim)
        # 线形层，将输出重构为原来图像大小
        self.deocder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
        
        
        """1.3其它"""
        self.norm_pix_loss = norm_pix_loss
        self.initialize_weights()
    
    """2、初始化权重的方法"""
    def initialize_weights(self):
        # 计算图片输入时的位置编码信息
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        # 将上一步计算结果转换为pytorch然后复制给模型的储存参数
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        
        # 计算解码器输入时需要的位置编码信息
        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
        
        # 获取patchembed层中类似线性层部分proj的权重参数，然后进行Xavier均匀初始化
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        
        # 对特殊的起始cls和掩蔽mask进行正态分布初始化，均值为0，标准差为0.02
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        
        # 对所有线性层和norm层进行参数初始化操作，调用下面定义的方法
        self.apply(self._init_weights) # 。apply会遍历模型所有子模块并应用传递给它的回调参数
        
    """2.1定义线性层和归一化层的初始化参数"""
    def _init_weights(self, m):
        # 首先检查m是否是线性层
        if isinstance(m, nn.Linear):
            # 是的话,使用xvaier初始化权重
            torch.nn.init.xavier_uniform_(m.weight)
            # 加一个判断有无偏差的条件
            if isinstance(m, nn.Linear) and m.bias is not None:
                # 将偏差初始化为0
                # nn.init.constant_用于将张量的所有元素设置为给定的常数值
                nn.init.constant_(m.bias, 0)
        # 检查是否为归一化层
        elif isinstance(m, nn.LayerNorm):
            # 将归一化层的参数初始化
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        
    """3、将输入图像分割后再展平为一个新的张量"""
    def patchify(self, imgs):
        # 输入imgs:(N, 3, H, W) N是批量，3是通道， HW是高和宽
        # 输出x：(N, L, patch_size**2*3) N是批量， L是分割成patch的数量，Patch...是将小patch转成一个向量
        
        # 获取patch的大小p
        p = self.patch_embed.patch_size[0] 
        
        # 断言图像为正方形并且能被patch整除
        assert imgs.shape[2] == imges.shape[3] and imgs.shape[2] % p == 0
        
        # 为了获取patch数量的中间步骤
        h = w = imgs.shape[2] // p
        
        # 调整张量形状
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        
        # .einsum函数对张量进行重拍和展平，使每个patch先展成一维，然后连续储存
        x = torch.einsum('nchpwq->nhwpqc', x)
        
        # 改变张量形状
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2*3))
        return x
    
    """4、反patch化，将分割的patch重新组合成图像"""
    def unpatchify(self, x):
        # 输入x:(N, L, patch_size**2*3) 分割的patch
        # 输出imgs:(N, 3, H, W)输出批量的图像
        
        # 获取patch的大小
        p = self.patch_embed.patch_size[0]
        # 反向求patch数量
        h = w = int(x.shape[1]**.5)
        # 断言patch数量一致
        assert h * w == x.shape[1]
        
        # 将patch组合成图像
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 2, h * p, h * p))
        return imgs
    
    """5、对输入序列进行随机掩码的操作"""
    def random_masking(self, x, mask_ratio):
        # 输入x为分割并处理后的图像，也就是patchify函数返回的值:(N, L, D)
        # mask_ratio为掩码率？
        
        N, L, D = x.shape
        
        # 计算需要保留不掩避的序列长度
        len_keep = int(L * (1 - mask_ratio))
        
        # 设置随机噪音张量
        noise = torch.rand(N, L, device=x.device) 
        
        # 用于实现样本随机打乱和恢复的操作
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        # 根据索引提取不被掩蔽的序列
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        
        # 建一个全为1的二进制掩码张量，形状为 [N, L]，然后将前 len_keep 个元素设置为0，表示这些位置是未被遮蔽的位置
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # 根据之前保存的恢复原顺序的索引 ids_restore，对掩码张量进行同样的操作，确保掩码张量与原始输入序列保持一致的顺序
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        # 返回：经过掩码处理后的未被掩码部分的序列， 二进制掩码表示， 用于回复原始序列的索引
        return x_mask, mask, ids_restore
        
    """6、编码器的前向传播"""
    def forward_encoder(self, x, mask_ratio):
        # 输入x:批量图片
        
        # 1-将输入的批量图片进行分割patch和转换
        x = self.patch_embed(x)
        
        # 2-添加位置编码
        x = x + self.pos_embed[:, 1:, :]
        
        # 3-对输入图像进行掩蔽操作
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
        # 4-添加cls特殊起始标记，用用于记录全局信息
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cles_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # 5-输入transform块进行计算
        for blk in self.blocks:
            x = blk(x)
        
        # 6-归一化计算
        x = self.norm(x)
        
        # 返回：编码器计算结果，而掩码标记mask，用于回复原始序列的索引
        return x, mask, ids_restore
    
    """7、解码器的前向传播"""
    def forward_decoder(self, x, ids_restore):
        
        # 1-将输入token转化为对应的高维向量表示
        x = self.decoder_embed(x)
        
        # 2-创建于掩码数量相同的mask token向量
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        
        # 3-将原始序列 x 中除第一个元素（即cls token）外的部分与刚创建的mask tokens拼接起来，形成一个新的序列 x_
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        
        # 4-反序列化
        x_ = token.gather(x_, dim=1, index=ids_restore.unqueeze(-1).repeat(1, 1, x.shape[2]))
        
        # 5-将cls重新加入向量开头
        x = torch.cat([x[:, :1, :], x_], dim=1)
        
        # 6-添加位置编码
        x = x + self.decoder_pos_embed
        
        # 7-输入解码器的transform块
        for blk in self.decoder_blocks:
            x = blk(x)
        
        # 8-归一化操作
        x = self.decoder_norm(x)
        
        # 9-像素级预测层
        x = self.decoder_pred(x)
        
        # 10-去除cls标记
        x = x[:, 1:, :]
        
        # 返回的x和原来的输入图像是一样的
        return x
    
    """8、定义前向传播的损失loss"""
    def forward_loss(self, imgs, pred, mask):
        # 输入图像imgs:(N, 3, H, W) 批量-通道-高-宽
        # 预测值（解码器的输出）pred(N, L, p*p*3) 批量-patch数量-每个patch所有像素的值
        # 掩码表示mask:(N, L), 0是未掩码，1是掩码了的 批量-patch数量
        
        # 1-先将imgs转为patchs
        target = self.patchify(imgs)
        
        # 2-对目标patch进行标准化处理，通过减去每patch各通道像素的均值并除以其方差（加了一个很小的数值以防止除零错误）来规范化数据
        if norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5
            
        # 3-计算预测patch (pred) 与目标patch (target) 之间的平方误差损失，即每个样本每个patch的平均像素损失
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)
        
        # 4-将损失乘以掩码 mask 后求和，得到的是基于被移除patch的平均损失。
        loss = (loss * mask).sum() / mask.sum()
        
        return loss
    
    """9、定义模型前向传播"""
    def forward(self, imgs, mask_ratio=0.75):
        # 1-编码器前向传播
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        
        # 2-解码器前向传播
        pred = self.forward_decoder(latent, ids_restore)
        
        # 3-计算loss
        loss = self.forward_loss(imgs, pred, mask)
        
        return loss, pred, mask
        

# 下面都是测试

In [4]:
from torchinfo import summary 

In [8]:
model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), )

In [9]:
summary(model, (1, 3, 224, 224))

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [PatchEmbed: 1, Conv2d: 2, Identity: 2]

In [6]:
model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), )

In [7]:
model

MaskedAutoencoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2

In [None]:
class MaskedAutoencoderViT2(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True,  norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True,  norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        # --------------------------------------------------------------------------

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

    def forward_loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

In [51]:
model2 = MaskedAutoencoderViT2(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), )

In [52]:
model2

MaskedAutoencoderViT2(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop

In [12]:
def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

In [13]:
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b 

In [14]:
mae_vit_base_patch16

<function __main__.mae_vit_base_patch16_dec512d8b(**kwargs)>

In [15]:
type(mae_vit_base_patch16)

function