In [None]:
from dataset.dataset import ChestDataset
from torch.utils.data import DataLoader
from qwen_vl_utils import process_vision_info
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
import torch

In [2]:
dataset= ChestDataset(index_path='index_files/mimic_index_frontal_train.json')

In [3]:
len(dataset)

204098

In [4]:
sample= dataset[1]

In [5]:
sample

{'image': <PIL.Image.Image image mode=RGB size=3056x2544>,
 'text': 'FINDINGS:\nFrontal and lateral views of the chest.  The lungs are clear of consolidation,\n effusion, or pulmonary vascular congestion.  The cardiac silhouette slightly\n enlarged and the aorta is tortuous.  No acute osseous abnormality detected.\n\nIMPRESSION:\nNo acute cardiopulmonary process.  Note evidence of congestive failure.'}

In [6]:
system_message = """You are an expert radiologist trained in interpreting chest X-rays. Given one or two radiographic images of the chest (frontal and/or lateral views), generate a detailed and clinically accurate radiology report. The report should include a structured summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. If no abnormality is present, state "No acute cardiopulmonary process."

Use professional radiological language, and focus on anatomical structures such as lungs, heart, mediastinum, pleura, bones, and devices (e.g., lines, tubes). Maintain a neutral, factual tone suitable for inclusion in a patient medical record."""

In [7]:
print(system_message)

You are an expert radiologist trained in interpreting chest X-rays. Given one or two radiographic images of the chest (frontal and/or lateral views), generate a detailed and clinically accurate radiology report. The report should include a structured summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. If no abnormality is present, state "No acute cardiopulmonary process."

Use professional radiological language, and focus on anatomical structures such as lungs, heart, mediastinum, pleura, bones, and devices (e.g., lines, tubes). Maintain a neutral, factual tone suitable for inclusion in a patient medical record.


In [8]:
def format_data(sample):
    return [
        {
            "role":"system",
            "content" : [{"type": "text", "text": system_message}]
        },
        {
            "role":"user",
            "content": [
                {
                    "type": "image",
                    "image": sample["image"],
                }
            ]
        },
        {
            "role": "assistant",
            "content":[
                {"type": "text", "text": sample["text"]},
            ]
        }
    ]
    

In [9]:
format_data(sample)

[{'role': 'system',
  'content': [{'type': 'text',
    'text': 'You are an expert radiologist trained in interpreting chest X-rays. Given one or two radiographic images of the chest (frontal and/or lateral views), generate a detailed and clinically accurate radiology report. The report should include a structured summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. If no abnormality is present, state "No acute cardiopulmonary process."\n\nUse professional radiological language, and focus on anatomical structures such as lungs, heart, mediastinum, pleura, bones, and devices (e.g., lines, tubes). Maintain a neutral, factual tone suitable for inclusion in a patient medical record.'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.Image.Image image mode=RGB size=3056x2544>}]},
 {'role': 'assistant',
  'content': [{'type': 'text',
    'text': 'FINDINGS:\nFrontal and 

In [10]:
class QwenDataCollator:
    def __init__(self, processor, max_length=512):
        self.processor = processor
        self.max_length = max_length
        self.system_message ="""You are an expert radiologist trained in interpreting chest X-rays. Given one or two radiographic images of the chest (frontal and/or lateral views), generate a detailed and clinically accurate radiology report. The report should include a structured summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. If no abnormality is present, state "No acute cardiopulmonary process." Use professional radiological language, and focus on anatomical structures such as lungs, heart, mediastinum, pleura, bones, and devices (e.g., lines, tubes). Maintain a neutral, factual tone suitable for inclusion in a patient medical record."""
    
    def format_data(self, sample):
        return [
            {
                "role":"system",
                "content" : [{"type": "text", "text": self.system_message}]
            },
            {
                "role":"user",
                "content": [
                    {
                        "type": "image",
                        "image": sample["image"],
                    }
                ]
            },
            {
                "role": "assistant",
                "content":[
                    {"type": "text", "text": sample["text"]},
                ]
            }
        ]
    
    def __call__(self, examples):
        # Extract images and texts from the batch
        examples = [self.format_data(sample) for sample in examples]
        texts =[
        self.processor.apply_chat_template(example, tokenize=False) for example in examples
        ]
        image_inputs = [process_vision_info(sample)[0] for sample in examples]
        # Tokenize the texts and process the images
        batch = self.processor(
            text=texts, images=image_inputs, return_tensors="pt", padding=True
        )
        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels
        # Ignore the image token index in the loss computation (model specific)
        if isinstance(self.processor, Qwen2VLProcessor):  # Check if the processor is Qwen2VLProcessor
            image_tokens = [151652, 151653, 151655]  # Specific image token IDs for Qwen2VLProcessor
        else:
            image_tokens = [self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token)]  # Convert image token to ID

        # Mask image token IDs in the labels
        for image_token_id in image_tokens:
            labels[labels == image_token_id] = -100  # Mask image token IDs in labels

        batch["labels"] = labels  # Add labels to the batch

        return batch 
       

In [11]:
model_id = "Qwen/Qwen2-VL-7B-Instruct"

In [12]:
# model = Qwen2VLForConditionalGeneration.from_pretrained(
#     model_id,
#     device_map="auto",
#     torch_dtype=torch.bfloat16,
# )

processor = Qwen2VLProcessor.from_pretrained(model_id, use_fast=True)
processor.tokenizer.padding_side = "right"

In [13]:
processor.image_processor

Qwen2VLImageProcessorFast {
  "crop_size": null,
  "data_format": "channels_first",
  "default_to_square": true,
  "device": null,
  "do_center_crop": null,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "Qwen2VLImageProcessorFast",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "input_data_format": null,
  "max_pixels": 12845056,
  "merge_size": 2,
  "min_pixels": 3136,
  "patch_size": 14,
  "processor_class": "Qwen2VLProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "return_tensors": null,
  "size": {
    "longest_edge": 12845056,
    "shortest_edge": 3136
  },
  "temporal_patch_size": 2
}

In [15]:
collator = QwenDataCollator(processor)

In [16]:
raw_batch = [dataset[i] for i in range(2)]

In [17]:
raw_batch

[{'image': <PIL.Image.Image image mode=RGB size=2544x3056>,
  'text': 'FINDINGS:\nFrontal and lateral views of the chest were obtained.  Lungs are\n clear without focal consolidation.  No pleural effusion or pneumothorax is\n seen.  Cardiac and mediastinal silhouettes are unremarkable.\n\nIMPRESSION:\nNo acute cardiopulmonary process.'},
 {'image': <PIL.Image.Image image mode=RGB size=3056x2544>,
  'text': 'FINDINGS:\nFrontal and lateral views of the chest.  The lungs are clear of consolidation,\n effusion, or pulmonary vascular congestion.  The cardiac silhouette slightly\n enlarged and the aorta is tortuous.  No acute osseous abnormality detected.\n\nIMPRESSION:\nNo acute cardiopulmonary process.  Note evidence of congestive failure.'}]

In [18]:
batch= collator(raw_batch)

In [19]:
batch['input_ids'].shape

torch.Size([2, 10153])

In [20]:
processor.tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True)

'system\nYou are an expert radiologist trained in interpreting chest X-rays. Given one or two radiographic images of the chest (frontal and/or lateral views), generate a detailed and clinically accurate radiology report. The report should include a structured summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. If no abnormality is present, state "No acute cardiopulmonary process." Use professional radiological language, and focus on anatomical structures such as lungs, heart, mediastinum, pleura, bones, and devices (e.g., lines, tubes). Maintain a neutral, factual tone suitable for inclusion in a patient medical record.\nuser\n\nassistant\nFINDINGS:\nFrontal and lateral views of the chest were obtained.  Lungs are\n clear without focal consolidation.  No pleural effusion or pneumothorax is\n seen.  Cardiac and mediastinal silhouettes are unremarkable.\n\nIMPRESSION:\nNo acute cardiop

In [15]:
from qwen_vl_utils import process_vision_info


def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cpu"):
    # Prepare the text input by applying the chat template
    text_input = processor.apply_chat_template(
        sample[1:2], tokenize=False, add_generation_prompt=True  # Use the sample without the system message
    )

    # Process the visual input from the sample
    image_inputs, _ = process_vision_info(sample)

    # Prepare the inputs for the model
    model_inputs = processor(
        text=[text_input],
        images=image_inputs,
        return_tensors="pt",
    ).to(
        device
    )  # Move inputs to the specified device

    # Generate text with the model
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_text[0]  # Return the first decoded output text

In [47]:
# Example of how to call the method with sample:
sample = format_data(dataset[0])

In [48]:
sample

[{'role': 'system',
  'content': [{'type': 'text',
    'text': 'You are an expert radiologist trained in interpreting chest X-rays. Given one or two radiographic images of the chest (frontal and/or lateral views), generate a detailed and clinically accurate radiology report. The report should include a structured summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. If no abnormality is present, state "No acute cardiopulmonary process."\n\nUse professional radiological language, and focus on anatomical structures such as lungs, heart, mediastinum, pleura, bones, and devices (e.g., lines, tubes). Maintain a neutral, factual tone suitable for inclusion in a patient medical record.'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.Image.Image image mode=RGB size=2544x3056>}]},
 {'role': 'assistant',
  'content': [{'type': 'text',
    'text': 'FINDINGS:\nFrontal and 

In [49]:
output = generate_text_from_sample(model, processor, sample)

RuntimeError: unable to mmap 3898649896 bytes from file </home/arpanp/.cache/huggingface/hub/models--Qwen--Qwen2-VL-7B-Instruct/snapshots/eed13092ef92e448dd6875b2a00151bd3f7db0ac/model-00001-of-00005.safetensors>: Cannot allocate memory (12)

In [50]:
process_vision_info(sample)

([<PIL.Image.Image image mode=RGB size=2548x3052>], None)