To train this document summarization agent, click **Runtime** > **Run all**. Don't forget to set your environment variables.

[![GitHub](https://img.shields.io/badge/GitHub-ART-blue?logo=github)](https://github.com/OpenPipe/ART)
[![Discord](https://img.shields.io/badge/Discord-Join-7289da?logo=discord&logoColor=white)](https://discord.gg/zbBHRUpwf4)
[![Docs](https://img.shields.io/badge/Docs-ART-green)](https://art.openpipe.ai)

**Document Summarization Agent**

In this notebook, you will use [ART](https://github.com/openpipe/art) to train a document summarization agent using SFT+RL (supervised fine-tuning followed by reinforcement learning). First, the model is warmed up via distillation from a larger teacher model (SFT), then it is further improved through RL with a judge-based reward signal.

The agent learns to summarize documents into 350 words or less while maximizing the number of questions that can be answered from the summary alone. Documents come from the [Repliqa](https://huggingface.co/datasets/ServiceNow/repliqa) dataset, and [Gemini 3 Flash](https://ai.google.dev/gemini-api/docs/models#gemini-3-flash) is used as a judge to evaluate whether questions can be answered correctly from the summary.

**Baselines** (average % of questions answered correctly from summary):
- GPT-4o: 38%
- GPT-4.1: 45%
- Gemini-2.5-pro: 36%
- Claude Sonnet-4: 57%

The goal is to train a model that outperforms all of them.

Now let's get started!

### Installation

In [None]:
!uv pip install openpipe-art datasets regex async-lru

### Environment Variables

This notebook uses the **ServerlessBackend** for training and inference, which requires a Weights & Biases API key.

We also need an **OpenRouter API key** to access Gemini 3 Flash, which serves as the judge model for evaluating summary quality.

In [None]:
import os

# Required for Gemini 3 Flash judge model
os.environ["OPENROUTER_API_KEY"] = ""

# Required for serverless training
os.environ["WANDB_API_KEY"] = ""

if not os.environ.get("OPENROUTER_API_KEY"):
    raise ValueError(
        "OPENROUTER_API_KEY is required for the Gemini 3 Flash judge model."
    )

if not os.environ.get("WANDB_API_KEY"):
    raise ValueError(
        "WANDB_API_KEY is required for inference, training, and logging to Weights & Biases."
    )

### Loading Documents

We use the [Repliqa](https://huggingface.co/datasets/ServiceNow/repliqa) dataset, which contains 3591 documents each paired with 5 questions and answers. The dataset is split into a validation set (91 documents) and a training set (3500 documents).

In [None]:
import random
from typing import Dict, List, Tuple

from datasets import load_dataset
from pydantic import BaseModel


class Question(BaseModel):
    q: str
    a: str


class Document(BaseModel):
    document_text: str
    questions: List[Question]


def load_documents(
    val_size: int = 91, train_size: int = 3500
) -> Tuple[List[Document], List[Document]]:
    ds = load_dataset("ServiceNow/repliqa")
    documents: Dict[str, Document] = {}

    for data in ds["repliqa_0"]:
        if data["document_id"] not in documents:
            documents[data["document_id"]] = Document(
                document_text=data["document_extracted"],
                questions=[],
            )
        documents[data["document_id"]].questions.append(
            Question(q=data["question"], a=data["answer"])
        )

    all_documents = list(documents.values())

    random.seed(80)
    random.shuffle(all_documents)

    if train_size + val_size > len(all_documents):
        raise ValueError(
            f"Train size + val size ({train_size + val_size}) is greater than "
            f"the total number of documents ({len(all_documents)})"
        )

    val_documents = all_documents[:val_size]
    train_documents = all_documents[val_size : val_size + train_size]

    print(f"Loaded {len(all_documents)} documents")
    print(f"Train set size: {len(train_documents)}")
    print(f"Val set size: {len(val_documents)}")

    return val_documents, train_documents


val_documents, train_documents = load_documents()

### Creating a Model

We'll use a Qwen 3 14B model as the base, trained via the **ServerlessBackend** which handles GPU provisioning, inference, and training through Weights & Biases.

In [None]:
import art
from art.serverless import ServerlessBackend

backend = ServerlessBackend()

model = art.TrainableModel(
    name="summarizer-002",
    project="summarize",
    base_model="OpenPipe/Qwen3-14B-Instruct",
)

await model.register(backend)

### SFT Warmup via Distillation

Before RL training, we warm up the student model by distilling from a larger teacher model (**Qwen3-235B**). The teacher generates summaries for the first 200 training documents, and the student learns to imitate them via supervised fine-tuning. These 200 documents will be excluded from the subsequent RL training to avoid data overlap.

In [None]:
import asyncio

from openai import AsyncOpenAI

TEACHER_MODEL = "Qwen/Qwen3-235B-A22B-Instruct-2507"
SFT_NUM_EXAMPLES = 200

teacher_client = AsyncOpenAI(
    api_key=os.environ["WANDB_API_KEY"],
    base_url="https://api.inference.wandb.ai/v1",
)

sft_documents = train_documents[:SFT_NUM_EXAMPLES]

system_prompt = "You are a specialized AI assistant that generates concise, informative summaries for documents."


async def get_teacher_completion(document: Document) -> art.Trajectory:
    summarize_prompt = (
        f"Here is a document: {document.document_text}\n\n"
        "Generate a summary that conveys all relevant information in a concise manner."
    )
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": summarize_prompt},
    ]
    completion = await teacher_client.chat.completions.create(
        model=TEACHER_MODEL,
        messages=messages,
        max_tokens=1000,
    )
    response = completion.choices[0].message.content
    print(f"Generated summary ({len(response.split())} words)")
    return art.Trajectory(
        messages_and_choices=[
            *messages,
            {"role": "assistant", "content": response},
        ],
    )


# Generate teacher completions for the first 200 training documents
sft_trajectories = await asyncio.gather(
    *[get_teacher_completion(doc) for doc in sft_documents]
)
sft_trajectories = list(sft_trajectories)

print(f"\nGenerated {len(sft_trajectories)} teacher trajectories for SFT warmup.")

In [None]:
from art.utils.sft import create_sft_dataset_iterator

for chunk in create_sft_dataset_iterator(sft_trajectories, peak_lr=2e-4):
    await model.train_sft(chunk.trajectories, chunk.config)

print("SFT warmup complete!")

### Judge Model

We use **Gemini 3 Flash** (via OpenRouter) as the judge model. For each question, the judge:
1. Answers the question using only the summary
2. Answers the question using the full document (as a reference)
3. Compares both answers and scores the summary-based answer as correct (1) or incorrect (0)

Responses are cached (up to 1024 entries) to avoid redundant API calls.

In [None]:
from async_lru import alru_cache

judge_semaphore = asyncio.Semaphore(20)

judge_client = AsyncOpenAI(
    api_key=os.environ["OPENROUTER_API_KEY"],
    base_url="https://openrouter.ai/api/v1",
)


@alru_cache(maxsize=1024)
async def get_judge_completion(
    prompt, temperature=0.0, max_tokens=600, retries=3
) -> str:
    for attempt in range(1, retries + 1):
        try:
            async with judge_semaphore:
                completion = await judge_client.chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    model="google/gemini-3-flash-preview",
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
            return completion.choices[0].message.content.strip()
        except Exception as e:
            if attempt < retries:
                print(
                    f"[Retry {attempt}/{retries}] get_judge_completion failed: {e}. Retrying..."
                )
                await asyncio.sleep(3)
            else:
                print(
                    f"[Failure] get_judge_completion failed after {retries} attempts: {e}"
                )
                return "ERROR: Get judge completion failed"


def clear_judge_cache():
    get_judge_completion.cache_clear()
    print("Judge cache cleared")

### Defining a Rollout

A rollout is a single episode of the agent performing its task. The process:

1. The model receives a document and generates a summary
2. For each of the 5 questions associated with the document:
   - Gemini 3 Flash answers the question using **only the summary**
   - Gemini 3 Flash answers the question using the **full document** (reference)
   - A judge compares both answers and scores 0 or 1
3. The **reward** is the total number of correctly answered questions (0-5)

The agent never sees the questions during summarization - it must learn to write summaries that capture all important information.

In [None]:
import openai
import regex


class SummarizerScenario(BaseModel):
    doc: Document


@art.retry(exceptions=(openai.LengthFinishReasonError,))
async def rollout(
    model: art.Model, scenario: SummarizerScenario
) -> art.Trajectory:
    client = model.openai_client()

    trajectory = art.Trajectory(
        messages_and_choices=[
            {
                "role": "system",
                "content": "You are a specialized AI assistant that generates concise, informative summaries for documents.",
            }
        ],
        reward=0,
        metrics={
            "word_count": 0,
            "len": 0,
            "percent": 0,
            "percent_full": 0,
            "percent_diff": 0,
        },
    )

    summarize_prompt = (
        f"Here is a document: {scenario.doc.document_text}\n\n"
        "Generate a summary that conveys all relevant information in a concise manner."
    )

    trajectory.messages_and_choices.append(
        {"role": "user", "content": summarize_prompt}
    )

    messages = trajectory.messages()
    completion = await client.chat.completions.create(
        model=model.get_inference_name(), messages=messages, max_tokens=1000
    )
    choice = completion.choices[0]
    trajectory.messages_and_choices.append(choice)
    summary = choice.message.content

    total_score = 0
    total_score_full = 0
    total_questions = 0

    for question in scenario.doc.questions:
        total_questions += 1

        # Score from summary (skip if summary contains Chinese characters or is too long)
        if not regex.search(r"\p{Han}", summary) and len(summary) <= 3000:
            prompt = (
                f"Here is a document: {summary}\n\n"
                f"Answer this question to the best of your ability in one sentence, "
                f"if the document does not contain the answer, just state so: {question.q}"
            )
            response = await get_judge_completion(prompt)

            judge_prompt = (
                f"Here is a document: {scenario.doc.document_text}\n\n"
                f"Here is a question: {question.q}\n\n"
                f"Here is a generated answer: {response}\n\n"
                f"Here is the golden answer: {question.a}\n\n"
                "If the answers mostly match return a 1, if they do not match return a 0. "
                "Do not return any other text."
            )
            score = await get_judge_completion(judge_prompt)
            try:
                total_score += int(score)
            except:
                pass

        # Score from full document (reference)
        prompt_full = (
            f"Here is a document: {scenario.doc.document_text}\n\n"
            f"Answer this question to the best of your ability in one sentence, "
            f"if the document does not contain the answer, just state so: {question.q}"
        )
        response_full = await get_judge_completion(prompt_full)

        judge_prompt_full = (
            f"Here is a document: {scenario.doc.document_text}\n\n"
            f"Here is a question: {question.q}\n\n"
            f"Here is a generated answer: {response_full}\n\n"
            f"Here is the golden answer: {question.a}\n\n"
            "If the answers mostly match return a 1, if they do not match return a 0. "
            "Do not return any other text."
        )
        score_full = await get_judge_completion(judge_prompt_full)
        try:
            total_score_full += int(score_full)
        except:
            pass

        # Debug logging (5% sample)
        if not regex.search(r"\p{Han}", summary) and len(summary) <= 3000:
            if random.random() < 0.05:
                print(f"Question: {question.q}")
                print(f"Golden: {question.a}")
                print(f"Generated: {response}")
                print(f"Score: {score}, Score-full: {score_full}")
                print()

    trajectory.metrics["percent"] = total_score / total_questions
    trajectory.metrics["percent_full"] = total_score_full / total_questions
    trajectory.metrics["percent_diff"] = (
        trajectory.metrics["percent"] - trajectory.metrics["percent_full"]
    )
    trajectory.metrics["word_count"] = len(summary.split())
    trajectory.metrics["len"] = len(summary)
    trajectory.reward = total_score

    return trajectory


print("Rollout function defined!")

### RL Training Loop

Now we continue with reinforcement learning. The first 200 documents used for SFT distillation are excluded to avoid data overlap, leaving 3300 documents for RL training.

For each training step:
1. Generate training rollouts (4 per batch document)
2. Every 10 steps, also generate validation rollouts (2 per validation document) and log metrics
3. Train the model on the training trajectories

We use `iterate_dataset` to handle batching and epoch management.

In [None]:
from art.utils import iterate_dataset

GROUPS_PER_STEP = 5
ROLLOUTS_PER_GROUP = 4
VAL_ROLLOUTS_PER_GROUP = 2
VAL_STEP_INTERVAL = 10
LEARNING_RATE = 5e-5
NUM_EPOCHS = 1
MAX_RL_STEPS = 1000

# Skip the documents used for SFT distillation
rl_train_documents = train_documents[SFT_NUM_EXAMPLES:]
rl_start_step = await model.get_step()
print(f"RL training on {len(rl_train_documents)} documents (excluding {SFT_NUM_EXAMPLES} used for SFT)")
print(f"Starting from step {rl_start_step}, will run {MAX_RL_STEPS} RL steps")

training_iterator = iterate_dataset(
    rl_train_documents,
    groups_per_step=GROUPS_PER_STEP,
    num_epochs=NUM_EPOCHS,
    initial_step=rl_start_step,
)

for batch in training_iterator:
    if batch.step >= rl_start_step + MAX_RL_STEPS:
        break

    print(
        f"Step {batch.step}, Epoch {batch.epoch}, "
        f"Epoch step {batch.epoch_step}, "
        f"Batch size {len(batch.items)}"
    )

    # Generate training rollouts
    train_groups = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(
                rollout(model, SummarizerScenario(doc=document))
                for _ in range(ROLLOUTS_PER_GROUP)
            )
            for document in batch.items
        ),
        pbar_desc=f"gather train (step {batch.step})",
    )

    # Run validation every VAL_STEP_INTERVAL steps
    if batch.step % VAL_STEP_INTERVAL == 0:
        val_groups = await art.gather_trajectory_groups(
            (
                art.TrajectoryGroup(
                    rollout(
                        model,
                        SummarizerScenario(doc=document),
                    )
                    for _ in range(VAL_ROLLOUTS_PER_GROUP)
                )
                for document in val_documents
            ),
            pbar_desc=f"gather val (step {batch.step})",
        )
        await model.log(val_groups, split="val")

    # Train on training trajectories
    result = await backend.train(
        model, train_groups, learning_rate=LEARNING_RATE
    )
    await model.log(
        train_groups,
        metrics=result.metrics,
        step=result.step,
        split="train",
    )

    print(f"Completed step {batch.step}")

### Use the Model

Try the trained model on a validation document to see how it summarizes.

In [None]:
test_doc = val_documents[0]

client = model.openai_client()
completion = await client.chat.completions.create(
    model=model.get_inference_name(),
    messages=[
        {
            "role": "system",
            "content": "You are a specialized AI assistant that generates concise, informative summaries for documents.",
        },
        {
            "role": "user",
            "content": (
                f"Here is a document: {test_doc.document_text}\n\n"
                "Generate a summary that conveys all relevant information in a concise manner."
            ),
        },
    ],
    max_tokens=1000,
)

summary = completion.choices[0].message.content
print(f"Summary ({len(summary.split())} words):\n")
print(summary)
print("\n--- Questions for this document ---")
for i, q in enumerate(test_doc.questions, 1):
    print(f"\n{i}. {q.q}")
    print(f"   Answer: {q.a}")

### Benchmarking (Optional)

Uncomment and run the cell below to benchmark SOTA models against your trained model on the validation set. This requires an `OPENROUTER_API_KEY` to access the models via OpenRouter.

In [None]:
# PROJECT_NAME = "summarize"
# BENCHMARK_ROLLOUTS = 2
#
# # Define baseline models (all accessed via OpenRouter)
# gpt_4o = art.Model(
#     name="gpt-4o",
#     project=PROJECT_NAME,
#     inference_model_name="openai/gpt-4o",
#     inference_api_key=os.environ["OPENROUTER_API_KEY"],
#     inference_base_url="https://openrouter.ai/api/v1",
# )
#
# gpt_4_1 = art.Model(
#     name="gpt-4.1",
#     project=PROJECT_NAME,
#     inference_model_name="openai/gpt-4.1",
#     inference_api_key=os.environ["OPENROUTER_API_KEY"],
#     inference_base_url="https://openrouter.ai/api/v1",
# )
#
# gemini_2_5_pro = art.Model(
#     name="gemini-2.5-pro",
#     project=PROJECT_NAME,
#     inference_model_name="google/gemini-2.5-pro-preview",
#     inference_api_key=os.environ["OPENROUTER_API_KEY"],
#     inference_base_url="https://openrouter.ai/api/v1",
# )
#
# sonnet_4 = art.Model(
#     name="sonnet-4",
#     project=PROJECT_NAME,
#     inference_model_name="anthropic/claude-sonnet-4",
#     inference_api_key=os.environ["OPENROUTER_API_KEY"],
#     inference_base_url="https://openrouter.ai/api/v1",
# )
#
#
# async def benchmark_model(bm_model: art.Model) -> None:
#     trajectory_groups = await art.gather_trajectory_groups(
#         (
#             art.TrajectoryGroup(
#                 rollout(bm_model, SummarizerScenario(doc=document))
#                 for _ in range(BENCHMARK_ROLLOUTS)
#             )
#             for document in val_documents
#         ),
#         pbar_desc=bm_model.name,
#     )
#     await bm_model.log(trajectories=trajectory_groups, split="val")
#
#
# benchmark_models = [gpt_4o, gpt_4_1, gemini_2_5_pro, sonnet_4]
# for bm_model in benchmark_models:
#     await bm_model.register(backend)
#
# # Benchmark all models simultaneously
# await asyncio.gather(*[benchmark_model(bm_model) for bm_model in benchmark_models])
#
# print("Benchmarking complete!")