In [3]:
import torch
import torch.nn as nn
from einops import rearrange, repeat, reduce, 


In [41]:
## Image 를 Patch Size로 자른다.
class image_embedding(nn.Module):
    def __init__(self, in_channels = 3, img_size = 224, patch_size = 16,emb_dim = 16*16*3 ):
        super().__init__()

        self.rearrange = rearrange(torch.zeros((1,1,1,1)),'b c (num_w p1)(num_h p2) -> b (num_w num_h)(p1 p2 c)', p1 = patch_size, p2 = patch_size)
        self.linear = nn.Linear(in_channels+patch_size*patch_size, emb_dim)

        self.cls_token = nn.Parameter(torch.randn(1,1,emb_dim))

        n_patchs = img_size * img_size // patch_size **2
        self.positions = nn.Parameter(torch.randn(n_patchs + 1, emb_dim))
    
    def forward(self, x):
        batch, channel, width, height  = x.shape
        # print("Before rearranfe x shape:", x.shape )
        x = self.rearrange(tensor = x)
        # print("After rearranfe x shape:", x.shape )
        x = self.linear(x)
        # print("cls_token shape:", self.cls_token.shape)
        c = repeat(self.cls_token, "() n d -> b n d", b = batch)
        x = torch.cat((c,x), 1)
        # print("add cls_token shape:", x.shape)
        # print("positions shape:", self.positions.shape)
        x = torch.add(x, self.positions)
        print('last shape', x.shape)

        return x

In [42]:
## Transformer Encorder Part
class mutil_head_attention(nn.Module):
    def __init__(self, emb_dim:int = 16*16*3, num_heads:int = 8, dropout_ratio:float = 0.2, verbose = False, **kwargs):
        super(mutil_head_attention,self).__init__()
        self.v = verbose
        self.emb_dim  = emb_dim
        self.num_heads = num_heads
        self.scaling = (self.emb_dim // self.num_heads) ** (-0.5)

        self.value = nn.Linear(emb_dim, emb_dim)
        self.key = nn.Linear(emb_dim, emb_dim)
        self.query = nn.Linear(emb_dim, emb_dim)
        self.att_drop = nn.Dropout(dropout_ratio)

        self.linear = nn.Linear(emb_dim, emb_dim)

    def forward(self,x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        if self.v : print(Q.size(), K.size(), V.size())
        # q=k=v = path_size * 2 + 1 & h*d = emd_dim
        Q = rearrange(Q, 'b q (h d) -> b h q d', h= self.num_heads)
        K = rearrange(K, 'b k (h d) -> b h d k', h = self.num_heads)
        V = rearrange(V, 'b V (h d) -> b h v d', h = self.num_heads)
        if self.v : print(Q.size(), K.size(), V.size())

        weight = torch.matmul(Q,K)
        weight = weight * self.scaling

        if self.v : print(weight.size())

        attention = torch.softmax(weight, dim = -1)
        attention = self.att_drop(attention)
        if self.v : print(weight.size())

        contex = torch.matmul(attention,V)
        contex = rearrange(contex, 'b h q d -> b q (h d)')
        if self.v: print(contex.size())

        x = self.linear(x)
        return x, attention

In [43]:
## MLP Block 
class mlp_block(nn.Module):
    def __init__(self, emb_dim:int = 16*16*3, forward_dim:int = 4, dropout_ratio:float = 0.2, **kwargs):
        super(mlp_block, self).__init__()
        self.linear_1 =  nn.Linear(emb_dim, forward_dim * emb_dim)
        self.dropout = nn.Dropout(dropout_ratio)
        self.linear_2 = nn.Linear(emb_dim * forward_dim, emb_dim)

    def foreard(self, x):
        x = self.linear_1(x)
        x = nn.ReLU(x)
        x = self.dropout(x)
        x = self.linear_2(x)

        return x

In [44]:
## Encoder Block
class encoder_block(nn.Module):
    def __init__(self, emb_dim:int = 16*16*3, num_heads:int = 8, forward_dim:int = 4, dropout_ratio:float = 0.2):
        super(encoder_block, self).__init__()

        self.norm_1 = nn.LayerNorm(emb_dim)
        self.norm_1_G = nn.GELU()
        self.mha = mutil_head_attention(emb_dim, num_heads, dropout_ratio)

        self.norm_2 = nn.LayerNorm(emb_dim)
        self.norm_2_G = nn.GELU()
        self.mlp = mlp_block(emb_dim, forward_dim, dropout_ratio)
        self.residual_dropout = nn.Dropout(dropout_ratio)

    def forward(self, x):
        x2 = self.norm_1(x)
        x2 = self.norm_1_G(x)
        x2, attention = self.mha(x2)
        x = torch.add(x2, x)

        x2 = self.norm_2(x)
        x2 = self.norm_2_G(x)
        x2 = self.mlp(x2)
        x = torch.add(x2,x)

        return x, attention

In [45]:
class vision_transformer(nn.Module):
    def __init__(self, in_channel:int = 3, img_size:int = 224, patch_size:int = 16, emb_dim:int = 16*16*3,
                 n_enc_layers:int = 15, num_heads:int= 4, forward_dim:int=4, dropout_ratio:float = 0.2, n_classes:int = 1000):
        super(vision_transformer, self).__init__()

        self.image_embedding = image_embedding(in_channel, img_size, patch_size, emb_dim)
        encoder_module = [ encoder_block(emb_dim, num_heads, forward_dim, dropout_ratio) for _ in range(n_enc_layers)]
        self.encoder_module = nn.ModuleList(encoder_module)

        self.reduce_layer = reduce('b n e -> b e', reduction='mean')
        self.nomalization = nn.LayerNorm(emb_dim)
        self.classification_head = nn.Linear(emb_dim, n_classes)

    def forward(self, x):
        x = self.image_embedding(x)
        attentions = [block(x)[1] for block in self.encoder_module]
        # print("before reduce x size:", x)
        x = self.reduce_layer(x)
        x = self.nomalization(x)
        # print(x.shape)
        x = self.classification_head(x)(8,1000)
        return x 


In [46]:
from torchsummary import summary
model = vision_transformer()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
summary(model, (8,3,224,224))

EinopsError:  Error while processing rearrange-reduction pattern "b c (num_w p1)(num_h p2) -> b (num_w num_h)(p1 p2 c)".
 Input tensor shape: torch.Size([1, 1, 1, 1]). Additional info: {'p1': 16, 'p2': 16}.
 Shape mismatch, can't divide axis of length 1 in chunks of 16