In [1]:
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from PIL import Image
import requests
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

device: cuda


In [None]:
model_id = "google/medgemma-4b-it"
# model_id = "./models/medgemma-4b-it"  # path local

bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_cfg,    # ใช้ quantization
    device_map="auto",              # กระจาย GPU/CPU อัตโนมัติ
    max_memory={0: "6GiB", "cpu": "30GiB"},  # กัน OOM
    trust_remote_code=True,
)

processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

OSError: We couldn't connect to 'https://huggingface.co' to load the files, and couldn't find them in the cached files.
Check your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

In [None]:
# โหลดและย่อภาพเพื่อลด token
image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
image = image.resize((224, 224))  # ย่อเพื่อลดภาระ attention

messages = [
    {"role": "system", "content": [{"type": "text", "text": "You are an expert radiologist."}]},
    {"role": "user", "content": [
        {"type": "text", "text": "Describe this X-ray"},
        {"type": "image", "image": image}
    ]}
]

In [None]:
# เตรียม input tensors
inputs = processor.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,
    return_dict=True, return_tensors="pt"
).to(model.device)

input_len = inputs["input_ids"].shape[-1]

# inference แบบประหยัด
with torch.inference_mode():
    generation = model.generate(
        **inputs,
        max_new_tokens=16,    # ลดความยาว output
        do_sample=False,      # deterministic
        use_cache=True,
    )
    generation = generation[0][input_len:]

decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)