In [None]:
import os
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BlipProcessor, BlipForConditionalGeneration

# Set directories and model IDs
hymba_model_id = "nvidia/Hymba-1.5B-Instruct"
blip_model_id = "Salesforce/blip-image-captioning-base"
cases_dir = '/path/to/cases/'

# Load models and processors
tokenizer_hymba = AutoTokenizer.from_pretrained(hymba_model_id, trust_remote_code=True)
model_hymba = AutoModelForCausalLM.from_pretrained(hymba_model_id, trust_remote_code=True).to('cuda').to(torch.bfloat16)

processor_blip = BlipProcessor.from_pretrained(blip_model_id)
model_blip = BlipForConditionalGeneration.from_pretrained(blip_model_id).to('cuda')

# Helper functions
def generate_caption(image, processor, model):
    inputs = processor(images=image, return_tensors="pt").to('cuda')
    out = model.generate(**inputs)
    caption = processor.decode(out[0], skip_special_tokens=True)
    return caption

def load_images_from_dir(image_dir):
    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    images = []
    for img_file in image_files:
        img_path = os.path.join(image_dir, img_file)
        img = Image.open(img_path).convert('RGB')
        images.append(img)
    return images

# Iterate through cases
for case in os.listdir(cases_dir):
    case_dir = os.path.join(cases_dir, case)
    if not os.path.isdir(case_dir):
        continue  # skip non-directories

    clinical_info_path = os.path.join(case_dir, 'diagnostic_prompt.txt')
    if not os.path.exists(clinical_info_path):
        print(f"Missing clinical info for case {case}")
        continue

    clinical_info = open(clinical_info_path, 'r', encoding='utf-8').read()

    image_dir = os.path.join(case_dir, 'images')
    if not os.path.exists(image_dir):
        print(f"No images found for case {case}")
        continue

    images = load_images_from_dir(image_dir)
    if not images:
        print(f"No images loaded for case {case}")
        continue

    # Generate captions for each image
    captions = []
    for img in images:
        caption = generate_caption(img, processor_blip, model_blip)
        captions.append(caption)

    # Combine clinical info and image captions
    prompt = f"""Clinical data: {clinical_info}

    Image descriptions:
    """
    for i, cap in enumerate(captions, 1):
        prompt += f"Image {i}: {cap}\n"

    prompt += "\nPlease provide a detailed diagnosis report based on the clinical data and images."

    # Prepare messages for Hymba
    messages = [
        {"role": "system", "content": "You are a professional radiologist providing detailed diagnosis reports."},
        {"role": "user", "content": prompt}
    ]

    # Tokenize the prompt
    inputs = tokenizer_hymba.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')

    # Generate response
    with torch.no_grad():
        outputs = model_hymba.generate(
            inputs.input_ids,
            max_length=4096,
            do_sample=False,
            temperature=0.7,
            use_cache=True
        )

    # Decode the response
    response = tokenizer_hymba.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)

    # Save the response
    response_path = os.path.join(case_dir, 'hymba-report.txt')
    with open(response_path, 'w', encoding='utf-8') as f:
        f.write(response)

    print(f"Report generated for case {case}.")