In [None]:
import torch
from PIL import Image
from transformers import SiglipVisionConfig, SiglipVisionModel, SiglipImageProcessor
from utils import from_path_to_vision_encoder, build_projector

In [None]:
vision_base_arch = "google/siglip2-so400m-patch16-384"
weight_dir = "./checkpoints/stage1-full-train/stage2-sft-siglip2-so500m-Qwen2.5-1.5B-llava"
vision_weight = from_path_to_vision_encoder(weight_dir)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

config = SiglipVisionConfig.from_pretrained(vision_base_arch)
vit = SiglipVisionModel.from_pretrained(vision_base_arch)
processor = SiglipImageProcessor.from_pretrained(vision_base_arch)

del vit.vision_model.encoder.layers[-1:]    
vit.vision_model.head = torch.nn.Identity()
vit.eval().to(device)

In [4]:
# ViT weight loading
vit_sd = vision_weight["vision_tower"]  
vit_sd = {k.replace("vision_tower.", "", 1): v for k, v in vit_sd.items()}

missing, unexpected = vit.load_state_dict(vit_sd, strict=True)

In [None]:
# Projector weight loading
proj_sd = vision_weight["mm_projector"] 
projector = build_projector(proj_sd, device)
missing_p, unexpected_p = projector.load_state_dict(proj_sd, strict=True)

In [7]:
IMAGE_PATH = "dog.jpg"

img = Image.open(IMAGE_PATH).convert("RGB")
batch = processor.preprocess(img, return_tensors="pt")  
pixel_values = batch["pixel_values"].to(device=device, dtype=torch.float32)

# Running through vit only 
with torch.no_grad():
    out = vit(pixel_values=pixel_values, output_hidden_states=True)
    feats = out.hidden_states[-1] 

print("feature shape (batch, num_patchs, vision_hidden_dim):", tuple(feats.shape))

feature shape (batch, num_patchs, vision_hidden_dim): (1, 576, 1152)


In [8]:
# run through vit then projector
with torch.no_grad():
    vit_out = vit(pixel_values=pixel_values, output_hidden_states=True)
    vit_feats = vit_out.hidden_states[-1]           # [batch, num_patchs, vision_hidden_dim]
    proj_feats = projector(vit_feats)               # [batch, num_patchs, llm_hidden_dim]

print("vit feature shape:", tuple(vit_feats.shape))
print("projected feature shape:", tuple(proj_feats.shape))

vit feature shape: (1, 576, 1152)
projected feature shape: (1, 576, 1536)
