In [1]:
import sys
import os

sys.path.append(os.path.join(os.getcwd(), ".."))


import torch
import torch.nn as nn
import torchinfo

from vision_transformer.EncoderBlock import EncoderBlock
from vision_transformer.Encoder import Encoder

from vision_transformer.MultiHeadAttn import MultiHeadAttention
from vision_transformer.FeedForwardBlock import FeedForwardBlock

In [2]:
D_MODEL = 768
ENCODER_BLOCKS_NUMBER = 12
DROPOUT_RATE = 0.3

HEADS_NUMBER = 12
FEED_FORWARD_HIDDEN_SIZE = D_MODEL * 4

In [3]:
encoder_blocks = []

for _ in range(ENCODER_BLOCKS_NUMBER): 
    multihead_attention = MultiHeadAttention(D_MODEL, HEADS_NUMBER, DROPOUT_RATE)
    feed_forward_block = FeedForwardBlock(D_MODEL, FEED_FORWARD_HIDDEN_SIZE, DROPOUT_RATE)
    
    encoder_block = EncoderBlock(multihead_attention, feed_forward_block, DROPOUT_RATE)
    
    encoder_blocks.append(encoder_block)
    
encoder_blocks = nn.ModuleList(encoder_blocks)

In [4]:
encoder = Encoder(encoder_blocks)

In [7]:
torchinfo.summary(encoder, (15, 256, D_MODEL), device='cpu', depth=2)

Layer (type:depth-idx)                        Output Shape              Param #
Encoder                                       [15, 256, 768]            --
├─ModuleList: 1-1                             --                        --
│    └─EncoderBlock: 2-1                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-2                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-3                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-4                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-5                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-6                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-7                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-8                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-9                      [15, 256, 768]            7,084,804
│    └─EncoderBlock: 2-10       