In [1]:
# Codeblock 1
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from torchinfo import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Codeblock 2
BATCH_SIZE   = 1
IMAGE_SIZE   = 384     #(1)
IN_CHANNELS  = 3

PATCH_SIZE   = 16      #(2)
EMBED_DIM    = 768     #(3)
NUM_HEADS    = 12      #(4)
NUM_LAYERS   = 12      #(5)
FFN_SIZE     = EMBED_DIM * 4    #(6)

NUM_PATCHES  = (IMAGE_SIZE//PATCH_SIZE) ** 2    #(7)

NUM_CLASSES  = 1000    #(8)

In [5]:
# Codeblock 3
class Patcher(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=IN_CHANNELS,    #(1)
                              out_channels=EMBED_DIM, 
                              kernel_size=PATCH_SIZE,     #(2)
                              stride=PATCH_SIZE)          #(3)

        self.flatten = nn.Flatten(start_dim=2)            #(4)

    def forward(self, x):
        #print(f'original\t: {x.size()}')

        x = self.conv(x)        #(5)
        #print(f'after conv\t: {x.size()}')

        x = self.flatten(x)     #(6)
        #print(f'after flatten\t: {x.size()}')

        x = x.permute(0, 2, 1)  #(7)
        #print(f'after permute\t: {x.size()}')

        return x

In [4]:
# Codeblock 4
patcher = Patcher()
x = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

x = patcher(x)

original	: torch.Size([1, 3, 384, 384])
after conv	: torch.Size([1, 768, 24, 24])
after flatten	: torch.Size([1, 768, 576])
after permute	: torch.Size([1, 576, 768])


In [8]:
# Codeblock 5
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.norm_0 = nn.LayerNorm(EMBED_DIM)    #(1)

        self.multihead_attention = nn.MultiheadAttention(EMBED_DIM,    #(2)
                                                         num_heads=NUM_HEADS, 
                                                         batch_first=True)

        self.norm_1 = nn.LayerNorm(EMBED_DIM)    #(3)

        self.ffn = nn.Sequential(                #(4)
            nn.Linear(in_features=EMBED_DIM, out_features=FFN_SIZE),
            nn.GELU(), 
            nn.Linear(in_features=FFN_SIZE, out_features=EMBED_DIM),
        )

    def forward(self, x):

        residual = x
        #print(f'residual dim\t: {residual.size()}')

        x = self.norm_0(x)
        #print(f'after norm\t: {x.size()}')

        x = self.multihead_attention(x, x, x)[0]
        #print(f'after attention\t: {x.size()}')

        x = x + residual
        #print(f'after addition\t: {x.size()}')

        residual = x
        #print(f'residual dim\t: {residual.size()}')

        x = self.norm_1(x)
        #print(f'after norm\t: {x.size()}')

        x = self.ffn(x)
        #print(f'after ffn\t: {x.size()}')

        x = x + residual
        #print(f'after addition\t: {x.size()}')

        return x

In [7]:
# Codeblock 6
encoder = Encoder()
x = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)

x = encoder(x)

residual dim	: torch.Size([1, 576, 768])
after norm	: torch.Size([1, 576, 768])
after attention	: torch.Size([1, 576, 768])
after addition	: torch.Size([1, 576, 768])
residual dim	: torch.Size([1, 576, 768])
after norm	: torch.Size([1, 576, 768])
after ffn	: torch.Size([1, 576, 768])
after addition	: torch.Size([1, 576, 768])


In [11]:
# Codeblock 7a
class DeiT(nn.Module):
    def __init__(self):
        super().__init__()

        self.patcher = Patcher()    #(1)
        
        self.class_token = nn.Parameter(torch.zeros(BATCH_SIZE, 1, EMBED_DIM))  #(2)
        self.dist_token  = nn.Parameter(torch.zeros(BATCH_SIZE, 1, EMBED_DIM))  #(3)
        
        trunc_normal_(self.class_token, std=.02)    #(4)
        trunc_normal_(self.dist_token, std=.02)     #(5)

        self.pos_embedding = nn.Parameter(torch.zeros(BATCH_SIZE, NUM_PATCHES+2, EMBED_DIM))  #(6)
        trunc_normal_(self.pos_embedding, std=.02)  #(7)
        
        self.encoders = nn.ModuleList([Encoder() for _ in range(NUM_LAYERS)])  #(8)
        
        self.norm_out = nn.LayerNorm(EMBED_DIM)     #(9)

        self.class_head = nn.Linear(in_features=EMBED_DIM, out_features=NUM_CLASSES)  #(10)
        self.dist_head  = nn.Linear(in_features=EMBED_DIM, out_features=NUM_CLASSES)  #(11)
        
# Codeblock 7b
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        x = self.patcher(x)           #(1)
        #print(f'after patcher\t\t: {x.size()}')
        
        x = torch.cat([self.class_token, self.dist_token, x], dim=1)  #(2)
        #print(f'after concat\t\t: {x.size()}')
        
        x = x + self.pos_embedding    #(3)
        #print(f'after pos embed\t\t: {x.size()}')
        
        for i, encoder in enumerate(self.encoders):
            x = encoder(x)            #(4)
            #print(f"after encoder #{i}\t: {x.size()}")

        x = self.norm_out(x)          #(5)
        #print(f'after norm\t\t: {x.size()}')
        
        class_out = x[:, 0]           #(6)
        #print(f'class_out\t\t: {class_out.size()}')
        
        dist_out  = x[:, 1]           #(7)
        #print(f'dist_out\t\t: {dist_out.size()}')
        
        class_out = self.class_head(class_out)    #(8)
        #print(f'after class_head\t: {class_out.size()}')
        
        dist_out  = self.dist_head(dist_out)       #(9)
        #print(f'after dist_head\t\t: {class_out.size()}')
        
        return class_out, dist_out

In [10]:
# Codeblock 8
deit = DeiT()
x = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

class_out, dist_out = deit(x)

original		: torch.Size([1, 3, 384, 384])
after patcher		: torch.Size([1, 576, 768])
after concat		: torch.Size([1, 578, 768])
after pos embed		: torch.Size([1, 578, 768])
after encoder #0	: torch.Size([1, 578, 768])
after encoder #1	: torch.Size([1, 578, 768])
after encoder #2	: torch.Size([1, 578, 768])
after encoder #3	: torch.Size([1, 578, 768])
after encoder #4	: torch.Size([1, 578, 768])
after encoder #5	: torch.Size([1, 578, 768])
after encoder #6	: torch.Size([1, 578, 768])
after encoder #7	: torch.Size([1, 578, 768])
after encoder #8	: torch.Size([1, 578, 768])
after encoder #9	: torch.Size([1, 578, 768])
after encoder #10	: torch.Size([1, 578, 768])
after encoder #11	: torch.Size([1, 578, 768])
after norm		: torch.Size([1, 578, 768])
class_out		: torch.Size([1, 768])
dist_out		: torch.Size([1, 768])
after class_head	: torch.Size([1, 1000])
after dist_head		: torch.Size([1, 1000])


In [12]:
# Codeblock 9
summary(deit, input_size=(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE))

original		: torch.Size([1, 3, 384, 384])
after patcher		: torch.Size([1, 576, 768])
after concat		: torch.Size([1, 578, 768])
after pos embed		: torch.Size([1, 578, 768])
after encoder #0	: torch.Size([1, 578, 768])
after encoder #1	: torch.Size([1, 578, 768])
after encoder #2	: torch.Size([1, 578, 768])
after encoder #3	: torch.Size([1, 578, 768])
after encoder #4	: torch.Size([1, 578, 768])
after encoder #5	: torch.Size([1, 578, 768])
after encoder #6	: torch.Size([1, 578, 768])
after encoder #7	: torch.Size([1, 578, 768])
after encoder #8	: torch.Size([1, 578, 768])
after encoder #9	: torch.Size([1, 578, 768])
after encoder #10	: torch.Size([1, 578, 768])
after encoder #11	: torch.Size([1, 578, 768])
after norm		: torch.Size([1, 578, 768])
class_out		: torch.Size([1, 768])
dist_out		: torch.Size([1, 768])
after class_head	: torch.Size([1, 1000])
after dist_head		: torch.Size([1, 1000])


Layer (type:depth-idx)                   Output Shape              Param #
DeiT                                     [1, 1000]                 445,440
├─Patcher: 1-1                           [1, 576, 768]             --
│    └─Conv2d: 2-1                       [1, 768, 24, 24]          590,592
│    └─Flatten: 2-2                      [1, 768, 576]             --
├─ModuleList: 1-2                        --                        --
│    └─Encoder: 2-3                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-1               [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-2      [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-3               [1, 578, 768]             1,536
│    │    └─Sequential: 3-4              [1, 578, 768]             4,722,432
│    └─Encoder: 2-4                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-5               [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-6      [1, 578, 76