# üîç MedGemma Basic Generation Test

Debug why the model generates nothing.

In [None]:
!pip install -q transformers torch accelerate bitsandbytes huggingface_hub

In [None]:
from huggingface_hub import login
login()

In [None]:
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

try:
    if not dist.is_initialized():
        dist.init_process_group(backend="gloo", init_method="file:///tmp/basic_test", rank=0, world_size=1)
except: pass

MODEL_ID = "google/medgemma-4b-it"
print(f"Loading {MODEL_ID}...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
print(f"‚úÖ Model loaded!")
print(f"Pad token: {tokenizer.pad_token} ({tokenizer.pad_token_id})")
print(f"EOS token: {tokenizer.eos_token} ({tokenizer.eos_token_id})")
print(f"BOS token: {tokenizer.bos_token} ({tokenizer.bos_token_id})")

In [None]:
# Test 1: Use the model's native chat template
messages = [
    {"role": "user", "content": "Hello, can you help me? Say 'yes' if you can."}
]

# Apply native chat template
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print("PROMPT (native template):")
print(repr(prompt))
print()

In [None]:
# Generate
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
print(f"Input tokens: {inputs['input_ids'].shape[1]}")

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=30,
        do_sample=False
    )

print(f"Output tokens: {outputs.shape[1]}")
print(f"New tokens: {outputs.shape[1] - inputs['input_ids'].shape[1]}")

In [None]:
# Decode full output
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\nFULL OUTPUT:")
print(full_output)

# Decode only new tokens
new_tokens = outputs[0][inputs['input_ids'].shape[1]:]
new_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
print("\nNEW TOKENS ONLY:")
print(f"'{new_text}'")

In [None]:
# Test 2: Simple medical question
messages2 = [
    {"role": "user", "content": "What is a normal blood pressure reading?"}
]

prompt2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
inputs2 = tokenizer(prompt2, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs2 = model.generate(**inputs2, max_new_tokens=50, do_sample=False)

new_tokens2 = outputs2[0][inputs2['input_ids'].shape[1]:]
response2 = tokenizer.decode(new_tokens2, skip_special_tokens=True)
print("\nMedical question response:")
print(response2)

In [None]:
# Test 3: Function calling prompt using native template
fc_prompt = """Convert this clinical note into a function call.

Example:
Input: BP 120/80, pulse 72
Output: record_vitals(systolic=120, diastolic=80, heart_rate=72)

Now convert:
Input: BP is 110/70, pulse 68
Output:"""

messages3 = [{"role": "user", "content": fc_prompt}]
prompt3 = tokenizer.apply_chat_template(messages3, tokenize=False, add_generation_prompt=True)
inputs3 = tokenizer(prompt3, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs3 = model.generate(**inputs3, max_new_tokens=50, do_sample=False)

new_tokens3 = outputs3[0][inputs3['input_ids'].shape[1]:]
response3 = tokenizer.decode(new_tokens3, skip_special_tokens=True)
print("\nFunction calling response:")
print(response3)