In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Vision Transformer (ViT)

![](assets/attn.png)

In [12]:
class MultiHeadAtt(nn.Module):
    """Basic attention block
    
    This is a simplified version referenced from: 
    https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L178-L202
    """
    def __init__(self, d_model, head=8):
        super().__init__()
        
        self.d_head = d_model // head
        self.head = head
        
        # We don't want to create *head* instances of Linear class
        # so we just group it to single Linear that takes in *d_model* channels and returns *d_head* x *head* channels
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        nn.init.xavier_uniform_(self.W_q.weight)
        
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        nn.init.xavier_uniform_(self.W_k.weight)
        
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        nn.init.xavier_uniform_(self.W_v.weight)
        
        self.W_o = nn.Linear(d_model, d_model)
        nn.init.xavier_uniform_(self.W_o.weight)
        nn.init.zeros_(self.W_o.bias)
    
    def forward(self, x):
        """
        Args:
        - x: a B x N x C tensor
        
        Annotations:
        - B: batch size
        - N: number of token
        - C: number of channel
        """
        B, N, C = x.shape
        
        queries = self.W_q(x) # B x N x head*d_head
        keys = self.W_k(x)
        values = self.W_v(x)
        
        queries = queries.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3) # B x head x N x d_head
        keys = keys.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3)
        values = values.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3)
        
        attn = (queries @ keys.transpose(-2, -1)) / self.d_head ** 0.5
        attn = F.softmax(attn, dim=-1) # B x head x N x N
        
        x = attn @ values # B x head x N x h_head
        x = x.transpose(1, 2) # B x N x head x h_head
        x = x.reshape(B, N, C) # B x N x head*h_head - Remind: d_model = C = head*h_head
        
        x = self.W_o(x)
        return x

![](assets/vit.png)

In [13]:
class FFN(nn.Module):
    """MLP or Feed forward network used in attention blocks"""
    def __init__(self, d_in_out, d_hidden):
        super().__init__()
        self.fc1 = nn.Linear(d_in_out, d_hidden)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        
        self.fc2 = nn.Linear(d_hidden, d_in_out)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc2.bias, std=1e-6)
    
    def forward(self, x):
        """
        Args:
        - x: a B x N x C tensor
        
        Annotations:
        - B: batch size
        - N: number of token
        - C: number of channel
        """
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

In [14]:
class TransformerEncoder(nn.Module):
    """Basic building block of ViT"""
    def __init__(self, d_model, head, d_ff_hid):
        super().__init__()
        
        self.input_norm = nn.LayerNorm(d_model)
        self.multi_attn = MultiHeadAtt(d_model, head)
        
        self.ff_norm = nn.LayerNorm(d_model)
        self.ff = FFN(d_model, d_ff_hid)
    
    def forward(self, x):
        """
        Args:
        - x: a B x N x C tensor
        
        Annotations:
        - B: batch size
        - N: number of token
        - C: number of channel
        """
        x_res = x
        x = self.input_norm(x)
        z = self.multi_attn(x) + x_res
        
        z_res = z
        z = self.ff_norm(z)
        z = self.ff(z) + z_res
        
        return z

In [15]:
class PatchEmbed(nn.Module):
    """A layer that splits image into patches and using a CNN to compute embedding feature of each patch"""
    def __init__(self, img_size, patch_size, img_c, d_model):
        super().__init__()
        
        self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
        self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
        
        self.grid_size = (self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        
        self.conv = nn.Conv2d(img_c, d_model, kernel_size=self.patch_size, stride=patch_size)
    
    def forward(self, x):
        """
        Args:
        - x: a B x C x H x W image tensor
        
        Annotations:
        - B: batch size
        - C: number of channel
        - H: height
        - W: width
        """
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        
        x = self.conv(x) # B x d_model x grid_H x grid_W
        x = torch.flatten(x, 2) # B x d_model x N
        x = x.transpose(1, 2) # B x N x d_model
        
        return x        

In [16]:
class ViT(nn.Module):
    """A skeleton of a typical ViT model"""
    def __init__(self, img_size, patch_size, img_c, d_model, num_class, encoders):
        self.img_size = img_size
        self.patch_size = patch_size
        
        self.patch_embeder = PatchEmbed(img_size, patch_size, img_c, d_model)
        
        num_patches = self.patch_embeder.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
        
        self.encoders = encoders
        self.mlp_head = nn.Linear(d_model, num_class)
        self.cls_morm = nn.LayerNorm(num_class)
    
    def forward(self, x):
        """
        Args:
        - x: a B x C x H x W image tensor
        
        Annotations:
        - B: batch size
        - C: number of channel
        - H: height
        - W: width
        """
        x_embed = self.patch_embeder(x)
        x_embed = torch.cat([x_embed, self.cls_token], dim=1)
        x_embed = x_embed + self.pos_embed
        
        x_transformed = self.encoders(x_embed)
        cls_logits = self.mlp_head(x_transformed)
        cls_logits = self.cls_morm(cls_logits)
        return cls_logits

# DeiT

# CaiT

![](assets/cait.png)

In [17]:
class ClassAttention(MultiHeadAtt):
    """New module of CaiT, ClassAttention only utilize the tokens to compute the final class for class token
    
    This is the simplified version referenced from:
    https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cait.py#L74-L106
    """
    def forward(self, x):
        """
        Args:
        - x: a B x N x C tensor
        
        Annotations:
        - B: batch size
        - N: number of token
        - C: number of channel
        """
        B, N, C = x.shape
        
        queries = self.W_q(x[:,0]) # Only takes class token as query
        keys = self.W_k(x) # B x N x head*d_head
        values = self.W_v(x)
        
        queries = queries.reshape(B, 1, self.head, self.d_head).permute(0, 2, 1, 3) # B x head x N x d_head
        keys = keys.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3)
        values = values.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3)
        
        attn = (queries @ keys.transpose(-2, -1)) / self.d_head ** 0.5
        attn = F.softmax(attn, dim=-1) # B x head x 1 x N
        
        x = attn @ values # B x head x 1 x h_head
        x = x.transpose(1, 2) # B x 1 x head x h_head
        x = x.reshape(B, 1, C) # B x 1 x head*h_head - Remind: d_model = C = head*h_head
        
        x = self.W_o(x)
        return x

In [18]:
class LayerScale(nn.Module):
    """Proposed layer of CaiT to make the optimization more stable"""
    def __init__(self, n_channel, agg_block, init_val=1e-4):
        super().__init__()
        
        self.gamma = nn.Parameter(init_val * torch.ones((n_channel)))
        self.layer_norm = nn.LayerNorm(n_channel)
        self.agg_block = agg_block
    
    def forward(self, x, x_res):
        return x_res + self.gamma * self.agg_block(self.layer_norm(x))

In [19]:
class CABlock(nn.Module):
    """Basic block of CaiT, utilize at the end of the network"""
    def __init__(self, d_model, head, mlp_hidden):
        self.cls_attn = LayerScale(n_channel=d_model, agg_block=ClassAttention(d_model, head))
        self.mlp = LayerScale(n_channel=d_model, agg_block=FFN(d_model, mlp_hidden))
    
    def forward(self, x, x_cls):
        u = torch.cat([x_cls, x], dim=1)
        x_cls = self.cls_attn(u, x_cls)
        x_cls = self.mlp(x_cls, x_cls)
        return x_cls

class SABlock(nn.Module):
    """Basic block of CaiT, utilize at the beginning of the network"""
    def __init__(self, d_model, head, mlp_hidden):
        self.attn = LayerScale(n_channel=d_model, agg_block=MultiHeadAtt(d_model, head))
        self.mlp = LayerScale(n_channel=d_model, agg_block=FFN(d_model, mlp_hidden))
    
    def forward(self, x):
        x = self.attn(x, x)
        x = self.mlp(x, x)
        return x

# Volo

![](assets/volo_attn.png)

In [20]:
class OutlookAttention(nn.Module):
    """Custom module VOLO
    
    This is the simplified version referenced from:
    https://github.com/sail-sg/volo/blob/main/models/volo.py#L45-L100
    """
    def __init__(self, d_model, head, kernel_size=3):
        super().__init__()
        self.head_dim = d_model // head
        self.head = head
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.attn = nn.Linear(d_model, head * kernel_size**4, bias=False)
        
        self.proj = nn.Linear(d_model, d_model)
        self.unfold = nn.Unfold(kernel_size=kernel_size, padding=self.padding)
    
    def forward(self, x):
        B, H, W, C = x.shape
        
        attn_map = self.attn(x).reshape(B, H * W, self.head, self.kernel_size**2, self.kernel_size**2)
        attn_map = attn_map.permute(0, 2, 1, 3, 4) / self.head_dim**0.5
        attn_map = F.softmax(attn_map, dim=-1) # B x head x N x k^2 x k^2
        
        v = self.W_v(x).permute(0, 3, 1, 2) # B x head*head_dim x H x W
        unfolded_v = self.unfold(v) # B x head*head_dim x H x W x k^2
        unfolded_v = unfolded_v.reshape(B, self.head, -1, self.kernel_size**2, H*W) # B x head x head_dim x k^2 x N
        unfolded_v = unfolded_v.permute(0, 1, 4, 3, 2) # B x head x N x k^2 x head_dim
        
        agg_val = attn_map @ unfolded_v #  B x head x N x k^2 x head_dim
        agg_val = agg_val.permute(0, 1, 4, 3, 2).reshape(B, C, self.kernel_size**2, H*W) # B x d_model x k^2 x N
        folded_val = F.fold(agg_val, output_size=(H, W),
                            kernel_size=self.kernel_size, padding=self.padding) # B x d_model x H x W
        folded_val = folded_val.permute(0, 2, 3, 1)
        
        prj_val = self.proj(folded_val)
        
        return prj_val

In [None]:
class Outlooker(nn.Module):
    """Basic block of VOLO that utilize at the beginning of the network"""
    def __init__(self, d_model, d_mlp_hidden, head, kernel_size):
        super().__init__()
        
        self.layer_norm_attn = nn.LayerNorm(d_model)
        self.outlook_attn = OutlookAttention(d_model, head, kernel_size)
        
        self.layer_norm_mlp = nn.LayerNorm(d_model)
        self.mlp = FFN(d_model, d_mlp_hidden)
    
    def forward(self, x):
        x = x + self.outlook_attn(self.layer_norm_attn(x))
        x = x + self.mlp(self.layer_norm_mlp(x))
        return x