In [None]:
'''
Patch Partition
划分实质就是维度的变换
'''

num_patches = (h // self.patch_h) * (w // self.patch_w)
# (b, c=3, h, w)->(b, n_patches, patch_size**2 * c)
patches = x.view(
    b, c,
    h // self.patch_h, self.patch_h, 
    w // self.patch_w, self.patch_w
).permute(0, 2, 4, 3, 5, 1).reshape(b, num_patches, -1)


In [None]:
'''
Masking
根据预设的 mask 比例采用服从均匀分布的策略随机采样一批 patches 喂给 Encoder，剩下的就 mask 掉
'''

# 根据 mask 比例计算需要 mask 掉的 patch 数量
# num_patches = (h // self.patch_h) * (w // self.patch_w)
num_masked = int(self.mask_ratio * num_patches)

# Shuffle:生成对应 patch 的随机索引
# torch.rand() 服从均匀分布(normal distribution)
# torch.rand() 只是生成随机数，argsort() 是为了获得成索引
# (b, n_patches)
shuffle_indices = torch.rand(b, num_patches, device=device).argsort()
# mask 和 unmasked patches 对应的索引
mask_ind, unmask_ind = shuffle_indices[:, :num_masked], shuffle_indices[:, num_masked:]

# 对应 batch 维度的索引：(b,1)
batch_ind = torch.arange(b, device=device).unsqueeze(-1)
# 利用先前生成的索引对 patches 进行采样，分为 mask 和 unmasked 两组
mask_patches, unmask_patches = patches[batch_ind, mask_ind], patches[batch_ind, unmask_ind]

In [None]:
'''
Encoder 对 un-masked 的 patches 进行编码
先对 un-masked patches 进行 emebdding 转换成 tokens，并且加上 position embeddings，从而为它们添加位置信息，然后才能是真正的编码过程。至于编码过程，实质上就是扔给 Transformer 玩 😄(query 和 key 玩一玩，玩出个 attention 后再和 value 一起玩)
'''

# 将 patches 通过 emebdding 转换成 tokens
unmask_tokens = self.encoder.patch_embed(unmask_patches)
# 为 tokens 加入 position embeddings 
# 注意这里索引加1是因为索引0对应 ViT 的 cls_token
unmask_tokens += self.encoder.pos_embed.repeat(b, 1, 1)[batch_ind, unmask_ind + 1]
# 真正的编码过程
encoded_tokens = self.encoder.transformer(unmask_tokens)

In [None]:
'''
Decoder
将编码后的 tokens 和 添加了位置信息后的 mask tokens 按原先对应 patches 的次序拼起来，然后喂给 Decoder 解码
若编码后的 tokens 维度若与 Decoder 要求的输入维度不一致，需要使用 linear projection 进行转换
'''

# 对编码后的 tokens 维度进行转换，从而符合 Decoder 要求的输入维度
enc_to_dec_tokens = self.enc_to_dec(encoded_tokens)

# 由于 mask token 实质上只有1个，因此要对其进行扩展，从而和 masked patches 一一对应
# (decoder_dim)->(b, n_masked, decoder_dim)
mask_tokens = self.mask_embed[None, None, :].repeat(b, num_masked, 1)
# 为 mask tokens 加入位置信息
mask_tokens += self.decoder_pos_embed(mask_ind)

# 将 mask tokens 与 编码后的 tokens 拼接起来
# (b, n_patches, decoder_dim)
concat_tokens = torch.cat([mask_tokens, enc_to_dec_tokens], dim=1)
# Un-shuffle：恢复原先 patches 的次序
dec_input_tokens = torch.empty_like(concat_tokens, device=device)
dec_input_tokens[batch_ind, shuffle_indices] = concat_tokens
# 将全量 tokens 喂给 Decoder 解码
decoded_tokens = self.decoder(dec_input_tokens)

In [None]:
'''
Loss Computation
取出解码后的 mask tokens 送入头部进行像素值预测，然后将预测结果和 masked patches 比较，计算 MSE loss
'''

# 取出解码后的 mask tokens
dec_mask_tokens = decoded_tokens[batch_ind, mask_ind, :]
# 预测 masked patches 的像素值
# (b, n_masked, n_pixels_per_patch=patch_size**2 x c)
pred_mask_pixel_values = self.head(dec_mask_tokens)
# loss 计算
loss = F.mse_loss(pred_mask_pixel_values, mask_patches)

In [None]:
'''
Encoder ViT & Decoder Transformer
'''

import torch
import torch.nn as nn

def to_pair(t):
    return t if isinstance(t, tuple) else (t, t)

class PreNorm(nn.Module):
    def __init__(self, dim, net):
        super().__init__()

        self.norm = nn.LayerNorm(dim)
        self.net = net
    
    def forward(self, x, **kwargs):
        return self.net(self.norm(x), **kwargs)

class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, dim_per_head=64, dropout=0.):
        super().__init__()

        self.num_heads = num_heads
        self.scale = dim_per_head ** -0.5

        inner_dim = dim_per_head * num_heads
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.attend = nn.Softmax(dim=-1)

        project_out = not (num_heads == 1 and dim_per_head == dim)
        self.out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()
    
    def forward(self, x):
        b, l, d = x.shape

        '''i. QKV projection'''
        # (b,l,dim_all_heads x 3)
        qkv = self.to_qkv(x)
        # (3,b,num_heads,l,dim_per_head)
        qkv = qkv.view(b, l, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).contiguous()
        # 3 x (1,b,num_heads,l,dim_per_head)
        q, k, v = qkv.chunk(3)
        q, k, v = q.squeeze(0), k.squeeze(0), v.squeeze(0)

        '''ii. Attention computation'''
        attn = self.attend(
            torch.matmul(q, k.transpose(-1, -2)) * self.scale
        )

        '''iii. Put attention on Value & reshape'''
        # (b,num_heads,l,dim_per_head)
        z = torch.matmul(attn, v)
        # (b,num_heads,l,dim_per_head)->(b,l,num_heads,dim_per_head)->(b,l,dim_all_heads)
        z = z.transpose(1, 2).reshape(b, l, -1)
        # assert z.size(-1) == q.size(-1) * self.num_heads

        '''iv. Project out'''
        # (b,l,dim_all_heads)->(b,l,dim)
        out = self.out(z)
        # assert out.size(-1) == d

        return out

class FFN(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(p=dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class Transformer(nn.Module):
    def __init__(self, dim, mlp_dim, depth=6, num_heads=8, dim_per_head=64, dropout=0.):
        super().__init__()

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, SelfAttention(dim, num_heads=num_heads, dim_per_head=dim_per_head, dropout=dropout)),
                PreNorm(dim, FFN(dim, mlp_dim, dropout=dropout))
            ]))
    
    def forward(self, x):
        for norm_attn, norm_ffn in self.layers:
            x = x + norm_attn(x)
            x = x + norm_ffn(x)
        
        return x

class ViT(nn.Module):
    def __init__(
        self, image_size, patch_size, 
        num_classes=1000, dim=1024, depth=6, num_heads=8, mlp_dim=2048,
        pool='cls', channels=3, dim_per_head=64, dropout=0., embed_dropout=0.
    ):
        super().__init__()

        img_h, img_w = to_pair(image_size)
        self.patch_h, self.patch_w = to_pair(patch_size)
        assert not img_h % self.patch_h and not img_w % self.patch_w, \
            f'Image dimensions ({img_h},{img_w}) must be divisible by the patch size ({self.patch_h},{self.patch_w}).'
        num_patches = (img_h // self.patch_h) * (img_w // self.patch_w)

        assert pool in {'cls', 'mean'}, f'pool type must be either cls (cls token) or mean (mean pooling), got: {pool}'
        
        patch_dim = channels * self.patch_h * self.patch_w
        self.patch_embed = nn.Linear(patch_dim, dim)

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # Add 1 for cls_token
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))

        self.dropout = nn.Dropout(p=embed_dropout)

        self.transformer = Transformer(
            dim, mlp_dim, depth=depth, num_heads=num_heads,
            dim_per_head=dim_per_head, dropout=dropout
        )

        self.pool = pool

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
    
    def forward(self, x):
        b, c, img_h, img_w = x.shape
        assert not img_h % self.patch_h and not img_w % self.patch_w, \
            f'Input image dimensions ({img_h},{img_w}) must be divisible by the patch size ({self.patch_h},{self.patch_w}).'
        
        '''i. Patch partition'''
        num_patches = (img_h // self.patch_h) * (img_w // self.patch_w)
        # (b,c,h,w)->(b,n_patches,patch_h*patch_w*c)
        patches = x.view(
            b, c, 
            img_h // self.patch_h, self.patch_h, 
            img_w // self.patch_w, self.patch_w
        ).permute(0, 2, 4, 3, 5, 1).reshape(b, num_patches, -1)

        '''ii. Patch embedding'''
        # (b,n_patches,dim)
        tokens = self.patch_embed(patches)
        # (b,n_patches+1,dim)
        tokens = torch.cat([self.cls_token.repeat(b, 1, 1), tokens], dim=1)
        tokens += self.pos_embed[:, :(num_patches + 1)]
        tokens = self.dropout(tokens)

        '''iii. Transformer Encoding'''
        enc_tokens = self.transformer(tokens)

        '''iv. Pooling'''
        # (b,dim)
        pooled = enc_tokens[:, 0] if self.pool == 'cls' else enc_tokens.mean(dim=1)

        '''v. Classification'''
        # (b,n_classes)
        logits = self.mlp_head(pooled)

        return logits

In [None]:
'''
Reconstruction(Inference)
'''

@torch.no_grad
def predict(self, x):
    self.eval()

    device = x.device
    b, c, h, w = x.shape

    '''i. Patch partition'''

    num_patches = (h // self.patch_h) * (w // self.patch_w)
    # (b, c=3, h, w)->(b, n_patches, patch_size**2*c)
    patches = x.view(
        b, c,
        h // self.patch_h, self.patch_h, 
        w // self.patch_w, self.patch_w
    ).permute(0, 2, 4, 3, 5, 1).reshape(b, num_patches, -1)

    '''ii. Divide into masked & un-masked groups'''

    num_masked = int(self.mask_ratio * num_patches)

    # Shuffle
    # (b, n_patches)
    shuffle_indices = torch.rand(b, num_patches, device=device).argsort()
    mask_ind, unmask_ind = shuffle_indices[:, :num_masked], shuffle_indices[:, num_masked:]

    # (b, 1)
    batch_ind = torch.arange(b, device=device).unsqueeze(-1)
    mask_patches, unmask_patches = patches[batch_ind, mask_ind], patches[batch_ind, unmask_ind]

    '''iii. Encode'''

    unmask_tokens = self.encoder.patch_embed(unmask_patches)
    # Add position embeddings
    unmask_tokens += self.encoder.pos_embed.repeat(b, 1, 1)[batch_ind, unmask_ind + 1]
    encoded_tokens = self.encoder.transformer(unmask_tokens)

    '''iv. Decode'''
    
    enc_to_dec_tokens = self.enc_to_dec(encoded_tokens)

    # (decoder_dim)->(b, n_masked, decoder_dim)
    mask_tokens = self.mask_embed[None, None, :].repeat(b, num_masked, 1)
    # Add position embeddings
    mask_tokens += self.decoder_pos_embed(mask_ind)

    # (b, n_patches, decoder_dim)
    concat_tokens = torch.cat([mask_tokens, enc_to_dec_tokens], dim=1)
    # dec_input_tokens = concat_tokens
    dec_input_tokens = torch.empty_like(concat_tokens, device=device)
    # Un-shuffle
    dec_input_tokens[batch_ind, shuffle_indices] = concat_tokens
    decoded_tokens = self.decoder(dec_input_tokens)

    '''v. Mask pixel Prediction'''

    dec_mask_tokens = decoded_tokens[batch_ind, mask_ind, :]
    # (b, n_masked, n_pixels_per_patch=patch_size**2 x c)
    pred_mask_pixel_values = self.head(dec_mask_tokens)

    # 比较下预测值和真实值
    mse_per_patch = (pred_mask_pixel_values - mask_patches).abs().mean(dim=-1)
    mse_all_patches = mse_per_patch.mean()

    print(f'mse per (masked)patch: {mse_per_patch} mse all (masked)patches: {mse_all_patches} total {num_masked} masked patches')
    print(f'all close: {torch.allclose(pred_mask_pixel_values, mask_patches, rtol=1e-1, atol=1e-1)}')
        
    '''vi. Reconstruction'''

    recons_patches = patches.detach()
    # Un-shuffle (b, n_patches, patch_size**2 * c)
    recons_patches[batch_ind, mask_ind] = pred_mask_pixel_values
    # 模型重建的效果图
    # Reshape back to image 
    # (b, n_patches, patch_size**2 * c)->(b, c, h, w)
    recons_img = recons_patches.view(
        b, h // self.patch_h, w // self.patch_w, 
        self.patch_h, self.patch_w, c
    ).permute(0, 5, 1, 3, 2, 4).reshape(b, c, h, w)

    mask_patches = torch.randn_like(mask_patches, device=mask_patches.device)
    # mask 效果图
    patches[batch_ind, mask_ind] = mask_patches
    patches_to_img = patches.view(
        b, h // self.patch_h, w // self.patch_w, 
        self.patch_h, self.patch_w, c
    ).permute(0, 5, 1, 3, 2, 4).reshape(b, c, h, w)

    return recons_img, patches_to_img

In [None]:
'''
simple pipline
'''

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 读入图像并缩放到适合模型输入的尺寸
from PIL import Image

img_raw = Image.open('test.jpg')
h, w = img_raw.height, img_raw.width
ratio = h / w
print(f"image hxw: {h} x {w} mode: {img_raw.mode}")

img_size, patch_size = (224, 224), (16, 16)
img = img_raw.resize(img_size)
rh, rw = img.height, img.width
print(f'resized image hxw: {rh} x {rw} mode: {img.mode}')
img.save('resized_test.jpg')

# 将图像转换成张量
from torchvision.transforms import ToTensor, ToPILImage

img_ts = ToTensor()(img).unsqueeze(0).to(device)
print(f"input tensor shape: {img_ts.shape} dtype: {img_ts.dtype} device: {img_ts.device}")

# 实例化模型并加载训练好的权重
encoder = ViT(img_size, patch_size, dim=512, mlp_dim=1024, dim_per_head=64)
decoder_dim = 512
mae = MAE(encoder, decoder_dim, decoder_depth=6)
weight = torch.load('mae.pth', map_location='cpu')
mae.to(device)

# 推理
# 模型重建的效果图，mask 效果图
recons_img_ts, masked_img_ts = mae.predict(img_ts)
recons_img_ts, masked_img_ts = recons_img_ts.cpu().squeeze(0), masked_img_ts.cpu().squeeze(0)

# 将结果保存下来以便和原图比较
recons_img = ToPILImage()(recons_img_ts)
recons_img.save('recons_test.jpg')

masked_img = ToPILImage()(masked_img_ts)
masked_img.save('masked_test.jpg')
