In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from jax import numpy as jnp

from qwen import QwenModel, utils, generate

In [2]:
def hf_predict(model, tokenizer, prompt, max_tokens=20):
    inputs = tokenizer(prompt, return_tensors="pt")
    out_ids = model.generate(**inputs, max_new_tokens=max_tokens)
    return tokenizer.decode(out_ids[0], skip_special_tokens=True)


def jax_predict(model, tokenizer, prompt, max_tokens=20):
    inputs = jnp.array(tokenizer(prompt, return_tensors="np").input_ids)
    out_ids = generate(model, inputs, max_tokens=max_tokens)
    return tokenizer.decode(out_ids[0], skip_special_tokens=True)

In [3]:
model_name = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float32, device_map="cpu"
)

jax_model = utils.from_hf(hf_model)
hf_text = hf_predict(hf_model, tokenizer, prompt="Once upon a time,")
jax_text = jax_predict(jax_model, tokenizer, prompt="Once upon a time,")

print(f"{hf_text} \n{jax_text}")

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Once upon a time, there was a little girl named Lily. She loved to play with her toys and watch the stars. 
Once upon a time, there was a little girl named Lily. She loved to play with her toys and watch the stars.
