In [1]:
!pip install einops



In [2]:
import torch
import numpy as np
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn

In [20]:
from timm.layers import trunc_normal_
from timm.models.vision_transformer import Block, PatchEmbed

In [None]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

In [18]:
class Patchify(torch.nn.Module):
    def __init__(self, patch_size, emb_dim, in_chans = 3) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.in_chans = in_chans
        self.proj = nn.Conv2d(in_chans, emb_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = rearrange(x, 'b c h w -> b (h w) c') # Final shape is (#batches, #patches, #emb_dim)
        return x

In [22]:
#test Patchify
x = torch.randn(1, 3, 224, 224)
patchify = Patchify(16, 768)
patches = patchify(x)
print(patches.shape)

model = PatchEmbed(
    img_size=224, patch_size=16, in_chans=3, embed_dim=768
)
x = torch.randn(1, 3, 224, 224)
patches = model(x)
print(patches.shape)

torch.Size([1, 196, 768])
torch.Size([1, 196, 768])


In [24]:
img_size = 224
patch_size = 16
emb_dim = 768
posemb = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2 + 1, emb_dim))

In [29]:
posemb[:,1:,:].shape

torch.Size([1, 196, 768])

In [None]:
class MAE_Encoder(torch.nn.Module):
    def __init__(
            self,
            img_size = 32,
            patch_size = 2,
            in_chans = 3,
            emb_dim = 192,
            num_layers = 12,
            num_heads = 3,
            mask_ratio = 0.75,
            mlp_dim = 768
            ) -> None:
        super().__init__()

        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.pos_embedding  = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2 + 1, emb_dim))

        self.patchify = PatchEmbed(
            img_size = img_size,
            patch_size = patch_size,
            in_chans = in_chans,
            embed_dim = emb_dim
        )
        ### Encoder model
        self.encoder = nn.ModuleList([
            Block(
                dim = emb_dim,
                num_heads = num_heads,
                mlp_ratio = mlp_dim / emb_dim,
                qkv_bias = True,
                qk_scale = None,
                norm_layer = nn.LayerNorm,
            ) for _ in range(num_layers)
        ])

        self.norm_layer = nn.LayerNorm(emb_dim)

        self.initialize_weights()
    
    def initialize_weights(self):
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.pos_embedding, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT: #Code taken from FAIR
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore


    def forward(self, x, mask_ratio = 0.75):
        x = self.patchify(x)

        #Add position embedding w/o cls token
        x = x + self.pos_embedding[:, 1:, :]

        #masking
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        #Append cls token
        cls_token = self.cls_token + self.pos_embedding[:, 0:1, :]
        cls_token = cls_token.expand(x.shape[0], -1, -1) #Expand cls token to all batches
        x = torch.cat((cls_token, x), dim=1)

        for block in self.encoder:
            x = block(x)

        x = self.norm_layer(x)

        return x, mask, ids_restore


In [None]:
class MAE_Decoder(torch.nn.Module):
    def __init__(
        self, 
        image_size = 32,
        patch_size = 2,
        emb_dim = 192,
        num_layers = 4,
        num_heads = 3,
        out_chans = 3,
        mlp_dim = 768
    ) -> None:

        super().__init__()

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embedding = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2 + 1, emb_dim))

        # self.decoder_emb = nn.Linear(encoder_emb_dim, decoder_embed_dim, bias=True)
        self.decoder = nn.ModuleList([
            Block(
                dim = emb_dim,
                num_heads = num_heads,
                mlp_ratio = mlp_dim / emb_dim,
                qkv_bias = True,
                qk_scale = None,
                norm_layer = nn.LayerNorm,
            ) for _ in range(num_layers)
        ])


        self.decoder_norm = nn.LayerNorm(emb_dim)
        self.decoder_pred = nn.Linear(emb_dim, patch_size * patch_size * out_chans, bias=True)
        self.initialize_weights()

    def initialize_weights(self):
        torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embedding, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT: #Code taken from FAIR
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, ids_restore):
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        #add position embedding
        x = x + self.decoder_pos_embedding

        for block in self.decoder:
            x = block(x)

        x = self.decoder_norm(x)

        x = self.decoder_pred(x)

        #remove cls token
        x = x[:, 1:, :]

        return x

In [33]:
#test MAE_Encoder and MAE_Decoder
encoder = MAE_Encoder()
decoder = MAE_Decoder()

TypeError: Block.__init__() got an unexpected keyword argument 'qk_scale'

In [None]:
torch.nn.Sequential(*[Block(192, 3,) for _ in range(4)])

Sequential(
  (0): Block(
    (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (attn): Attention(
      (qkv): Linear(in_features=192, out_features=576, bias=False)
      (q_norm): Identity()
      (k_norm): Identity()
      (attn_drop): Dropout(p=0.0, inplace=False)
      (proj): Linear(in_features=192, out_features=192, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
    )
    (ls1): Identity()
    (drop_path1): Identity()
    (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=192, out_features=768, bias=True)
      (act): GELU(approximate='none')
      (drop1): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (fc2): Linear(in_features=768, out_features=192, bias=True)
      (drop2): Dropout(p=0.0, inplace=False)
    )
    (ls2): Identity()
    (drop_path2): Identity()
  )
  (1): Block(
    (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (attn): Attention(
      (q

In [16]:
192*3

576

In [None]:
x = torch.randn(batch_size, channels, height, width) #batch, channel, height, width
print(x.shape)
#define the patchify layer
patchify = nn.Conv2d(channels, D, kernel_size=patch_size, stride=patch_size)
#apply the patchify layer
x = patchify(x)
print(x.shape)
#reshape the patches to (batch, height * width, channel)
x = rearrange(x, 'b c h w -> b (h w) c')
print(x.shape)
#Add class token
cls_token = nn.Parameter(torch.randn(1, 1, 192))
#expand cls token to the batch size
batch_cls_token = cls_token.expand(batch_size, -1, -1)
x = torch.cat((batch_cls_token, x), dim=1)
print(x.shape)
#position embedding
pos_embed = nn.Parameter(torch.empty(1, 257, 192).normal_(std=0.02))
x += pos_embed
print(x.shape)


In [19]:
batch_size = 4
channels = 3
height = 32
width = 32
patch_size = 2
num_classes = 10

D = 192 # hidden dimension

torch.Size([4, 3, 32, 32])
torch.Size([4, 192, 16, 16])
torch.Size([4, 256, 192])
torch.Size([4, 257, 192])
torch.Size([4, 257, 192])


In [31]:
x = torch.randn(4,61,192)
encoder_model.encoder.layers(x).shape

torch.Size([4, 61, 192])

In [12]:
from torchvision.models import VisionTransformer
encoder_model = VisionTransformer(
    image_size=32,
    patch_size=2,
    num_classes=10,
    num_layers=4,
    num_heads=3,
    hidden_dim=192,
    mlp_dim=768,
    dropout=0.0,
    attention_dropout=0.0,
)

VisionTransformer(
  (conv_proj): Conv2d(3, 192, kernel_size=(2, 2), stride=(2, 2))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=192, out_features=768, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=768, out_features=192, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (self_attenti

In [7]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

In [13]:
def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

In [14]:
class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes


In [15]:
class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)

        self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)

        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:] # remove global feature

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T-1:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

class ViT_Classifier(torch.nn.Module):
    def __init__(self, encoder : MAE_Encoder, num_classes=10) -> None:
        super().__init__()
        self.cls_token = encoder.cls_token
        self.pos_embedding = encoder.pos_embedding
        self.patchify = encoder.patchify
        self.transformer = encoder.transformer
        self.layer_norm = encoder.layer_norm
        self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')
        logits = self.head(features[0])
        return logits

In [None]:

# encoder_model.conv_proj = nn.Identity()
# patchify = nn.Conv2d(3, 192, kernel_size=2, stride=2)
# encoder_model.heads = nn.Identity()

In [9]:
encoder_model

VisionTransformer(
  (conv_proj): Conv2d(3, 192, kernel_size=(2, 2), stride=(2, 2))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=192, out_features=768, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=768, out_features=192, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (self_attenti

In [8]:
x = torch.randn(4, 3, 32, 32) 

In [10]:
encoder_model.conv_proj(x).shape

torch.Size([4, 192, 16, 16])