In [105]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

# Vision Transformer (ViT)

![](assets/attn.png)

In [106]:
class MultiHeadAtt(nn.Module):
    """Basic attention block
    
    This is a simplified version referenced from: 
    https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L178-L202
    """
    def __init__(self, d_model, head=8):
        super().__init__()
        
        self.d_head = d_model // head
        self.head = head
        
        # We don't want to create *head* instances of Linear class
        # so we just group it to single Linear that takes in *d_model* channels and returns *d_head* x *head* channels
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        nn.init.xavier_uniform_(self.W_q.weight)
        
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        nn.init.xavier_uniform_(self.W_k.weight)
        
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        nn.init.xavier_uniform_(self.W_v.weight)
        
        self.W_o = nn.Linear(d_model, d_model)
        nn.init.xavier_uniform_(self.W_o.weight)
        nn.init.zeros_(self.W_o.bias)
    
    def forward(self, x):
        """
        Args:
        - x: a B x N x C tensor
        
        Annotations:
        - B: batch size
        - N: number of token
        - C: number of channel
        """
        B, N, C = x.shape
        
        queries = self.W_q(x) # B x N x head*d_head
        keys = self.W_k(x)
        values = self.W_v(x)
        
        queries = queries.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3) # B x head x N x d_head
        keys = keys.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3)
        values = values.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3)
        
        attn = (queries @ keys.transpose(-2, -1)) / self.d_head ** 0.5
        attn = F.softmax(attn, dim=-1) # B x head x N x N
        
        x = attn @ values # B x head x N x h_head
        x = x.transpose(1, 2) # B x N x head x h_head
        x = x.reshape(B, N, C) # B x N x head*h_head - Remind: d_model = C = head*h_head
        
        x = self.W_o(x)
        return x

In [107]:
# Test
module = MultiHeadAtt(256)
input_tensor = torch.ones((1, 16, 256))

output_tensor = module(input_tensor)
assert output_tensor.shape == (1, 16, 256)

![](assets/vit.png)

In [108]:
class FFN(nn.Module):
    """MLP or Feed forward network used in attention blocks"""
    def __init__(self, d_in_out, d_hidden):
        super().__init__()
        self.fc1 = nn.Linear(d_in_out, d_hidden)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        
        self.fc2 = nn.Linear(d_hidden, d_in_out)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc2.bias, std=1e-6)
    
    def forward(self, x):
        """
        Args:
        - x: a B x N x C tensor
        
        Annotations:
        - B: batch size
        - N: number of token
        - C: number of channel
        """
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

In [109]:
class TransformerEncoder(nn.Module):
    """Basic building block of ViT"""
    def __init__(self, d_model, head, d_ff_hid):
        super().__init__()
        
        self.input_norm = nn.LayerNorm(d_model)
        self.multi_attn = MultiHeadAtt(d_model, head)
        
        self.ff_norm = nn.LayerNorm(d_model)
        self.ff = FFN(d_model, d_ff_hid)
    
    def forward(self, x):
        """
        Args:
        - x: a B x N x C tensor
        
        Annotations:
        - B: batch size
        - N: number of token
        - C: number of channel
        """
        x_res = x
        x = self.input_norm(x)
        z = self.multi_attn(x) + x_res
        
        z_res = z
        z = self.ff_norm(z)
        z = self.ff(z) + z_res
        
        return z

In [110]:
class PatchEmbed(nn.Module):
    """A layer that splits image into patches and using a CNN to compute embedding feature of each patch"""
    def __init__(self, img_size, patch_size, img_c, d_model):
        super().__init__()
        
        self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
        self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
        
        self.grid_size = (self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1])
        self.num_patches = "practice"
        
        self.conv = nn.Conv2d(img_c, d_model, kernel_size="practice", stride="practice")
    
    def forward(self, x):
        """
        Args:
        - x: a B x C x H x W image tensor
        
        Annotations:
        - B: batch size
        - C: number of channel
        - H: height
        - W: width
        """
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        
        x = self.conv(x) # B x d_model x grid_H x grid_W
        x = torch.flatten(x, 2) # B x d_model x N
        x = x.transpose(1, 2) # B x N x d_model
        
        return x        

In [111]:
# Test
module = PatchEmbed(img_size=224, patch_size="practice", img_c=3, d_model=256)
input_tensor = torch.ones((1, 3, 224, 224))

output_tensor = module(input_tensor)
assert output_tensor.shape == (1, 14 * 14, 256)

In [112]:
class ViT(nn.Module):
    """A skeleton of a typical ViT model"""
    def __init__(self, img_size, patch_size, img_c, d_model, num_class, encoders):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        
        self.patch_embeder = PatchEmbed(img_size, patch_size, img_c, d_model)
        
        num_patches = self.patch_embeder.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter("practice")
        
        self.encoders = encoders
        self.mlp_head = nn.Linear(d_model, num_class)
        self.cls_morm = nn.LayerNorm(num_class)
    
    def forward(self, x):
        """
        Args:
        - x: a B x C x H x W image tensor
        
        Annotations:
        - B: batch size
        - C: number of channel
        - H: height
        - W: width
        """
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x_embed = self.patch_embeder(x)
        x_embed = "practice"
        x_embed = x_embed + self.pos_embed
        
        x_transformed = self.encoders(x_embed)
        cls_transformed_token = x_transformed[:, 0, :] 
        cls_logits = self.mlp_head(cls_transformed_token)
        cls_logits = self.cls_morm(cls_logits)
        return cls_logits

In [113]:
# ViT variants -- ViT-B/16
n_class = 10
n_layer = 12
d_model = 768
n_head = 12
mlp_hidden = 3072

vit_encoder = nn.Sequential(*[TransformerEncoder(d_model=d_model, head=n_head, d_ff_hid=mlp_hidden) for i in range(n_layer)])
vit_b = ViT(img_size=224, patch_size=16, img_c=3, d_model=d_model, num_class=n_class, encoders=vit_encoder)

input_tensor = torch.ones((1, 3, 224, 224))

output_tensor = vit_b(input_tensor)
assert output_tensor.shape == (1, n_class), output_tensor.shape
summary(vit_b, input_size=(3, 224, 224), device='cpu')

Layer (type:depth-idx)                   Param #
├─PatchEmbed: 1-1                        --
|    └─Conv2d: 2-1                       590,592
├─Sequential: 1-2                        --
|    └─TransformerEncoder: 2-2           --
|    |    └─LayerNorm: 3-1               1,536
|    |    └─MultiHeadAtt: 3-2            2,360,064
|    |    └─LayerNorm: 3-3               1,536
|    |    └─FFN: 3-4                     4,722,432
|    └─TransformerEncoder: 2-3           --
|    |    └─LayerNorm: 3-5               1,536
|    |    └─MultiHeadAtt: 3-6            2,360,064
|    |    └─LayerNorm: 3-7               1,536
|    |    └─FFN: 3-8                     4,722,432
|    └─TransformerEncoder: 2-4           --
|    |    └─LayerNorm: 3-9               1,536
|    |    └─MultiHeadAtt: 3-10           2,360,064
|    |    └─LayerNorm: 3-11              1,536
|    |    └─FFN: 3-12                    4,722,432
|    └─TransformerEncoder: 2-5           --
|    |    └─LayerNorm: 3-13              1,536
|  

Layer (type:depth-idx)                   Param #
├─PatchEmbed: 1-1                        --
|    └─Conv2d: 2-1                       590,592
├─Sequential: 1-2                        --
|    └─TransformerEncoder: 2-2           --
|    |    └─LayerNorm: 3-1               1,536
|    |    └─MultiHeadAtt: 3-2            2,360,064
|    |    └─LayerNorm: 3-3               1,536
|    |    └─FFN: 3-4                     4,722,432
|    └─TransformerEncoder: 2-3           --
|    |    └─LayerNorm: 3-5               1,536
|    |    └─MultiHeadAtt: 3-6            2,360,064
|    |    └─LayerNorm: 3-7               1,536
|    |    └─FFN: 3-8                     4,722,432
|    └─TransformerEncoder: 2-4           --
|    |    └─LayerNorm: 3-9               1,536
|    |    └─MultiHeadAtt: 3-10           2,360,064
|    |    └─LayerNorm: 3-11              1,536
|    |    └─FFN: 3-12                    4,722,432
|    └─TransformerEncoder: 2-5           --
|    |    └─LayerNorm: 3-13              1,536
|  

In [114]:
# ViT variants -- ViT-L/16
n_class = 10
n_layer = 24
d_model = 1024
n_head = 16
mlp_hidden = 4096

vit_encoder = nn.Sequential(*[TransformerEncoder(d_model=d_model, head=n_head, d_ff_hid=mlp_hidden) for i in range(n_layer)])
vit_l = ViT(img_size=224, patch_size=16, img_c=3, d_model=d_model, num_class=n_class, encoders=vit_encoder)

input_tensor = torch.ones((1, 3, 224, 224))

output_tensor = vit_l(input_tensor)
assert output_tensor.shape == (1, n_class), output_tensor.shape
summary(vit_l, input_size=(3, 224, 224), device='cpu')

Layer (type:depth-idx)                   Param #
├─PatchEmbed: 1-1                        --
|    └─Conv2d: 2-1                       787,456
├─Sequential: 1-2                        --
|    └─TransformerEncoder: 2-2           --
|    |    └─LayerNorm: 3-1               2,048
|    |    └─MultiHeadAtt: 3-2            4,195,328
|    |    └─LayerNorm: 3-3               2,048
|    |    └─FFN: 3-4                     8,393,728
|    └─TransformerEncoder: 2-3           --
|    |    └─LayerNorm: 3-5               2,048
|    |    └─MultiHeadAtt: 3-6            4,195,328
|    |    └─LayerNorm: 3-7               2,048
|    |    └─FFN: 3-8                     8,393,728
|    └─TransformerEncoder: 2-4           --
|    |    └─LayerNorm: 3-9               2,048
|    |    └─MultiHeadAtt: 3-10           4,195,328
|    |    └─LayerNorm: 3-11              2,048
|    |    └─FFN: 3-12                    8,393,728
|    └─TransformerEncoder: 2-5           --
|    |    └─LayerNorm: 3-13              2,048
|  

Layer (type:depth-idx)                   Param #
├─PatchEmbed: 1-1                        --
|    └─Conv2d: 2-1                       787,456
├─Sequential: 1-2                        --
|    └─TransformerEncoder: 2-2           --
|    |    └─LayerNorm: 3-1               2,048
|    |    └─MultiHeadAtt: 3-2            4,195,328
|    |    └─LayerNorm: 3-3               2,048
|    |    └─FFN: 3-4                     8,393,728
|    └─TransformerEncoder: 2-3           --
|    |    └─LayerNorm: 3-5               2,048
|    |    └─MultiHeadAtt: 3-6            4,195,328
|    |    └─LayerNorm: 3-7               2,048
|    |    └─FFN: 3-8                     8,393,728
|    └─TransformerEncoder: 2-4           --
|    |    └─LayerNorm: 3-9               2,048
|    |    └─MultiHeadAtt: 3-10           4,195,328
|    |    └─LayerNorm: 3-11              2,048
|    |    └─FFN: 3-12                    8,393,728
|    └─TransformerEncoder: 2-5           --
|    |    └─LayerNorm: 3-13              2,048
|  

In [115]:
# ViT variants -- ViT-H/16
n_class = 10
n_layer = 32
d_model = 1280
n_head = 16
mlp_hidden = 5120

vit_encoder = nn.Sequential(*[TransformerEncoder(d_model=d_model, head=n_head, d_ff_hid=mlp_hidden) for i in range(n_layer)])
vit_h = ViT(img_size=224, patch_size=16, img_c=3, d_model=d_model, num_class=n_class, encoders=vit_encoder)

input_tensor = torch.ones((1, 3, 224, 224))

output_tensor = vit_h(input_tensor)
assert output_tensor.shape == (1, n_class), output_tensor.shape
summary(vit_h, input_size=(3, 224, 224), device='cpu')

Layer (type:depth-idx)                   Param #
├─PatchEmbed: 1-1                        --
|    └─Conv2d: 2-1                       984,320
├─Sequential: 1-2                        --
|    └─TransformerEncoder: 2-2           --
|    |    └─LayerNorm: 3-1               2,560
|    |    └─MultiHeadAtt: 3-2            6,554,880
|    |    └─LayerNorm: 3-3               2,560
|    |    └─FFN: 3-4                     13,113,600
|    └─TransformerEncoder: 2-3           --
|    |    └─LayerNorm: 3-5               2,560
|    |    └─MultiHeadAtt: 3-6            6,554,880
|    |    └─LayerNorm: 3-7               2,560
|    |    └─FFN: 3-8                     13,113,600
|    └─TransformerEncoder: 2-4           --
|    |    └─LayerNorm: 3-9               2,560
|    |    └─MultiHeadAtt: 3-10           6,554,880
|    |    └─LayerNorm: 3-11              2,560
|    |    └─FFN: 3-12                    13,113,600
|    └─TransformerEncoder: 2-5           --
|    |    └─LayerNorm: 3-13              2,560


Layer (type:depth-idx)                   Param #
├─PatchEmbed: 1-1                        --
|    └─Conv2d: 2-1                       984,320
├─Sequential: 1-2                        --
|    └─TransformerEncoder: 2-2           --
|    |    └─LayerNorm: 3-1               2,560
|    |    └─MultiHeadAtt: 3-2            6,554,880
|    |    └─LayerNorm: 3-3               2,560
|    |    └─FFN: 3-4                     13,113,600
|    └─TransformerEncoder: 2-3           --
|    |    └─LayerNorm: 3-5               2,560
|    |    └─MultiHeadAtt: 3-6            6,554,880
|    |    └─LayerNorm: 3-7               2,560
|    |    └─FFN: 3-8                     13,113,600
|    └─TransformerEncoder: 2-4           --
|    |    └─LayerNorm: 3-9               2,560
|    |    └─MultiHeadAtt: 3-10           6,554,880
|    |    └─LayerNorm: 3-11              2,560
|    |    └─FFN: 3-12                    13,113,600
|    └─TransformerEncoder: 2-5           --
|    |    └─LayerNorm: 3-13              2,560


## DeepViT

In [116]:
# source : https://github.com/zhoudaquan/dvit_repo/blob/master/models/layers.py
class ReAttention(nn.Module):
    """
    It is observed that similarity along same batch of data is extremely large. 
    Thus can reduce the bs dimension when calculating the attention map.
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3, apply_transform=True, transform_scale=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.apply_transform = apply_transform
        
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        if apply_transform:
            self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1)
            self.var_norm = nn.BatchNorm2d(self.num_heads)
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
            self.reatten_scale = self.scale if transform_scale else 1.0
        else:
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x, atten=None):
        B, N, C = x.shape
        # x = self.fc(x)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        if self.apply_transform:
            attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
        attn_next = attn
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn_next

In [117]:
model = ReAttention(64)

In [118]:
output = model(torch.randn(12, 196, 64))

In [119]:
print(output[0].shape, output[1].shape)

torch.Size([12, 196, 64]) torch.Size([12, 8, 196, 196])


# CaiT

![](assets/cait.png)

In [120]:
class ClassAttention(MultiHeadAtt):
    """New module of CaiT, ClassAttention only utilize the tokens to compute the final class for class token
    
    This is the simplified version referenced from:
    https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cait.py#L74-L106
    """
    def forward(self, x):
        """
        Args:
        - x: a B x N x C tensor
        
        Annotations:
        - B: batch size
        - N: number of token
        - C: number of channel
        """
        B, N, C = x.shape
        
        queries = self.W_q("practice") # Only takes class token as query
        keys = self.W_k(x) # B x N x head*d_head
        values = self.W_v(x)
        
        queries = queries.reshape(B, 1, self.head, self.d_head).permute(0, 2, 1, 3) # B x head x N x d_head
        keys = keys.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3)
        values = values.reshape(B, N, self.head, self.d_head).permute(0, 2, 1, 3)
        
        attn = (queries @ keys.transpose(-2, -1)) / self.d_head ** 0.5
        attn = F.softmax(attn, dim=-1) # B x head x 1 x N
        
        x = attn @ values # B x head x 1 x h_head
        x = x.transpose(1, 2) # B x 1 x head x h_head
        x = x.reshape(B, 1, C) # B x 1 x head*h_head - Remind: d_model = C = head*h_head
        
        x = self.W_o(x)
        return x

### LayerScale Architecture
![](assets/layer_scale.png)

In [121]:
class LayerScale(nn.Module):
    """Proposed layer of CaiT to make the optimization more stable"""
    def __init__(self, n_channel, agg_block, init_val=1e-4):
        super().__init__()
        
        self.gamma = nn.Parameter(init_val * torch.ones(("practice")))
        self.layer_norm = nn.LayerNorm(n_channel)
        self.agg_block = agg_block
    
    def forward(self, x, x_res):
        return x_res + self.gamma * self.agg_block(self.layer_norm(x))

In [122]:
class CABlock(nn.Module):
    """Basic block of CaiT, utilize at the end of the network"""
    def __init__(self, d_model, head, mlp_hidden):
        super().__init__()
        self.cls_attn = LayerScale(n_channel=d_model, agg_block=ClassAttention(d_model, head))
        self.mlp = LayerScale(n_channel=d_model, agg_block=FFN(d_model, mlp_hidden))
    
    def forward(self, x, x_cls):
        u = torch.cat([x_cls, x], dim=1)
        x_cls = self.cls_attn(u, x_cls)
        x_cls = self.mlp(x_cls, x_cls)
        return x_cls

class SABlock(nn.Module):
    """Basic block of CaiT, utilize at the beginning of the network"""
    def __init__(self, d_model, head, mlp_hidden):
        super().__init__()
        self.attn = LayerScale(n_channel=d_model, agg_block=MultiHeadAtt(d_model, head))
        self.mlp = LayerScale(n_channel=d_model, agg_block=FFN(d_model, mlp_hidden))
    
    def forward(self, x):
        x = self.attn(x, x)
        x = self.mlp(x, x)
        return x

In [123]:
class CaiT(nn.Module):
    """A skeleton of a typical CaiT model"""
    def __init__(self, img_size, patch_size, img_c, d_model, head, mlp_hidden, num_class, n_sa, n_ca):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        
        self.patch_embeder = PatchEmbed(img_size, patch_size, img_c, d_model)
        
        num_patches = self.patch_embeder.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, d_model))
        
        self.sa_blocks = nn.Sequential(*[SABlock(d_model, head, mlp_hidden) for i in range(n_sa)])
        self.ca_blocks = nn.ModuleList([CABlock(d_model, head, mlp_hidden) for i in range(n_ca)])
        self.mlp_head = nn.Linear(d_model, num_class)
        self.cls_morm = nn.LayerNorm(num_class)
    
    def forward(self, x):
        """
        Args:
        - x: a B x C x H x W image tensor
        
        Annotations:
        - B: batch size
        - C: number of channel
        - H: height
        - W: width
        """
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x_embed = self.patch_embeder(x)
        x_embed = x_embed + self.pos_embed
        
        x_embed = self.sa_blocks(x_embed)
        
        for block in self.ca_blocks:
            cls_token = block(x_embed, cls_token)

        cls_logits = self.mlp_head(cls_token) 
        cls_logits = self.cls_morm(cls_logits) # B x 1 x n_class
        return cls_logits.squeeze(dim=1) # B x n_class

In [124]:
# CaiT variants -- CaiT-XS36
n_class = 10
d_model = 288
n_head = 6 # equals d_model/48
mlp_hidden = 4 * d_model
n_sa = 36
n_ca = 2

cait_xs36 = CaiT(img_size=224, patch_size=16, img_c=3, num_class=n_class,
                d_model=d_model, head=n_head, mlp_hidden=mlp_hidden, n_sa=n_sa, n_ca=n_ca)

input_tensor = torch.ones((1, 3, 224, 224))

output_tensor = cait_xs36(input_tensor)
assert output_tensor.shape == (1, n_class), output_tensor.shape
summary(cait_xs36, input_size=(3, 224, 224), device='cpu')

Layer (type:depth-idx)                        Param #
├─PatchEmbed: 1-1                             --
|    └─Conv2d: 2-1                            221,472
├─Sequential: 1-2                             --
|    └─SABlock: 2-2                           --
|    |    └─LayerScale: 3-1                   332,928
|    |    └─LayerScale: 3-2                   665,856
|    └─SABlock: 2-3                           --
|    |    └─LayerScale: 3-3                   332,928
|    |    └─LayerScale: 3-4                   665,856
|    └─SABlock: 2-4                           --
|    |    └─LayerScale: 3-5                   332,928
|    |    └─LayerScale: 3-6                   665,856
|    └─SABlock: 2-5                           --
|    |    └─LayerScale: 3-7                   332,928
|    |    └─LayerScale: 3-8                   665,856
|    └─SABlock: 2-6                           --
|    |    └─LayerScale: 3-9                   332,928
|    |    └─LayerScale: 3-10                  665,856
|    └─SA

Layer (type:depth-idx)                        Param #
├─PatchEmbed: 1-1                             --
|    └─Conv2d: 2-1                            221,472
├─Sequential: 1-2                             --
|    └─SABlock: 2-2                           --
|    |    └─LayerScale: 3-1                   332,928
|    |    └─LayerScale: 3-2                   665,856
|    └─SABlock: 2-3                           --
|    |    └─LayerScale: 3-3                   332,928
|    |    └─LayerScale: 3-4                   665,856
|    └─SABlock: 2-4                           --
|    |    └─LayerScale: 3-5                   332,928
|    |    └─LayerScale: 3-6                   665,856
|    └─SABlock: 2-5                           --
|    |    └─LayerScale: 3-7                   332,928
|    |    └─LayerScale: 3-8                   665,856
|    └─SABlock: 2-6                           --
|    |    └─LayerScale: 3-9                   332,928
|    |    └─LayerScale: 3-10                  665,856
|    └─SA

In [125]:
# CaiT variants -- CaiT-S36
n_class = 10
d_model = 384
n_head = 8
mlp_hidden = 4 * d_model
n_sa = 36
n_ca = 2

cait_s36 = CaiT(img_size=224, patch_size=16, img_c=3, num_class=n_class,
                d_model=d_model, head=n_head, mlp_hidden=mlp_hidden, n_sa=n_sa, n_ca=n_ca)

input_tensor = torch.ones((1, 3, 224, 224))

output_tensor = cait_s36(input_tensor)
assert output_tensor.shape == (1, n_class), output_tensor.shape
summary(cait_s36, input_size=(3, 224, 224), device='cpu')

Layer (type:depth-idx)                        Param #
├─PatchEmbed: 1-1                             --
|    └─Conv2d: 2-1                            295,296
├─Sequential: 1-2                             --
|    └─SABlock: 2-2                           --
|    |    └─LayerScale: 3-1                   591,360
|    |    └─LayerScale: 3-2                   1,182,720
|    └─SABlock: 2-3                           --
|    |    └─LayerScale: 3-3                   591,360
|    |    └─LayerScale: 3-4                   1,182,720
|    └─SABlock: 2-4                           --
|    |    └─LayerScale: 3-5                   591,360
|    |    └─LayerScale: 3-6                   1,182,720
|    └─SABlock: 2-5                           --
|    |    └─LayerScale: 3-7                   591,360
|    |    └─LayerScale: 3-8                   1,182,720
|    └─SABlock: 2-6                           --
|    |    └─LayerScale: 3-9                   591,360
|    |    └─LayerScale: 3-10                  1,182,720

Layer (type:depth-idx)                        Param #
├─PatchEmbed: 1-1                             --
|    └─Conv2d: 2-1                            295,296
├─Sequential: 1-2                             --
|    └─SABlock: 2-2                           --
|    |    └─LayerScale: 3-1                   591,360
|    |    └─LayerScale: 3-2                   1,182,720
|    └─SABlock: 2-3                           --
|    |    └─LayerScale: 3-3                   591,360
|    |    └─LayerScale: 3-4                   1,182,720
|    └─SABlock: 2-4                           --
|    |    └─LayerScale: 3-5                   591,360
|    |    └─LayerScale: 3-6                   1,182,720
|    └─SABlock: 2-5                           --
|    |    └─LayerScale: 3-7                   591,360
|    |    └─LayerScale: 3-8                   1,182,720
|    └─SABlock: 2-6                           --
|    |    └─LayerScale: 3-9                   591,360
|    |    └─LayerScale: 3-10                  1,182,720

In [126]:
# CaiT variants -- CaiT-M36
n_class = 10
d_model = 768
n_head = 16 
mlp_hidden = 4 * d_model
n_sa = 36
n_ca = 2

cait_m36 = CaiT(img_size=224, patch_size=16, img_c=3, num_class=n_class,
                d_model=d_model, head=n_head, mlp_hidden=mlp_hidden, n_sa=n_sa, n_ca=n_ca)

input_tensor = torch.ones((1, 3, 224, 224))

output_tensor = cait_m36(input_tensor)
assert output_tensor.shape == (1, n_class), output_tensor.shape
summary(cait_m36, input_size=(3, 224, 224), device='cpu')

Layer (type:depth-idx)                        Param #
├─PatchEmbed: 1-1                             --
|    └─Conv2d: 2-1                            590,592
├─Sequential: 1-2                             --
|    └─SABlock: 2-2                           --
|    |    └─LayerScale: 3-1                   2,362,368
|    |    └─LayerScale: 3-2                   4,724,736
|    └─SABlock: 2-3                           --
|    |    └─LayerScale: 3-3                   2,362,368
|    |    └─LayerScale: 3-4                   4,724,736
|    └─SABlock: 2-4                           --
|    |    └─LayerScale: 3-5                   2,362,368
|    |    └─LayerScale: 3-6                   4,724,736
|    └─SABlock: 2-5                           --
|    |    └─LayerScale: 3-7                   2,362,368
|    |    └─LayerScale: 3-8                   4,724,736
|    └─SABlock: 2-6                           --
|    |    └─LayerScale: 3-9                   2,362,368
|    |    └─LayerScale: 3-10                 

Layer (type:depth-idx)                        Param #
├─PatchEmbed: 1-1                             --
|    └─Conv2d: 2-1                            590,592
├─Sequential: 1-2                             --
|    └─SABlock: 2-2                           --
|    |    └─LayerScale: 3-1                   2,362,368
|    |    └─LayerScale: 3-2                   4,724,736
|    └─SABlock: 2-3                           --
|    |    └─LayerScale: 3-3                   2,362,368
|    |    └─LayerScale: 3-4                   4,724,736
|    └─SABlock: 2-4                           --
|    |    └─LayerScale: 3-5                   2,362,368
|    |    └─LayerScale: 3-6                   4,724,736
|    └─SABlock: 2-5                           --
|    |    └─LayerScale: 3-7                   2,362,368
|    |    └─LayerScale: 3-8                   4,724,736
|    └─SABlock: 2-6                           --
|    |    └─LayerScale: 3-9                   2,362,368
|    |    └─LayerScale: 3-10                 

In [127]:
# CaiT variants -- CaiT-M48
n_class = 10
d_model = 768
n_head = 16 
mlp_hidden = 4 * d_model
n_sa = 48
n_ca = 2

cait_m48 = CaiT(img_size=224, patch_size=16, img_c=3, num_class=n_class,
                d_model=d_model, head=n_head, mlp_hidden=mlp_hidden, n_sa=n_sa, n_ca=n_ca)

input_tensor = torch.ones((1, 3, 224, 224))

output_tensor = cait_m48(input_tensor)
assert output_tensor.shape == (1, n_class), output_tensor.shape
summary(cait_m48, input_size=(3, 224, 224), device='cpu')

Layer (type:depth-idx)                        Param #
├─PatchEmbed: 1-1                             --
|    └─Conv2d: 2-1                            590,592
├─Sequential: 1-2                             --
|    └─SABlock: 2-2                           --
|    |    └─LayerScale: 3-1                   2,362,368
|    |    └─LayerScale: 3-2                   4,724,736
|    └─SABlock: 2-3                           --
|    |    └─LayerScale: 3-3                   2,362,368
|    |    └─LayerScale: 3-4                   4,724,736
|    └─SABlock: 2-4                           --
|    |    └─LayerScale: 3-5                   2,362,368
|    |    └─LayerScale: 3-6                   4,724,736
|    └─SABlock: 2-5                           --
|    |    └─LayerScale: 3-7                   2,362,368
|    |    └─LayerScale: 3-8                   4,724,736
|    └─SABlock: 2-6                           --
|    |    └─LayerScale: 3-9                   2,362,368
|    |    └─LayerScale: 3-10                 

Layer (type:depth-idx)                        Param #
├─PatchEmbed: 1-1                             --
|    └─Conv2d: 2-1                            590,592
├─Sequential: 1-2                             --
|    └─SABlock: 2-2                           --
|    |    └─LayerScale: 3-1                   2,362,368
|    |    └─LayerScale: 3-2                   4,724,736
|    └─SABlock: 2-3                           --
|    |    └─LayerScale: 3-3                   2,362,368
|    |    └─LayerScale: 3-4                   4,724,736
|    └─SABlock: 2-4                           --
|    |    └─LayerScale: 3-5                   2,362,368
|    |    └─LayerScale: 3-6                   4,724,736
|    └─SABlock: 2-5                           --
|    |    └─LayerScale: 3-7                   2,362,368
|    |    └─LayerScale: 3-8                   4,724,736
|    └─SABlock: 2-6                           --
|    |    └─LayerScale: 3-9                   2,362,368
|    |    └─LayerScale: 3-10                 

## CeiT

In [128]:
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

#### LEFF

![](assets/leff.png)

In [129]:


class LeFF(nn.Module):
    
    def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
        super().__init__()
        
        scale_dim = dim*scale
        self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
                                    Rearrange('b n c -> b c n'),
                                    nn.BatchNorm1d(scale_dim),
                                    nn.GELU(),
                                    Rearrange('b c "practice" -> b c h w', h="practice", w="practice")
                                    )
        
        self.depth_conv =  nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
                          nn.BatchNorm2d(scale_dim),
                          nn.GELU(),
                          Rearrange('b c h w -> b (h w) c', h="practice", w="practice")
                          )
        
        self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
                                    Rearrange('b n c -> b c n'),
                                    nn.BatchNorm1d(dim),
                                    nn.GELU(),
                                    Rearrange('b c n -> b n c')
                                    )
        
    def forward(self, x):
        x = self.up_proj(x)
        x = self.depth_conv(x)
        x = self.down_proj(x)
        return x
    
    
class TransformerLeFF(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, scale = 4, depth_kernel = 3, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, LeFF(dim, scale, depth_kernel)))
            ]))
    def forward(self, x):
        c = list()
        for attn, leff in self.layers:
            x = attn(x)
            cls_tokens = x[:, 0]
            c.append(cls_tokens)
            x = leff("practice")
            x = torch.cat((cls_tokens.unsqueeze(1), x), dim=1) 
        return x, torch.stack(c).transpose(0, 1)

In [130]:
model = TransformerLeFF(192, 2, 4, 64)

In [131]:
output = model(torch.randn(12, 197, 192))

In [132]:
output[1].shape

torch.Size([12, 2, 192])

![](assets/lca.png)

In [133]:
class LCAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out
    
class LCA(nn.Module):
    # I remove Residual connection from here, in paper author didn't explicitly mentioned to use Residual connection, 
    # so I removed it, althougth with Residual connection also this code will work.
    def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layers.append(nn.ModuleList([
                PreNorm(dim, LCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x[:, -1].unsqueeze(1)
            
            x = x[:, -1].unsqueeze(1) + ff(x)
        return x
        

In [134]:
model = LCA(192, 4, 64, 10)

In [135]:
output = model(output[1])[:, 0]

In [136]:
output.shape

torch.Size([12, 192])

---

# Appendix - Volo

![](assets/volo_attn.png)

In [137]:
class OutlookAttention(nn.Module):
    """Custom module VOLO
    
    This is the simplified version referenced from, we assume that stride=1:
    https://github.com/sail-sg/volo/blob/main/models/volo.py#L45-L100
    """
    def __init__(self, d_model, head, kernel_size=3):
        super().__init__()
        self.head_dim = d_model // head
        self.head = head
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.attn = nn.Linear(d_model, head * kernel_size**4, bias=False)
        
        self.proj = nn.Linear(d_model, d_model)
        self.unfold = nn.Unfold(kernel_size=kernel_size, padding=self.padding)
    
    def forward(self, x):
        B, H, W, C = x.shape
        
        attn_map = self.attn(x).reshape(B, H * W, self.head, self.kernel_size**2, self.kernel_size**2)
        attn_map = attn_map.permute(0, 2, 1, 3, 4) / self.head_dim**0.5
        attn_map = F.softmax(attn_map, dim=-1) # B x head x N x k^2 x k^2
        
        v = self.W_v(x).permute(0, 3, 1, 2) # B x head*head_dim x H x W
        unfolded_v = self.unfold(v) # B x head*head_dim x H x W x k^2
        unfolded_v = unfolded_v.reshape(B, self.head, -1, self.kernel_size**2, H*W) # B x head x head_dim x k^2 x N
        unfolded_v = unfolded_v.permute(0, 1, 4, 3, 2) # B x head x N x k^2 x head_dim
        
        agg_val = attn_map @ unfolded_v #  B x head x N x k^2 x head_dim
        agg_val = agg_val.permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size**2, H*W) # B x d_model*k^2 x N
        folded_val = F.fold(agg_val, output_size=(H, W),
                            kernel_size=self.kernel_size, padding=self.padding) # B x d_model x H x W
        folded_val = folded_val.permute(0, 2, 3, 1)
        
        prj_val = self.proj(folded_val)
        
        return prj_val

In [138]:
# Test
module = OutlookAttention(d_model=256, head=8, kernel_size=3)
input_tensor = torch.ones((1, 16, 16, 256))

output_tensor = module(input_tensor)
assert output_tensor.shape == (1, 16, 16, 256)

In [139]:
class Outlooker(nn.Module):
    """Basic block of VOLO that utilize at the beginning of the network"""
    def __init__(self, d_model, d_mlp_hidden, head, kernel_size):
        super().__init__()
        
        self.layer_norm_attn = nn.LayerNorm(d_model)
        self.outlook_attn = OutlookAttention(d_model, head, kernel_size)
        
        self.layer_norm_mlp = nn.LayerNorm(d_model)
        self.mlp = FFN(d_model, d_mlp_hidden)
    
    def forward(self, x):
        x = x + self.outlook_attn(self.layer_norm_attn(x))
        x = x + self.mlp(self.layer_norm_mlp(x))
        return x