# Import

In [7]:
%cd /home/is/dwipraseetyo-a/NAS_HAI/Project/Qwen2.5-Omni
%pwd

import commons, const_variable
import os, librosa, random, pickle, pydicom, requests, torch, re, json, pydicom
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
from dotenv import load_dotenv

load_dotenv()

def fix_broken_bullet_blocks(text):
    lines = text.splitlines()
    fixed_lines = []
    i = 0
    while i < len(lines):
        line = lines[i].strip()

        # Detect pattern:
        # Line i: startswith "*   Something"
        # Line i+1: is "*"
        # Line i+2: startswith "*   SomethingElse"
        if line.startswith("*") and not line == "*" and i + 2 < len(lines):
            next_line = lines[i + 1].strip()
            next_next_line = lines[i + 2].strip()
            if next_line == "*" and next_next_line.startswith("*"):
                # Add current line
                fixed_lines.append(lines[i])
                # Merge next_next_line as continuation bullet
                fixed_lines.append("    " + next_next_line)
                i += 3
                continue

        # If not matching pattern, just add the line
        fixed_lines.append(lines[i])
        i += 1

    return "\n".join(fixed_lines)

/home/ldap-users-2/dwipraseetyo-a/Project/Qwen2.5-Omni


# Iterate

In [6]:
model_id = "/home/is/dwipraseetyo-a/NAS_HAI/Project/pretrain/medgemma-4b-it"
auth_key = os.getenv("HG_AUTH_KEY")
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=auth_key,
)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.21it/s]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
DATA_PATH = "/home/is/dwipraseetyo-a/NAS_HAI/Datasets/cidrz"

In [None]:
for split in ['dev', 'test']:
    df = pd.read_csv(f"{DATA_PATH}/metadata_cut_processed.csv.{split}")
    results = []
    paths = []
    images = []
    messages_batch = []

    for now_row in tqdm(df.itertuples(), desc=f"Processing {split},", total=len(df)):
        now_path = now_row.path_file_image
        now_image = commons.crop_and_convert("/home/is/dwipraseetyo-a/NAS_HAI/Datasets/cidrz/" + now_path)
        answer = "This is Tuberculosis" if now_row.ground_truth_tb == 1 else "This is Not Tuberculosis"
        system_prompt = (
            "You are an expert radiologist. Analyze the chest X-ray and return only the following sections:\n\n"
            "**Specific Findings:** Provide a comprehensive and detailed description of all relevant radiographic abnormalities.\n\n"
            "**Differential Diagnosis:** Only if the X-ray not healty and describe the reason.\n\n"
            "**Conclusion:** Summarize your diagnostic impression clearly.\n\n"
            "Do not include any **Disclaimer** or unrelated content."
        ) if now_row.ground_truth_tb == 0 else (
            "You are an expert radiologist. Analyze the chest X-ray and return only the following sections:\n\n"
            "**Key Features:** Highlight radiographic patterns that are characteristic of tuberculosis, explained clearly and in detail.\n\n"
            "**Conclusion:** Summarize your diagnostic impression clearly.\n\n"
            "Do not include any **Disclaimer** or unrelated content."
        )
        message = [
            {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
            {"role": "user", "content": [
                {"type": "text", "text": f"Describe the X-ray." + answer},
                {"type": "image", "image": now_image}
            ]}
        ]
        paths.append(now_path)
        messages_batch.append(message)

    # Process in batches
    batch_size = 24
    for i in tqdm(range(0, len(messages_batch), batch_size), desc="Generating", total=len(range(0, len(messages_batch), batch_size))):
        batch_messages = messages_batch[i:i+batch_size]
        batch_paths = paths[i:i+batch_size]

        inputs = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=True,
                                               return_dict=True, return_tensors="pt", padding=True).to(model.device, dtype=torch.bfloat16)

        input_lens = [len(ids) for ids in inputs['input_ids']]

        with torch.inference_mode():
            generation = model.generate(**inputs, max_new_tokens=768, do_sample=False)

        # Slice outputs to exclude input prompt
        for j in range(len(batch_messages)):
            output_ids = generation[j][input_lens[j]:]
            decoded = processor.decode(output_ids, skip_special_tokens=True)
            results.append({
                "path_file_image": batch_paths[j],
                "llm_analyze_image": decoded
            })
    pd.DataFrame(results).to_csv(f'datas/reasoning/xray/medgemma_xray.csv.{split}', index=False)

Processing dev,:   0%|                                                                                                                                                                                               | 0/269 [00:00<?, ?it/s]


In [None]:
for split in ['train', 'dev', 'test']:
    df = pd.read_csv(f"datas/reasoning/xray/medgemma_xray.csv.{split}")
    df = df.rename(columns={'coughdur': 'cough_duration', 'ngtsweats': 'night_sweets', 'weightloss': 'weight_loss', 'body_wt': 'body_weight'})

    df_temp = pd.DataFrame(columns=['path_file_image', 'llm_analyze_image'])
    for now_row in tqdm(df.itertuples(), desc=f"Processing {split},", total=len(df)):
        now_path = now_row.path_file_image
        text = now_row.llm_analyze_image

        section_pattern = r'(?:\*\*)?\{?([A-Za-z ]+?)\}?:?(?:\*\*)?\n\n|(?:\.\n\n)?([A-Za-z ]+?):\n\n'
        matches = list(re.finditer(section_pattern, text))
        parsed = []

        start = 0
        for i, match in enumerate(matches):
            title = match.group(1) or match.group(2)
            title = title.strip()

            content_start = match.end()
            content_end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
            content = text[content_start:content_end].strip()

            disclaimer = re.search(r'\n\n\*?\*?Disclaimer:.*', content, re.IGNORECASE | re.DOTALL)
            if disclaimer:
                content = content[:disclaimer.start()].strip()
            parsed.append((title, content))
        
        extracted_data = "\n\n"
        for idx, (title, content) in enumerate(parsed):
            if idx == len(parsed) - 1:
                extracted_data += f"{content.replace('*   ', '')}" + " \n\n"
            else:
                extracted_data += f"* {title}\n{content.replace('*   ', '    *   ')}" + " \n\n"
        
        df_temp.loc[len(df_temp)] = [now_path, fix_broken_bullet_blocks(extracted_data)]
    pd.DataFrame(df_temp).to_csv(f'datas/reasoning/xray/medgemma_xray_formatted.csv.{split}', index=False)

Processing train,: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2151/2151 [00:02<00:00, 786.99it/s]
Processing dev,: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 269/269 [00:00<00:00, 835.83it/s]
Processing test,: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 270/270 [00:00<00:00, 828.18it/s]
