In [1]:
from transformers import (
    AutoImageProcessor,
    ResNetForImageClassification,
    GPT2LMHeadModel,
    AutoTokenizer,
    DataCollatorWithPadding,
)
import torch
from datasets import Dataset
from PIL import Image
from torch.utils.data import DataLoader, Subset
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import requests
from io import BytesIO


# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
import os
import kagglehub
from kagglehub import KaggleDatasetAdapter
# Load the CSV files using kagglehub
print("Loading dataset from Kaggle using kagglehub...")

# Load the projections CSV
df_image = kagglehub.load_dataset(
    KaggleDatasetAdapter.PANDAS,
    "raddar/chest-xrays-indiana-university",
    "indiana_projections.csv"
)

# Load the reports CSV
df_report = kagglehub.load_dataset(
    KaggleDatasetAdapter.PANDAS,
    "raddar/chest-xrays-indiana-university", 
    "indiana_reports.csv"
)

print("Dataset loaded successfully!")
print(f"Images dataset shape: {df_image.shape}")
print(f"Reports dataset shape: {df_report.shape}")

# Create a DataFrame for images and captions
data = []
for i in range(len(df_image)):
    uid = df_image.iloc[i]['uid']
    image = df_image.iloc[i]['filename']
    index = df_report.loc[df_report['uid'] == uid]
    
    if not index.empty:    
        index = index.index[0]
        caption = df_report.iloc[index]['findings']
        if isinstance(caption, float):  # Skip rows with missing captions
            continue
        data.append({'imgs': image, 'captions': caption})

# Convert to a DataFrame
df = pd.DataFrame(data)
print(f"Final dataset with valid captions: {len(df)} samples")

# Download the dataset files to get the images path
dataset_path = kagglehub.dataset_download("raddar/chest-xrays-indiana-university")
print(f"Dataset downloaded to: {dataset_path}")

# Update image paths to use the downloaded dataset path
images_path = os.path.join(dataset_path, "images", "images_normalized")
df['imgs'] = df['imgs'].apply(lambda x: os.path.join(images_path, x))

# Verify first few image paths exist
print("Checking if image files exist:")
for i in range(min(3, len(df))):
    img_path = df.iloc[i]['imgs']
    exists = os.path.exists(img_path)
    print(f"  {img_path}: {'✓' if exists else '✗'}")

# Convert pandas DataFrame to a Dataset object
dataset = Dataset.from_pandas(df)
print(f"Dataset object created with {len(dataset)} samples")

Loading dataset from Kaggle using kagglehub...


  df_image = kagglehub.load_dataset(
  df_report = kagglehub.load_dataset(


Dataset loaded successfully!
Images dataset shape: (7466, 3)
Reports dataset shape: (3851, 8)
Final dataset with valid captions: 6469 samples
Dataset downloaded to: /home/alexandre/.cache/kagglehub/datasets/raddar/chest-xrays-indiana-university/versions/2
Checking if image files exist:
  /home/alexandre/.cache/kagglehub/datasets/raddar/chest-xrays-indiana-university/versions/2/images/images_normalized/1_IM-0001-4001.dcm.png: ✓
  /home/alexandre/.cache/kagglehub/datasets/raddar/chest-xrays-indiana-university/versions/2/images/images_normalized/1_IM-0001-3001.dcm.png: ✓
  /home/alexandre/.cache/kagglehub/datasets/raddar/chest-xrays-indiana-university/versions/2/images/images_normalized/2_IM-0652-1001.dcm.png: ✓
Dataset object created with 6469 samples


# Loading pretrained models

In [3]:
# Load ResNet-50 for feature extraction (frozen)
resnet_model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50").to(device)
resnet_model.eval()  # We won't train the ResNet, just use it for feature extraction

# Load GPT-2 for language generation
gpt2_model_name = "gpt2"  # or "distilgpt2" for a lighter version
tokenizer = AutoTokenizer.from_pretrained(gpt2_model_name)
# GPT-2 doesn't have a pad token by default, let's assign one:
tokenizer.pad_token = tokenizer.eos_token

# Preprocessing data

In [4]:
# Processor for ResNet images
image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50", use_fast=True)
def preprocess_images_and_captions(example):
    # Process the image
    image = Image.open(example["imgs"]).convert("L")  # Convert to grayscale
    image = Image.merge("RGB", [image, image, image])  # Convert grayscale to RGB
    image_inputs = image_processor(image, return_tensors="pt")
    pixel_values = image_inputs["pixel_values"].squeeze(0)  # Shape [3, 224, 224]

    # Tokenize the caption
    text_inputs = tokenizer(
        example["captions"],
        truncation=True,
        max_length=32, 
        return_tensors="pt"
    )

    return {
        "pixel_values": pixel_values.tolist(),  # Convert tensor to list
        "input_ids": text_inputs["input_ids"].squeeze(0).tolist(),  # Convert tensor to list
        "attention_mask": text_inputs["attention_mask"].squeeze(0).tolist(),  # Convert tensor to list
    }

Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor.


In [5]:
# Process dataset in smaller chunks to avoid memory overflow
def process_dataset_in_chunks(dataset, chunk_size=1000,num_workers=5):
    total_samples = len(dataset)
    processed_chunks = []
    
    for i in range(0, total_samples, chunk_size):
        end_idx = min(i + chunk_size, total_samples)
        chunk = dataset.select(range(i, end_idx))
        
        print(f"Processing chunk {i//chunk_size + 1}/{(total_samples-1)//chunk_size + 1}")
        
        processed_chunk = chunk.map(
            preprocess_images_and_captions,
            num_proc=num_workers,  
            batched=False,
            desc=f"Chunk {i//chunk_size + 1}"
        )
        
        processed_chunks.append(processed_chunk)
        
        # Clean up memory
        del chunk
        gc.collect()
    
    # Concatenate all chunks
    from datasets import concatenate_datasets
    return concatenate_datasets(processed_chunks)

In [7]:
num_workers = 10  # Adjust based on your CPU cores
batch_size = 50  # Process in batches for better efficiency

split = dataset.train_test_split(test_size=0.1, seed=42)  # 60% train, 40% test

training_dataset = split['train']
testing_dataset = split['test']
# clean up the memory
del dataset, split
import gc
gc.collect()
print("Preprocessing training dataset with multiple workers...")
print("Processing training dataset in chunks...")
training_dataset = process_dataset_in_chunks(training_dataset, chunk_size=1500, num_workers=num_workers)

print("Processing testing dataset in chunks...")
testing_dataset = process_dataset_in_chunks(testing_dataset, chunk_size=1500, num_workers=num_workers)


print("Preprocessing complete!")

Preprocessing training dataset with multiple workers...
Processing training dataset in chunks...
Processing chunk 1/4


Chunk 1 (num_proc=10): 100%|██████████| 1500/1500 [00:22<00:00, 66.89 examples/s] 


Processing chunk 2/4


Chunk 2 (num_proc=10): 100%|██████████| 1500/1500 [00:23<00:00, 63.84 examples/s] 


Processing chunk 3/4


Chunk 3 (num_proc=10): 100%|██████████| 1500/1500 [00:23<00:00, 63.24 examples/s] 


Processing chunk 4/4


Chunk 4 (num_proc=10): 100%|██████████| 1322/1322 [00:21<00:00, 61.11 examples/s] 


Processing testing dataset in chunks...
Processing chunk 1/1


Chunk 1 (num_proc=10): 100%|██████████| 647/647 [00:10<00:00, 62.74 examples/s] 


Preprocessing complete!


In [8]:
text_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

def combined_collate_fn(batch):
    pixel_values_list = []
    input_ids_list = []
    attention_mask_list = []
    
    for item in batch:
        # Validate and convert pixel values
        pv = torch.tensor(item["pixel_values"], dtype=torch.float32)
        if pv.shape != torch.Size([3, 224, 224]):
            raise ValueError(f"Expected pixel_values shape [3, 224, 224], got {pv.shape}")
        pixel_values_list.append(pv)
        
        # Convert text data
        input_ids_list.append(torch.tensor(item["input_ids"], dtype=torch.long))
        attention_mask_list.append(torch.tensor(item["attention_mask"], dtype=torch.long))
    
    # Stack pixel values into a single tensor
    pixel_values = torch.stack(pixel_values_list, dim=0)  # [batch_size, 3, 224, 224]
    
    # Use Hugging Face DataCollatorWithPadding for tokenized text
    text_batch = {
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list,
    }
    text_batch = text_collator(text_batch)
    
    # Add pixel values to the text batch
    text_batch["pixel_values"] = pixel_values
    
    return text_batch

dataloader = DataLoader(
    training_dataset,
    batch_size=10,         # only 2 samples in this example
    shuffle=True,
    collate_fn=combined_collate_fn,
    drop_last=True,      # can be True if you have many samples
)

In [9]:
class FeatureToCaption(nn.Module):
    """
    We:
      - Extract features from ResNet (outside this class, in the training loop, frozen)
      - Project them to GPT-2 hidden dim
      - Sum them with the GPT-2 token embeddings
    """
    def __init__(self, feature_dim=2048, hidden_dim=768, gpt2_name="gpt2"):
        super().__init__()
        self.linear = nn.Linear(feature_dim, hidden_dim)
        
        self.llm = GPT2LMHeadModel.from_pretrained(gpt2_name)
        # Because GPT-2 doesn't define pad_token by default
        self.llm.config.pad_token_id = tokenizer.eos_token_id

    def forward(self, resnet_features, input_ids, attention_mask):
        """
        resnet_features: [batch_size, feature_dim]
        input_ids:       [batch_size, seq_len]
        attention_mask:  [batch_size, seq_len]
        """
        # 1) Project the ResNet features to GPT-2 hidden size
        #    shape: [batch_size, hidden_dim]
        projected = self.linear(resnet_features)

        # 2) Expand them along seq_len dimension
        #    shape: [batch_size, 1, hidden_dim] -> [batch_size, seq_len, hidden_dim]
        batch_size, seq_len = input_ids.shape
        projected = projected.unsqueeze(1).expand(batch_size, seq_len, -1)

        # 3) GPT-2 token embeddings
        #    shape: [batch_size, seq_len, hidden_dim]
        token_embeds = self.llm.transformer.wte(input_ids)

        # 4) Sum them (the simplest approach)
        inputs_embeds = token_embeds + projected

        # 5) Forward pass through GPT-2
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=input_ids,  # for CrossEntropyLoss
        )
        return outputs


In [10]:
from transformers import get_scheduler
from tqdm import tqdm

# Instantiate our feature-to-caption model
model = FeatureToCaption(gpt2_name=gpt2_model_name).to(device)

################################################################################
# Training Loop with Progress Bars
################################################################################
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
epochs = 3
warmup_ratio = 0.1
num_training_steps = len(dataloader) * epochs
scaler = torch.amp.GradScaler()
lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=int(warmup_ratio * num_training_steps),  # 10% warmup
    num_training_steps=num_training_steps,
)

for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0.0
    
    # Create progress bar for the current epoch
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=True)
    
    for batch_idx, batch in enumerate(progress_bar):
        # batch has "pixel_values", "input_ids", "attention_mask"
        pixel_values = batch["pixel_values"].to(device)      # [batch_size, 3, 224, 224]
        input_ids = batch["input_ids"].to(device)            # [batch_size, seq_len]
        attention_mask = batch["attention_mask"].to(device)  # [batch_size, seq_len]

        # -------------------- Freeze ResNet & Extract Features -------------------
        with torch.no_grad():
            # 1) Embeddings
            emb_out = resnet_model.resnet.embedder(pixel_values)
            # 2) Encoder
            enc_out = resnet_model.resnet.encoder(emb_out)
            # 3) Pool & Flatten -> shape: [batch_size, 2048]
            pooled_features = resnet_model.resnet.pooler(enc_out.last_hidden_state).flatten(1)
        # -------------------------------------------------------------------------

        # Forward pass
        outputs = model(pooled_features, input_ids, attention_mask)
        loss = outputs.loss
        
        # Backward pass with scaling
        scaler.scale(loss).backward()

        # Unscale gradients BEFORE stepping
        scaler.unscale_(optimizer)

        # Clip gradients after unscaling
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Step optimizer and scaler
        scaler.step(optimizer)
        scaler.update()

        # Step learning rate scheduler
        lr_scheduler.step()

        # Zero gradients for next iteration
        optimizer.zero_grad()

        total_loss += loss.item()
        
        # Update progress bar with current loss and running average
        current_avg_loss = total_loss / (batch_idx + 1)
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Avg Loss': f'{current_avg_loss:.4f}',
            'LR': f'{lr_scheduler.get_last_lr()[0]:.6f}'
        })

    # Final epoch summary
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch}/{epochs} completed - Average Loss: {avg_loss:.4f}")
    print("-" * 50)

Epoch 1/3:   0%|          | 0/582 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Epoch 1/3: 100%|██████████| 582/582 [04:17<00:00,  2.26it/s, Loss=0.8602, Avg Loss=1.6300, LR=0.000084]


Epoch 1/3 completed - Average Loss: 1.6300
--------------------------------------------------


Epoch 2/3: 100%|██████████| 582/582 [04:17<00:00,  2.26it/s, Loss=0.8094, Avg Loss=0.8543, LR=0.000030]


Epoch 2/3 completed - Average Loss: 0.8543
--------------------------------------------------


Epoch 3/3: 100%|██████████| 582/582 [04:16<00:00,  2.27it/s, Loss=1.2525, Avg Loss=0.6581, LR=0.000000]

Epoch 3/3 completed - Average Loss: 0.6581
--------------------------------------------------





In [None]:
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = "models"
model_name = f"xray_captioning_v1_{timestamp}"
model_path = os.path.join(model_dir, f"{model_name}.pt")

if not os.path.exists(model_dir):
    os.makedirs(model_dir)
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epochs,
    'loss': avg_loss,
    'model_config': {
        'feature_dim': 2048,
        'hidden_dim': 768,
        'gpt2_name': gpt2_model_name
    }
}
torch.save(checkpoint, model_path)
torch.save(model, model_path)

# ---- Loading ----
#model = torch.load(model_path)
#model.eval()

## Testing the model

In [16]:
model.eval()

def evaluate_model_on_test_set(model, test_dataset, num_samples=100):
    """Properly evaluate model on test set"""
    evalRefs = []
    evalHyps = []
    
    model.eval()
    
    # Use tqdm for evaluation progress
    test_samples = list(test_dataset.select(range(min(num_samples, len(test_dataset)))))
    
    for idx, sample in enumerate(tqdm(test_samples, desc="Evaluating")):
        with torch.no_grad():
            # Convert to tensors
            pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0).to(device)
            input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(device)
            attention_mask = torch.tensor(sample["attention_mask"]).unsqueeze(0).to(device)
        
            # Extract ResNet features
            emb_out = resnet_model.resnet.embedder(pixel_values)
            enc_out = resnet_model.resnet.encoder(emb_out)
            pooled_features = resnet_model.resnet.pooler(enc_out.last_hidden_state).flatten(1)
        
            # Generate caption
            outputs = model.llm.generate(
                inputs_embeds=(model.linear(pooled_features).unsqueeze(1) + 
                              model.llm.transformer.wte(input_ids)),
                attention_mask=attention_mask,
                max_length=100,
                num_beams=2,  # Use beam search for better quality
                do_sample=False,  # Deterministic for evaluation
                pad_token_id=tokenizer.eos_token_id
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Add to evaluation lists
        evalRefs.append({
            "image_id": idx,
            "caption": sample["captions"]
        })
        
        evalHyps.append({
            "image_id": idx,
            "caption": generated_text
        })
    
    return evalRefs, evalHyps

# Run evaluation
evalRefs, evalHyps = evaluate_model_on_test_set(model, testing_dataset, num_samples=100)
print(f"Evaluated on {len(evalHyps)} test samples")


Evaluating: 100%|██████████| 100/100 [00:09<00:00, 10.46it/s]


Evaluated on 100 test samples


In [24]:
# Fix NLTK downloads - add this at the top of your evaluation cell
import nltk
import ssl

try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

# Download required NLTK data
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('stopwords')
print("NLTK data downloaded successfully!")

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.tokenize import word_tokenize
import sacrebleu


NLTK data downloaded successfully!


[nltk_data] Downloading package punkt to /home/alexandre/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/alexandre/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/alexandre/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


# COMPUTE EVALUATION METRICS

In [35]:
#!pip install git+https://github.com/salaniz/pycocoevalcap.git
evalRefs = []
evalHyps = []

model.eval()

total_samples = len(testing_dataset)
print(f"Processing {total_samples} testing samples...")


for idx, sample in enumerate(tqdm(testing_dataset, desc="Generating captions", total=total_samples)):
    with torch.no_grad():
        # Convert pixel_values back to tensor
        pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0).to(device)
        
        # Convert input_ids and attention_mask to tensors
        input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(device)
        attention_mask = torch.tensor(sample["attention_mask"]).unsqueeze(0).to(device)
    
        # ResNet Features
        emb_out = resnet_model.resnet.embedder(pixel_values)
        enc_out = resnet_model.resnet.encoder(emb_out)
        pooled_features = resnet_model.resnet.pooler(enc_out.last_hidden_state).flatten(1)
    
        # Generate a caption
        combined_embeds = (model.linear(pooled_features).unsqueeze(1) + model.llm.transformer.wte(input_ids))
    
        # Generate a caption
        outputs = model.llm.generate(
            inputs_embeds=combined_embeds,
            attention_mask=attention_mask,
            max_length=50,  # Increase this
            min_length=10,  # Add minimum length
            num_beams=3,    # More beams
            no_repeat_ngram_size=2,  # Avoid repetition
            do_sample=True,  # Add some randomness
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Convert references into the required format
    # If sample["captions"] is a single string:
    if isinstance(sample["captions"], str):
        gt_captions = [sample["captions"]]
    else:
        gt_captions = sample["captions"]
    
    # Append references
    # pycocoevalcap expects something like:
    # {"image_id": <id>, "caption": "some reference caption"}
    for ref in gt_captions:
        evalRefs.append({
            "image_id": idx,
            "caption": ref
        })

    # Append hypothesis
    evalHyps.append({
        "image_id": idx,
        "caption": generated_text
    })
print(f"Number of references: {len(evalRefs)}")
print(f"Number of hypotheses: {len(evalHyps)}")

Processing 647 testing samples...


Generating captions: 100%|██████████| 647/647 [01:07<00:00,  9.59it/s]

Number of references: 647
Number of hypotheses: 647





In [36]:
# Better evaluation metrics for medical captioning
from rouge_score import rouge_scorer
from bert_score import score as bert_score
def evaluate_medical_captions(references, predictions):
    """Enhanced evaluation for medical image captioning with better error handling"""
    
    # Filter out empty predictions for better analysis
    valid_pairs = [(ref, pred) for ref, pred in zip(references, predictions) 
                   if pred.strip() != "" and ref.strip() != ""]
    
    if len(valid_pairs) == 0:
        print("ERROR: No valid prediction-reference pairs found!")
        return {'error': 'no_valid_pairs'}
    
    valid_refs, valid_preds = zip(*valid_pairs)
    
    print(f"Evaluating {len(valid_pairs)} valid pairs out of {len(references)} total")
    
    # 1. BLEU Score with better tokenization
    from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
    import re
    smooth_fn = SmoothingFunction().method1
    bleu_scores = []
    
    for ref, pred in zip(valid_refs, valid_preds):
        # Better tokenization for medical text
        ref_tokens = re.findall(r'\b\w+\b', ref.lower())
        pred_tokens = re.findall(r'\b\w+\b', pred.lower())
        
        if len(pred_tokens) > 0:  # Avoid empty predictions
            score = sentence_bleu([ref_tokens], pred_tokens, smoothing_function=smooth_fn)
            bleu_scores.append(score)
    
    # 2. ROUGE-L Score
    from rouge_score import rouge_scorer
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    rouge_scores = []
    
    for ref, pred in zip(valid_refs, valid_preds):
        if pred.strip():  # Only if prediction is not empty
            score = rouge_scorer_obj.score(ref, pred)['rougeL'].fmeasure
            rouge_scores.append(score)
    
    # 3. BERTScore (semantic similarity) - only on valid pairs
    bert_f1 = 0.0
    if len(valid_preds) > 0:
        try:
            from bert_score import score as bert_score
            # Filter out empty predictions for BERTScore
            non_empty_preds = [p for p in valid_preds if p.strip()]
            non_empty_refs = [valid_refs[i] for i, p in enumerate(valid_preds) if p.strip()]
            
            if len(non_empty_preds) > 0:
                P, R, F1 = bert_score(non_empty_preds, non_empty_refs, lang="en", verbose=False)
                bert_f1 = F1.mean().item()
        except Exception as e:
            print(f"BERTScore error: {e}")
            bert_f1 = 0.0
    
    # 4. Enhanced Medical keyword overlap
    medical_keywords = [
        # Basic anatomy
        'heart', 'lungs', 'chest', 'ribs', 'diaphragm', 'mediastinum',
        # Positions/orientations
        'bilateral', 'right', 'left', 'upper', 'lower', 'middle', 'base', 'apex',
        'posterior', 'anterior', 'lateral', 'medial',
        # Normal findings
        'normal', 'clear', 'unremarkable', 'stable', 'unchanged',
        # Abnormal findings
        'abnormal', 'pneumonia', 'consolidation', 'effusion', 'infiltrate',
        'cardiomegaly', 'atelectasis', 'pneumothorax', 'opacity', 'nodule',
        'mass', 'lesion', 'edema', 'congestion', 'hyperinflation',
        # Descriptors
        'increased', 'decreased', 'enlarged', 'small', 'large', 'prominent',
        'mild', 'moderate', 'severe', 'acute', 'chronic'
    ]
    
    keyword_overlap_scores = []
    medical_recall_scores = []
    
    for ref, pred in zip(valid_refs, valid_preds):
        ref_words = set(re.findall(r'\b\w+\b', ref.lower()))
        pred_words = set(re.findall(r'\b\w+\b', pred.lower()))
        
        ref_keywords = ref_words.intersection(set(medical_keywords))
        pred_keywords = pred_words.intersection(set(medical_keywords))
        
        # Precision: how many predicted keywords are correct
        if len(pred_keywords) == 0:
            precision = 1.0 if len(ref_keywords) == 0 else 0.0
        else:
            precision = len(ref_keywords.intersection(pred_keywords)) / len(pred_keywords)
        
        # Recall: how many reference keywords were predicted
        if len(ref_keywords) == 0:
            recall = 1.0 if len(pred_keywords) == 0 else 0.0
        else:
            recall = len(ref_keywords.intersection(pred_keywords)) / len(ref_keywords)
        
        keyword_overlap_scores.append(precision)
        medical_recall_scores.append(recall)
    
    # 5. Caption length analysis
    avg_ref_length = sum(len(ref.split()) for ref in valid_refs) / len(valid_refs)
    avg_pred_length = sum(len(pred.split()) for pred in valid_preds) / len(valid_preds)
    
    # 6. Vocabulary overlap
    all_ref_words = set()
    all_pred_words = set()
    for ref, pred in zip(valid_refs, valid_preds):
        all_ref_words.update(ref.lower().split())
        all_pred_words.update(pred.lower().split())
    
    vocab_overlap = len(all_ref_words.intersection(all_pred_words)) / len(all_ref_words.union(all_pred_words))
    
    return {
        'bleu': sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0,
        'rouge_l': sum(rouge_scores) / len(rouge_scores) if rouge_scores else 0.0,
        'bert_f1': bert_f1,
        'medical_keyword_precision': sum(keyword_overlap_scores) / len(keyword_overlap_scores) if keyword_overlap_scores else 0.0,
        'medical_keyword_recall': sum(medical_recall_scores) / len(medical_recall_scores) if medical_recall_scores else 0.0,
        'vocab_overlap': vocab_overlap,
        'avg_ref_length': avg_ref_length,
        'avg_pred_length': avg_pred_length,
        'valid_predictions': len(valid_pairs),
        'total_samples': len(references),
        'empty_predictions': len(references) - len(valid_pairs)
    }

# Run improved evaluation
references = [ref['caption'] for ref in evalRefs]
predictions = [hyp['caption'] for hyp in evalHyps]

print("Computing enhanced evaluation metrics...")
results = evaluate_medical_captions(references, predictions)

print("\n" + "="*60)
print("ENHANCED EVALUATION RESULTS")
print("="*60)

if 'error' not in results:
    print("SIMILARITY METRICS:")
    print(f"  BLEU Score:           {results['bleu']:.4f}")
    print(f"  ROUGE-L:              {results['rouge_l']:.4f}")
    print(f"  BERT F1:              {results['bert_f1']:.4f}")
    print(f"  Vocabulary Overlap:   {results['vocab_overlap']:.4f}")
    
    print("\nMEDICAL TERMINOLOGY:")
    print(f"  Keyword Precision:    {results['medical_keyword_precision']:.4f}")
    print(f"  Keyword Recall:       {results['medical_keyword_recall']:.4f}")
    
    print("\nCAPTION STATISTICS:")
    print(f"  Avg Reference Length: {results['avg_ref_length']:.1f} words")
    print(f"  Avg Prediction Length:{results['avg_pred_length']:.1f} words")
    print(f"  Valid Predictions:    {results['valid_predictions']}/{results['total_samples']}")
    print(f"  Empty Predictions:    {results['empty_predictions']}")
else:
    print("ERROR in evaluation:", results)

print("="*60)

# Show some example predictions for debugging
print("\nSAMPLE PREDICTIONS:")
print("-" * 40)
for i in range(min(5, len(references))):
    print(f"Reference {i+1}: {references[i][:100]}...")
    print(f"Prediction {i+1}: {predictions[i][:100]}...")
    print("-" * 40)


Computing enhanced evaluation metrics...
Evaluating 567 valid pairs out of 647 total


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



ENHANCED EVALUATION RESULTS
SIMILARITY METRICS:
  BLEU Score:           0.0199
  ROUGE-L:              0.1895
  BERT F1:              0.8625
  Vocabulary Overlap:   0.1949

MEDICAL TERMINOLOGY:
  Keyword Precision:    0.4927
  Keyword Recall:       0.1673

CAPTION STATISTICS:
  Avg Reference Length: 33.4 words
  Avg Prediction Length:8.4 words
  Valid Predictions:    567/647
  Empty Predictions:    80

SAMPLE PREDICTIONS:
----------------------------------------
Reference 1: The lungs are clear without focal consolidation, effusion, or pneumothorax. Normal heart size. Bony ...
Prediction 1: ...
----------------------------------------
Reference 2: Heart size within normal limits. Prominent right perihilar density consistent with lymphadenopathy, ...
Prediction 2: X unchanged. No focal alveolar consolidation, no definite pleural effusion seen....
----------------------------------------
Reference 3: The heart is borderline size. Aorta is atherosclerotic. The mediastinum is stable. The 