# Test the fine tuned pila gemma

In [1]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from peft import PeftModel, PeftConfig
from PIL import Image
import torch

# Load your LoRA model from HuggingFace
model_id = "SiyunHE/medical-pilagemma-lora"
base_model_id = "google/paligemma-3b-pt-224"

# Load processor
processor = AutoProcessor.from_pretrained(base_model_id, use_fast=True)

# Load base model config from LoRA
peft_config = PeftConfig.from_pretrained(model_id)

# Load base model
base_model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_id)

# Load LoRA weights
model = PeftModel.from_pretrained(base_model, model_id).eval()

# Inference function
def analyze_with_gemma(query, image_path):
    image = Image.open(image_path).convert("RGB")

    prompt = "<image> " + query

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False
        )
        # Only keep new generated tokens
        output = outputs[0][input_len:]

        response = processor.decode(output, skip_special_tokens=True)

    return response

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.80s/it]


In [2]:
query = "My arm is red and ithchy. What is it?"
image_path = "skin_rash.jpg"
response = analyze_with_gemma(query, image_path)
print(response)

bug bite
