In [17]:
from ViTHelper import MasterEncoder, BasicEncoder, SelfAttention, AttentionHead

import torch
import torchsummary
master_encoder = MasterEncoder(max_seq_length=17, embedding_size=8, how_many_basic_encoders=1, num_atten_heads=2)
# test = torch.rand(1, 17, 16*16*3)
encoder = BasicEncoder(max_seq_length=17, embedding_size=2*2*3, num_atten_heads=1)
self_attn = SelfAttention(max_seq_length=17, embedding_size=2*2*3, num_atten_heads=3)
attn_head = AttentionHead(max_seq_length=17, qkv_size=2*2*3)

creating model...
creating basic encoder...
creating self attention layer...
creating basic encoder...
creating self attention layer...
creating self attention layer...


In [20]:
# print(torchsummary.summary(attn_head, (17, 2*2*3)))
# print(torchsummary.summary(self_attn, (17, 2*2*3)))
# print(torchsummary.summary(encoder, (17, 2*2*3)))
print(torchsummary.summary(master_encoder, ( 17, 8)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         LayerNorm-1                [-1, 17, 8]              16
            Linear-2                   [-1, 68]           4,692
            Linear-3                   [-1, 68]           4,692
            Linear-4                   [-1, 68]           4,692
           Softmax-5               [-1, 17, 17]               0
     AttentionHead-6                [-1, 17, 4]               0
            Linear-7                   [-1, 68]           4,692
            Linear-8                   [-1, 68]           4,692
            Linear-9                   [-1, 68]           4,692
          Softmax-10               [-1, 17, 17]               0
    AttentionHead-11                [-1, 17, 4]               0
    SelfAttention-12                [-1, 17, 8]               0
        LayerNorm-13                [-1, 17, 8]              16
           Linear-14                  [

In [21]:
master_encoder = MasterEncoder(max_seq_length=17, embedding_size=8, how_many_basic_encoders=1, num_atten_heads=2)
test = torch.rand(1, 17, 8)
print(master_encoder(test).shape)


creating model...
creating basic encoder...
creating self attention layer...
torch.Size([1, 17, 8])


In [16]:
import torch 
import torch.nn as nn
from einops import rearrange

class ViTEmbeddings(nn.Module):
    def __init__(self, img_size, patch_size, num_classes, embedding_size):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.embedding_size = embedding_size
        self.num_patches = (img_size // patch_size) ** 2
        print(self.num_patches)
        self.patch_embedding = nn.Conv2d(3, embedding_size, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, embedding_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_size))

    def forward(self, x):
        x = self.patch_embedding(x)
        # flatten  using einops
        x = rearrange(x, 'b c h w -> b (h w) c')
        print(self.cls_token.shape, self.cls_token.expand(x.size(0), -1, -1).shape, x.shape)
        x = torch.cat((self.cls_token.expand(x.size(0), -1, -1), x), dim=1)
        x = x + self.positional_embedding
        return x
    
vit_embeddings = ViTEmbeddings(img_size=64, patch_size=16, num_classes=5, embedding_size=8)
test = torch.rand(2, 3, 64, 64)
# test.backward()
out = vit_embeddings(test)
out.shape

16
torch.Size([1, 1, 8]) torch.Size([2, 1, 8]) torch.Size([2, 16, 8])
torch.Size([2, 17, 8])


torch.Size([2, 17, 8])

In [None]:
import torch
import torch.nn as nn
from einops import rearrange

# from self_attention_cv import TransformerEncoder


class ViT(nn.Module):
    def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
        """
        Args:
            img_dim: the spatial image size
            in_channels: number of img channels
            patch_dim: desired patch dim
            num_classes: classification task classes
            dim: the linear layer's dim to project the patches for MHSA
            blocks: number of transformer blocks
            heads: number of heads
            dim_linear_block: inner dim of the transformer linear block
            dim_head: dim head in case you want to define it. defaults to dim/heads
            dropout: for pos emb and transformer
            transformer: in case you want to provide another transformer implementation
            classification: creates an extra CLS token
        """
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
        self.p = patch_dim
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, dim)

        self.emb_dropout = nn.Dropout(dropout)
        if self.classification:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
            self.mlp_head = nn.Linear(dim, num_classes)
        else:
            self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

        # if transformer is None:
        #     self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
        #                                           dim_head=self.dim_head,
        #                                           dim_linear_block=dim_linear_block,
        #                                           dropout=dropout)
        # else:
        #     self.transformer = transformer

    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

    def forward(self, img, mask=None):
        batch_size = img.shape[0]
        img_patches = rearrange(
            img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        # project patches with linear layer + add pos emb
        img_patches = self.project_patches(img_patches)

        if self.classification:
            img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

        patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

        # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
        y = self.transformer(patch_embeddings, mask)

        if self.classification:
            # we index only the cls token for classification. nlp tricks :P
            return self.mlp_head(y[:, 0, :])
        else:
            return y
