In [None]:
import os
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch

# Set model ID and directories
model_id = "Qwen/Qwen2-VL-72B-Instruct"
cases_dir = '/media/elboardy/RLAB-Disk01/(final)merged_images_with_labels_order_and_folders_mask_normalized (Copy)/'

# Load the Qwen2-VL model and processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype="auto",  # Adjust if needed
    device_map="auto",
    cache_dir='/media/elboardy/RLAB-Disk01/Large-Language-Models-Weights',
    offload_folder='/media/elboardy/RLAB-Disk01/Large-Language-Models-Weights',
)
processor = AutoProcessor.from_pretrained(model_id)

# Function to load images
def load_image(image_path):
    return Image.open(image_path)

# Iterate through cases in the directory
for case in os.listdir(cases_dir):
    case_dir = os.path.join(cases_dir, case)
    image_files = [
        f for f in os.listdir(case_dir)
        if f.lower().endswith('.png')
    ]

    clinical_information_path = os.path.join(case_dir, 'diagnostic_prompt.txt')
    if not os.path.exists(clinical_information_path):
        print(f"Missing clinical information file for case: {case}")
        continue

    clinical_information = open(clinical_information_path).read()

    # Define prompts
    system_prompt = """Consider that you are a professional radiologist with several years of experience. Write a detailed diagnosis report for the provided images and clinical data."""

    user_prompt = f"""You will be given batches of MRI images for patients likely to have brain tumors. Each batch includes sequences and segmentation masks.
    Clinical data:
    {clinical_information}"""

    # Load images
    images = []
    for image_file in image_files:
        image_path = os.path.join(case_dir, image_file)
        images.append({"type": "image", "image": load_image(image_path)})

    # Prepare messages for Qwen2-VL
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": [{"type": "text", "text": user_prompt}] + images},
    ]

    # Prepare inputs for Qwen2-VL
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    # Generate the response
    generated_ids = model.generate(**inputs, max_new_tokens=4096)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    response_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    # Save the response
    response_path = os.path.join(case_dir, 'qwen-vl-72b-response.txt')
    with open(response_path, 'w', encoding='utf-8') as f:
        f.write(response_text)

    print(f"Response saved for case {case}.")
