To run this notebook, click **Runtime** > **Run all**.

[![GitHub](https://img.shields.io/badge/GitHub-ART-blue?logo=github)](https://github.com/OpenPipe/ART)
[![Discord](https://img.shields.io/badge/Discord-Join-7289da?logo=discord&logoColor=white)](https://discord.gg/zbBHRUpwf4)
[![Docs](https://img.shields.io/badge/Docs-ART-green)](https://docs.art-e.dev/fundamentals/sft-training)

This notebook demonstrates **distillation** with ART — training a small model on completions generated by a larger teacher model.

We'll distill a text-to-SQL capability: a large teacher model generates SQL queries from natural language questions over an e-commerce schema, and a smaller student model learns to produce the same quality output. This is a common production pattern — large models handle complex joins, subqueries, and aggregations well, but are too slow or expensive for real-time use.

For training from a static JSONL dataset, see the [SFT notebook](https://github.com/OpenPipe/ART/blob/main/examples/sft/sft.ipynb).

Completions and metrics will be logged to [Weights & Biases](https://wandb.ai).

### Installation

In [None]:
# %%capture
!uv pip install openpipe-art==0.5.9 openai --prerelease allow --no-cache-dir

### Environment Variables

Set your `WANDB_API_KEY` to call the teacher model via [W&B Inference](https://wandb.ai/site/inference) and to use the serverless backend — get one at [wandb.ai](https://wandb.ai/settings).

In [None]:
import os

WANDB_API_KEY = ""  # required

if WANDB_API_KEY:
    os.environ["WANDB_API_KEY"] = WANDB_API_KEY

### Generate Teacher Completions

We define an e-commerce database schema and a set of natural language questions ranging from simple lookups to complex aggregations with subqueries. The teacher model generates SQL for each question via [W&B Inference](https://wandb.ai/site/inference), and we collect these as training trajectories for the student.

In [None]:
import asyncio
import os

from openai import AsyncOpenAI

import art

TEACHER_MODEL = "zai-org/GLM-5-FP8"

teacher_client = AsyncOpenAI(
    api_key=os.environ["WANDB_API_KEY"],
    base_url="https://api.inference.wandb.ai/v1",
)

system_prompt = """You are a SQL expert. Given a database schema and a natural language question, write a single SQL query that answers the question. Return only the SQL query, no explanation.

Schema:
  customers (id INT, name TEXT, email TEXT, signup_date DATE, plan TEXT)
  orders (id INT, customer_id INT, total DECIMAL, status TEXT, created_at TIMESTAMP)
  order_items (id INT, order_id INT, product_id INT, quantity INT, unit_price DECIMAL)
  products (id INT, name TEXT, category TEXT, price DECIMAL, stock INT)

Notes:
  - orders.status is one of: 'pending', 'shipped', 'delivered', 'cancelled'
  - customers.plan is one of: 'free', 'pro', 'enterprise'
  - order_items.unit_price is the price at time of purchase (may differ from products.price)"""

questions = [
    "Which customers signed up in the last 30 days?",
    "What's the total revenue by product category for Q4 2024?",
    "Find customers who placed more than 5 orders but never bought anything in the 'Electronics' category.",
    "What's the average order value for each customer plan tier, excluding cancelled orders?",
    "List the top 10 products by revenue that have less than 50 units in stock.",
    "Find customers whose total spending exceeds the average customer spending by more than 2x.",
    "What percentage of orders from the last month are still in 'pending' status, broken down by customer plan?",
    "Show the month-over-month growth rate of new customer signups for the past 12 months.",
    "Which products have been ordered together in the same order more than 10 times?",
    "Find customers who haven't placed an order in the last 90 days but previously ordered at least once a month for 3 consecutive months.",
    "What is the customer lifetime value (total spending) for each signup cohort month?",
    "Rank product categories by their return customer rate (customers who bought from the same category more than once).",
]


async def get_teacher_completion(question):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    completion = await teacher_client.chat.completions.create(
        model=TEACHER_MODEL,
        messages=messages,
    )
    response = completion.choices[0].message.content
    print(f"Q: {question}\n{response}\n")
    return art.Trajectory(
        messages_and_choices=[
            *messages,
            {"role": "assistant", "content": response},
        ],
    )


trajectories = await asyncio.gather(*[get_teacher_completion(q) for q in questions])
trajectories = list(trajectories)

print(f"Generated {len(trajectories)} trajectories from teacher model.")

### Training

Use `create_sft_dataset_iterator` to train the student model on teacher outputs. It computes the learning rate schedule over the full dataset and yields chunks, so each `train_sft` call logs its own metrics.

In [None]:
from art.serverless.backend import ServerlessBackend
from art.utils.sft import create_sft_dataset_iterator

backend = ServerlessBackend()
student = art.TrainableModel(
    name="distillation-text-to-sql",
    project="sft-distillation",
    base_model="OpenPipe/Qwen3-14B-Instruct",
)
await student.register(backend)

for chunk in create_sft_dataset_iterator(trajectories, epochs=3, peak_lr=2e-4):
    await student.train_sft(chunk.trajectories, chunk.config)

print("Training complete!")

### RL Training with RULER

Now we'll improve the SFT-trained student using **reinforcement learning**. Instead of imitating the teacher's outputs, the model learns to maximize a reward signal from an LLM judge.

[RULER](https://docs.art-e.dev/fundamentals/ruler) scores multiple completions per question *relative to each other* — no hand-crafted reward function needed. We use GLM-5 via W&B Inference as the judge.

In [None]:
from art.rewards import ruler_score_group

JUDGE_MODEL = "openai/zai-org/GLM-5-FP8"
judge_litellm_params = {
    "api_base": "https://api.inference.wandb.ai/v1",
    "api_key": os.environ["WANDB_API_KEY"],
}

client = student.openai_client()


async def rollout(question: str) -> art.Trajectory:
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    completion = await client.chat.completions.create(
        model=student.get_inference_name(),
        messages=messages,
    )
    return art.Trajectory(
        messages_and_choices=[*messages, completion.choices[0]],
    )

In [None]:
NUM_GENERATIONS = 4  # trajectories per question (RULER compares these)
RL_STEPS = 3

for step in range(RL_STEPS):
    groups = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(
                rollout(question) for _ in range(NUM_GENERATIONS)
            )
            for question in questions
        ),
        after_each=lambda group: ruler_score_group(
            group,
            judge_model=JUDGE_MODEL,
            extra_litellm_params=judge_litellm_params,
            swallow_exceptions=True,
        ),
        pbar_desc=f"rl step {step}",
    )

    result = await backend.train(student, groups, learning_rate=5e-6)
    await student.log(groups, metrics=result.metrics, step=result.step, split="train")
    print(f"Step {step}: {result.metrics}")

print("RL training complete!")

### Inference

Try the trained model with a new question.

In [None]:
completion = await client.chat.completions.create(
    model=student.get_inference_name(),
    messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": "Which product category has the highest average order value?"},
    ],
)
print(completion.choices[0].message.content)

---

For more details, see the [SFT Training docs](https://docs.art-e.dev/fundamentals/sft-training) and [RULER docs](https://docs.art-e.dev/fundamentals/ruler). For training from a static dataset, see the [SFT notebook](https://github.com/OpenPipe/ART/blob/main/examples/sft/sft.ipynb). Questions? Join the [Discord](https://discord.gg/zbBHRUpwf4)!