In [None]:
import os
import json
from typing import List, Dict, Optional
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, {torch.cuda.get_device_name(device)}")

In [None]:
# Load pre-trained LLaVA-Next model and processor
model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
processor = LlavaNextProcessor.from_pretrained(model_name)

# Load and configure the model
model = LlavaNextForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
).to(device)

# Configure LoRA
lora_config = LoraConfig(
    r=16,  # rank
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# Wrap the model with LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
class XRayReportDataset(Dataset):
    def __init__(self, 
                 data_dir: str, 
                 annotation_file: str, 
                 split: str = 'train'):
        self.data_dir = data_dir
        self.split = split
        self.processor = processor  # Use LlavaNextProcessor

        # Load annotations
        with open(annotation_file, 'r') as f:
            self.annotations = json.load(f)[split]

        # Define anatomical regions
        self.regions = ['lung', 'heart', 'mediastinal', 'bone']

    def __len__(self) -> int:
        return len(self.annotations)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        annotation = self.annotations[idx]
        patient_id = annotation['id']
        report = annotation['report']

        # Load the first image for this patient
        image_folder = os.path.join(self.data_dir, patient_id)
        image_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.png')])
        image_path = os.path.join(image_folder, image_files[0])  # Use the first image
        image = Image.open(image_path).convert('RGB')

        # Prepare report text for each anatomical region
        region_reports = {region: report.get(region, "") for region in self.regions}

        return {
            'image': image,
            'reports': region_reports,
            'patient_id': patient_id
        }

    @staticmethod
    def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
        images = [item['image'] for item in batch]
        reports = [item['reports'] for item in batch]
        patient_ids = [item['patient_id'] for item in batch]

        return {
            'images': images,
            'reports': reports,
            'patient_ids': patient_ids
        }


In [None]:
# Create datasets and dataloaders
data_dir = 'data/images'
annotation_file = 'data/annotation_quiz_all_with_val.json'

train_dataset = XRayReportDataset(data_dir, annotation_file, split='train')
val_dataset = XRayReportDataset(data_dir, annotation_file, split='val')
test_dataset = XRayReportDataset(data_dir, annotation_file, split='test')

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=XRayReportDataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=XRayReportDataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=XRayReportDataset.collate_fn)

In [None]:
# Prepare inputs function
def prepare_inputs(batch, processor):
    images = batch['images']
    
    conversations = []
    for report in batch['reports']:
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "Provide a detailed X-ray report for the following anatomical regions: lung, heart, mediastinal, and bone."},
                ],
            },
            {
                "role": "assistant",
                "content": f"Here's a detailed X-ray report for the anatomical regions:\n\nLung: {report['lung']}\n\nHeart: {report['heart']}\n\nMediastinal: {report['mediastinal']}\n\nBone: {report['bone']}",
            },
        ]
        conversations.append(conversation)
    
    inputs = processor(images=images, conversations=conversations, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Prepare labels for training
    labels = inputs['labels'].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    inputs['labels'] = labels

    return inputs