In [None]:
import os
from PIL import Image
import requests
from transformers import AutoModelForCausalLM, AutoProcessor

# Set model ID and directories
model_id = "microsoft/Phi-3.5-vision-instruct"
cases_dir = '/media/elboardy/RLAB-Disk01/(final)merged_images_with_labels_order_and_folders_mask_normalized (Copy)/'

# Load the Phi-3.5-vision-instruct model and processor
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda",
    trust_remote_code=True,
    torch_dtype="auto",
    _attn_implementation='flash_attention_2'
)
processor = AutoProcessor.from_pretrained(
    model_id,
    trust_remote_code=True,
    num_crops=4
)

# Function to load images from directory
def load_images_from_dir(image_dir):
    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
    images = []
    for i, image_file in enumerate(image_files, start=1):
        image_path = os.path.join(image_dir, image_file)
        images.append(Image.open(image_path))
    return images

# Iterate through cases in the directory
for case in os.listdir(cases_dir):
    case_dir = os.path.join(cases_dir, case)
    if not os.path.isdir(case_dir):
        continue  # Skip non-directory files

    image_dir = os.path.join(case_dir, 'images')  # Assuming images are in a subdirectory named 'images'
    if not os.path.exists(image_dir):
        print(f"No images found for case: {case}")
        continue

    clinical_information_path = os.path.join(case_dir, 'diagnostic_prompt.txt')
    if not os.path.exists(clinical_information_path):
        print(f"Missing clinical information file for case: {case}")
        continue

    clinical_information = open(clinical_information_path, 'r', encoding='utf-8').read()

    # Define prompts
    system_prompt = """Consider that you are a professional radiologist with several years of experience. Write a detailed diagnosis report for the provided images and clinical data."""
    user_prompt = f"""Clinical data: {clinical_information}\n\nPlease provide a diagnosis report based on the images and clinical data."""

    # Load images
    images = load_images_from_dir(image_dir)

    # Prepare placeholders for images
    placeholder = ""
    for i in range(1, len(images) + 1):
        placeholder += f"<|image_{i}|>\n"

    # Create messages
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": placeholder + user_prompt},
    ]

    # Apply chat template
    prompt = processor.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Prepare inputs
    inputs = processor(prompt, images, return_tensors="pt").to("cuda:0")

    # Generate response
    generation_args = {
        "max_new_tokens": 1000,
        "temperature": 0.0,
        "do_sample": False,
    }
    generate_ids = model.generate(
        **inputs,
        eos_token_id=processor.tokenizer.eos_token_id,
        **generation_args
    )

    # Remove input tokens
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]

    # Decode response
    response = processor.batch_decode(
        generate_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )[0]

    # Save the response
    response_path = os.path.join(case_dir, 'phi-3.5-vision-instruct-response.txt')
    with open(response_path, 'w', encoding='utf-8') as f:
        f.write(response)

    print(f"Response saved for case {case}.")