In [1]:
import os
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import torch

  from .autonotebook import tqdm as notebook_tqdm


# Load Qwen2-VL-7B-Instruct model

In [2]:
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)

# Load the processor
min_pixels = 1000000
max_pixels = 1000000
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)

# Diagnosis codes and descriptions
PTB_CODE = ['NORM', 'MI', 'STTC', 'CD', 'HYP']
CONDITIONS = ["Normal ECG", "Myocardial Infarction", "ST/T Change", "Conduction Disturbance", "Hypertrophy"]


`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.32s/it]


# Load images generated based on ConMIL interpretations and prediction sets

In [3]:
sample_path = 'PTB-XL'
    
# Load ground truth and prediction set
gt_path = os.path.join(sample_path, "ground_truth.txt")
with open(gt_path, "r") as f:
    gt_text = f.read().strip()
    gt = PTB_CODE.index(gt_text.split(":")[1].strip())

# Load prediction set figures
prediction_set = [fname.replace(".png", "") for fname in os.listdir(sample_path) if fname.endswith(".png") and fname != "org.png"]
pset_indices = [PTB_CODE.index(code) for code in prediction_set]

# Load images for predictions
conmil_paths = [os.path.join(sample_path, f"{PTB_CODE[idx]}.png") for idx in pset_indices]
images = [Image.open(path).convert("RGB") for path in conmil_paths]

# Generate instructions

In [4]:
# Create dynamic instruction
diagnosis_options = ", ".join([CONDITIONS[idx] for idx in sorted(pset_indices)])
instruction = (
    f"Given a 12-lead ECG and visual interpretations for the following possible diagnoses: {diagnosis_options}, "
    f"determine the most likely diagnosis. Use the provided model interpretations as reference and "
    f"base your decision solely on these visual features without additional analysis or introducing new criteria.\n\n"
    f"Provide your answer in the following format:\n"
    f"Diagnosis: <Selected Diagnosis>\n"
    f"Reason: <Brief reason for the choice based on visual features>"
)

# Prepare messages for the processor
content = [{"type": f"Model interpretation for diagnosis: {CONDITIONS[idx]}", "image": img} for idx, img in zip(pset_indices, images)]
content.append({"type": "Instruction", "text": instruction})
messages = [{"role": "user", "content": content}]

# Process vision info and prepare inputs
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=[image_inputs],
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")  # Move inputs to GPU if using CUDA

# LLM inference

In [5]:
 # Perform inference
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
)

# Parse LLM response
response = output_text[0].strip()
diagnosis = response.split("Diagnosis: ")[1].split("\n")[0].strip() if "Diagnosis: " in response else "Unknown"

# Show LLM response

In [6]:
print("Instruction:", instruction)
print("Generated Response:", response)
print("Extracted Conclusion:", diagnosis)
print("Ground Truth:", CONDITIONS[gt])

Instruction: Given a 12-lead ECG and visual interpretations for the following possible diagnoses: Myocardial Infarction, Conduction Disturbance, determine the most likely diagnosis. Use the provided model interpretations as reference and base your decision solely on these visual features without additional analysis or introducing new criteria.

Provide your answer in the following format:
Diagnosis: <Selected Diagnosis>
Reason: <Brief reason for the choice based on visual features>
Generated Response: Diagnosis: Conduction Disturbance
Reason: The ECG shows a regular rhythm with a consistent PR interval and QRS duration, which is indicative of a conduction disturbance rather than a myocardial infarction. The absence of ST segment elevation or depression, T wave inversion, or Q wave abnormalities further supports the diagnosis of conduction disturbance.
Extracted Conclusion: Conduction Disturbance
Ground Truth: Conduction Disturbance


# Now we see how LLM works without ConMIL's support

In [7]:
# Load original image
org_image_path = os.path.join(sample_path, "org.png")

org_image = Image.open(org_image_path).convert("RGB")

# Create instruction
instruction = (
    f"Given a 12-lead ECG, determine the most likely diagnosis among the following categories:"
    f"Normal ECG, Myocardial Infarction, ST/T Change, Conduction Disturbance, or Hypertrophy."
    f"Base your decision solely on the visual features of the provided ECG without performing additional analysis or introducing new criteria.\n\n"
    f"Provide your answer in the following format:\n"
    f"Diagnosis: <Selected Diagnosis>\n"
    f"Reason: <Brief reason for the choice based on visual features>"
)

# Prepare messages for the processor
messages = [
    {
        "role": "user",
        "content": [
            {"type": "ECG Plot", "image": org_image},
            {"type": "Instruction", "text": instruction},
        ],
    }
]

# Process vision info and prepare inputs
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=[image_inputs],
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")  # Move inputs to GPU if using CUDA

# Perform inference

In [8]:
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
)

# Parse LLM response
response = output_text[0].strip()
diagnosis = response.split("Diagnosis: ")[1].split("\n")[0].strip() if "Diagnosis: " in response else "Unknown"

In [9]:
print("Instruction:", instruction)
print("Generated Response:", response)
print("Extracted Conclusion:", diagnosis)
print("Ground Truth:", CONDITIONS[gt])

Instruction: Given a 12-lead ECG, determine the most likely diagnosis among the following categories:Normal ECG, Myocardial Infarction, ST/T Change, Conduction Disturbance, or Hypertrophy.Base your decision solely on the visual features of the provided ECG without performing additional analysis or introducing new criteria.

Provide your answer in the following format:
Diagnosis: <Selected Diagnosis>
Reason: <Brief reason for the choice based on visual features>
Generated Response: Diagnosis: ST/T Change
Reason: The ECG shows ST segment elevation in leads V1-V6, which is indicative of acute myocardial infarction. However, the presence of ST segment elevation in the absence of Q waves suggests a diagnosis of ST elevation myocardial infarction (STEMI) rather than non-ST elevation myocardial infarction (NSTEMI). The T waves are also elevated, which is consistent with the ST/T change pattern.
Extracted Conclusion: ST/T Change
Ground Truth: Conduction Disturbance


# We can also see how LLM works with only ConMIL prediction set

In [10]:
# Create instruction
instruction = (
    f"Given a 12-lead ECG, determine the most likely diagnosis among the following categories:"
    f"{diagnosis_options}"
    f"Base your decision solely on the visual features of the provided ECG without performing additional analysis or introducing new criteria.\n\n"
    f"Provide your answer in the following format:\n"
    f"Diagnosis: <Selected Diagnosis>\n"
    f"Reason: <Brief reason for the choice based on visual features>"
)

# Prepare messages for the processor
messages = [
    {
        "role": "user",
        "content": [
            {"type": "ECG Plot", "image": org_image},
            {"type": "Instruction", "text": instruction},
        ],
    }
]

# Process vision info and prepare inputs
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=[image_inputs],
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")  # Move inputs to GPU if using CUDA

In [11]:
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
)

# Parse LLM response
response = output_text[0].strip()
diagnosis = response.split("Diagnosis: ")[1].split("\n")[0].strip() if "Diagnosis: " in response else "Unknown"

In [12]:
print("Instruction:", instruction)
print("Generated Response:", response)
print("Extracted Conclusion:", diagnosis)
print("Ground Truth:", CONDITIONS[gt])

Instruction: Given a 12-lead ECG, determine the most likely diagnosis among the following categories:Myocardial Infarction, Conduction DisturbanceBase your decision solely on the visual features of the provided ECG without performing additional analysis or introducing new criteria.

Provide your answer in the following format:
Diagnosis: <Selected Diagnosis>
Reason: <Brief reason for the choice based on visual features>
Generated Response: Diagnosis: Conduction Disturbance
Reason: The ECG shows a regular rhythm with a consistent PR interval and QRS duration, which suggests a conduction disturbance rather than a myocardial infarction. The absence of ST segment or T wave abnormalities that are typically seen in myocardial infarction further supports this diagnosis.
Extracted Conclusion: Conduction Disturbance
Ground Truth: Conduction Disturbance
