# query transformer according to blip

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, ff_hidden_dim=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        attn_output, _ = self.self_attn(x, x, x)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)

        ff_output = self.ff(x)
        x = x + self.dropout2(ff_output)
        x = self.norm2(x)
        return x

class QFormer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_channels=3,
                 embed_dim=768,
                 depth=6,
                 num_heads=8,
                 ff_hidden_dim=2048):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, ff_hidden_dim)
            for _ in range(depth)
        ])
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches, embed_dim))

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.pos_embed

        for layer in self.encoder_layers:
            x = layer(x)

        return x

In [13]:
model = QFormer(img_size=224, patch_size=16, depth=4)
dummy_input = torch.randn(2, 3, 224, 224)
output = model(dummy_input)

In [14]:
from torchsummary import summary
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
        PatchEmbed-2             [-1, 196, 768]               0
MultiheadAttention-3  [[-1, 196, 768], [-1, 196, 196]]               0
           Dropout-4             [-1, 196, 768]               0
         LayerNorm-5             [-1, 196, 768]           1,536
            Linear-6            [-1, 196, 2048]       1,574,912
              GELU-7            [-1, 196, 2048]               0
           Dropout-8            [-1, 196, 2048]               0
            Linear-9             [-1, 196, 768]       1,573,632
          Dropout-10             [-1, 196, 768]               0
          Dropout-11             [-1, 196, 768]               0
        LayerNorm-12             [-1, 196, 768]           1,536
TransformerEncoderLayer-13             [-1, 196, 768]               0
MultiheadAttention-14  [[-