In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViltForQuestionAnswering, ViltProcessor, ViltConfig, AdamW
from PIL import Image
import pandas as pd
from tqdm import tqdm
import os
import pandas as pd
import json

In [2]:
config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

In [3]:
f = open('/scratch/bvs9764/physionet.org/files/mimic-ext-mimic-cxr-vqa/1.0.0/MIMIC-Ext-MIMIC-CXR-VQA/dataset/train.json')

# Return JSON object as dictionary
data_questions = json.load(f)
print(data_questions[0])

{'split': 'train', 'idx': 0, 'subject_id': '17945608', 'study_id': '55914880', 'image_id': 'e557790f-48a1ede1-2d7b2605-5f6c34c1-6713a5c0', 'image_path': 'p17/p17945608/s55914880/e557790f-48a1ede1-2d7b2605-5f6c34c1-6713a5c0.jpg', 'question': 'Is there any occurrence of anatomical findings in the left hilar structures?', 'semantic_type': 'verify', 'content_type': 'presence', 'template': 'Is there any occurrence of ${category} in the ${object}?', 'template_program': 'program_1', 'template_arguments': {'object': {'0': 'left hilar structures'}, 'attribute': {}, 'category': {'0': 'anatomicalfinding'}, 'viewpos': {}, 'gender': {}}, 'answer': ['yes']}


In [4]:
import json
import pandas as pd

# Load the JSON file
with open('/scratch/bvs9764/physionet.org/files/mimic-ext-mimic-cxr-vqa/1.0.0/MIMIC-Ext-MIMIC-CXR-VQA/dataset/train.json', 'r') as f:
    data = json.load(f)

# Extract relevant fields
processed_data = []
for record in data:
    processed_data.append({
        'image_path': record['image_path'],
        'question': record['question'],
        'answer': 1 if 'yes' in record['answer'] else 0  # Convert 'yes'/'no' to 1/0
    })

# Convert to DataFrame and save as CSV
df = pd.DataFrame(processed_data)
#df.to_csv('train_processed.csv', index=False)


In [5]:
len(df)

290031

In [6]:
from torchvision import transforms

xray_transforms = transforms.Compose([
    # Resize to match ViLT model input size
    transforms.Resize((384, 384)),
    
    # Center crop to maintain focus on central anatomy
    transforms.CenterCrop(384),
    
    # Random horizontal flip for data augmentation (50% chance)
    transforms.RandomHorizontalFlip(p=0.5),
    
    # Random rotation within a small range (+/- 10 degrees)
    transforms.RandomRotation(degrees=10),
    
    # Adjust contrast to enhance visibility of structures
    transforms.ColorJitter(contrast=0.2),
    
    # Convert image to tensor and normalize
    transforms.ToTensor(),
])


In [9]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import ViltProcessor
from PIL import Image

class MIMICCXRQA_Dataset(Dataset):
    def __init__(self, csv_path, data_dir, processor, transform=None):
        self.data = pd.read_csv(csv_path)
        self.data_dir = data_dir
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx]['image_path']
        question = self.data.iloc[idx]['question']
        label = self.data.iloc[idx]['answer']  # Convert the string to 0 or 1

        # Load image
        full_img_path = os.path.join(self.data_dir, img_path)

        try:
                # Load image
            image = Image.open(full_img_path).convert("RGB")

                # Apply image transformations if any
            if self.transform:
                image = self.transform(image)

                # Process the image-question pair
            encoding = self.processor(
                    images=image,
                    text=question,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True
                )

                # Ensure the tensors are squeezed for batch loading
            encoding = {k: v.squeeze() for k, v in encoding.items()}
            encoding['labels'] = torch.tensor(label, dtype=torch.long)

            return encoding

        except (FileNotFoundError, UnidentifiedImageError) as e:
                print(f"Warning: Missing or corrupt file '{full_img_path}', skipping...")
                idx += 1  # Skip to the next data sample

        # If all subsequent images are missing, raise an IndexError
        raise IndexError(f"All images after index {idx} are missing or corrupt.")


In [10]:
# Path to the data directory where images are stored
data_dir = '/scratch/bvs9764/physionet.org/files/mimic-cxr-jpg/2.1.0/files'

processor = ViltProcessor.from_pretrained('dandelin/vilt-b32-mlm',do_rescale = False)
# Initialize dataset
train_dataset = MIMICCXRQA_Dataset(
    csv_path='/scratch/bvs9764/train_processed.csv',
    data_dir=data_dir,
    processor=processor,
    transform=xray_transforms
)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)


In [11]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x1514b8a26290>

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-mlm",num_labels = 2)
model.to(device)

Some weights of ViltForQuestionAnswering were not initialized from the model checkpoint at dandelin/vilt-b32-mlm and are newly initialized: ['classifier.0.bias', 'classifier.0.weight', 'classifier.1.bias', 'classifier.1.weight', 'classifier.3.bias', 'classifier.3.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViltForQuestionAnswering(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=76

In [None]:
from transformers import AdamW
from tqdm import tqdm
import torch

# Freeze ViLT transformer layers to speed up fine-tuning
for param in model.vilt.parameters():
    param.requires_grad = False

# Ensure the classification head is trainable
for param in model.classifier.parameters():
    param.requires_grad = True
    
# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
model.train()
epochs = 3

for epoch in range(epochs):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        try:
            # Move inputs to the device
            for k, v in batch.items():
                batch[k] = v.to(device)

            # Expand labels to match the logits shape
            labels = batch['labels']  # Assuming labels are in the batch
            batch_size = labels.shape[0]
            expanded_labels = torch.zeros(batch_size, 2, device=labels.device)
            expanded_labels[torch.arange(batch_size), labels] = 1

            # Forward pass
            outputs = model(input_ids=batch['input_ids'],
                            pixel_values=batch['pixel_values'],
                            attention_mask=batch['attention_mask'],
                            labels=expanded_labels)

            loss = outputs.loss

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update progress bar
            loop.set_description(f"Epoch {epoch + 1}")
            loop.set_postfix(loss=loss.item())

        except Exception as e:
            print(f"Error during training: {e}, skipping batch...")

Epoch 1:   2%|▏         | 1440/72508 [14:07<11:50:23,  1.67it/s, loss=0.437]

In [None]:
# Define the directory to save the model and processor
save_directory = "vilt_finetuned_vqa"

# Save the model
model.save_pretrained(save_directory)

# Save the processor
processor.save_pretrained(save_directory)

print(f"Model and processor saved to {save_directory}")