In [None]:
import os
from PIL import Image
from transformers import LlavaForConditionalGeneration, AutoProcessor
import torch

# Set model ID and directories
model_id = "Intel/llava-gemma-7b"
cases_dir = '/media/elboardy/RLAB-Disk01/(final)merged_images_with_labels_order_and_folders_mask_normalized (Copy)/'

# Load the Llava-Gemma model and processor
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,  # Use appropriate dtype based on hardware
    device_map="auto",
    cache_dir='/media/elboardy/RLAB-Disk01/Large-Language-Models-Weights',
    offload_folder='/media/elboardy/RLAB-Disk01/Large-Language-Models-Weights',
)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
# Function to load an image
def load_image(image_path):
    return Image.open(image_path)

# Iterate through cases in the provided directory
for case in os.listdir(cases_dir):
    case_dir = os.path.join(cases_dir, case)
    image_files = [
        f for f in os.listdir(case_dir)
        if f.lower().endswith('.png')
    ]

    clinical_information_path = os.path.join(case_dir, 'diagnostic_prompt.txt')
    if not os.path.exists(clinical_information_path):
        print(f"Missing clinical information file for case: {case}")
        continue

    clinical_information = open(clinical_information_path).read()

    # Define prompts
    system_prompt = """Consider that you are a professional radiologist with several years of experience. Write a detailed diagnosis report for the provided images and clinical data."""

    user_prompt = f"""You will be given batches of MRI images for patients likely to have brain tumors. Each batch includes sequences and segmentation masks.
    Clinical data:
    {clinical_information}"""

    # Load images
    images = []
    for image_file in image_files:
        image_path = os.path.join(case_dir, image_file)
        images.append(load_image(image_path))

    # Prepare the messages for Llava-Gemma
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]

    # Apply the chat template
    input_text = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # Prepare inputs
    inputs = processor(
        images=images,
        text=input_text,
        add_special_tokens=False,
        return_tensors="pt"
    ).to(model.device)

    # Generate the response
    output = model.generate(**inputs, max_new_tokens=4096)
    response_text = processor.tokenizer.decode(output[0], skip_special_tokens=True)

    # Save the response
    response_path = os.path.join(case_dir, 'llava-gemma-response-7b.txt')
    with open(response_path, 'w', encoding='utf-8') as f:
        f.write(response_text)

    print(f"Response saved for case {case}.")