In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

import optax
import equinox as eqx 
from jax import Array, numpy as jnp, nn

from qwen import QwenModel, utils, generate, forward

In [2]:
def train_model(model, dataset, batch_size=10, epochs=1, lr=1e-4):
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

    def loss_fn(model, tokens):
        logits = forward(model, tokens)
        logp = nn.log_softmax(logits[:, :-1, :], axis=-1)
        return -jnp.take_along_axis(logp, tokens[:, 1:, None], axis=-1).mean()

    @eqx.filter_jit
    def step(model, opt_state, tokens):
        loss, grads = eqx.filter_value_and_grad(loss_fn)(model, tokens)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    for _ in range(epochs):
        for i in range(0, len(dataset), batch_size):
            batch = dataset[i : i + batch_size]
            model, opt_state, loss = step(model, opt_state, batch)
            print(f"loss: {loss}")
    return model


def predict(model, tokenizer, prompt, max_tokens=2):
    inputs = jnp.array(tokenizer(prompt, return_tensors="np").input_ids)
    out_ids = generate(model, inputs, max_tokens=max_tokens)
    return tokenizer.decode(out_ids[0], skip_special_tokens=True)


def create_dataset(text, tokenizer, repeat=40, seq_len=6):
    text = (text * repeat).strip()
    tokens = tokenizer(text, return_tensors="np").input_ids.squeeze()
    tokens = tokens[: (len(tokens) // seq_len) * seq_len]
    return jnp.array(tokens.reshape(-1, seq_len))

In [3]:
model_name = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float32, device_map="cpu"
)
model = utils.from_hf(hf_model)
dataset = create_dataset("Two times two is ten.", tokenizer)

pre_text = predict(model, tokenizer, "Two times two is")
model = train_model(model, dataset, epochs=1)
post_text = predict(model, tokenizer, "Two times two is")
print(f"{pre_text} \n{post_text}")

loss: 7.405213356018066
loss: 0.638215959072113
loss: 0.006730901543051004
loss: 0.21361371874809265
Two times two is 4 
Two times two is ten.T
