In [None]:
from PIL import Image
import clip
import torch
from torch.nn import functional as F
from tqdm import tqdm
import argparse
from dataloaders import CustomDataLoader
from custom_tokenizers import Tokenizer
from types import SimpleNamespace
from decoder_layers import KVCache, PaliGemmaForConditionalGeneration
from configs.config import DataConfig, EncoderConfig, DecoderConfig, PaliGemmaConfig



In [2]:
def beam_search_generate(
    model,
    tokenizer,
    pixel_values,        # shape: (B, 2, C, H, W)
    image_token_index,   # placeholder index for image tokens
    max_length=30,
    num_beams=3,
    eos_token_id=2,
    pad_token_id=0,
    device="cpu",
):
    model.eval()
    batch_size = pixel_values.size(0)
    eos_token_id = eos_token_id or tokenizer.eos_token_id

    # Encode images
    B, N, C, H, W = pixel_values.shape
    pixel_values = pixel_values.view(B * N, C, H, W).to(device)
    vision_features = model.vision_tower(pixel_values)
    vision_features = vision_features.view(B, N, *vision_features.shape[1:])
    image_features = torch.cat([vision_features[:, 0], vision_features[:, 1]], dim=1)
    image_features = model.multi_modal_projector(image_features)  # shape: [B, Seq, Hidden]

    # Init input_ids with image token index
    input_ids = torch.full((batch_size * num_beams, 1), image_token_index, dtype=torch.long, device=device)
    attention_mask = torch.ones_like(input_ids, device=device)

    # Expand inputs for each beam
    image_features = image_features.unsqueeze(1).repeat(1, num_beams, 1, 1)
    image_features = image_features.view(batch_size * num_beams, *image_features.shape[2:])

    beam_scores = torch.zeros((batch_size, num_beams), device=device)
    beam_scores[:, 1:] = -1e9  # mask beams other than first
    beam_scores = beam_scores.view(-1)  # shape: [B * num_beams]

    sequences = input_ids
    is_done = [False] * batch_size

    for step in range(max_length):
        # Embedding
        input_embeds = model.language_model.get_input_embeddings()(sequences)
        
        # Merge image + text features
        merged_input, attn_mask, pos_ids = model._merge_input_ids_with_image_features(
            sequences, image_features, input_embeds, attention_mask, kv_cache=None
        )

        # Forward pass
        outputs = model.language_model(
            inputs_embeds=merged_input,
            attention_mask=attn_mask,
            position_ids=pos_ids,
        )
        logits = outputs["logits"]  # shape: [B * num_beams, Seq_len, Vocab]
        next_token_logits = logits[:, -1, :]  # take last token only
        next_token_log_probs = F.log_softmax(next_token_logits, dim=-1)

        # Add current beam scores
        next_token_log_probs = next_token_log_probs + beam_scores[:, None]

        # Get top k * num_beams candidates
        vocab_size = next_token_log_probs.size(-1)
        next_token_log_probs = next_token_log_probs.view(batch_size, num_beams * vocab_size)
        topk_log_probs, topk_indices = torch.topk(next_token_log_probs, num_beams, dim=-1)

        # Prepare for next step
        beam_indices = topk_indices // vocab_size
        token_indices = topk_indices % vocab_size

        # Reorder sequences and image features
        sequences = sequences.view(batch_size, num_beams, -1)
        new_sequences = []
        for i in range(batch_size):
            new_sequences.append(sequences[i, beam_indices[i]])
        sequences = torch.stack(new_sequences).view(batch_size * num_beams, -1)
        sequences = torch.cat([sequences, token_indices.view(-1, 1)], dim=-1)

        # Update scores
        beam_scores = topk_log_probs.view(-1)

        # Update attention mask
        attention_mask = torch.cat([attention_mask, torch.ones_like(token_indices.view(-1, 1))], dim=1)

        # Check if all sequences have ended
        if eos_token_id is not None:
            for i in range(batch_size):
                for beam_id in range(num_beams):
                    if sequences[i * num_beams + beam_id, -1] == eos_token_id:
                        is_done[i] = True

        if all(is_done):
            break

    # Reshape to [batch_size, num_beams, seq_len] and pick best beam
    sequences = sequences.view(batch_size, num_beams, -1)
    beam_scores = beam_scores.view(batch_size, num_beams)
    best_indices = torch.argmax(beam_scores, dim=1)

    best_sequences = []
    for i in range(batch_size):
        best_sequences.append(sequences[i, best_indices[i]])
    best_sequences = torch.stack(best_sequences)

    return best_sequences

In [3]:
# Load model
config = PaliGemmaConfig(text_config=DecoderConfig(), vision_config=EncoderConfig())
model = PaliGemmaForConditionalGeneration(config)
model.load_state_dict(torch.load("checkpoints/best_model.pt", map_location="cpu"))
model.eval()

tokenizer = Tokenizer(DataConfig())

device = "cuda" if torch.cuda.is_available() else "cpu"
model_clip, preprocess = clip.load("ViT-B/32", device=device)

# Load and preprocess image
image_1 = preprocess(Image.open("data/updated_iu_xray/CXR1_1_IM-0001/0.png"))
image_2 = preprocess(Image.open("data/updated_iu_xray/CXR1_1_IM-0001/1.png"))
pixel_values = torch.stack((image_1, image_2), 0).unsqueeze(0)
generated_ids = beam_search_generate(
    model,
    tokenizer,
    pixel_values=pixel_values,
    image_token_index=config.image_token_index,
    eos_token_id=tokenizer.eos_token_id,
    device="cpu",
    num_beams=5,
    max_length=50,
)

# Decode to text
captions = tokenizer.decode_batch(generated_ids)
for caption in captions:
    print(caption)

overlies . desired nipple medial decrease obscuring correction ectatic combination


In [4]:
generated_ids

tensor([[978,   4, 407, 897, 843, 376, 932, 350, 466, 305]])