In [1]:
from PIL import Image
import clip
import torch
from torch.nn import functional as F
from tqdm import tqdm
import argparse
from dataloaders import CustomDataLoader
from custom_tokenizers import Tokenizer
from types import SimpleNamespace
from decoder_layers import KVCache, PaliGemmaForConditionalGeneration
from configs.config import DataConfig, EncoderConfig, DecoderConfig, PaliGemmaConfig
from metrics import compute_scores

2025-06-29 05:59:35.642072: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-29 05:59:35.655617: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751176775.668517     689 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751176775.672310     689 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1751176775.683005     689 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
def beam_search_generate(
    model,
    tokenizer,
    pixel_values,        # shape: (B, 2, C, H, W)
    image_token_index,   # placeholder index for image tokens
    max_length=60,
    num_beams=5,
    eos_token_id=2,
    pad_token_id=0,
    device="cuda",
):
    model.eval()
    batch_size = pixel_values.size(0)
    eos_token_id = eos_token_id or tokenizer.eos_token_id

    # Encode images
    B, N, C, H, W = pixel_values.shape
    pixel_values = pixel_values.view(B * N, C, H, W).to(device)
    pixel_values = pixel_values.to(dtype=next(model.vision_tower.parameters()).dtype, device="cuda")
    vision_features = model.vision_tower(pixel_values)
    vision_features = vision_features.view(B, N, *vision_features.shape[1:])
    image_features = torch.cat([vision_features[:, 0], vision_features[:, 1]], dim=1)
    image_features = model.multi_modal_projector(image_features)  # shape: [B, Seq, Hidden]

    # Init input_ids with image token index
    input_ids = torch.full((batch_size * num_beams, 1), image_token_index, dtype=torch.long, device=device)
    attention_mask = torch.ones_like(input_ids, device=device)

    # Expand inputs for each beam
    image_features = image_features.unsqueeze(1).repeat(1, num_beams, 1, 1)
    image_features = image_features.view(batch_size * num_beams, *image_features.shape[2:])

    beam_scores = torch.zeros((batch_size, num_beams), device=device)
    beam_scores[:, 1:] = -1e9  # mask beams other than first
    beam_scores = beam_scores.view(-1)  # shape: [B * num_beams]

    sequences = input_ids
    is_done = [False] * batch_size

    for step in range(max_length):
        # Embedding
        input_embeds = model.language_model.get_input_embeddings()(sequences)
        
        # Merge image + text features
        merged_input, attn_mask, pos_ids = model._merge_input_ids_with_image_features(
            input_ids = sequences, image_features = image_features, inputs_embeds = input_embeds, attention_mask = attention_mask, kv_cache=None
        )

        # Forward pass
        outputs = model.language_model(
            inputs_embeds=merged_input,
            attention_mask=attn_mask,
            position_ids=pos_ids,
        )
        logits = outputs["logits"]  # shape: [B * num_beams, Seq_len, Vocab]
        next_token_logits = logits[:, -1, :]  # take last token only
        next_token_log_probs = F.log_softmax(next_token_logits, dim=-1)
        # next_token_log_probs[:, image_token_index] = -1e9


        # Add current beam scores
        next_token_log_probs = next_token_log_probs + beam_scores[:, None]

        # Get top k * num_beams candidates
        vocab_size = next_token_log_probs.size(-1)
        next_token_log_probs = next_token_log_probs.view(batch_size, num_beams * vocab_size)
        topk_log_probs, topk_indices = torch.topk(next_token_log_probs, num_beams, dim=-1)

        # Prepare for next step
        beam_indices = topk_indices // vocab_size
        token_indices = topk_indices % vocab_size

        # Reorder sequences and image features
        sequences = sequences.view(batch_size, num_beams, -1)
        new_sequences = []
        for i in range(batch_size):
            new_sequences.append(sequences[i, beam_indices[i]])
        sequences = torch.stack(new_sequences).view(batch_size * num_beams, -1)
        sequences = torch.cat([sequences, token_indices.view(-1, 1)], dim=-1)

        # Update scores
        beam_scores = topk_log_probs.view(-1)

        # Update attention mask
        attention_mask = torch.cat([attention_mask, torch.ones_like(token_indices.view(-1, 1))], dim=1)

        # Check if all sequences have ended
        if eos_token_id is not None:
            for i in range(batch_size):
                done_for_beam = True
                for beam_id in range(num_beams):
                    token = sequences[i * num_beams + beam_id, -1]
                    if token != eos_token_id:
                        done_for_beam = False
                        break
                is_done[i] = done_for_beam
        
            if all(is_done):
                break

    # Reshape to [batch_size, num_beams, seq_len] and pick best beam
    sequences = sequences.view(batch_size, num_beams, -1)
    beam_scores = beam_scores.view(batch_size, num_beams)
    best_indices = torch.argmax(beam_scores, dim=1)

    best_sequences = []
    for i in range(batch_size):
        best_sequences.append(sequences[i, best_indices[i]])
    best_sequences = torch.stack(best_sequences)

    return best_sequences

In [6]:
# Load model
config = PaliGemmaConfig(text_config=DecoderConfig(), vision_config=EncoderConfig())
model = PaliGemmaForConditionalGeneration(config)
model.load_state_dict(torch.load("checkpoints/experiment_1.pt", map_location="cuda"))
model.eval()

tokenizer = Tokenizer(DataConfig())

device = "cuda" if torch.cuda.is_available() else "cpu"
model_clip, preprocess = clip.load("ViT-B/32", device=device)
model = model.to(device)

# Load and preprocess image
image_1 = preprocess(Image.open("data/updated_iu_xray/CXR38_IM-1911/0.png"))
image_2 = preprocess(Image.open("data/updated_iu_xray/CXR38_IM-1911/1.png"))
pixel_values = torch.stack((image_1, image_2), 0).unsqueeze(0)
generated_ids = beam_search_generate(
    model,
    tokenizer,
    pixel_values=pixel_values,
    image_token_index=config.image_token_index,
    eos_token_id=tokenizer.eos_token_id,
    device="cuda",
    num_beams=3,
    max_length=100,
)

# Decode to text
captions = tokenizer.decode_batch(generated_ids)
for caption in captions:
    print(caption)

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 47.50 GiB of which 10.56 MiB is free. Process 1687999 has 14.44 GiB memory in use. Process 1681600 has 33.04 GiB memory in use. Of the allocated memory 31.83 GiB is allocated by PyTorch, and 731.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [14]:
def evaluate_model(model, tokenizer, dataloader, device, image_token_index, max_len=60, num_beams=5):
    model.eval()
    gts = {}
    res = {}

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
            input_ids, pixel_values, att_masks = batch
            pixel_values = pixel_values.to(device)

            generated_ids = beam_search_generate(
                model,
                tokenizer,
                pixel_values=pixel_values,
                image_token_index=image_token_index,
                eos_token_id=tokenizer.eos_token_id,
                device=device,
                num_beams=num_beams,
                max_length=max_len,
            )

            # Decode predictions
            decoded_preds = tokenizer.decode_batch(generated_ids)
            decoded_preds = [pred.strip().lower() for pred in decoded_preds]

            # Decode references
            references = input_ids
            if isinstance(references[0], torch.Tensor):
                references = [tokenizer.decode(ref).replace("<image>", "").strip().lower() for ref in references]

            batch_size = pixel_values.size(0)
            for j in range(batch_size):
                image_id = f"img_{i * batch_size + j}"
                gts[image_id] = [references[j]]  # list of refs
                res[image_id] = [decoded_preds[j]]  # model output
            for i in range(1):
                print(f"Pred: {decoded_preds[i]}")
                print(f"Ref: {references[i]}")
                print()

    scores = compute_scores(gts, res)

    # Optional print
    for metric, score in scores.items():
        print(f"{metric.upper()}: {score:.4f}")

    return scores

In [7]:
# Install Java for Meteor
# !apt-get update -y
# !apt-get install -y openjdk-17-jre-headless

In [15]:
tokenizer = Tokenizer(DataConfig())

test_loader = CustomDataLoader(
    split="test",
    batch_size=64,
    num_workers=1,
    tokenizer=tokenizer,
    shuffle=True
)

scores = evaluate_model(model, tokenizer, test_loader, device, image_token_index=763)
scores

Evaluating:  10%|█         | 1/10 [00:07<01:05,  7.24s/it]

Pred: <image> the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the cardiac monitor leads show an unchanged from prior granulomatous process the heart size and vasculature central venous catheter tip overlying external cardiac monitor leads
Ref: <bos> the heart size and pulmonary vascularity appear within normal limits . the lungs are free of focal airspace disease . no pleural effusion or pneumothorax is seen . <eos>



Evaluating:  20%|██        | 2/10 [00:13<00:52,  6.50s/it]

Pred: <image> there are clear without acute bony structures show an unchanged from prior examination consists normal limits . there is a small t-spine osteophytes mediastinal contours are clear without acute bony structures show an unchanged from prior examination consists normal limits . <eos>
Ref: <bos> lungs are clear . there is no pneumothorax or pleural effusion . the heart and mediastinum are within normal limits . bony structures are intact . <eos>



Evaluating:  30%|███       | 3/10 [00:19<00:44,  6.32s/it]

Pred: <image> status post limits . <eos>
Ref: <bos> the lungs are clear . there is no pleural effusion or pneumothorax . the heart is not significantly enlarged . the mediastinum is normal . arthritic changes of the skeletal structures are noted . <eos>



Evaluating:  40%|████      | 4/10 [00:25<00:37,  6.26s/it]

Pred: <image> the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the cardiac monitor leads show an unchanged from prior granulomatous process the heart size and vasculature central venous catheter tip overlying external cardiac monitor leads
Ref: <bos> there is persistent marked enlargement of the pulmonary arteries . normal heart size . no focal airspace consolidation . no pleural effusion or pneumothorax . visualized osseous structures are unremarkable in appearance . <eos>



Evaluating:  50%|█████     | 5/10 [00:31<00:31,  6.26s/it]

Pred: <image> both lungs are clear without acute bony structures show an unchanged from prior examination consists normal limits . there is a small t-spine osteophytes mediastinal contours are clear without acute bony structures show an unchanged from prior examination consists normal limits . <eos>
Ref: <bos> the cardiomediastinal silhouette is normal in size and contour . no focal consolidation pneumothorax or large pleural effusion . normal xxxx . <eos>



Evaluating:  60%|██████    | 6/10 [00:38<00:25,  6.28s/it]

Pred: <image> the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the lungs are clear without evidence of the cardiac monitor leads show an unchanged from prior granulomatous process the heart size and vasculature central venous catheter tip overlying external cardiac monitor leads
Ref: <bos> cardiomediastinal silhouette within normal limits . no acute bony abnormality . there are xxxx xxxx opacities atelectasis versus airspace disease . no large effusion or pneumothorax . <eos>



Evaluating:  70%|███████   | 7/10 [00:44<00:18,  6.32s/it]

Pred: <image> . <eos>
Ref: <bos> the heart size and pulmonary vascularity appear within normal limits . the lungs are free of focal airspace disease . no pleural effusion or pneumothorax is seen . <eos>



Evaluating:  80%|████████  | 8/10 [00:50<00:12,  6.37s/it]

Pred: <image> . <eos>
Ref: <bos> cardiac and mediastinal contours are within normal limits . prior granulomatous disease . the lungs are clear . thoracic spondylosis . <eos>



Evaluating:  90%|█████████ | 9/10 [00:57<00:06,  6.41s/it]

Pred: <image> result bronchovascular crowding as well aerated suspicious pulmonary vascularity appear within normal limits . there are clear without acute bony structures show an unchanged from prior examination consists normal limits for patient age . <eos>
Ref: <bos> lung volumes remain low . no infiltrates . heart and pulmonary xxxx remain normal . <eos>



Evaluating: 100%|██████████| 10/10 [00:58<00:00,  5.90s/it]

Pred: <image> are clear without acute bony structures show an unchanged from prior examination consists normal limits . there is a small t-spine osteophytes mediastinal contours are clear without acute bony structures show an unchanged from prior examination consists normal limits . <eos>
Ref: <bos> lungs are clear without focal consolidation effusion or pneumothorax . normal heart size . negative for <unk> . mild degenerative changes of the thoracic spine . <eos>






BLEU_1: 0.1270
BLEU_2: 0.0820
BLEU_3: 0.0470
BLEU_4: 0.0268
METEOR: 0.1035
ROUGE_L: 0.1241


{'BLEU_1': 0.1269983373832917,
 'BLEU_2': 0.08195098960400135,
 'BLEU_3': 0.04700835898767729,
 'BLEU_4': 0.02680478528886467,
 'METEOR': 0.10346761761716412,
 'ROUGE_L': np.float64(0.12414180940742026)}