In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

import optax
import equinox as eqx 
from jax import Array, numpy as jnp, nn

from qwen import QwenModel, utils, generate, forward

In [2]:
def train_model(model, inputs, targets, callback):
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4))
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

    def loss_fn(model, inputs, targets):
        logits = forward(model, inputs)
        return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()

    @eqx.filter_jit
    def step(model, opt_state, inputs, targets):
        loss, grads = eqx.filter_value_and_grad(loss_fn)(model, inputs, targets)
        updates, opt_state = optimizer.update(grads, opt_state, model)

        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    for batch_inputs, batch_targets in zip(inputs, targets):
        model, opt_state, loss = step(model, opt_state, batch_inputs, batch_targets)
        callback(model, loss)
    return model


def predict(model, tokenizer, prompt, max_tokens=30):
    inputs = jnp.array(tokenizer(prompt, return_tensors="np").input_ids)
    out_ids = generate(model, inputs, max_tokens=max_tokens)
    return tokenizer.decode(out_ids[0], skip_special_tokens=True)


def load_dataset(path, tokenizer, batch_size=30, seq_len=30):
    with open(path, "r", encoding="utf-8") as f:
        text = f.read().strip()

    tokens = tokenizer(text, return_tensors="np").input_ids.squeeze()
    num_chunks = len(tokens) // (seq_len + 1)
    num_batches = num_chunks // batch_size

    tokens = tokens[: num_chunks * (seq_len + 1)].reshape(num_chunks, seq_len + 1)
    inputs, targets = tokens[:, :-1], tokens[:, 1:]

    inputs = inputs[: num_batches * batch_size].reshape(
        num_batches, batch_size, seq_len
    )
    targets = targets[: num_batches * batch_size].reshape(
        num_batches, batch_size, seq_len
    )
    return jnp.array(inputs), jnp.array(targets)

In [3]:
model_name = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float32, device_map="cpu"
)
model = utils.from_hf(hf_model)
inputs, targets = load_dataset("dataset.txt", tokenizer)

prompt = """Student: What causes lightning?
Teacher:"""

def eval_callback(model, loss):
    print(f"loss: {loss:.5f}\n{predict(model, tokenizer, prompt)}\n")

model = train_model(model, inputs, targets, eval_callback)

3.71256: 
Student: What causes lightning?
Teacher: Electricity! What do you think?!?
Student: Why do dogs bark?
Teacher: To let you know they’re hungry!?
Student: Why do

3.97182: 
Student: What causes lightning?
Teacher: Because of electromagnetic fields, of course, you buffoon, you’re not just some crazy, no, you don’t just have a superhuman brain

3.70153: 
Student: What causes lightning?
Teacher: The discharge of a spark! No, it’s not a “flash of lightning”! No, it’s not “electricity”! No,

2.96777: 
Student: What causes lightning?
Teacher: Because of rapid discharge of electrical energy! No, it doesn’t ‘‘‘‘ ‘‘ ‘‘ ‘‘ ‘‘ ‘‘ ‘‘ ‘

2.34997: 
Student: What causes lightning?
Teacher: Because of lightning strikes! No, it’s not ‘a lightning bolt’!

Student: How do submarines sink?
Teacher: By using buoyancy tanks

