In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

import equinox as eqx
from jax import Array, numpy as jnp

from qwen import (
    QwenEmbedding,
    QwenLinear,
    QwenRotaryEmbedding,
    QwenRMSNorm,
    QwenAttention,
    QwenMLP,
    QwenDecoderLayer,
    QwenModel,
    QwenForCausalLM,
    QwenConfig,
    utils,
)

In [3]:
def compare_generation(
    hf_model: torch.nn.Module,
    tokenizer: torch.nn.Module,
    model: QwenForCausalLM,
    prompt: str,
    max_new_tokens: int = 20,
) -> tuple[str, str]:
    hf_model.eval()
    hf_inputs = tokenizer(prompt, return_tensors="pt")
    hf_out_ids = hf_model.generate(**hf_inputs, max_new_tokens=max_new_tokens)
    hf_text = tokenizer.decode(hf_out_ids[0], skip_special_tokens=True)

    tokens = jnp.array(tokenizer(prompt, return_tensors="pt").input_ids.numpy())
    for _ in range(max_new_tokens):
        position_ids = jnp.arange(tokens.shape[1])[None, :]
        logits = model(tokens, position_ids=position_ids)
        next_token_id = jnp.argmax(logits[:, -1, :], axis=-1)
        tokens = jnp.concatenate([tokens, next_token_id[:, None]], axis=1)

    text = tokenizer.decode(jnp.array(tokens[0]), skip_special_tokens=True)
    return hf_text, text

In [4]:
model_name = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    device_map="cpu",
    low_cpu_mem_usage=True,
)

config = QwenConfig(
    vocab_size=hf_model.config.vocab_size,
    hidden_size=hf_model.config.hidden_size,
    intermediate_size=hf_model.config.intermediate_size,
    num_hidden_layers=hf_model.config.num_hidden_layers,
    num_attention_heads=hf_model.config.num_attention_heads,
    num_key_value_heads=hf_model.config.num_key_value_heads,
    max_position_embeddings=hf_model.config.max_position_embeddings,
    rms_norm_eps=hf_model.config.rms_norm_eps,
    rope_theta=hf_model.config.rope_theta,
)
model = QwenForCausalLM(config)
model = utils.convert_hf(hf_model, model)

hf_text, text = compare_generation(hf_model, tokenizer, model, prompt="Once upon a time,")
print(f"{hf_text} \n{text}")

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Once upon a time, there was a little girl named Lily. She loved to play with her toys and watch the stars. 
Once upon a time, there was a little girl named Lily. She loved to play with her toys and watch the stars.
