In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image
import glob

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-flan-t5-xl",
).to(device)
model.eval()

overlay_paths = sorted(glob.glob("/home/s2behappy4/data/gyuhyeong/code/bridge_data/**/*.png", recursive=True))
print(f"Found {len(overlay_paths)} overlay images")

loader = DataLoader(overlay_paths, batch_size=8, shuffle=False, num_workers=4)

all_tokens = []
with torch.no_grad():
    for batch_paths in loader:
        images = [Image.open(p).convert("RGB") for p in batch_paths]
        inputs = processor(images=images, return_tensors="pt").to(device)

        vision_out   = model.vision_model(pixel_values=inputs.pixel_values)
        img_embeds   = vision_out.last_hidden_state        

        batch_queries   = model.query_tokens.expand(img_embeds.size(0), -1, -1)
        qf_out          = model.qformer(
            query_embeds           = batch_queries,
            encoder_hidden_states  = img_embeds,
            encoder_attention_mask = torch.ones(img_embeds.size()[:-1], device=device, dtype=torch.long),
            return_dict            = True,
        )
        mask_tokens     = qf_out.last_hidden_state.mean(dim=1)  

        all_tokens.append(mask_tokens.cpu())

all_tokens = torch.cat(all_tokens, dim=0)  
torch.save(all_tokens, "/home/s2behappy4/data/gyuhyeong/code/bridge_data/mask_token_single.pt")
print("✅ mask_token_single.pt saved")

In [None]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
import glob
from PIL import Image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoProcessor.from_pretrained(
    "llava-hf/llava-onevision-qwen2-7b-ov-hf",
    trust_remote_code=True
)

model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    "llava-hf/llava-onevision-qwen2-7b-ov-hf",
    trust_remote_code  = True,
    torch_dtype        = torch.bfloat16,
    low_cpu_mem_usage  = True,
    device_map         = "auto",
).eval().to(device)

In [8]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, SiglipImageProcessor
from PIL import Image
import glob

sig_processor = SiglipImageProcessor.from_pretrained(
    "google/siglip-base-patch16-224"
)

mask_tokens = torch.load(
    "/home/s2behappy4/data/gyuhyeong/code/bridge_data/mask_token_single.pt"
).to(device).float()  

orig_paths = sorted(glob.glob(
    "/home/s2behappy4/data/gyuhyeong/code/bridge_data/**/*.png",
    recursive=True
))

feat_list = []
batch_size = 16

for i in range(0, len(orig_paths), batch_size):
    batch_paths = orig_paths[i : i + batch_size]
    imgs = [Image.open(p).convert("RGB") for p in batch_paths]

    inputs = sig_processor(
        images=imgs,
        return_tensors="pt",
        do_resize=True,
        size={"height": 224, "width": 224},
    ).pixel_values  

    if inputs.dim() == 5:
        inputs = inputs.squeeze(1)
    inputs = inputs.to(device)

    with torch.no_grad():
        vis_out = model.vision_tower(
            pixel_values=inputs,
            interpolate_pos_encoding=True
        )
    img_feats = vis_out.last_hidden_state.mean(dim=1)  
    feat_list.append(img_feats.cpu())

image_feats = torch.cat(feat_list, dim=0).float().to(device)  
assert mask_tokens.size(0) == image_feats.size(0), "Error"

bridge = nn.Linear(768, model.vision_tower.config.hidden_size).to(device)
for p in model.parameters():
    p.requires_grad = False
bridge.train()

ds        = TensorDataset(mask_tokens, image_feats)
loader    = DataLoader(ds, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(bridge.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(5):
    total_loss = 0.0
    for m_tok, img_f in loader:
        pred = bridge(m_tok)
        loss = criterion(pred, img_f)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} — Avg Loss: {total_loss/len(loader):.4f}")

torch.save(
    bridge.state_dict(),
    "/home/s2behappy4/data/gyuhyeong/code/bridge_data/bridge_weights_train.pt"
)
print("✅ bridge weights saved")

Epoch 1 — Avg Loss: 0.0715
Epoch 2 — Avg Loss: 0.0131
Epoch 3 — Avg Loss: 0.0108
Epoch 4 — Avg Loss: 0.0098
Epoch 5 — Avg Loss: 0.0092
✅ bridge weights saved
