In [4]:
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from PIL import Image
import requests
import torch

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

device: cuda


In [None]:
model_id = "google/medgemma-4b-it"
# model_id = "./models/medgemma-4b-it"  # path local

bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_cfg,    # ใช้ quantization
    device_map="auto",              # กระจาย GPU/CPU อัตโนมัติ
    max_memory={0: "6GiB", "cpu": "30GiB"},  # กัน OOM
    trust_remote_code=True,
    local_files_only=True,          # path local
)

processor = AutoProcessor.from_pretrained(
    model_id,
    trust_remote_code=True,
    local_files_only=True,          # path local
)

In [22]:
# โหลดและย่อภาพเพื่อลด token

# image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
# image = Image.open(requests.get(image_url, stream=True).raw)

image_path = "./img/Chest_Xray_1.png"
image = Image.open(image_path)

messages = [
    {"role": "system", "content": [{"type": "text", "text": "You are an expert radiologist."}]},
    {"role": "user", "content": [
        {"type": "text", "text": "Describe this X-ray"},
        {"type": "image", "image": image}
    ]}
]

In [23]:
# เตรียม input tensors
inputs = processor.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,
    return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)

input_len = inputs["input_ids"].shape[-1]

# inference แบบประหยัด
with torch.inference_mode():
    generation = model.generate(
        **inputs,
        max_new_tokens=200,    # ลดความยาว output
        do_sample=False,      # deterministic
        use_cache=True,
    )
    generation = generation[0][input_len:]

decoded = processor.decode(generation, skip_special_tokens=True)
print(type(decoded))
print(len(decoded))
print(decoded)

<class 'str'>
0



In [26]:
print(model.config.to_dict().keys())
print(model.config.output_attentions)

dict_keys(['text_config', 'vision_config', 'mm_tokens_per_image', 'boi_token_index', 'eoi_token_index', 'image_token_index', 'initializer_range', 'return_dict', 'output_hidden_states', 'torchscript', 'dtype', 'pruned_heads', 'tie_word_embeddings', 'chunk_size_feed_forward', 'is_encoder_decoder', 'is_decoder', 'cross_attention_hidden_size', 'add_cross_attention', 'tie_encoder_decoder', 'architectures', 'finetuning_task', 'id2label', 'label2id', 'task_specific_params', 'problem_type', 'tokenizer_class', 'prefix', 'bos_token_id', 'pad_token_id', 'eos_token_id', 'sep_token_id', 'decoder_start_token_id', 'max_length', 'min_length', 'do_sample', 'early_stopping', 'num_beams', 'temperature', 'top_k', 'top_p', 'typical_p', 'repetition_penalty', 'length_penalty', 'no_repeat_ngram_size', 'encoder_no_repeat_ngram_size', 'bad_words_ids', 'num_return_sequences', 'output_scores', 'return_dict_in_generate', 'forced_bos_token_id', 'forced_eos_token_id', 'remove_invalid_values', 'exponential_decay_leng

In [27]:
with torch.inference_mode():
    out = model(
        **inputs,
        output_attentions=True,
        return_dict=True,
        use_cache=False
    )

print(type(out.attentions))

<class 'tuple'>


In [28]:
print(f"len attentions : {len(out.attentions)}")      # จำนวนเลเยอร์
print(f"out.attentions : {out.attentions}")           # เลเยอร์สุดท้าย

len attentions : 34
out.attentions : (None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)


✅ สรุปสั้น ๆ
* (None, None, …) = โมเดลนี้ไม่รองรับการดึง attention matrix โดยตรง

* แม้จะใส่ output_attentions=True ก็จะได้ None เพราะทีม Google ไม่ expose ค่าเหล่านี้ออกมาใน checkpoint