In [1]:
from dataset.dataset import ChestDataset
from torch.utils.data import DataLoader
from qwen_vl_utils import process_vision_info
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
import torch

In [2]:
dataset= ChestDataset(index_path='index_files/mimic_index_frontal_train.json')

In [3]:
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
min_pixels =224*224
max_pixels = 512*512

In [5]:
processor = Qwen2_5_VLProcessor.from_pretrained(model_id, use_fast=True, min_pixels=min_pixels, max_pixels=max_pixels)


In [6]:
processor.tokenizer.padding_side='right'

In [7]:
class QwenDataCollator:
    def __init__(self, processor, max_length=512):
        self.processor = processor
        self.max_length = max_length
        self.system_message ="""You are an expert radiologist trained in interpreting chest X-rays. Given radiographic image of the chest (frontal view), generate a detailed and clinically accurate radiology report. The report should include a summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. Use professional radiological language. Maintain a neutral, factual tone suitable for inclusion in a patient medical record. Do NOT reference previous imaging or mention paging physicians. Instead, describe the chest x-ray in a stand-alone format."""
    def format_data(self, sample):
        return [
            {
                "role":"system",
                "content" : [{"type": "text", "text": self.system_message}]
            },
            {
                "role":"user",
                "content": [
                    {
                        "type": "image",
                        "image": sample["image"],
                    }
                ]
            },
            {
                "role": "assistant",
                "content":[
                    {"type": "text", "text": sample["text"]},
                ]
            }
        ]
    
    def __call__(self, examples):
        # Extract images and texts from the batch
        examples = [self.format_data(sample) for sample in examples]
        texts =[
        self.processor.apply_chat_template(example, tokenize=False) for example in examples
        ]
        image_inputs = [process_vision_info(sample)[0] for sample in examples]
        # Tokenize the texts and process the images
        batch = self.processor(
            text=texts, images=image_inputs, return_tensors="pt", padding=True
        )
        #extract only assistant content
        input_ids_lists = batch['input_ids'].tolist()
        assert len(examples) == len(input_ids_lists)
        labels_list = []
        for ids_list in input_ids_lists:
            label_ids = [-100] * len(ids_list)
            for begin_end_indexs in find_assistant_content_sublist_indexes(ids_list):
                label_ids[begin_end_indexs[0]:begin_end_indexs[1]] = ids_list[begin_end_indexs[0]:begin_end_indexs[1]]
            labels_list.append(label_ids)
        
        labels = torch.tensor(labels_list, dtype= torch.int64)
        batch["labels"] = labels  # Add labels to the batch

        return batch
 
def find_assistant_content_sublist_indexes(l):
    '''
    This function tries to find the indexes of the assistant content in the input_ids list to build labels.
    '''
    # (Pdb++) processor.tokenizer.encode("<|im_start|>assistant\n")
    # [151644, 77091, 198]
    # (Pdb++) processor.tokenizer.encode("<|im_end|>\n")
    # [151645, 198]

    start_indexes = []
    end_indexes = []

    # Iterate through the list to find starting points
    for i in range(len(l) - 2):
        # Check if the current and next elements form the start sequence
        if l[i] == 151644 and l[i+1] == 77091 and l[i+2] == 198:
            start_indexes.append(i+3)
            # Now look for the first 151645 and 198 after the start
            for j in range(i+3, len(l)-1):
                if l[j] == 151645 and l[j+1] == 198:
                    end_indexes.append(j+2) # **NOTE** the <|im_end|>\n 2 tokens should be included in the label, so that model can predicate end of output.
                    break  # Move to the next start after finding the end

    return list(zip(start_indexes, end_indexes))

In [8]:
batch =[dataset[0], dataset[1]]

In [9]:
batch

[{'image': <PIL.Image.Image image mode=RGB size=2539x3050>,
  'text': 'FINDINGS:\nNone\n\nIMPRESSION:\nThere is less bilateral pleural effusion, predominantly on the left. The cardiac\nsilhouette and bilateral basal parenchymal opacities appear stable. No evidence\nof pneumonia or pulmonary edema.'},
 {'image': <PIL.Image.Image image mode=RGB size=2022x2022>,
  'text': 'FINDINGS:\nNone\n\nIMPRESSION:\nThe dual-channel ICD device with leads in the right atrium and apex of the right ventricle is visible without any apparent abnormalities. The cardiac silhouette is within normal limits, and there is no evidence of vascular congestion, pleural effusion, or acute focal pneumonia.'}]

In [10]:
class QwenDataCollator:
    def __init__(self, processor, max_length=512):
        self.processor = processor
        self.max_length = max_length
        self.system_message ="""You are an expert radiologist trained in interpreting chest X-rays. Given radiographic image of the chest (frontal view), generate a detailed and clinically accurate radiology report. The report should include a summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. Use professional radiological language. Maintain a neutral, factual tone suitable for inclusion in a patient medical record. Do NOT reference previous imaging or mention paging physicians. Instead, describe the chest x-ray in a stand-alone format."""
    def format_data(self, sample):
        return [
            {
                "role":"system",
                "content" : [{"type": "text", "text": self.system_message}]
            },
            {
                "role":"user",
                "content": [
                    {
                        "type": "image",
                        "image": sample["image"],
                    }
                ]
            },
            {
                "role": "assistant",
                "content":[
                    {"type": "text", "text": sample["text"]},
                ]
            }
        ]
    
    def __call__(self, examples):
        # Extract images and texts from the batch
        examples = [self.format_data(sample) for sample in examples]
        texts =[
        self.processor.apply_chat_template(example, tokenize=False) for example in examples
        ]
        image_inputs = [process_vision_info(sample)[0] for sample in examples]
        # Tokenize the texts and process the images
        batch = self.processor(
            text=texts, images=image_inputs, return_tensors="pt", padding=True
        )
        #extract only assistant content
        input_ids_lists = batch['input_ids'].tolist()
        assert len(examples) == len(input_ids_lists)
        labels_list = []
        for ids_list in input_ids_lists:
            label_ids = [-100] * len(ids_list)
            for begin_end_indexs in find_assistant_content_sublist_indexes(ids_list):
                label_ids[begin_end_indexs[0]:begin_end_indexs[1]] = ids_list[begin_end_indexs[0]:begin_end_indexs[1]]
            labels_list.append(label_ids)
        
        labels = torch.tensor(labels_list, dtype= torch.int64)
        batch["labels"] = labels  # Add labels to the batch

        return batch
 
def find_assistant_content_sublist_indexes(l):
    '''
    This function tries to find the indexes of the assistant content in the input_ids list to build labels.
    '''
    # (Pdb++) processor.tokenizer.encode("<|im_start|>assistant\n")
    # [151644, 77091, 198]
    # (Pdb++) processor.tokenizer.encode("<|im_end|>\n")
    # [151645, 198]

    start_indexes = []
    end_indexes = []

    # Iterate through the list to find starting points
    for i in range(len(l) - 2):
        # Check if the current and next elements form the start sequence
        if l[i] == 151644 and l[i+1] == 77091 and l[i+2] == 198:
            start_indexes.append(i+3)
            # Now look for the first 151645 and 198 after the start
            for j in range(i+3, len(l)-1):
                if l[j] == 151645 and l[j+1] == 198:
                    end_indexes.append(j+2) # **NOTE** the <|im_end|>\n 2 tokens should be included in the label, so that model can predicate end of output.
                    break  # Move to the next start after finding the end

    return list(zip(start_indexes, end_indexes))

In [11]:
collate_fn= QwenDataCollator(processor)

In [12]:
batch_from_coll=collate_fn(batch)

In [13]:
batch_from_coll.keys()

dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw', 'labels'])

In [14]:
print(processor.tokenizer.decode(batch_from_coll["input_ids"][1]))

<|im_start|>system
You are an expert radiologist trained in interpreting chest X-rays. Given radiographic image of the chest (frontal view), generate a detailed and clinically accurate radiology report. The report should include a summary of the observed findings and, if possible, an impression that highlights key diagnoses or abnormalities. Do not speculate beyond the image content. Use professional radiological language. Maintain a neutral, factual tone suitable for inclusion in a patient medical record. Do NOT reference previous imaging or mention paging physicians. Instead, describe the chest x-ray in a stand-alone format.<|im_end|>
<|im_start|>user
<|vision_start|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pa

In [15]:
labels=batch_from_coll['labels']

In [16]:
lbl=labels.tolist()
lbl=[x for x in lbl[1] if x != -100]

In [17]:
print(processor.tokenizer.decode(lbl))

FINDINGS:
None

IMPRESSION:
The dual-channel ICD device with leads in the right atrium and apex of the right ventricle is visible without any apparent abnormalities. The cardiac silhouette is within normal limits, and there is no evidence of vascular congestion, pleural effusion, or acute focal pneumonia.<|im_end|>



In [18]:
from trl import SFTConfig, SFTTrainer

In [19]:
class CustomTrainer(SFTTrainer):
    def custom_loss(self, model, inputs, return_outputs=False,num_items_in_batch=None):
        
        output=model(**inputs)
        loss=output.loss
        return (loss,output) if return_outputs else loss
    
        

In [20]:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [24]:
from trl import SFTConfig

# Configure training arguments
training_args = SFTConfig(
    output_dir="test",  # Directory to save the model
    num_train_epochs=1,  # Number of training epochs
    per_device_train_batch_size=2,  # Batch size for training
    # Optimizer and scheduler settings
    optim="adamw_torch_fused",  # Optimizer type
    learning_rate=2e-4,  # Learning rate for training
    lr_scheduler_type="constant",  # Type of learning rate scheduler
    # Logging and evaluation
    logging_steps=10,  # Steps interval for logging
    save_strategy="steps",  # Strategy for saving the model
    save_steps=20,  # Steps interval for saving
    # Mixed precision and gradient settings
    max_grad_norm=0.3,  # Maximum norm for gradient clipping
    warmup_ratio=0.03,  # Ratio of total steps for warmup
    # Hub and reporting
    report_to=None,  # Reporting tool for tracking metrics
    # Dataset configuration
    dataset_text_field="",  # Text field in dataset
    dataset_kwargs={"skip_prepare_dataset": True},  # Additional dataset options
    # max_seq_length=1024  # Maximum sequence length for input
    label_names=["labels"],  # Names of label columns in dataset
)

training_args.remove_unused_columns = False  # Keep unused columns in dataset

In [None]:
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
    processing_class=processor.tokenizer,
)