# query transformer according to blip

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, ff_hidden_dim=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        attn_output, _ = self.self_attn(x, x, x)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)

        ff_output = self.ff(x)
        x = x + self.dropout2(ff_output)
        x = self.norm2(x)
        return x

class QFormer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_channels=3,
                 embed_dim=768,
                 depth=6,
                 num_heads=8,
                 ff_hidden_dim=2048):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, ff_hidden_dim)
            for _ in range(depth)
        ])
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches, embed_dim))

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.pos_embed

        for layer in self.encoder_layers:
            x = layer(x)

        return x

In [13]:
model = QFormer(img_size=224, patch_size=16, depth=4)
dummy_input = torch.randn(2, 3, 224, 224)
output = model(dummy_input)

In [14]:
from torchsummary import summary
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
        PatchEmbed-2             [-1, 196, 768]               0
MultiheadAttention-3  [[-1, 196, 768], [-1, 196, 196]]               0
           Dropout-4             [-1, 196, 768]               0
         LayerNorm-5             [-1, 196, 768]           1,536
            Linear-6            [-1, 196, 2048]       1,574,912
              GELU-7            [-1, 196, 2048]               0
           Dropout-8            [-1, 196, 2048]               0
            Linear-9             [-1, 196, 768]       1,573,632
          Dropout-10             [-1, 196, 768]               0
          Dropout-11             [-1, 196, 768]               0
        LayerNorm-12             [-1, 196, 768]           1,536
TransformerEncoderLayer-13             [-1, 196, 768]               0
MultiheadAttention-14  [[-

# multiple QFormers

In [15]:
import torch
import torch.nn as nn

class MultiQFormer(nn.Module):
    def __init__(self,
                 num_encoders=3,
                 img_size=224,
                 patch_size=16,
                 in_channels=3,
                 embed_dim=768,
                 depth=6,
                 num_heads=8,
                 ff_hidden_dim=2048,
                 output_dim=1024):
        super().__init__()

        self.encoders = nn.ModuleList([
            QFormer(
                img_size=img_size,
                patch_size=patch_size,
                in_channels=in_channels,
                embed_dim=embed_dim,
                depth=depth,
                num_heads=num_heads,
                ff_hidden_dim=ff_hidden_dim
            )
            for _ in range(num_encoders)
        ])

        self.output_proj = nn.Linear(num_encoders * embed_dim, output_dim)

    def forward(self, x):
        feats= []
        for i, img in enumerate(x):
            feat = self.encoders[i](img)
            pooled = feat.mean(dim=1)
            feats.append(pooled)

        concat = torch.cat(feats, dim=-1)

        return self.output_proj(concat)

In [34]:
model = MultiQFormer(num_encoders=4, output_dim=1024)
images = [torch.randn(1, 3, 224, 224) for _ in range(4)]
output = model(images)

this doesn't run because forward does not get regular input. fix it if needed.

In [37]:
from torchsummary import summary
summary(model, images)

TypeError: rand(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got Tensor"

# llava integration

In [2]:
%%shell
git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
pip install -e .

Cloning into 'LLaVA'...
remote: Enumerating objects: 2297, done.[K
remote: Total 2297 (delta 0), reused 0 (delta 0), pack-reused 2297 (from 1)[K
Receiving objects: 100% (2297/2297), 13.71 MiB | 13.74 MiB/s, done.
Resolving deltas: 100% (1405/1405), done.
Obtaining file:///content/LLaVA
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting torch==2.1.2 (from llava==1.2.2.post1)
  Downloading torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl.metadata (25 kB)
Collecting torchvision==0.16.2 (from llava==1.2.2.post1)
  Downloading torchvision-0.16.2-cp311-cp311-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting transformers==4.37.2 (from llava==1.2.2.post1)
  Downloading transformers-4.37.2-py3-none-any.whl.metadata (129 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m 



## build custom tower

In [16]:
import torch
import torch.nn as nn

class MultiQFormerTower(nn.Module):
    def __init__(self, img_size=224, embed_dim=768, depth=6, num_heads=8, ff_hidden_dim=2048, output_dim=768, aggregation='mean'):
        super().__init__()
        self.encoder = MultiQFormer(
            num_encoders=1,
            img_size=img_size,
            patch_size=16,
            in_channels=3,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            ff_hidden_dim=ff_hidden_dim,
            output_dim=output_dim
        )
        assert aggregation in ['mean', 'max'], "Unsupported aggregation"
        self.aggregation = aggregation
        self.output_dim = output_dim

    def forward(self, images):
        """
        images: list of [B, 3, H, W] tensors, length = variable
        returns: Tensor of shape [B, output_dim]
        """
        if not isinstance(images, list):
            raise ValueError("Expected a list of images")

        encoded = [self.encoder([img]) for img in images]
        encoded = torch.stack(encoded, dim=1)

        if self.aggregation == 'mean':
            agg = encoded.mean(dim=1)
        elif self.aggregation == 'max':
            agg = encoded.max(dim=1).values

        return agg

    @property
    def hidden_size(self):
        return self.output_dim

    @property
    def config(self):
        class Dummy:
            hidden_size = self.output_dim
        return Dummy()

In [17]:
tower = MultiQFormerTower()

In [18]:
tower.eval()

batch_size = 1
img_size = 224

imgs = [torch.randn(batch_size, 3, img_size, img_size) for _ in range(2)]

out = tower(imgs)
print(out.shape)

torch.Size([1, 768])


## add tower to llava

In [14]:
import torch
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
from transformers import LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained("liuhaotian/llava-v1.5-7b")

model = LlavaLlamaForCausalLM.from_pretrained(
    "liuhaotian/llava-v1.5-7b",
    device_map="auto",
    torch_dtype=torch.float16,
    offload_folder="./offload"
)

You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
model.model.vision_tower

CLIPVisionTower()

In [19]:
device="cuda"
tower.to(device)

MultiQFormerTower(
  (encoder): MultiQFormer(
    (encoders): ModuleList(
      (0): QFormer(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (encoder_layers): ModuleList(
          (0-5): 6 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1, inplace=False)
            (ff): Sequential(
              (0): Linear(in_features=768, out_features=2048, bias=True)
              (1): GELU(approximate='none')
              (2): Dropout(p=0.1, inplace=False)
              (3): Linear(in_features=2048, out_features=768, bias=True)
              (

In [21]:
model.model.vision_tower=tower

In [22]:
model.model.vision_tower

MultiQFormerTower(
  (encoder): MultiQFormer(
    (encoders): ModuleList(
      (0): QFormer(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (encoder_layers): ModuleList(
          (0-5): 6 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1, inplace=False)
            (ff): Sequential(
              (0): Linear(in_features=768, out_features=2048, bias=True)
              (1): GELU(approximate='none')
              (2): Dropout(p=0.1, inplace=False)
              (3): Linear(in_features=2048, out_features=768, bias=True)
              (