In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import ViltProcessor, ViltForQuestionAnswering
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from tqdm import tqdm
from PIL import Image
import pandas as pd
import os
import json

In [None]:
model_path = "vilt_finetuned_vqa"  # Directory where your fine-tuned model is saved
processor = ViltProcessor.from_pretrained(model_path)
model = ViltForQuestionAnswering.from_pretrained(model_path)

In [None]:
# Load the JSON file
with open('/scratch/bvs9764/physionet.org/files/mimic-ext-mimic-cxr-vqa/1.0.0/MIMIC-Ext-MIMIC-CXR-VQA/dataset/eval.json', 'r') as f:
    data = json.load(f)

# Extract relevant fields
processed_data = []
for record in data:
    processed_data.append({
        'image_path': record['image_path'],
        'question': record['question'],
        'answer': 1 if 'yes' in record['answer'] else 0  # Convert 'yes'/'no' to 1/0
    })

# Convert to DataFrame and save as CSV
df = pd.DataFrame(processed_data)
df.to_csv('eval_processed.csv', index=False)


In [None]:
class MIMICCXRQA_Dataset(Dataset):
    def __init__(self, csv_path, data_dir, processor, transform=None):
        self.data = pd.read_csv(csv_path)
        self.data_dir = data_dir
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx]['image_path']
        question = self.data.iloc[idx]['question']
        label = self.data.iloc[idx]['answer']  # Convert the string to 0 or 1

        # Load image
        full_img_path = f"{self.data_dir}/{img_path}"
        image = Image.open(full_img_path).convert("RGB")
        
        # Apply image transformations if any
        if self.transform:
            image = self.transform(image)

        # Process the image-question pair
        encoding = self.processor(images=image, text=question, return_tensors="pt", padding="max_length", truncation=True)

        # Ensure the tensors are squeezed for batch loading
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding['labels'] = torch.tensor(label, dtype=torch.long)

        return encoding

In [None]:
# Path to the data directory where images are stored
data_dir = '/scratch/bvs9764/physionet.org/files/mimic-cxr-jpg/2.1.0/files'
# Initialize dataset
eval_dataset = MIMICCXRQA_Dataset(
    csv_path='/scratch/bvs9764/eval_processed.csv',
    data_dir=data_dir,
    processor=processor,
    transform=xray_transforms
)

# Create DataLoader
eval_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [None]:
# Define paths
data_dir = '/scratch/bvs9764/physionet.org/files/mimic-cxr-jpg/2.1.0/files'
csv_path = '/scratch/bvs9764/eval_processed.csv'

# Lists to store true and predicted labels
true_labels = []
predicted_labels = []

# Evaluation loop
with torch.no_grad():
    for batch in tqdm(eval_loader, desc="Evaluating"):
        # Move inputs to the device
        input_ids = batch['input_ids'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Get model predictions
        outputs = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)

        # Store labels and predictions
        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(preds.cpu().numpy())

# Compute performance metrics
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average='binary')
recall = recall_score(true_labels, predicted_labels, average='binary')
f1 = f1_score(true_labels, predicted_labels, average='binary')

# Print the metrics
print("\nEvaluation Metrics:")
print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")

# Detailed classification report
print("\nClassification Report:")
print(classification_report(true_labels, predicted_labels, target_names=['No', 'Yes']))
