In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from jax import Array, numpy as jnp

from qwen import QwenModel, utils, generate

In [2]:
def compare_generation(
    hf_model: torch.nn.Module,
    tokenizer: torch.nn.Module,
    model: QwenModel,
    prompt: str,
    max_tokens: int = 20,
) -> tuple[str, str]:
    hf_model.eval()
    hf_inputs = tokenizer(prompt, return_tensors="pt")
    hf_out_ids = hf_model.generate(**hf_inputs, max_new_tokens=max_tokens)
    hf_text = tokenizer.decode(hf_out_ids[0], skip_special_tokens=True)

    inputs = jnp.array(tokenizer(prompt, return_tensors="pt").input_ids.numpy())
    out_ids = generate(model, inputs, max_tokens=max_tokens)
    text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
    return hf_text, text

In [3]:
model_name = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float32, device_map="cpu"
)

model = utils.from_hf(hf_model)
hf_text, text = compare_generation(
    hf_model, tokenizer, model, prompt="Once upon a time,"
)
print(f"{hf_text} \n{text}")

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Once upon a time, there was a little girl named Lily. She loved to play with her toys and watch the stars. 
Once upon a time, there was a little girl named Lily. She loved to play with her toys and watch the stars.
