In [103]:
import requests
from PIL import Image
import matplotlib.pyplot as plt
import torch
import numpy as np

from transformers import SamModel, SamProcessor
from transformers import CLIPVisionModel

import math

In [2]:
device = "cpu"

In [4]:
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

In [6]:
print(f"SAM: {sum(p.numel() for p in sam_model.parameters()):,d}")
print(f"CLIP: {sum(p.numel() for p in clip_model.parameters()):,d}")

SAM: 93,735,472
CLIP: 303,179,776


In [7]:
# dummy input

sample = torch.rand(2, 3, 1024, 1024) # [B, C, H, W]

In [9]:
class conv_block(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.layer_1 = torch.nn.Conv2d(256, 512, stride = 2, kernel_size = 2)
        self.layer_2 = torch.nn.Conv2d(512, 1024, stride = 2, kernel_size = 2)
    
    def forward(self, x):
        return self.layer_2(self.layer_1(x))
    
Conv = conv_block()

In [8]:
with torch.no_grad():
    local_features = sam_model.vision_encoder(sample)

torch.Size([2, 256, 64, 64])

In [10]:
sam_output = local_features.last_hidden_state
sam_output.shape

torch.Size([2, 256, 64, 64])

In [55]:
sam_output_conv = Conv(sam_output)
sam_features = sam_output_conv.flatten(2,3).transpose(1,2)
sam_features.shape

torch.Size([2, 256, 1024])

In [77]:
type(sam_features)

torch.Tensor

In [66]:
batch_size = sample.shape[0]
class_embeds_clip = clip_model.vision_model.embeddings.class_embedding.expand(batch_size, 1, -1)
class_embeds_clip.shape

torch.Size([2, 1, 1024])

In [92]:
embeddings = torch.cat((class_embeds_clip, sam_features), dim=1)
embeddings.shape

torch.Size([2, 257, 1024])

In [None]:
# with torch.no_grad():
#     clip_output = clip_model(sample, sam_output_conv)

# clip_output = local_features.last_hidden_state
# clip_output.shape

In [68]:
clip_model.vision_model.embeddings.position_embedding

Embedding(257, 1024)

In [94]:
embeddings = clip_model.vision_model.embeddings.position_embedding(
                clip_model.vision_model.embeddings.position_ids[:, :embeddings.shape[1]]
            )

embeddings.shape


torch.Size([1, 257, 1024])

In [None]:
pooled_output = clip_model.vision_model.post_layernorm(
                    clip_model.vision_model.encoder(
                        clip_model.vision_model.pre_layrnorm(embeddings)
                    ).last_hidde
                    n_state
                )
pooled_output.shape

torch.Size([1, 257, 1024])

In [None]:
class CLIP_modified(torch.nn.Module):
    def __init__(self, clip: torch.nn.ModuleDict, sam_features: torch.Tensor):
        super().__init__()
        assert len(sam_features.shape) == 3, "do the flattening and transpose"

        self.batch_size = sam_features.shape[0]
        self.sam_features = sam_features
        self.clip = clip


    def forward(self):
        class_embeds_clip = self.clip.vision_model.embeddings.class_embedding.expand(batch_size, 1, -1) # [B, 1, 1024]
        embeddings = torch.cat((class_embeds_clip, self.sam_features), dim=1)   # [B, 257, 1024]
        
        # 1
        pos_embed = clip_model.vision_model.embeddings.position_embedding(
                clip_model.vision_model.embeddings.position_ids[:, :embeddings.shape[1]]
            )                                                                   # [B, 257, 1024]
        
        # 2, 3, 4
        embeddings = embeddings + pos_embed  # ‚Üê FIX: ADD instead of replace
        # [B, 257, 1024]

        pooled_output = clip_model.vision_model.post_layernorm(
                    clip_model.vision_model.encoder(
                        clip_model.vision_model.pre_layrnorm(embeddings)
                    ).last_hidden_state
                )
        
        return pooled_output


In [80]:
with torch.no_grad():
    glabal_features = sam_model.vision_encoder(sample)

In [97]:
sam_output = glabal_features.last_hidden_state
sam_output_conv = Conv(sam_output)
sam_features = sam_output_conv.flatten(2,3).transpose(1,2)
sam_features.shape

torch.Size([2, 256, 1024])

In [99]:
vision_model = CLIP_modified(clip = clip_model, sam_features = sam_features)
with torch.no_grad():
    clip_fearures = vision_model()
clip_fearures.shape

torch.Size([2, 257, 1024])

In [101]:
global_features_total = torch.cat((clip_fearures[:, 1:], sam_features), dim = -1)
global_features_total.shape

torch.Size([2, 256, 2048])

In [102]:
projector = torch.nn.Linear(2048, 1280)  # TODO MlpProjector
projected_features = projector(global_features_total)
projected_features.shape

torch.Size([2, 256, 1280])

In [112]:
batch_size = sample.shape[0]
batch_size

2

In [105]:
# 2. Add spatial separators (newline tokens between rows
h = w = int(math.sqrt(256))  # h=16, w=16
features_2d = projected_features.view(batch_size, h, w, 1280)
features_2d.shape

torch.Size([2, 16, 16, 1280])

In [109]:
# add new line
image_newline = torch.nn.Parameter(torch.randn(1280))  # Learnable token
features_with_newlines = torch.cat([
    features_2d, 
    image_newline[None, None, None, :].expand(batch_size, h, 1, 1280)
], dim=2)  # [2, 16, 17, 1280]

features_with_newlines.shape

torch.Size([2, 16, 17, 1280])

In [113]:
vision_tokens = features_with_newlines.reshape(batch_size, -1, 1280)
vision_tokens.shape

torch.Size([2, 272, 1280])

In [116]:
# 3. Add separator token at the end
view_separator = torch.nn.Parameter(torch.randn(1280))
vision_tokens = torch.cat([
    vision_tokens,
    view_separator[None, None, :].expand(batch_size, 1, 1280)
], dim=1)  # [2, 273, 1280]

vision_tokens.shape

torch.Size([2, 275, 1280])

## all together

In [None]:
class CLIP_modified(torch.nn.Module):
    def __init__(self, clip: torch.nn.ModuleDict):
        super().__init__()
        
        self.clip = clip

    def forward(self, sam_features: torch.Tensor):
        assert len(sam_features.shape) == 3, "do the flattening and transpose"

        self.batch_size = sam_features.shape[0]
        self.sam_features = sam_features
        
        class_embeds_clip = self.clip.vision_model.embeddings.class_embedding.expand(batch_size, 1, -1) # [B, 1, 1024]
        embeddings = torch.cat((class_embeds_clip, self.sam_features), dim=1)   # [B, 257, 1024]
        
        # 1
        pos_embed = clip_model.vision_model.embeddings.position_embedding(
                clip_model.vision_model.embeddings.position_ids[:, :embeddings.shape[1]]
            )                                                                   # [B, 257, 1024]
        
        # 2, 3, 4
        embeddings = embeddings + pos_embed  # ADD instead of replace
        # [B, 257, 1024]

        pooled_output = clip_model.vision_model.post_layernorm(
                    clip_model.vision_model.encoder(
                        clip_model.vision_model.pre_layrnorm(embeddings)
                    ).last_hidden_state
                )
        
        return pooled_output


In [None]:
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

vision_model = CLIP_modified(clip = clip_model)

projector = torch.nn.Linear(2048, 1280, bias = True)  # TODO MlpProjector

In [121]:
# dummy input

sample = torch.rand(2, 3, 1024, 1024) # [B, C, H, W]
batch_size = sample.shape[0]

In [None]:
with torch.no_grad():
    sam_features = sam_model.vision_encoder(sample)

sam_output = sam_features.last_hidden_state
sam_output_conv = Conv(sam_output)
sam_features = sam_output_conv.flatten(2,3).transpose(1,2)


with torch.no_grad():
    clip_fearures = vision_model(sam_features = sam_features)

features_total = torch.cat((clip_fearures[:, 1:], sam_features), dim = -1)

# 1. Project to language model dimension
projected_features = projector(features_total)                             # [B, 256, 1280]

# 2. Add spatial separators (newline tokens between rows
h = w = int(math.sqrt(256))  # h=16, w=16
features_2d = projected_features.view(batch_size, h, w, 1280)              # [B, 16, 16, 1280]

# add new line
image_newline = torch.nn.Parameter(torch.randn(1280))  # Learnable token
features_with_newlines = torch.cat([
    features_2d, 
    image_newline[None, None, None, :].expand(batch_size, h, 1, 1280)
], dim=2)                                                                   # [B, 16, 17, 1280]

# flatten back
vision_tokens = features_with_newlines.reshape(batch_size, -1, 1280)        # [B, 272, 1280]

# 3. Add separator token at the end
view_separator = torch.nn.Parameter(torch.randn(1280))
vision_tokens = torch.cat([
    vision_tokens,
    view_separator[None, None, :].expand(batch_size, 1, 1280)
], dim=1)                                                                    # [B, 273, 1280]

vision_tokens.shape

torch.Size([2, 273, 1280])