In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import get_scheduler
from transformers import AutoTokenizer, AutoModel
from torch.optim.lr_scheduler import ReduceLROnPlateau
from custom_tokenizers import Tokenizer
from configs.config import DataConfig, EncoderConfig, DecoderConfig, PaliGemmaConfig
from decoder_coba import KVCache, PaliGemmaForConditionalGeneration
from dataloaders_coba import CustomDataLoader
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

2025-07-01 22:34:58.630483: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-01 22:34:58.642040: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751409298.655266   11475 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751409298.659097   11475 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1751409298.670186   11475 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [3]:
def train(
    model,
    dataloader,
    optimizer,
    device,
    epoch,
    bert,
    grad_accumulation_steps=1,
    max_grad_norm=1.0,
    use_amp=False,
):
    model.train()
    kv_cache = KVCache()
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    config = DecoderConfig()

    total_loss = 0.0

    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
        input_ids, pixel_values, target, attention_mask, input_embeds = batch
        input_ids = input_ids.to(device)
        pixel_values = pixel_values.to(device)
        attention_mask = attention_mask.to(device)
        input_embeds = input_embeds.unsqueeze(1).to(device)
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(
                input_ids=input_ids,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                inputs_embeds=input_embeds
            )
            
            logits = outputs["logits"]
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            
            loss = torch.nn.functional.cross_entropy(
                shift_logits.view(-1, config.vocab_size),
                shift_labels.view(-1),
                label_smoothing=0.1
            )

        scaler.scale(loss).backward()

        if (step + 1) % grad_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        total_loss += loss.item()

    return total_loss / len(dataloader)

@torch.no_grad()
def validate(model, dataloader, device, bert, use_amp=False):
    model.eval()
    total_loss = 0.0
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    config = DecoderConfig()

    for batch in tqdm(dataloader, desc="Validating"):
        input_ids, pixel_values, target, attention_mask, input_embeds = batch
        input_ids = input_ids.to(device)
        pixel_values = pixel_values.to(device)
        attention_mask = attention_mask.to(device)
        input_embeds = input_embeds.unsqueeze(1).to(device)
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(
                input_ids=input_ids,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                inputs_embeds=input_embeds
            )
            logits = outputs["logits"]  # shape: [B, T, V]
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            
            loss = torch.nn.functional.cross_entropy(
                shift_logits.view(-1, config.vocab_size),
                shift_labels.view(-1),
                label_smoothing=0.1
            )

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [4]:
import os
import re
import json
import torch
from transformers import AutoTokenizer, AutoModel
from configs.config import DataConfig
from tqdm import tqdm

def clean_report_iu_xray(report):
    report_cleaner = lambda t: t.replace('..', '.').replace('1. ', '') \
        .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
        .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
        .strip().lower().split('. ')
    sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
                                    .replace('\\', '').replace("'", '').strip().lower())
    tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
    return ' . '.join(tokens) + ' .'

def compute_and_save_bert_embeddings(split="train", save_path="cached_bert_embeds.pt"):
    # === Load BERT model ===
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").eval().cuda()

    # === Load annotations ===
    config = DataConfig()
    annotation_path = config.annotation_path
    with open(annotation_path, 'r') as f:
        data = json.load(f)[split]

    print(f"Computing embeddings for {len(data)} samples in split '{split}'...")

    embeddings = []
    for example in tqdm(data):
        report = clean_report_iu_xray(example["report"])

        # Tokenize
        inputs = tokenizer(report, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.cuda() for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_state = outputs.last_hidden_state
            attention_mask = inputs["attention_mask"]

            # Average the hidden states (excluding padding)
            masked = last_hidden_state * attention_mask.unsqueeze(-1)
            mean_embedding = masked.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)

        # Save as 1D CPU tensor (remove batch dimension)
        embeddings.append(mean_embedding.squeeze(0).cpu())

    # === Save all embeddings ===
    os.makedirs("cached_embeds", exist_ok=True)
    torch.save(embeddings, os.path.join("cached_embeds", f"{split}_bert_embeds.pt"))
    print(f"Saved {split} embeddings to 'cached_embeds/{split}_bert_embeds.pt'")


if __name__ == "__main__":
    compute_and_save_bert_embeddings(split="train")
    compute_and_save_bert_embeddings(split="val")
    compute_and_save_bert_embeddings(split="test")

Computing embeddings for 2069 samples in split 'train'...


  0%|          | 0/2069 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 2069/2069 [00:07<00:00, 260.14it/s]


Saved train embeddings to 'cached_embeds/train_bert_embeds.pt'
Computing embeddings for 296 samples in split 'val'...


  0%|          | 0/296 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 296/296 [00:01<00:00, 264.31it/s]


Saved val embeddings to 'cached_embeds/val_bert_embeds.pt'
Computing embeddings for 590 samples in split 'test'...


  0%|          | 0/590 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 590/590 [00:02<00:00, 266.93it/s]


Saved test embeddings to 'cached_embeds/test_bert_embeds.pt'


In [5]:
def main():
    # Hyperparameters
    epochs = 10
    batch_size = 64
    learning_rate = 3e-5
    grad_accumulation_steps = 1
    use_amp = True
    patience = 30

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    # === Load Config ===
    text_config = DecoderConfig()
    vision_config = EncoderConfig()
    text_config.vocab_size += 1
    image_token_index = text_config.vocab_size - 1
    config = PaliGemmaConfig(
        text_config=text_config,
        vision_config=vision_config,
        image_token_index=image_token_index,
    )

    # === Load Model ===
    model = PaliGemmaForConditionalGeneration(config).to(device)

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    tokenizer.add_tokens('<image>')
    bert = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

    # === Dataloader ===
    train_loader = CustomDataLoader(
        split="train",
        batch_size=batch_size,
        num_workers=1,
        tokenizer=tokenizer,
        shuffle=False
    )
    val_loader = CustomDataLoader(
        split="val",
        batch_size=batch_size,
        num_workers=1,
        tokenizer=tokenizer,
        shuffle=False
    )

    # === Optimizer and Scheduler ===
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    num_training_steps = epochs * len(train_loader) // grad_accumulation_steps
    # lr_scheduler = get_scheduler(
    #     "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    # )
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='abs')

    best_val_loss = float("inf")
    patience_counter = 0

    # === Training Loop ===
    start_time = time.time()
    train_loss = []
    val_loss = []
    for epoch in range(1, epochs + 1):
        avg_train_loss = train(
            model,
            train_loader,
            optimizer,
            # lr_scheduler,
            device,
            epoch,
            bert,
            grad_accumulation_steps,
            use_amp=use_amp,
        )
        avg_val_loss = validate(model, val_loader, device, bert, use_amp)
        lr_scheduler.step(avg_val_loss)
        train_loss.append(avg_train_loss)
        val_loss.append(avg_val_loss)
        print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            print("Validation loss improved. Saving model...")
            best_val_loss = avg_val_loss
            patience_counter = 0
            os.makedirs("checkpoints", exist_ok=True)
            torch.save(model.state_dict(), f"checkpoints/exp_1.pt")
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break
    save_dir = "plot"
    os.makedirs(save_dir, exist_ok=True)

    # Plotting
    epochs = list(range(1, len(train_loss) + 1))
    plt.figure(figsize=(8, 6))
    plt.plot(epochs, train_loss, label='Train Loss', marker='o')
    plt.plot(epochs, val_loss, label='Validation Loss', marker='x')
    plt.title("Training and Validation Loss per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    # Save the plot
    plot_path = os.path.join(save_dir, "exp_1.png")
    plt.savefig(plot_path)
    plt.close()

    # Calculate elapsed time
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Elapsed time: {elapsed_time} seconds")

In [6]:
if __name__ == "__main__":
    main()

Using device: cuda


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch 1:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 1: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Epoch 1 - Train Loss: 6.7990 - Val Loss: 4.1630
Validation loss improved. Saving model...


Epoch 2:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 2: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Epoch 2 - Train Loss: 3.7991 - Val Loss: 3.3207
Validation loss improved. Saving model...


Epoch 3:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 3: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.08s/it]


Epoch 3 - Train Loss: 3.1113 - Val Loss: 2.9347
Validation loss improved. Saving model...


Epoch 4:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 4: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Epoch 4 - Train Loss: 2.8735 - Val Loss: 2.8092
Validation loss improved. Saving model...


Epoch 5:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 5: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Epoch 5 - Train Loss: 2.7840 - Val Loss: 2.7462
Validation loss improved. Saving model...


Epoch 6:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 6: 100%|██████████| 33/33 [00:35<00:00,  1.08s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Epoch 6 - Train Loss: 2.7340 - Val Loss: 2.7029
Validation loss improved. Saving model...


Epoch 7:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 7: 100%|██████████| 33/33 [00:35<00:00,  1.06s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.06s/it]


Epoch 7 - Train Loss: 2.6901 - Val Loss: 2.6641
Validation loss improved. Saving model...


Epoch 8:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 8: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Epoch 8 - Train Loss: 2.6586 - Val Loss: 2.6399
Validation loss improved. Saving model...


Epoch 9:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 9: 100%|██████████| 33/33 [00:35<00:00,  1.06s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.06s/it]


Epoch 9 - Train Loss: 2.6329 - Val Loss: 2.6174
Validation loss improved. Saving model...


Epoch 10:   0%|          | 0/33 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 10: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validating: 100%|██████████| 5/5 [00:05<00:00,  1.07s/it]


Epoch 10 - Train Loss: 2.6089 - Val Loss: 2.5934
Validation loss improved. Saving model...
Elapsed time: 436.1700131893158 seconds


In [7]:
from PIL import Image
import clip
import torch
from torch.nn import functional as F
from tqdm import tqdm
import argparse
from dataloaders_coba import CustomDataLoader
from custom_tokenizers import Tokenizer
from types import SimpleNamespace
from decoder_coba import KVCache, PaliGemmaForConditionalGeneration
from configs.config import DataConfig, EncoderConfig, DecoderConfig, PaliGemmaConfig
from metrics import compute_scores

In [8]:
def beam_search_generate(
    model,
    tokenizer,
    inputs_embeds,
    pixel_values,
    image_token_index,
    max_length=60,
    num_beams=5,
    eos_token_id=2,
    pad_token_id=0,
    device="cuda",
):
    model.eval()
    batch_size = pixel_values.size(0)
    eos_token_id = eos_token_id or tokenizer.eos_token_id

    # Encode images
    B, N, C, H, W = pixel_values.shape
    pixel_values = pixel_values.view(B * N, C, H, W).to(device)
    pixel_values = pixel_values.to(dtype=next(model.vision_tower.parameters()).dtype, device=device)
    vision_features = model.vision_tower(pixel_values)
    vision_features = vision_features.view(B, N, *vision_features.shape[1:])
    image_features = torch.cat([vision_features[:, 0], vision_features[:, 1]], dim=1)
    image_features = model.multi_modal_projector(image_features)  # shape: [B, Seq, Hidden]

    # Init input_ids with <image> tokens
    input_ids = torch.full((batch_size * num_beams, 1), image_token_index, dtype=torch.long, device=device)
    attention_mask = torch.ones_like(input_ids)

    # Expand inputs for beams
    image_features = image_features.unsqueeze(1).repeat(1, num_beams, 1, 1)
    image_features = image_features.view(batch_size * num_beams, *image_features.shape[2:])

    inputs_embeds = inputs_embeds.unsqueeze(1).repeat(1, num_beams, 1, 1)  # [B, num_beams, 1, 768]
    inputs_embeds = inputs_embeds.view(batch_size * num_beams, 1, -1)      # [B*num_beams, 1, 768]


    beam_scores = torch.zeros((batch_size, num_beams), device=device)
    beam_scores[:, 1:] = -1e9
    beam_scores = beam_scores.view(-1)

    sequences = input_ids
    is_done = [False] * batch_size

    for step in range(max_length):

        merged_input, attn_mask, pos_ids = model._merge_input_ids_with_image_features(
            input_ids=sequences,
            image_features=image_features,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            kv_cache=None
        )

        outputs = model.language_model(
            inputs_embeds=merged_input,
            attention_mask=attn_mask,
            position_ids=pos_ids,
        )
        logits = outputs["logits"]
        next_token_logits = logits[:, -1, :]
        next_token_log_probs = F.log_softmax(next_token_logits, dim=-1)

        next_token_log_probs = next_token_log_probs + beam_scores[:, None]

        vocab_size = next_token_log_probs.size(-1)
        next_token_log_probs = next_token_log_probs.view(batch_size, num_beams * vocab_size)
        next_token_log_probs[:, image_token_index] = -1e9
        topk_log_probs, topk_indices = torch.topk(next_token_log_probs, num_beams, dim=-1)

        beam_indices = topk_indices // vocab_size
        token_indices = topk_indices % vocab_size

        sequences = sequences.view(batch_size, num_beams, -1)
        new_sequences = [sequences[i, beam_indices[i]] for i in range(batch_size)]
        sequences = torch.stack(new_sequences).view(batch_size * num_beams, -1)
        sequences = torch.cat([sequences, token_indices.view(-1, 1)], dim=-1)

        beam_scores = topk_log_probs.view(-1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(token_indices.view(-1, 1))], dim=1)

        if eos_token_id is not None:
            for i in range(batch_size):
                done_for_beam = True
                for beam_id in range(num_beams):
                    token = sequences[i * num_beams + beam_id, -1]
                    if token != eos_token_id:
                        done_for_beam = False
                        break
                is_done[i] = done_for_beam
            if all(is_done):
                break

    sequences = sequences.view(batch_size, num_beams, -1)
    beam_scores = beam_scores.view(batch_size, num_beams)
    best_indices = torch.argmax(beam_scores, dim=1)

    best_sequences = [sequences[i, best_indices[i]] for i in range(batch_size)]
    return torch.stack(best_sequences)

def evaluate_model(model, tokenizer, dataloader, device, image_token_index, max_len=60, num_beams=5):
    model.eval()
    gts = {}
    res = {}

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
            # From your dataset: (image_id, image, ids, mask, bert, len(ids))
            targets, pixel_values, input_ids, masks, inputs_embeds = batch

            pixel_values = pixel_values.to(device)
            inputs_embeds = inputs_embeds.unsqueeze(1).to(device)

            generated_ids = beam_search_generate(
                model,
                tokenizer,
                inputs_embeds=inputs_embeds,
                pixel_values=pixel_values,
                image_token_index=image_token_index,
                eos_token_id=tokenizer.eos_token_id,
                device=device,
                num_beams=num_beams,
                max_length=max_len,
            )

            decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            decoded_preds = [pred.strip().lower() for pred in decoded_preds]

            references = targets
            references = [tokenizer.decode(ref.tolist(), skip_special_tokens=True).replace("<image>", "").strip().lower() for ref in references]


            batch_size = pixel_values.size(0)
            for j in range(batch_size):
                img_id = f"img_{i * batch_size + j}"
                gts[img_id] = [references[j]]
                res[img_id] = [decoded_preds[j]]
            for i in range(1):
                # Optional print
                print(f"Pred: {decoded_preds[i]}")
                print(f"Ref : {references[i]}")
                print()

    scores = compute_scores(gts, res)
    for metric, score in scores.items():
        print(f"{metric.upper()}: {score:.4f}")

    return scores, gts, res

In [9]:
config = PaliGemmaConfig(text_config=DecoderConfig(), vision_config=EncoderConfig())
model = PaliGemmaForConditionalGeneration(config)
model.load_state_dict(torch.load("checkpoints/exp_1.pt", map_location="cuda"))
model.eval()

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
tokenizer.add_tokens('<image>')

device = "cuda" if torch.cuda.is_available() else "cpu"
model_clip, preprocess = clip.load("ViT-B/32", device=device)
model = model.to(device)

test_loader = CustomDataLoader(
    split="test",
    batch_size=64,
    num_workers=1,
    tokenizer=tokenizer,
    shuffle=True
)

scores, gts, res = evaluate_model(model, tokenizer, test_loader, device, image_token_index=28996)
scores

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Evaluating:  10%|█         | 1/10 [00:10<01:36, 10.75s/it]

Pred: <image>pped thexxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx <image>
Ref : no pneumothorax pleural effusion or airspace consolidation. heart size is upper limits of normal. pulmonary vasculature appear within normal limits. xxxx xxxx are intact.



Evaluating:  20%|██        | 2/10 [00:20<01:19,  9.88s/it]

Pred: <image> hostile no no no. no. no..... no.. no. no........... no.. no..... no <image>x.. no... <image>x......... <image>
Ref : heart size within normal limits. no focal alveolar consolidation no definite pleural effusion seen. no typical findings of pulmonary edema. no pneumothorax.



Evaluating:  30%|███       | 3/10 [00:29<01:07,  9.67s/it]

Pred: <image>lak the the............................ <image> and.......................... <image>
Ref : normal cardiac contour. clear hyperexpanded lungs bilaterally with no pneumothorax or pleural effusion.



Evaluating:  40%|████      | 4/10 [00:39<00:57,  9.65s/it]

Pred: <image> score the arex....................................................... <image>
Ref : lungs are clear bilaterally. there is no focal consolidation pleural effusion or pneumothoraces. cardiomediastinal silhouette is within normal limits. xxxx are unremarkable.



Evaluating:  50%|█████     | 5/10 [00:48<00:48,  9.66s/it]

Pred: <image>lak the the............................x........................... <image>
Ref : heart size within normal limits. small nodular opacity in the right upper lobe. this does not look like an acute infiltrate and more xxxx represents a granuloma. no pneumothorax or effusions.



Evaluating:  60%|██████    | 6/10 [00:58<00:38,  9.69s/it]

Pred: <image> hostile the the are are are are are......................................... the <image>x....... <image>
Ref : the cardiomediastinal silhouette and pulmonary vasculature are within normal limits in size. the lungs are clear of focal airspace disease pneumothorax or pleural effusion. there are no acute bony findings.



Evaluating:  70%|███████   | 7/10 [01:08<00:29,  9.73s/it]

Pred: <image> hostile the are are.. are are............... <image> and....................... <image>x <image>x....... <image>
Ref : both lungs are clear and expanded. heart and mediastinum normal.



Evaluating:  80%|████████  | 8/10 [01:18<00:19,  9.78s/it]

Pred: <image> exploits the the.............. <image> and........................................ <image>
Ref : heart size and mediastinal contours appear within normal limits. pulmonary vascularity is within normal limits. no focal consolidation suspicious pulmonary opacity pneumothorax or definite pleural effusion. visualized osseous structures appear intact.



Evaluating:  90%|█████████ | 9/10 [01:28<00:09,  9.83s/it]

Pred: <image> secondary <image> andststst.................................................
Ref : lungs are hyperinflated but clear. no focal infiltrate or effusion. heart and mediastinal contours within normal limits. calcified mediastinal xxxx identified.



Evaluating: 100%|██████████| 10/10 [01:30<00:00,  9.04s/it]

Pred: <image>lak thexxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.xxxxxxxxxxxxxxxxxxxx <image>
Ref : xxxx sternotomy xxxx and mediastinal postsurgical changes. stable cardiomegaly. crowded bronchovascular and interstitial markings xxxx related to low lung volumes and technique. grossly stable appearance of the lungs compared to prior exam without overt edema or gross airspace consolidation.






BLEU_1: 0.0153
BLEU_2: 0.0014
BLEU_3: 0.0000
BLEU_4: 0.0000
METEOR: 0.0121
ROUGE_L: 0.0526


{'BLEU_1': 0.015308923701470682,
 'BLEU_2': 0.0013710466038697627,
 'BLEU_3': 3.7567492191074755e-09,
 'BLEU_4': 6.46292172358475e-12,
 'METEOR': 0.012082561070748834,
 'ROUGE_L': np.float64(0.052584822845214665)}