<a href="https://colab.research.google.com/github/OCWC22/arte-langraph-email-reply/blob/main/examples/langgraph/art-e-langgraph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To train this email search agent using LangGraph, click **Runtime** > **Run all**. Make sure you've enabled a free Tesla T4 GPU!

<div class="align-center">
<a href="https://github.com/openpipe/art"><img src="https://github.com/openpipe/art/raw/main/assets/ART_pill.png" height="50"></a>
<a href="https://discord.gg/zbBHRUpwf4"><img src="https://github.com/openpipe/art/raw/main/assets/Discord.png" height="50"></a>
<a href="https://art.openpipe.ai"><img src="https://github.com/openpipe/art/raw/main/assets/Documentation_pill.png" height="50"></a>

Questions? Join the Discord and ask away! For feature requests or to leave a star, visit our [Github](https://github.com/openpipe/art).

</div>

<a href="https://art.openpipe.ai/"><img src="https://github.com/openpipe/art/raw/main/assets/Header_separator.png" height="5"></a>

**Email Search Agent with LangGraph**

In this notebook, you will be using [ART](https://github.com/openpipe/art) together with [LangGraph](https://langchain-ai.github.io/langgraph/) to train your own ART•E agent from scratch! This implementation demonstrates how to integrate LangGraph's agent framework with ART's training capabilities.

Beginning with a Qwen 2.5 7B base model, you will train it to search through emails and answer questions about them using LangGraph's ReAct agent pattern. You will construct an [agentic environment](#Environment), define a [rollout](#Rollout) using LangGraph, and run a [training loop](#Loop). You will also learn how to use [RULER](#ruler) to judge the quality of the agent's answers.

**RULER**

RULER is a robust technique for evaluating the quality of an agent's answers and training the agent to produce more of its best completions. To learn more about RULER, see the [RULER documentation](https://art.openpipe.ai/fundamentals/ruler).

Now let's get started!

In [None]:
#@title 💿 Installation

# Portions adapted from Unsloth Notebooks (https://github.com/unslothai/notebooks)
# Copyright (c) Unsloth contributors.
# License: GNU LGPL v3.0.
# Modifications by OpenPipe:
# - switched to uv
# - changed vllm/triton pinning logic
# - added litellm/protobuf pins
# See /licenses/LGPL-3.0.txt and /licenses/GPL-3.0.txt for full text.

%%capture
import os

if "COLAB_" not in "".join(os.environ.keys()):
    !uv pip install "openpipe-art[backend,langgraph]==0.4.9" langchain-core langgraph langchain_openai tenacity datasets "gql<4" --prerelease allow --no-cache-dir
else:
    try:
        import numpy

        get_numpy = f"numpy=={numpy.__version__}"
    except:
        get_numpy = "numpy"
    try:
        import subprocess

        is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except:
        is_t4 = False
    get_vllm, get_triton = (
        ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm", "triton")
    )
    !uv pip install --upgrade \
        "openpipe-art[backend,langgraph]==0.4.9" langchain-core langgraph langchain_openai tenacity datasets "gql<4" "protobuf==5.29.5" {get_vllm} {get_numpy} --prerelease allow --no-cache-dir
    !uv pip install -qqq {get_triton}

<a name="Environment-Variables"></a>

### Environment Variables

**OpenAI (used for RULER judge model)**

Our RULER reward function queries third-party models to judge the quality of the agent's performance. Any model supported by LiteLLM works. For this example we'll use OpenAI's o4-mini model, so we'll need to set the `OPENAI_API_KEY` environment variable.

**Weights & Biases (optional)**

Later on in the notebook, we'll be creating a model that can automatically logs metrics to Weights & Biases and chat completions to Weave. In order to do so, you'll need to provide your Weights & Biases API key as an environment variable.

In [None]:
import os

from dotenv import load_dotenv

load_dotenv()

# Required
# os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"

if not os.environ.get("OPENAI_API_KEY"):
    raise ValueError(
        "OPENAI_API_KEY is required for RULER functionality when using openai/o4-mini."
    )

# Optional
# os.environ["WANDB_API_KEY"] = "YOUR_API_KEY"

if not os.environ.get("WANDB_API_KEY"):
    print("WANDB_API_KEY is not set. We'll skip logging metrics to Weights & Biases.")

<a name="Environment"></a>

### Email Search Environment

ART allows your agent to learn by interacting with its environment. In this example, we'll create an environment where the agent can search through emails and answer questions about them using LangGraph's tools integration.

The agent will have access to three tools:

1. `search_inbox` - Search for emails by keywords
2. `read_email` - Read a specific email by message ID
3. `return_final_answer` - Return the final answer with source email IDs

In [None]:
#@title Email Search Code

import os
import random
import sqlite3
from dataclasses import asdict, dataclass
from datetime import datetime
from textwrap import dedent
from typing import List, Literal, Optional

from datasets import Dataset, Features, Sequence, Value, load_dataset
from pydantic import BaseModel, Field
from tqdm import tqdm


# Email and Scenario data models
class Email(BaseModel):
    message_id: str
    date: str  # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
    subject: Optional[str] = None
    from_address: Optional[str] = None
    to_addresses: List[str] = []  # Populated from recipients table
    cc_addresses: List[str] = []  # Populated from recipients table
    bcc_addresses: List[str] = []  # Populated from recipients table
    body: Optional[str] = None
    file_name: Optional[str] = None


class Scenario(BaseModel):
    id: int
    question: str
    answer: str
    message_ids: List[str]  # message_ids (strings) of referenced emails
    how_realistic: float
    inbox_address: str
    query_date: str
    split: Literal["train", "test"]


@dataclass
class SearchResult:
    message_id: str
    snippet: str


class FinalAnswer(BaseModel):
    answer: str
    source_ids: list[str]


# Database configuration
DB_PATH = "./enron_emails.db"
EMAIL_DATASET_REPO_ID = "corbt/enron-emails"
SCENARIO_DATASET_REPO_ID = "corbt/enron_emails_sample_questions"

# Global database connection
db_conn = None


def create_email_database():
    """Create the email database from Hugging Face dataset"""
    print("Creating email database from Hugging Face dataset...")
    print(
        "This will download and process the full Enron email dataset - this may take several minutes..."
    )

    # Database schema
    SQL_CREATE_TABLES = """
    DROP TABLE IF EXISTS recipients;
    DROP TABLE IF EXISTS emails_fts;
    DROP TABLE IF EXISTS emails;

    CREATE TABLE emails (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        message_id TEXT UNIQUE,
        subject TEXT,
        from_address TEXT,
        date TEXT,
        body TEXT,
        file_name TEXT
    );

    CREATE TABLE recipients (
        email_id TEXT,
        recipient_address TEXT,
        recipient_type TEXT
    );
    """

    SQL_CREATE_INDEXES_TRIGGERS = """
    CREATE INDEX idx_emails_from ON emails(from_address);
    CREATE INDEX idx_emails_date ON emails(date);
    CREATE INDEX idx_emails_message_id ON emails(message_id);
    CREATE INDEX idx_recipients_address ON recipients(recipient_address);
    CREATE INDEX idx_recipients_type ON recipients(recipient_type);
    CREATE INDEX idx_recipients_email_id ON recipients(email_id);
    CREATE INDEX idx_recipients_address_email ON recipients(recipient_address, email_id);

    CREATE VIRTUAL TABLE emails_fts USING fts5(
        subject,
        body,
        content='emails',
        content_rowid='id'
    );

    CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
        INSERT INTO emails_fts (rowid, subject, body)
        VALUES (new.id, new.subject, new.body);
    END;

    CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
        DELETE FROM emails_fts WHERE rowid=old.id;
    END;

    CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
        UPDATE emails_fts SET subject=new.subject, body=new.body WHERE rowid=old.id;
    END;
    """

    # Create database
    conn = sqlite3.connect(DB_PATH)
    cursor = conn.cursor()
    cursor.executescript(SQL_CREATE_TABLES)
    conn.commit()

    # Load dataset
    print("Loading full email dataset...")
    expected_features = Features(
        {
            "message_id": Value("string"),
            "subject": Value("string"),
            "from": Value("string"),
            "to": Sequence(Value("string")),
            "cc": Sequence(Value("string")),
            "bcc": Sequence(Value("string")),
            "date": Value("timestamp[us]"),
            "body": Value("string"),
            "file_name": Value("string"),
        }
    )

    dataset = load_dataset(
        EMAIL_DATASET_REPO_ID, features=expected_features, split="train"
    )
    print(f"Dataset contains {len(dataset)} total emails")

    # Populate database with ALL emails (not limited to 1000)
    print("Populating database with all emails...")
    conn.execute("PRAGMA synchronous = OFF;")
    conn.execute("PRAGMA journal_mode = MEMORY;")
    conn.execute("BEGIN TRANSACTION;")

    record_count = 0
    skipped_count = 0
    duplicate_count = 0
    processed_emails = set()  # Track (subject, body, from) tuples for deduplication

    for email_data in tqdm(dataset, desc="Inserting emails"):
        message_id = email_data["message_id"]
        subject = email_data["subject"]
        from_address = email_data["from"]
        date_obj: datetime = email_data["date"]
        body = email_data["body"]
        file_name = email_data["file_name"]
        to_list = [str(addr) for addr in email_data["to"] if addr]
        cc_list = [str(addr) for addr in email_data["cc"] if addr]
        bcc_list = [str(addr) for addr in email_data["bcc"] if addr]

        # Apply the same filters as the original project
        total_recipients = len(to_list) + len(cc_list) + len(bcc_list)

        # Filter out very long emails and those with too many recipients
        if len(body) > 5000:
            skipped_count += 1
            continue

        if total_recipients > 30:
            skipped_count += 1
            continue

        # Deduplication check (same as original project)
        email_key = (subject, body, from_address)
        if email_key in processed_emails:
            duplicate_count += 1
            continue
        else:
            processed_emails.add(email_key)

        date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")

        cursor.execute(
            """
            INSERT INTO emails (message_id, subject, from_address, date, body, file_name)
            VALUES (?, ?, ?, ?, ?, ?)
        """,
            (message_id, subject, from_address, date_str, body, file_name),
        )

        # Insert recipients
        recipient_data = []
        for addr in to_list:
            recipient_data.append((message_id, addr, "to"))
        for addr in cc_list:
            recipient_data.append((message_id, addr, "cc"))
        for addr in bcc_list:
            recipient_data.append((message_id, addr, "bcc"))

        if recipient_data:
            cursor.executemany(
                """
                INSERT INTO recipients (email_id, recipient_address, recipient_type)
                VALUES (?, ?, ?)
            """,
                recipient_data,
            )

        record_count += 1

    conn.commit()

    # Create indexes and triggers
    print("Creating indexes and FTS...")
    cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
    cursor.execute('INSERT INTO emails_fts(emails_fts) VALUES("rebuild")')
    conn.commit()

    print(f"Successfully created database with {record_count} emails.")
    print(f"Skipped {skipped_count} emails due to length/recipient limits.")
    print(f"Skipped {duplicate_count} duplicate emails.")
    return conn


def get_db_connection():
    """Get database connection"""
    if os.path.exists(DB_PATH):
        print(f"Loading existing database from {DB_PATH}")
        db_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
    else:
        db_conn = create_email_database()
    return db_conn


def search_emails(
    inbox: str,
    keywords: List[str],
    from_addr: Optional[str] = None,
    to_addr: Optional[str] = None,
    sent_after: Optional[str] = None,
    sent_before: Optional[str] = None,
    max_results: int = 10,
) -> List[SearchResult]:
    """Search the email database based on keywords and filters"""
    conn = get_db_connection()
    cursor = conn.cursor()

    where_clauses: List[str] = []
    params: List[str | int] = []

    if not keywords:
        raise ValueError("No keywords provided for search.")

    if max_results > 10:
        raise ValueError("max_results must be less than or equal to 10.")

    # FTS5 default is AND, so just join keywords. Escape quotes for safety.
    fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords)
    where_clauses.append("fts.emails_fts MATCH ?")
    params.append(fts_query)

    # Inbox filter
    where_clauses.append("""
        (e.from_address = ? OR EXISTS (
            SELECT 1 FROM recipients r_inbox
            WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.message_id
        ))
    """)
    params.extend([inbox, inbox])

    if from_addr:
        where_clauses.append("e.from_address = ?")
        params.append(from_addr)

    if to_addr:
        where_clauses.append("""
            EXISTS (
                SELECT 1 FROM recipients r_to
                WHERE r_to.recipient_address = ? AND r_to.email_id = e.message_id
            )
        """)
        params.append(to_addr)

    if sent_after:
        where_clauses.append("e.date >= ?")
        params.append(f"{sent_after} 00:00:00")

    if sent_before:
        where_clauses.append("e.date < ?")
        params.append(f"{sent_before} 00:00:00")

    sql = f"""
        SELECT
            e.message_id,
            snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
        FROM
            emails e JOIN emails_fts fts ON e.id = fts.rowid
        WHERE
            {" AND ".join(where_clauses)}
        ORDER BY
            e.date DESC
        LIMIT ?;
    """
    params.append(max_results)

    cursor.execute(sql, params)
    results = cursor.fetchall()

    return [SearchResult(message_id=row[0], snippet=row[1]) for row in results]


def read_email(message_id: str) -> Optional[Email]:
    """Retrieve a single email by its message_id"""
    conn = get_db_connection()
    cursor = conn.cursor()

    # Get email details
    cursor.execute(
        "SELECT message_id, date, subject, from_address, body, file_name FROM emails WHERE message_id = ?",
        (message_id,),
    )
    email_row = cursor.fetchone()

    if not email_row:
        return None

    msg_id, date, subject, from_addr, body, file_name = email_row

    # Get recipients
    cursor.execute(
        "SELECT recipient_address, recipient_type FROM recipients WHERE email_id = ?",
        (message_id,),
    )
    recipient_rows = cursor.fetchall()

    to_addresses = []
    cc_addresses = []
    bcc_addresses = []

    for addr, type_val in recipient_rows:
        if type_val.lower() == "to":
            to_addresses.append(addr)
        elif type_val.lower() == "cc":
            cc_addresses.append(addr)
        elif type_val.lower() == "bcc":
            bcc_addresses.append(addr)

    return Email(
        message_id=msg_id,
        date=date,
        subject=subject,
        from_address=from_addr,
        to_addresses=to_addresses,
        cc_addresses=cc_addresses,
        bcc_addresses=bcc_addresses,
        body=body,
        file_name=file_name,
    )


def load_training_scenarios(
    split: Literal["train", "test"] = "train",
    limit: Optional[int] = None,
    max_messages: Optional[int] = 1,
    shuffle: bool = False,
    seed: Optional[int] = None,
) -> List[Scenario]:
    """Load training scenarios from Hugging Face dataset"""
    print(f"Loading {split} scenarios from Hugging Face...")
    dataset: Dataset = load_dataset(SCENARIO_DATASET_REPO_ID, split=split)

    if max_messages is not None:
        dataset = dataset.filter(lambda x: len(x["message_ids"]) <= max_messages)

    if shuffle or (seed is not None):
        if seed is not None:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()

    # Convert each row to a Scenario object
    scenarios = [Scenario(**row, split=split) for row in dataset]

    if max_messages is not None:
        scenarios = [s for s in scenarios if len(s.message_ids) <= max_messages]

    if shuffle:
        if seed is not None:
            rng = random.Random(seed)
            rng.shuffle(scenarios)
        else:
            random.shuffle(scenarios)

    if limit is not None:
        scenarios = scenarios[:limit]

    print(f"Loaded {len(scenarios)} scenarios.")
    return scenarios


# Load training scenarios
training_scenarios = load_training_scenarios(
    split="train", limit=50, max_messages=1, shuffle=True, seed=42
)

print("Email search environment created with full Enron dataset!")
print(
    f"Database contains the complete email dataset, loaded {len(training_scenarios)} training scenarios."
)

# print first scenario
print("\nSample scenario")
print("id:", training_scenarios[0].id)
print("question:", training_scenarios[0].question)
print("answer:", training_scenarios[0].answer)
print("message_ids:", training_scenarios[0].message_ids)
print("how_realistic:", training_scenarios[0].how_realistic)
print("inbox_address:", training_scenarios[0].inbox_address)
print("query_date:", training_scenarios[0].query_date)
print("split:", training_scenarios[0].split)

Loading train scenarios from Hugging Face...
Loaded 50 scenarios.
Email search environment created with full Enron dataset!
Database contains the complete email dataset, loaded 50 training scenarios.

Sample scenario
id: 3296
question: Who can I contact for Power Operations when Sally is in London?
answer: Stacey White (x31870) and Leslie Reeves (x37962).
message_ids: ['<6033065.1075856098960.JavaMail.evans@thyme>']
how_realistic: 0.699999988079071
inbox_address: louise.kitchen@enron.com
query_date: 2001-01-25
split: train


### Creating a Model

Now that we've defined the rules of our environment, we can create a model that will learn to search emails effectively. We'll use a Qwen 2.5 7B model for this example.

In [None]:
import art
from art.local import LocalBackend

random.seed(42)

# Declare the model
model = art.TrainableModel(
    name="email-agent-langgraph-001",
    project="email-search-agent-langgraph",
    base_model="Qwen/Qwen2.5-7B-Instruct",
)

# To run on a T4, we need to override some config defaults.
model._internal_config = art.dev.InternalModelConfig(
    init_args=art.dev.InitArgs(
        max_seq_length=8192,
    ),
    engine_args=art.dev.EngineArgs(
        enforce_eager=True,
        gpu_memory_utilization=0.8,
    ),
)

# Initialize the server
backend = LocalBackend(
    # Normally we don't want to run the server in-process, but for the output
    # to show up properly on Google Colab we'll enable this.
    in_process=True,
    path="./.art",
)

# Register the model with the local Backend (sets up logging, inference, and training)
await model.register(backend)

<a name="Rollout"></a>

### Defining a Rollout with LangGraph

A rollout is a single episode of an agent performing its task. In this example, we'll use LangGraph's ReAct agent to handle the rollout. The rollout function presents the agent with an email search scenario, and the LangGraph agent uses the available tools to search for emails and answer the question.

When the agent provides a final answer, the `correct` metric is calculated based on whether the answer is correct.

In [None]:
import uuid

import weave
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from litellm import acompletion
from tenacity import retry, stop_after_attempt
from art.langgraph import init_chat_model

import art

if os.getenv("WANDB_API_KEY", ""):
    weave.init(model.project, settings={"print_call_link": False})

MAX_TURNS = 20

class CorrectnessJudgeResponse(BaseModel):
    reasoning: str = Field(description="Explanation of the reasoning process.")
    accept: bool = Field(description="Whether the AI answer should be accepted.")


@retry(stop=stop_after_attempt(3))
async def judge_correctness(
    scenario: Scenario, answer: str
) -> CorrectnessJudgeResponse:
    system_prompt = dedent(
        """
        You are given a question, the reference answer (labelled **Reference answer**), and an answer generated by an AI assistant (labelled **AI answer**).

        Your task is to decide whether the AI answer is correct and should be accepted. You should accept the answer if it contains the relevant information from the reference answer. You should not accept the answer if it is missing information relevant to the question, or if it contradicts the reference answer.
        """
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": (
                f"Question: {scenario.question}\n"
                f"Reference answer: {scenario.answer}\n"
                f"AI answer: {answer}"
            ),
        },
    ]

    response = await acompletion(
        model="openai/gpt-4.1",
        messages=messages,
        response_format=CorrectnessJudgeResponse,
    )

    first_choice = response.choices[0]
    raw_content = first_choice.message.content or "{}"

    try:
        return CorrectnessJudgeResponse.model_validate_json(raw_content)
    except Exception as e:
        return CorrectnessJudgeResponse(
            reasoning=f"Parse error: {e}\nRaw: {raw_content}", accept=False
        )


class ProjectTrajectory(art.Trajectory):
    final_answer: FinalAnswer | None = None


class EmailScenario(BaseModel):
    step: int
    scenario: Scenario


@weave.op
async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory:
    scenario = email_scenario.scenario

    traj = ProjectTrajectory(
        reward=0.0,
        messages_and_choices=[],
        metadata={
            "scenario_id": scenario.id,
            "step": email_scenario.step,
        },
    )

    system_prompt = dedent(
        f"""
        You are an email search agent. You are given a user query and a list of tools you can use to search the user's email. Use the tools to search the user's emails and find the answer to the user's query. You may take up to {MAX_TURNS} turns to find the answer, so if your first search doesn't find the answer, you can try with different keywords.

        User's email address is {scenario.inbox_address}
        Today's date is {scenario.query_date}

        When you have found the answer, use the return_final_answer_tool to provide your final answer along with the source message IDs.
        """
    )

    # Store final answer in trajectory
    final_answer = None

    # Define tools inside the rollout function to access local variables
    @tool
    def search_inbox_tool(keywords: list[str]) -> list[dict]:
        """Search the inbox for emails matching the given keywords and return
        a list of dictionaries so the LLM can easily consume them."""
        results = search_emails(
            inbox=scenario.inbox_address,
            keywords=keywords,
            sent_before=scenario.query_date,
        )
        return [asdict(result) for result in results]

    @tool
    def read_email_tool(message_id: str) -> dict | None:
        """Read a specific email by message ID."""
        email = read_email(message_id)
        if email:
            return email.model_dump()
        return None

    @tool
    def return_final_answer_tool(answer: str, reference_message_ids: list[str]) -> dict:
        """Return the final answer and the message IDs of the emails that were used to generate the answer."""
        nonlocal final_answer
        final_answer = FinalAnswer(answer=answer, source_ids=reference_message_ids)
        return final_answer.model_dump()

    # Create LangGraph tools
    tools = [search_inbox_tool, read_email_tool, return_final_answer_tool]

    chat_model = init_chat_model(model.name, temperature=1.0)

    # Create the LangGraph ReAct agent
    react_agent = create_react_agent(chat_model, tools)

    try:
        # Run the agent
        config = {
            "configurable": {"thread_id": str(uuid.uuid4())},
            "recursion_limit": MAX_TURNS,
        }

        await react_agent.ainvoke(
            {
                "messages": [
                    SystemMessage(content=system_prompt),
                    HumanMessage(content=scenario.question),
                ]
            },
            config=config,
        )

        # Check if we got a final answer
        if final_answer:
            traj.final_answer = final_answer
            # Score the trajectory
            correctness_judge_response = await judge_correctness(
                scenario, traj.final_answer.answer
            )
            traj.metrics["correct"] = float(correctness_judge_response.accept)

    except Exception as e:
        print(f"Error running LangGraph agent: {e}")
        # Add error information to trajectory
        traj.messages_and_choices.append(
            {"role": "assistant", "content": f"Error: {str(e)}"}
        )

    return traj


print("LangGraph rollout function defined!")

LangGraph rollout function defined!


<a name="ruler"></a>

### How RULER works

**RULER** leverages two key insights:

1. Relative scoring is easier than absolute scoring: It's easier for an LLM to rank several solutions relative to each other than to score them in isolation
2. GRPO only needs relative scores: Since GRPO normalizes scores within each group, only the relative rankings matter, not absolute values

The process:

1. Generate N trajectories for a given scenario
2. Pass all N trajectories to **RULER**
3. **RULER** deduplicates common prefixes (e.g., identical system messages)
4. An LLM judge scores each trajectory from 0 to 1 based on goal achievement
5. These scores are used directly as rewards in GRPO training

To learn more about **RULER**, check out the [RULER docs](https://art.openpipe.ai/fundamentals/ruler).

In [None]:
#@title Sample RULER evaluation

import art
from art.rewards import ruler_score_group

# Test RULER with a simple example
base_messages = [
    {"role": "system", "content": "You count numbers using numeric symbols."},
    {"role": "user", "content": "Count to 10."},
]

good_trajectory = art.Trajectory(
    messages_and_choices=[
        *base_messages,
        {"role": "assistant", "content": "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"},
    ],
    reward=0,
)

mediocre_trajectory = art.Trajectory(
    messages_and_choices=[
        *base_messages,
        {
            "role": "assistant",
            "content": "one, two, three, four, five, six, seven, eight, nine, ten",
        },
    ],
    reward=0,
)

bad_trajectory = art.Trajectory(
    messages_and_choices=[
        *base_messages,
        {"role": "assistant", "content": "a, b, c, d, e, f, g, h, i, j"},
    ],
    reward=0,
)

sample_group = art.TrajectoryGroup(
    trajectories=[
        good_trajectory,
        mediocre_trajectory,
        bad_trajectory,
    ]
)

judged_group = await ruler_score_group(sample_group, "openai/o4-mini", debug=True)
assert judged_group is not None

# Display rankings
sorted_trajectories = sorted(
    judged_group.trajectories, key=lambda t: t.reward, reverse=True
)
for rank, traj in enumerate(sorted_trajectories, 1):
    messages = traj.messages()
    print(f"\nRank {rank}: Score {traj.reward:.3f}")
    print(f"  Response: {messages[-1]['content'][:50]}...")


Rank 1: Score 1.000
  Response: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10...

Rank 2: Score 0.500
  Response: one, two, three, four, five, six, seven, eight, ni...

Rank 3: Score 0.000
  Response: a, b, c, d, e, f, g, h, i, j...


<a name="Loop"></a>

### Training Loop with LangGraph

The training loop is where the magic happens. For each of the steps defined below, the rollout function will be called multiple times in parallel using LangGraph's ReAct agent. Each scenario will produce a trajectory, which will be used to update the model.

The `gather` step will wait for all of the trajectories to be generated, then it will use RULER to assign relative scores to each trajectory.

Our notebook will then delete all but the most recent checkpoint and train the model on the scored trajectories.

In [None]:
# Training configuration
from art.utils import iterate_dataset
from art.langgraph import wrap_rollout

training_config = {
    "groups_per_step": 2,
    "num_epochs": 20,
    "rollouts_per_group": 4,
    "learning_rate": 1e-5,
    "max_steps": 20,
}

# Use iterate_dataset with real training scenarios (similar to train.py)
training_iterator = iterate_dataset(
    training_scenarios,  # Use real scenarios from Hugging Face
    groups_per_step=training_config["groups_per_step"],
    num_epochs=training_config["num_epochs"],
    initial_step=await model.get_step(),
)

for batch in training_iterator:
    print(
        f"Training step {batch.step}, epoch {batch.epoch}, epoch step {batch.epoch_step}"
    )
    print(f"Batch contains {len(batch.items)} scenarios")

    # Create trajectory groups for this batch (similar to train.py)
    groups = []
    for scenario in batch.items:
        groups.append(
            art.TrajectoryGroup(
                (
                    wrap_rollout(model, rollout)(
                        model, EmailScenario(step=batch.step, scenario=scenario)
                    )
                    for _ in range(training_config["rollouts_per_group"])
                )
            )
        )
    print(groups[0])
    # Gather all trajectory groups
    finished_groups = await art.gather_trajectory_groups(
        groups,
        pbar_desc="gather",
        max_exceptions=training_config["rollouts_per_group"] * len(batch.items),
    )

    judged_groups = []
    for group in finished_groups:
        # Use RULER to assign relative scores to each trajectory
        judged_group = await ruler_score_group(group, "openai/o4-mini", debug=True)
        judged_groups.append(judged_group)

    await model.delete_checkpoints()
    await model.train(
        judged_groups,
        config=art.TrainConfig(learning_rate=training_config["learning_rate"]),
        # Lowering the logprob_calculation_chunk_size is a memory saving measure
        # to allow longer sequences (up to 8192 tokens) to be processed on a T4.
        _config={"logprob_calculation_chunk_size": 8},
    )

    print(f"Completed training step {batch.step}")

    # Stop after max_steps for demo purposes (adjust as needed)
    if batch.step >= training_config["max_steps"]:
        break

### Using the Model

Just like that, you've trained an agent to search emails and answer questions using LangGraph! Now it's time to use your model outside of the training loop.

Check out the code below for a small demo of the model you just trained!

In [None]:
#@title Loading/inference code

# Test the trained model using the rollout function
# This avoids memory issues and uses the same inference path as training

print("Testing the trained LangGraph model with a real scenario...\n")


# Use a scenario from our training set
test_scenario = training_scenarios[1]

print(f"Test scenario ID: {test_scenario.id}")
print(f"Question: {test_scenario.question}")
print(f"Expected answer: {test_scenario.answer}")
print(f"Reference message IDs: {test_scenario.message_ids}")
print(f"Inbox: {test_scenario.inbox_address}")
print(f"Query date: {test_scenario.query_date}")
print("-" * 50)

# Run the rollout function with the trained model
test_email_scenario = EmailScenario.model_validate(
    {"step": 0, "scenario": test_scenario.model_dump()}
)
result_trajectory = await wrap_rollout(model, rollout)(model, test_email_scenario)

print("LangGraph Agent's trajectory:")
print("-" * 20)

# Display the conversation
messages = result_trajectory.messages()
for i, msg in enumerate(messages):
    role = msg.get("role", "unknown")
    content = msg.get("content", "")
    tool_calls = msg.get("tool_calls", [])

    if role == "system":
        print(
            f"[SYSTEM]: {content[:100]}..."
            if len(content) > 100
            else f"[SYSTEM]: {content}"
        )
    elif role == "user":
        print(f"[USER]: {content}")
    elif role == "assistant":
        if tool_calls:
            print(f"[ASSISTANT]: {tool_calls}")
        if content:
            print(f"[ASSISTANT]: {content}")
    elif role == "tool":
        tool_name = msg.get("name", "unknown_tool")
        print(
            f"[TOOL - {tool_name}]: {content[:200]}..."
            if len(content) > 200
            else f"[TOOL - {tool_name}]: {content}"
        )

    print()

print("-" * 50)
if result_trajectory.final_answer:
    print(f"Agent's Final Answer: {result_trajectory.final_answer.answer}")
    print(f"Source IDs Used: {result_trajectory.final_answer.source_ids}")
else:
    print("No final answer provided by the agent")

print(f"\nExpected Answer: {test_scenario.answer}")
print(f"Expected Source IDs: {test_scenario.message_ids}")

print("\n🎉 LangGraph email search agent testing completed!")
print(
    "The agent used LangGraph's ReAct pattern with the same inference path as during training."
)

<div class="align-center">
<a href="https://github.com/openpipe/art"><img src="https://github.com/openpipe/art/raw/main/assets/ART_pill.png" height="50"></a>
<a href="https://discord.gg/zbBHRUpwf4"><img src="https://github.com/openpipe/art/raw/main/assets/Discord.png" height="50"></a>
<a href="https://art.openpipe.ai"><img src="https://github.com/openpipe/art/raw/main/assets/Documentation_pill.png" height="50"></a>

Questions? Join the Discord and ask away! For feature requests or to leave a star, visit our [Github](https://github.com/openpipe/art).

</div>

# Task
Explain how to set up an RL task to write and reply to emails based on user criteria and examples of their writing style.

## Define the email environment

### Subtask:
Create a simulated environment where the agent receives emails and needs to compose replies or new emails. This would involve defining the state of the environment (e.g., the incoming email, the user's instructions/criteria).


**Reasoning**:
Define the data structures for the email environment state, including incoming email details, user instructions, and a place for writing style examples. This addresses steps 1, 2, 3, and 4 of the subtask.



In [1]:
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field

# Data structure to represent an email
class SimulatedEmail(BaseModel):
    """Represents a simulated email."""
    message_id: str = Field(description="Unique identifier for the email.")
    date: str = Field(description="Date the email was sent (ISO 8601 format).")
    subject: Optional[str] = Field(None, description="Subject line of the email.")
    from_address: str = Field(description="Sender's email address.")
    to_addresses: List[str] = Field([], description="List of 'To' recipient email addresses.")
    cc_addresses: List[str] = Field([], description="List of 'CC' recipient email addresses.")
    bcc_addresses: List[str] = Field([], description="List of 'BCC' recipient email addresses.")
    body: str = Field(description="Body content of the email.")
    attachments: List[str] = Field([], description="List of attachment names/identifiers.")

# Data structure to represent user instructions
class UserInstructions(BaseModel):
    """Represents user instructions for composing an email."""
    task_type: str = Field(description="Type of task: 'reply' or 'new_email'.")
    recipient: Optional[str | List[str]] = Field(None, description="Recipient(s) for a new email.")
    subject: Optional[str] = Field(None, description="Subject for a new email.")
    key_points: List[str] = Field([], description="Key points to include in the email.")
    tone: Optional[str] = Field(None, description="Desired tone for the email (e.g., 'formal', 'casual').")
    length: Optional[str] = Field(None, description="Desired length (e.g., 'brief', 'detailed').")
    other_criteria: Optional[Dict[str, Any]] = Field(None, description="Any other specific instructions.")

# Data structure to represent the state of the email environment for the agent
class EmailEnvironmentState(BaseModel):
    """Represents the state of the email environment for the agent."""
    incoming_email: Optional[SimulatedEmail] = Field(None, description="The incoming email if the task is a reply.")
    user_instructions: UserInstructions = Field(description="Instructions from the user for the email task.")
    writing_style_examples: List[SimulatedEmail] = Field([], description="Examples of the user's writing style.")
    current_date: str = Field(description="The current date in the simulated environment (ISO 8601 format).")

print("Defined data structures for the email environment state.")

Defined data structures for the email environment state.


## Develop tools/actions

### Subtask:
Define the actions the agent can take within the environment. This might include tools for drafting email content, sending the email, or perhaps even searching for information to include in the email.


**Reasoning**:
Define the actions as functions or classes that can be called by the agent, including drafting email content, sending the email, and searching for information.



In [2]:
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
import uuid
from datetime import datetime

# Data structure to represent drafted email components
class DraftedEmail(BaseModel):
    """Represents components of a drafted email."""
    to_addresses: List[str] = Field([], description="List of 'To' recipient email addresses.")
    cc_addresses: List[str] = Field([], description="List of 'CC' recipient email addresses.")
    bcc_addresses: List[str] = Field([], description="List of 'BCC' recipient email addresses.")
    subject: str = Field("", description="Subject line of the email.")
    body: str = Field("", description="Body content of the email.")

# Action 1: Draft email content
def draft_email_action(state: EmailEnvironmentState, draft_components: DraftedEmail) -> DraftedEmail:
    """
    Simulates the agent drafting email content.

    Args:
        state: The current state of the email environment.
        draft_components: The current draft components being built by the agent.

    Returns:
        The updated drafted email components.
    """
    # In a real RL setup, the agent would generate draft_components based on state
    # This is a placeholder function that just returns the provided components
    print("Agent is drafting email content...")
    return draft_components

# Action 2: Send the email
def send_email_action(drafted_email: DraftedEmail) -> SimulatedEmail:
    """
    Simulates sending the drafted email.

    Args:
        drafted_email: The drafted email components.

    Returns:
        A SimulatedEmail object representing the sent email.
    """
    print("Agent is sending the email...")
    # Simulate creating a sent email object
    sent_email = SimulatedEmail(
        message_id=str(uuid.uuid4()),
        date=datetime.now().isoformat(), # Use current simulation time
        subject=drafted_email.subject,
        from_address="simulated_user@example.com", # Assume a fixed user address for simulation
        to_addresses=drafted_email.to_addresses,
        cc_addresses=drafted_email.cc_addresses,
        bcc_addresses=drafted_email.bcc_addresses,
        body=drafted_email.body,
        attachments=[] # For simplicity, no attachments in this simulation
    )
    print(f"Email sent with Subject: {sent_email.subject}")
    return sent_email

# Action 3: Search for information (Placeholder)
def search_information_action(query: str) -> str:
    """
    Simulates searching for information based on a query.

    Args:
        query: The search query from the agent.

    Returns:
        A string containing simulated search results.
    """
    print(f"Agent is searching for information with query: '{query}'")
    # In a real RL setup, this would interact with a knowledge base or search API
    # This is a placeholder
    simulated_results = f"Simulated search results for '{query}': Information found regarding the query."
    print(f"Simulated search results: {simulated_results[:50]}...")
    return simulated_results

print("Defined actions for the email agent.")

Defined actions for the email agent.


## Create training data

### Subtask:
Gather a dataset of emails you have written or replied to, along with the context (the email you were replying to, or the situation for a new email) and your intended message or style.


**Reasoning**:
The subtask requires gathering a dataset of emails with context and extracting key information. Since I cannot access external files or user data directly, I will simulate the creation of a small, sample dataset in the required structured format based on the previously defined `SimulatedEmail` and `UserInstructions` models. This will demonstrate the structure of the dataset needed for training.



In [3]:
from datetime import datetime
import uuid

# Simulate creating a small sample dataset
# This dataset will contain examples of email scenarios, user instructions,
# and the corresponding desired email output (simulated from "past emails").
# In a real scenario, this data would be gathered from a user's actual emails.

sample_training_dataset = []

# Example 1: Replying to an email
incoming_email_1 = SimulatedEmail(
    message_id=str(uuid.uuid4()),
    date="2023-10-26T10:00:00",
    subject="Meeting Tomorrow",
    from_address="colleague@example.com",
    to_addresses=["simulated_user@example.com"],
    body="Hi, just a reminder about our meeting tomorrow at 2 PM. Can you confirm if you'll be there?",
)
user_instructions_1 = UserInstructions(
    task_type="reply",
    key_points=["Confirm attendance", "Looking forward to it"],
    tone="friendly",
    length="brief",
)
# Simulated "past email" (the desired reply)
desired_output_email_1 = SimulatedEmail(
    message_id=str(uuid.uuid4()), # This would be the ID of the email the user actually sent
    date="2023-10-26T11:00:00",
    subject="Re: Meeting Tomorrow",
    from_address="simulated_user@example.com",
    to_addresses=["colleague@example.com"],
    body="Hi, Yes, I'll be there. Looking forward to it!",
)

sample_training_dataset.append({
    "context": {
        "incoming_email": incoming_email_1.model_dump(),
        "user_instructions": user_instructions_1.model_dump(),
        "writing_style_examples": [], # Add relevant style examples here if available
        "current_date": "2023-10-26T11:00:00"
    },
    "desired_output": desired_output_email_1.model_dump()
})

# Example 2: Composing a new email
user_instructions_2 = UserInstructions(
    task_type="new_email",
    recipient="manager@example.com",
    subject="Project Update",
    key_points=["Project Alpha on track", "Milestone X completed", "Need approval for Phase 2"],
    tone="formal",
    length="detailed",
    other_criteria={"attachments_needed": False}
)
# Simulated "past email" (the desired new email)
desired_output_email_2 = SimulatedEmail(
    message_id=str(uuid.uuid4()),
    date="2023-10-27T09:00:00",
    subject="Project Update - Alpha",
    from_address="simulated_user@example.com",
    to_addresses=["manager@example.com"],
    body="Dear Manager, Please find below an update on Project Alpha. The project remains on track. We have successfully completed Milestone X. We require your approval to proceed with Phase 2. Thank you.",
)

sample_training_dataset.append({
    "context": {
        "incoming_email": None, # No incoming email for a new email task
        "user_instructions": user_instructions_2.model_dump(),
        "writing_style_examples": [], # Add relevant style examples here if available
        "current_date": "2023-10-27T09:00:00"
    },
    "desired_output": desired_output_email_2.model_dump()
})

# Display the structure of the sample dataset
print(f"Created a sample training dataset with {len(sample_training_dataset)} entries.")
import json
print("Sample dataset structure:")
print(json.dumps(sample_training_dataset[0], indent=2))


Created a sample training dataset with 2 entries.
Sample dataset structure:
{
  "context": {
    "incoming_email": {
      "message_id": "88d7ab39-4660-4654-9d09-fc7d68c4082f",
      "date": "2023-10-26T10:00:00",
      "subject": "Meeting Tomorrow",
      "from_address": "colleague@example.com",
      "to_addresses": [
        "simulated_user@example.com"
      ],
      "cc_addresses": [],
      "bcc_addresses": [],
      "body": "Hi, just a reminder about our meeting tomorrow at 2 PM. Can you confirm if you'll be there?",
      "attachments": []
    },
    "user_instructions": {
      "task_type": "reply",
      "recipient": null,
      "subject": null,
      "key_points": [
        "Confirm attendance",
        "Looking forward to it"
      ],
      "tone": "friendly",
      "length": "brief",
      "other_criteria": null
    },
    "writing_style_examples": [],
    "current_date": "2023-10-26T11:00:00"
  },
  "desired_output": {
    "message_id": "a5da1965-1035-49eb-9ad1-75446ff62e

## Design the reward function

### Subtask:
Design the reward function. This is a crucial step. You need a way to automatically evaluate how good the agent's generated emails are. This could involve:

*   Using an LLM judge (similar to RULER) to score emails based on your criteria (e.g., tone, clarity, inclusion of key points).
*   Potentially comparing generated emails to your examples or gold-standard replies.


**Reasoning**:
Define the `evaluate_email` function as requested, incorporating the use of an LLM judge (conceptually similar to RULER) and a comparison to the desired output email from the training data to generate a reward signal.



In [4]:
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
import json
import random # Used for simulating LLM judge variability

# Assuming SimulatedEmail, UserInstructions, and EmailEnvironmentState are defined in previous cells

class LLMJudgeScore(BaseModel):
    """Represents the output structure from a simulated LLM judge."""
    reasoning: str = Field(description="Explanation from the LLM judge.")
    score: float = Field(description="A numerical score from 0 to 1 based on evaluation criteria.")

# This is a simulated LLM judge function. In a real implementation, this would
# involve calling an actual LLM API with a carefully crafted prompt.
async def simulate_llm_judge(
    generated_email: SimulatedEmail,
    environment_state: EmailEnvironmentState,
    desired_output_email: Optional[SimulatedEmail] = None # Included for comparison
) -> LLMJudgeScore:
    """
    Simulates an LLM judging the quality of a generated email.

    Args:
        generated_email: The email generated by the agent.
        environment_state: The context in which the email was generated.
        desired_output_email: The gold-standard or desired email for comparison.

    Returns:
        A simulated LLMJudgeScore object.
    """
    print("Simulating LLM judge evaluation...")

    # Access user instructions and context from the environment state
    instructions = environment_state.user_instructions
    incoming_email = environment_state.incoming_email
    writing_style_examples = environment_state.writing_style_examples

    # --- LLM Judge Criteria Outline ---
    # The actual prompt for the LLM judge would detail these points:
    # 1. Adherence to UserInstructions:
    #    - Inclusion of all key_points?
    #    - Correct tone (formal, casual, etc.)?
    #    - Appropriate length (brief, detailed)?
    #    - Any other_criteria met?
    # 2. Contextual Appropriateness:
    #    - If a reply, does it logically follow the incoming_email?
    #    - Is it appropriate for the recipient(s)?
    # 3. Writing Style:
    #    - Does it match the style of writing_style_examples (if provided)? (This is harder for a simple judge, might need specific few-shot examples or fine-tuning).
    # 4. Overall Quality:
    #    - Clarity, coherence, grammar, spelling.

    # --- Scoring Logic (Simulated) ---
    # This is a simplified simulation. A real judge would use the LLM's output
    # to derive a structured score or ranking.

    score = 0.0
    reasoning_points = []

    # Simulate scoring based on instructions
    if all(point in generated_email.body for point in instructions.key_points):
        score += 0.4 # Reward for including key points
        reasoning_points.append("Included all key points.")
    else:
        reasoning_points.append("Missed some key points.")

    # Simple tone check simulation (very basic)
    if instructions.tone == "formal" and ("Dear" in generated_email.body and "Sincerely" in generated_email.body):
         score += 0.2
         reasoning_points.append("Tone appears formal.")
    elif instructions.tone == "friendly" and ("Hi," in generated_email.body or "Thanks," in generated_email.body):
         score += 0.2
         reasoning_points.append("Tone appears friendly.")
    else:
         reasoning_points.append(f"Tone evaluation based on criteria '{instructions.tone}' inconclusive.")


    # Simulate length check (very basic)
    generated_length = len(generated_email.body.split())
    if instructions.length == "brief" and generated_length < 100:
        score += 0.1
        reasoning_points.append("Email is brief as requested.")
    elif instructions.length == "detailed" and generated_length > 100:
        score += 0.1
        reasoning_points.append("Email is detailed as requested.")
    else:
        reasoning_points.append(f"Length evaluation based on criteria '{instructions.length}' inconclusive.")


    # --- Comparison to Desired Output (if available) ---
    # This is a crucial part for supervised-like signals.
    if desired_output_email:
        # Simulate comparison - a real implementation could use:
        # - String matching (exact or fuzzy)
        # - Semantic similarity (e.g., using embeddings)
        # - Overlap in key information extracted from both emails

        # Simple simulation: reward if the body is an exact match (unlikely in reality)
        # Or, a more realistic approach: reward based on similarity metrics.
        # For this simulation, let's use a simplified overlap check.
        desired_body_words = set(desired_output_email.body.lower().split())
        generated_body_words = set(generated_email.body.lower().split())
        common_words = desired_body_words.intersection(generated_body_words)
        overlap_ratio = len(common_words) / max(len(desired_body_words), 1) # Avoid division by zero

        # Reward based on word overlap (simple proxy for content similarity)
        score += overlap_ratio * 0.3 # Allocate some reward for content similarity

        if overlap_ratio > 0.5:
             reasoning_points.append(f"Generated email has significant word overlap with desired output ({overlap_ratio:.2f}).")
        else:
             reasoning_points.append(f"Generated email has low word overlap with desired output ({overlap_ratio:.2f}).")

    # Add a small random component to simulate LLM variability and exploration
    score = max(0.0, min(1.0, score + random.uniform(-0.05, 0.05))) # Ensure score is between 0 and 1

    reasoning = " ".join(reasoning_points) if reasoning_points else "Basic evaluation performed."

    return LLMJudgeScore(reasoning=reasoning, score=score)


# Define the main reward function
async def evaluate_email(
    generated_email: SimulatedEmail,
    environment_state: EmailEnvironmentState,
    desired_output_email: Optional[SimulatedEmail] = None # From training data
) -> float:
    """
    Evaluates the quality of a generated email and returns a reward score.

    This function combines the assessment from a simulated LLM judge
    and a comparison to the desired output email (if provided in training data).

    Args:
        generated_email: The email generated by the agent.
        environment_state: The context in which the email was generated.
        desired_output_email: The gold-standard or desired email from the training data.

    Returns:
        A numerical reward score for the generated email.
    """
    print("Evaluating generated email...")

    # Use the simulated LLM judge to get an initial score based on instructions and context
    judge_result = await simulate_llm_judge(generated_email, environment_state, desired_output_email)

    # The LLM judge score is a primary component of the reward
    reward_score = judge_result.score

    # Additional potential reward/penalty logic could go here, for example:
    # - Penalize for errors (e.g., sending to wrong recipient, empty body)
    # - Reward for efficiency (e.g., using minimal steps if applicable)

    # For this design, the LLM judge score (which already incorporates comparison
    # to desired output via simulation) is the main reward signal.

    print(f"Evaluation complete. Reward score: {reward_score:.4f}")
    print(f"Judge Reasoning: {judge_result.reasoning}")

    return reward_score

print("Defined the evaluate_email function for reward calculation.")


Defined the evaluate_email function for reward calculation.


## Adapt the agent architecture

### Subtask:
Adapt the agent architecture. Modify or create a LangGraph agent (or similar) that can interact with your custom environment and use the defined tools to generate email content.


**Reasoning**:
Define the LangGraph state and workflow, incorporating the previously defined actions as nodes and handling transitions based on the agent's process of drafting and sending an email.



In [5]:
import operator
from typing import Annotated, Sequence, TypedDict

from langchain_core.messages import BaseMessage
from langgraph.graph import StateGraph, END

# Assuming SimulatedEmail, UserInstructions, EmailEnvironmentState,
# draft_email_action, send_email_action, and search_information_action
# are defined in previous cells.

# 1. Define the state graph
class AgentState(TypedDict):
    """
    Represents the state of the LangGraph agent.

    Attributes:
        environment_state: The current state of the email environment.
        drafted_email: The email being drafted by the agent.
        messages: A list of messages representing the conversation history (for agent reasoning).
        search_results: Results from information search (if performed).
        final_email_sent: Flag indicating if the email has been sent.
    """
    environment_state: EmailEnvironmentState
    drafted_email: DraftedEmail
    messages: Annotated[Sequence[BaseMessage], operator.add]
    search_results: Optional[str]
    final_email_sent: bool


# 2. Implement the LangGraph workflow
workflow = StateGraph(AgentState)

# Define nodes for the graph
def call_draft_email(state: AgentState) -> AgentState:
    """Calls the draft_email_action and updates the state."""
    print("Calling draft_email node...")
    # In a real agent, the LLM would generate draft_components based on state.
    # For this structure definition, we'll simulate a draft update.
    # A real agent would use a tool/LLM call here to get DraftedEmail components.
    # For now, let's assume the agent somehow produces a draft.
    # This part needs the LLM integration to decide what to draft.
    # As a placeholder, let's update the body based on instructions.
    instructions = state['environment_state'].user_instructions
    current_draft = state['drafted_email'] or DraftedEmail()

    # Simulate agent deciding to draft based on instructions
    simulated_body_draft = f"Draft based on instructions: {', '.join(instructions.key_points)}."
    if instructions.tone:
        simulated_body_draft = f"Using a {instructions.tone} tone. " + simulated_body_draft

    updated_draft = DraftedEmail(
        to_addresses=current_draft.to_addresses or ([instructions.recipient] if isinstance(instructions.recipient, str) else instructions.recipient) or [],
        cc_addresses=current_draft.cc_addresses,
        bcc_addresses=current_draft.bcc_addresses,
        subject=current_draft.subject or instructions.subject or (f"Re: {state['environment_state'].incoming_email.subject}" if state['environment_state'].incoming_email else "New Email"),
        body=simulated_body_draft # This is where LLM output would go
    )

    print("Drafting complete (simulated).")
    return {"drafted_email": updated_draft, "messages": [("tool_code", f"Drafted email subject: {updated_draft.subject[:50]}..., body: {updated_draft.body[:50]}...")]}


def call_send_email(state: AgentState) -> AgentState:
    """Calls the send_email_action and updates the state."""
    print("Calling send_email node...")
    drafted_email = state['drafted_email']
    if not drafted_email or not drafted_email.body:
         print("Cannot send empty email.")
         # Handle this case - perhaps transition to an error state or redraft
         return {"messages": [("tool_code", "Attempted to send an empty email.")]}

    sent_email = send_email_action(drafted_email)
    print("Email sent.")
    return {"final_email_sent": True, "messages": [("tool_code", f"Email sent: {sent_email.subject[:50]}...")]}

def call_search_information(state: AgentState) -> AgentState:
    """Calls the search_information_action and updates the state."""
    print("Calling search_information node...")
    # In a real agent, the LLM would decide the query based on state.
    # For this structure definition, we'll simulate a query.
    # A real agent would use a tool/LLM call here to get the query.
    simulated_query = f"information about {state['environment_state'].user_instructions.key_points[0] if state['environment_state'].user_instructions.key_points else 'general topic'}"
    print(f"Simulated search query: {simulated_query}")
    search_results = search_information_action(simulated_query)
    print("Search complete (simulated).")
    return {"search_results": search_results, "messages": [("tool_code", f"Search results: {search_results[:50]}...")]}


# Add nodes to the graph
workflow.add_node("draft_email", call_draft_email)
workflow.add_node("send_email", call_send_email)
workflow.add_node("search_information", call_search_information)

# Define the entry point
workflow.set_entry_point("draft_email") # Start by drafting the email

# Define edges and conditional transitions

# After drafting, the agent might need to search or decide to send
# This transition logic would be more complex in a real agent,
# potentially involving an LLM call to decide the next step.
# For this structure, let's add a simple condition: if search is needed
# based on instructions (simulated), go to search, otherwise go to send.
def should_search(state: AgentState) -> str:
    """Determines if the agent needs to search for information."""
    # Simulate the decision based on user instructions
    instructions = state['environment_state'].user_instructions
    # If user instructions mention needing info or a specific complex topic,
    # the agent might decide to search.
    # This is a placeholder condition.
    if "search_needed" in instructions.other_criteria:
        print("Decision: Search needed.")
        return "search"
    else:
        print("Decision: No search needed, proceed to send or redraft.")
        # In a real agent, this might go to a review/revise node or directly to send
        # For this simplified graph, we'll go directly to send if no search.
        return "send"


workflow.add_conditional_edges(
    "draft_email",
    should_search, # Use the function to decide the next step
    {"search": "search_information", "send": "send_email"}
)

# After searching, the agent should go back to drafting to incorporate results
workflow.add_edge("search_information", "draft_email")

# After sending the email, the process is finished
workflow.add_edge("send_email", END)


# 4. Ensure the agent can access EmailEnvironmentState
# The EmailEnvironmentState is part of the AgentState and is passed between nodes.

# 5. Define the entry point and exit condition
# Entry point is set to "draft_email"
# Exit condition is reaching the END node after "send_email"

# Compile the graph
app = workflow.compile()

print("LangGraph agent workflow defined.")

ModuleNotFoundError: No module named 'langgraph'

**Reasoning**:
The previous command failed because the `langgraph` library is not installed. Install the required package.



In [6]:
%pip install langgraph

Collecting langgraph
  Downloading langgraph-0.6.6-py3-none-any.whl.metadata (6.8 kB)
Collecting langgraph-checkpoint<3.0.0,>=2.1.0 (from langgraph)
  Downloading langgraph_checkpoint-2.1.1-py3-none-any.whl.metadata (4.2 kB)
Collecting langgraph-prebuilt<0.7.0,>=0.6.0 (from langgraph)
  Downloading langgraph_prebuilt-0.6.4-py3-none-any.whl.metadata (4.5 kB)
Collecting langgraph-sdk<0.3.0,>=0.2.2 (from langgraph)
  Downloading langgraph_sdk-0.2.3-py3-none-any.whl.metadata (1.5 kB)
Collecting ormsgpack>=1.10.0 (from langgraph-checkpoint<3.0.0,>=2.1.0->langgraph)
  Downloading ormsgpack-1.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading langgraph-0.6.6-py3-none-any.whl (153 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153.3/153.3 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading langgraph_chec

**Reasoning**:
Now that `langgraph` is installed, attempt to define the LangGraph agent architecture again using the previously defined actions and state.



In [7]:
import operator
from typing import Annotated, Sequence, TypedDict, Optional

from langchain_core.messages import BaseMessage
from langgraph.graph import StateGraph, END

# Assuming SimulatedEmail, UserInstructions, EmailEnvironmentState,
# DraftedEmail, draft_email_action, send_email_action, and search_information_action
# are defined in previous cells.

# 1. Define the state graph
class AgentState(TypedDict):
    """
    Represents the state of the LangGraph agent.

    Attributes:
        environment_state: The current state of the email environment.
        drafted_email: The email being drafted by the agent.
        messages: A list of messages representing the conversation history (for agent reasoning).
        search_results: Results from information search (if performed).
        final_email_sent: Flag indicating if the email has been sent.
    """
    environment_state: EmailEnvironmentState
    drafted_email: DraftedEmail
    messages: Annotated[Sequence[BaseMessage], operator.add]
    search_results: Optional[str]
    final_email_sent: bool


# 2. Implement the LangGraph workflow
workflow = StateGraph(AgentState)

# Define nodes for the graph
def call_draft_email(state: AgentState) -> AgentState:
    """Calls the draft_email_action and updates the state."""
    print("Calling draft_email node...")
    # In a real agent, the LLM would generate draft_components based on state.
    # For this structure definition, we'll simulate a draft update.
    # A real agent would use a tool/LLM call here to get DraftedEmail components.
    # For now, let's assume the agent somehow produces a draft.
    # This part needs the LLM integration to decide what to draft.
    # As a placeholder, let's update the body based on instructions.
    instructions = state['environment_state'].user_instructions
    current_draft = state.get('drafted_email') or DraftedEmail(to_addresses=[], cc_addresses=[], bcc_addresses=[], subject="", body="") # Initialize if None

    # Simulate agent deciding to draft based on instructions
    simulated_body_draft = f"Draft based on instructions: {', '.join(instructions.key_points)}."
    if instructions.tone:
        simulated_body_draft = f"Using a {instructions.tone} tone. " + simulated_body_draft

    # Ensure recipient list is always a list
    recipients = [instructions.recipient] if isinstance(instructions.recipient, str) else instructions.recipient if instructions.recipient is not None else []

    updated_draft = DraftedEmail(
        to_addresses=current_draft.to_addresses or recipients,
        cc_addresses=current_draft.cc_addresses,
        bcc_addresses=current_draft.bcc_addresses,
        subject=current_draft.subject or instructions.subject or (f"Re: {state['environment_state'].incoming_email.subject}" if state['environment_state'].incoming_email else "New Email"),
        body=simulated_body_draft # This is where LLM output would go
    )

    # Call the simulated action
    drafted_email_result = draft_email_action(state['environment_state'], updated_draft)


    print("Drafting complete (simulated).")
    return {"drafted_email": drafted_email_result, "messages": [("tool_code", f"Drafted email subject: {drafted_email_result.subject[:50]}..., body: {drafted_email_result.body[:50]}...")]}


def call_send_email(state: AgentState) -> AgentState:
    """Calls the send_email_action and updates the state."""
    print("Calling send_email node...")
    drafted_email = state.get('drafted_email')
    if not drafted_email or not drafted_email.body:
         print("Cannot send empty email.")
         # Handle this case - perhaps transition to an error state or redraft
         return {"messages": [("tool_code", "Attempted to send an empty email.")]}

    sent_email = send_email_action(drafted_email)
    print("Email sent.")
    return {"final_email_sent": True, "messages": [("tool_code", f"Email sent: {sent_email.subject[:50]}...")]}

def call_search_information(state: AgentState) -> AgentState:
    """Calls the search_information_action and updates the state."""
    print("Calling search_information node...")
    # In a real agent, the LLM would decide the query based on state.
    # For this structure definition, we'll simulate a query.
    # A real agent would use a tool/LLM call here to get the query.
    simulated_query = f"information about {state['environment_state'].user_instructions.key_points[0] if state['environment_state'].user_instructions.key_points else 'general topic'}"
    print(f"Simulated search query: {simulated_query}")
    search_results = search_information_action(simulated_query)
    print("Search complete (simulated).")
    return {"search_results": search_results, "messages": [("tool_code", f"Search results: {search_results[:50]}...")]}


# Add nodes to the graph
workflow.add_node("draft_email", call_draft_email)
workflow.add_node("send_email", call_send_email)
workflow.add_node("search_information", call_search_information)

# Define the entry point
workflow.set_entry_point("draft_email") # Start by drafting the email

# Define edges and conditional transitions

# After drafting, the agent might need to search or decide to send
# This transition logic would be more complex in a real agent,
# potentially involving an LLM call to decide the next step.
# For this structure, let's add a simple condition: if search is needed
# based on instructions (simulated), go to search, otherwise go to send.
def should_search(state: AgentState) -> str:
    """Determines if the agent needs to search for information."""
    # Simulate the decision based on user instructions
    instructions = state['environment_state'].user_instructions
    # If user instructions mention needing info or a specific complex topic,
    # the agent might decide to search.
    # This is a placeholder condition.
    if instructions.other_criteria and "search_needed" in instructions.other_criteria and instructions.other_criteria["search_needed"]:
        print("Decision: Search needed.")
        return "search"
    else:
        print("Decision: No search needed, proceed to send or redraft.")
        # In a real agent, this might go to a review/revise node or directly to send
        # For this simplified graph, we'll go directly to send if no search.
        return "send"


workflow.add_conditional_edges(
    "draft_email",
    should_search, # Use the function to decide the next step
    {"search": "search_information", "send": "send_email"}
)

# After searching, the agent should go back to drafting to incorporate results
workflow.add_edge("search_information", "draft_email")

# After sending the email, the process is finished
workflow.add_edge("send_email", END)


# 4. Ensure the agent can access EmailEnvironmentState
# The EmailEnvironmentState is part of the AgentState and is passed between nodes.

# 5. Define the entry point and exit condition
# Entry point is set to "draft_email"
# Exit condition is reaching the END node after "send_email"

# Compile the graph
app = workflow.compile()

print("LangGraph agent workflow defined.")

LangGraph agent workflow defined.


## Implement the training loop

### Subtask:
Implement the training loop. Use a framework like ART (or a similar RL training setup) to train the agent. The agent would generate emails in the simulated environment, receive a reward based on the reward function, and update its policy to generate better emails over time.


**Reasoning**:
Implement the training loop using the ART framework to train the LangGraph agent defined in the previous step. This involves iterating through training scenarios, running the agent workflow, evaluating the generated email with the reward function, and updating the model.



In [8]:
import art
import asyncio
from art.rewards import ruler_score_group # Assuming ruler_score_group can be adapted or a similar mechanism is used

# Reuse the training scenarios loaded previously
# training_scenarios = load_training_scenarios(split="train", limit=50, max_messages=1, shuffle=True, seed=42)

# Reuse the defined rollout function and model
# from art.langgraph import wrap_rollout
# model = art.TrainableModel(...) # Assumes model is already defined and registered
# backend = LocalBackend(...) # Assumes backend is already defined
# await model.register(backend) # Assumes model is registered
# rollout = ... # Assumes rollout function is defined

# Reuse the evaluate_email reward function
# async def evaluate_email(...)

# Reuse the LangGraph app (workflow)
# app = workflow.compile()

# 1. Set up the training configuration parameters
# Using parameters similar to the example notebook, adapted for our task
training_config = {
    "groups_per_step": 2,       # Number of scenario groups processed per training step
    "num_epochs": 5,            # Number of times to iterate through the dataset
    "rollouts_per_group": 3,    # Number of times to run the agent for each scenario in a group
    "learning_rate": 1e-5,
    "max_steps": 5,             # Stop after this many training steps for the demo
}

print("Training configuration set.")

# 2. Define a mechanism to iterate through the training data scenarios.
# Adapt the iterate_dataset function for our scenario structure
class EmailScenarioBatchItem(art.BatchItem):
    scenario: Dict[str, Any] # Store the scenario data including context and desired_output

async def email_scenario_iterator(
    scenarios: List[Dict[str, Any]],
    groups_per_step: int,
    num_epochs: int,
    initial_step: int = 0,
):
    """
    Iterates through email scenarios, yielding batches for training steps.

    Args:
        scenarios: A list of scenario dictionaries (context and desired_output).
        groups_per_step: Number of scenario groups per step.
        num_epochs: Number of epochs to iterate through the dataset.
        initial_step: The starting training step number.

    Yields:
        Batch objects containing EmailScenarioBatchItems.
    """
    scenario_count = len(scenarios)
    scenarios_per_step = groups_per_step # Each scenario is a group of size 1

    for epoch in range(num_epochs):
        # Shuffle scenarios at the start of each epoch
        random.shuffle(scenarios)
        print(f"\nStarting epoch {epoch + 1}/{num_epochs}")

        for i in range(0, scenario_count, scenarios_per_step):
            current_scenarios = scenarios[i : i + scenarios_per_step]
            current_step = initial_step + epoch * (scenario_count // scenarios_per_step) + (i // scenarios_per_step)

            batch_items = [
                EmailScenarioBatchItem(
                    scenario=scenario_data,
                    group_id=str(uuid.uuid4()), # Unique group ID for ART
                    trajectory_id_prefix=str(uuid.uuid4()), # Prefix for trajectory IDs
                ) for scenario_data in current_scenarios
            ]

            yield art.Batch(
                step=current_step,
                epoch=epoch + 1,
                epoch_step=i // scenarios_per_step,
                items=batch_items,
            )


print("Email scenario iterator defined.")

# Prepare the sample_training_dataset for the iterator
# Ensure the scenario data structure matches what the iterator expects
# (a list of dictionaries, each with 'context' and 'desired_output')
# sample_training_dataset is already in this format from the previous step.
training_data_for_iterator = sample_training_dataset # Use the previously created sample dataset

# 3. Within the training loop, for each scenario in the current batch:
# This is integrated into the gather_trajectory_groups and subsequent steps.

# 4. Use the collected training examples to update the agent's policy.
# 5. Optionally, include logging or monitoring.
# 6. Continue the loop for the defined number of training steps.

async def run_training_loop(model: art.TrainableModel, scenarios: List[Dict[str, Any]], training_config: Dict):
    """
    Runs the main RL training loop using ART.

    Args:
        model: The ART trainable model.
        scenarios: The list of training scenario dictionaries.
        training_config: Dictionary of training configuration parameters.
    """
    print("Starting training loop...")

    initial_step = await model.get_step()
    print(f"Initial model step: {initial_step}")

    # Use the custom email scenario iterator
    training_iterator = email_scenario_iterator(
        scenarios,
        groups_per_step=training_config["groups_per_step"],
        num_epochs=training_config["num_epochs"],
        initial_step=initial_step,
    )

    for batch in training_iterator:
        if batch.step > training_config["max_steps"]:
            print(f"Reached max steps ({training_config['max_steps']}). Stopping training.")
            break

        print(
            f"\n--- Training step {batch.step}, epoch {batch.epoch}, epoch step {batch.epoch_step} ---"
        )
        print(f"Batch contains {len(batch.items)} scenarios.")

        # Create trajectory groups for this batch
        groups = []
        for batch_item in batch.items:
            scenario_data = batch_item.scenario
            environment_state_data = scenario_data["context"]
            desired_output_email_data = scenario_data["desired_output"]

            # Reconstruct Pydantic objects from dictionaries
            environment_state = EmailEnvironmentState.model_validate(environment_state_data)
            desired_output_email = SimulatedEmail.model_validate(desired_output_email_data) if desired_output_email_data else None

            # Define the function to run for each rollout in the group
            # This function needs to:
            # 1. Initialize the LangGraph agent state with environment_state.
            # 2. Run the LangGraph app (workflow).
            # 3. Capture the final generated email from the agent state.
            # 4. Evaluate the generated email using evaluate_email.
            # 5. Create an ART Trajectory object with messages and the reward.

            async def run_agent_rollout(model: art.Model, env_state: EmailEnvironmentState, desired_output: Optional[SimulatedEmail]) -> art.Trajectory:
                """Runs one agent rollout for a specific scenario."""
                print("Running agent rollout...")
                traj = art.Trajectory(
                    reward=0.0, # Initial reward
                    messages_and_choices=[], # Store interaction history
                    metadata={
                        "scenario_id": str(uuid.uuid4()), # Unique ID for this rollout
                        "batch_step": batch.step,
                        "epoch": batch.epoch,
                        "epoch_step": batch.epoch_step,
                        "instructions_task_type": env_state.user_instructions.task_type,
                        # Add other relevant metadata from environment_state
                    },
                )

                try:
                    # Initialize LangGraph state
                    initial_langgraph_state = AgentState(
                        environment_state=env_state,
                        drafted_email=DraftedEmail(to_addresses=[], cc_addresses=[], bcc_addresses=[], subject="", body=""), # Start with empty draft
                        messages=[], # Start with no messages
                        search_results=None,
                        final_email_sent=False,
                    )

                    # Run the LangGraph app
                    # Note: Running LangGraph compiled app directly here.
                    # The nodes within the app will use the actions (draft, send, search).
                    # We need to capture the final state/output from running the app.
                    # LangGraph's `invoke` or `ainvoke` returns the final state.
                    final_langgraph_state = await app.ainvoke(initial_langgraph_state)

                    # Capture the final generated email from the state
                    final_generated_email = final_langgraph_state.get('drafted_email') # Assuming the last state before END has the final draft

                    # Check if email was actually sent based on workflow
                    email_was_sent = final_langgraph_state.get('final_email_sent', False)

                    if final_generated_email and email_was_sent:
                        print("Agent successfully generated and sent an email.")
                        # Evaluate the generated email to get the reward
                        reward = await evaluate_email(
                            final_generated_email,
                            env_state,
                            desired_output # Pass the desired output for comparison
                        )
                        traj.reward = reward
                        traj.metrics["email_sent"] = 1.0
                        traj.metadata["final_subject"] = final_generated_email.subject
                        traj.metadata["final_body_snippet"] = final_generated_email.body[:100] + "..." if len(final_generated_email.body) > 100 else final_generated_email.body

                        # You might want to capture the messages/tool calls from the LangGraph run
                        # This requires integrating LangGraph's tracing or manual logging within nodes
                        # For simplicity here, we'll just note that the email was sent.
                        traj.messages_and_choices.append({"role": "user", "content": "Simulated Email Task Started"})
                        traj.messages_and_choices.append({"role": "assistant", "content": f"Simulated Agent Run Completed. Email Subject: {final_generated_email.subject}"})


                    else:
                        print("Agent did not successfully send an email.")
                        traj.reward = 0.0 # Zero reward if email wasn't sent or drafted properly
                        traj.metrics["email_sent"] = 0.0
                        traj.messages_and_choices.append({"role": "user", "content": "Simulated Email Task Started"})
                        traj.messages_and_choices.append({"role": "assistant", "content": "Simulated Agent Run Failed to Send Email."})


                except Exception as e:
                    print(f"Error during agent rollout: {e}")
                    traj.reward = 0.0 # Penalize errors
                    traj.metrics["email_sent"] = 0.0
                    traj.messages_and_choices.append({"role": "user", "content": "Simulated Email Task Started"})
                    traj.messages_and_choices.append({"role": "assistant", "content": f"Error during rollout: {str(e)}"})


                return traj

            # Create a group with multiple rollouts for this scenario
            groups.append(
                art.TrajectoryGroup(
                    (
                        run_agent_rollout(model, environment_state, desired_output)
                        for _ in range(training_config["rollouts_per_group"])
                    ),
                    group_id=batch_item.group_id,
                    trajectory_id_prefix=batch_item.trajectory_id_prefix,
                )
            )

        print(f"Gathering {len(groups)} trajectory groups with {training_config['rollouts_per_group']} rollouts each...")
        # Gather all trajectory groups in parallel
        # Using asyncio.gather to run all group futures
        finished_groups = await asyncio.gather(*(group.as_future() for group in groups))
        print("Finished gathering trajectory groups.")

        # RULER Scoring (or use the direct reward from evaluate_email)
        # If evaluate_email already provides a score from 0-1, you might directly use that
        # instead of relative ranking with ruler_score_group.
        # However, RULER can still be used for comparing multiple trajectories for the *same* scenario
        # which is what rollouts_per_group provides.

        judged_groups = []
        print("Scoring trajectory groups with RULER...")
        for group in finished_groups:
            # Use ruler_score_group if you want relative ranking within the group.
            # This requires the LLM judge to compare trajectories directly,
            # which is different from our evaluate_email that scores one email at a time
            # against criteria/desired output.
            # Option A: If evaluate_email gives a good absolute score (0-1),
            # use that directly and skip ruler_score_group.
            # Option B: Adapt ruler_score_group to compare the *final emails*
            # generated in each trajectory for the same scenario. This is more complex.

            # Let's assume for simplicity that evaluate_email provides a sufficient
            # score, and we'll use that directly. ART's train function can handle
            # trajectories with pre-assigned rewards.
            # If using RULER for relative ranking, you'd call it here:
            # judged_group = await ruler_score_group(group, "openai/o4-mini", debug=False)
            # judged_groups.append(judged_group)

            # Using Option A: Directly use the reward from evaluate_email
            judged_groups.append(group) # Pass the group as is, assuming rewards are set

        print("Finished scoring trajectory groups.")

        # Log metrics from the trajectories (optional)
        all_trajectories_in_batch = [t for group in judged_groups for t in group.trajectories]
        avg_reward = sum(t.reward for t in all_trajectories_in_batch) / max(len(all_trajectories_in_batch), 1)
        avg_email_sent = sum(t.metrics.get("email_sent", 0.0) for t in all_trajectories_in_batch) / max(len(all_trajectories_in_batch), 1)
        print(f"Batch average reward: {avg_reward:.4f}")
        print(f"Batch email sent rate: {avg_email_sent:.2f}")

        # 4. Use the collected training examples to update the agent's policy.
        # The ART train function takes the scored trajectory groups.
        print("Starting model training step...")
        await model.train(
            judged_groups,
            config=art.TrainConfig(learning_rate=training_config["learning_rate"]),
             # Add any necessary internal configs for model training
             _config={"logprob_calculation_chunk_size": 8} # Example from previous notebook
        )
        print(f"Completed training step {batch.step}.")

    print("Training loop finished.")

# To run the training loop, you would await this function call:
# await run_training_loop(model, training_data_for_iterator, training_config)

print("Training loop function defined.")

# Note: Running the full training loop here would take significant time and resources.
# The code block defines the function `run_training_loop` which contains the logic.
# To execute it, you would need to call `await run_training_loop(...)` in a subsequent cell
# after ensuring the `model`, `training_data_for_iterator`, and `training_config` are available.
# For the purpose of this subtask (implementing the loop), defining the function is sufficient.


ModuleNotFoundError: No module named 'art'

**Reasoning**:
The previous command failed because the `art` library was not found. Although ART was installed in a prior cell, the kernel state seems to have reset or the import was not successful in the new execution context. Re-importing the necessary libraries, including `art`, and then executing the training loop logic is required to proceed.



In [9]:
import art
import asyncio
import random
import uuid
from typing import Dict, Any, List, Optional

# Assuming SimulatedEmail, UserInstructions, EmailEnvironmentState,
# DraftedEmail, draft_email_action, send_email_action, and search_information_action
# are defined in previous cells.

# Assuming sample_training_dataset is defined in a previous cell

# Assuming evaluate_email function is defined in a previous cell

# Assuming the LangGraph app (workflow) is compiled and available as `app`

# Assuming model and backend are defined and the model is registered

# Reuse the training scenarios loaded previously
# training_scenarios = load_training_scenarios(split="train", limit=50, max_messages=1, shuffle=True, seed=42)
# NOTE: Using sample_training_dataset which has the structure {context: {...}, desired_output: {...}}

# 1. Set up the training configuration parameters
# Using parameters similar to the example notebook, adapted for our task
training_config = {
    "groups_per_step": 2,       # Number of scenario groups processed per training step
    "num_epochs": 5,            # Number of times to iterate through the dataset
    "rollouts_per_group": 3,    # Number of times to run the agent for each scenario in a group
    "learning_rate": 1e-5,
    "max_steps": 5,             # Stop after this many training steps for the demo
}

print("Training configuration set.")

# 2. Define a mechanism to iterate through the training data scenarios.
class EmailScenarioBatchItem(art.BatchItem):
    scenario: Dict[str, Any] # Store the scenario data including context and desired_output

async def email_scenario_iterator(
    scenarios: List[Dict[str, Any]],
    groups_per_step: int,
    num_epochs: int,
    initial_step: int = 0,
):
    """
    Iterates through email scenarios, yielding batches for training steps.

    Args:
        scenarios: A list of scenario dictionaries (context and desired_output).
        groups_per_step: Number of scenario groups per step.
        num_epochs: Number of epochs to iterate through the dataset.
        initial_step: The starting training step number.

    Yields:
        Batch objects containing EmailScenarioBatchItems.
    """
    scenario_count = len(scenarios)
    scenarios_per_step = groups_per_step # Each scenario is a group of size 1

    for epoch in range(num_epochs):
        # Shuffle scenarios at the start of each epoch
        random.shuffle(scenarios)
        print(f"\nStarting epoch {epoch + 1}/{num_epochs}")

        for i in range(0, scenario_count, scenarios_per_step):
            current_scenarios = scenarios[i : i + scenarios_per_step]
            current_step = initial_step + epoch * (scenario_count // scenarios_per_step) + (i // scenarios_per_step)

            batch_items = [
                EmailScenarioBatchItem(
                    scenario=scenario_data,
                    group_id=str(uuid.uuid4()), # Unique group ID for ART
                    trajectory_id_prefix=str(uuid.uuid4()), # Prefix for trajectory IDs
                ) for scenario_data in current_scenarios
            ]

            yield art.Batch(
                step=current_step,
                epoch=epoch + 1,
                epoch_step=i // scenarios_per_step,
                items=batch_items,
            )


print("Email scenario iterator defined.")

# Prepare the sample_training_dataset for the iterator
# Ensure the scenario data structure matches what the iterator expects
# (a list of dictionaries, each with 'context' and 'desired_output')
# sample_training_dataset is already in this format from the previous step.
training_data_for_iterator = sample_training_dataset # Use the previously created sample dataset

# 3. Within the training loop, for each scenario in the current batch:
# This is integrated into the gather_trajectory_groups and subsequent steps.

# 4. Use the collected training examples to update the agent's policy.
# 5. Optionally, include logging or monitoring.
# 6. Continue the loop for the defined number of training steps.

async def run_agent_rollout_with_langgraph(model: art.Model, env_state: EmailEnvironmentState, desired_output: Optional[SimulatedEmail], langgraph_app, batch_info: Dict) -> art.Trajectory:
    """Runs one agent rollout for a specific scenario using the LangGraph app."""
    print("Running agent rollout with LangGraph...")
    traj = art.Trajectory(
        reward=0.0, # Initial reward
        messages_and_choices=[], # Store interaction history
        metadata={
            "scenario_id": str(uuid.uuid4()), # Unique ID for this rollout
            "batch_step": batch_info.get("step"),
            "epoch": batch_info.get("epoch"),
            "epoch_step": batch_info.get("epoch_step"),
            "instructions_task_type": env_state.user_instructions.task_type,
            # Add other relevant metadata from environment_state
        },
    )

    try:
        # Initialize LangGraph state
        initial_langgraph_state = AgentState(
            environment_state=env_state,
            drafted_email=DraftedEmail(to_addresses=[], cc_addresses=[], bcc_addresses=[], subject="", body=""), # Start with empty draft
            messages=[], # Start with no messages
            search_results=None,
            final_email_sent=False,
        )

        # Run the LangGraph app
        # LangGraph's `invoke` or `ainvoke` returns the final state.
        # We need to capture the messages/tool calls from the LangGraph run
        # This might require integrating LangGraph's tracing or manual logging within nodes
        # For simplicity here, we'll just note the start and end of the run and capture final state.
        traj.messages_and_choices.append({"role": "system", "content": "LangGraph Agent Run Started"})

        # Note: The LangGraph app `app` needs to be available in this scope.
        # The nodes within `app` use the actions (draft_email_action, send_email_action, search_information_action).
        # These action functions need to access the environment_state from the LangGraph state.
        # The current action functions are defined globally and take env_state as an argument.
        # The LangGraph nodes defined previously (`call_draft_email`, etc.) already handle passing the env_state.

        final_langgraph_state = await langgraph_app.ainvoke(initial_langgraph_state)

        # Capture interaction history from the messages list in the final state
        # Assuming the LangGraph state's 'messages' key contains the history
        if 'messages' in final_langgraph_state and isinstance(final_langgraph_state['messages'], list):
             traj.messages_and_choices.extend([{"role": msg[0] if isinstance(msg, tuple) else "unknown", "content": str(msg[1] if isinstance(msg, tuple) else msg)} for msg in final_langgraph_state['messages']])


        # Capture the final generated email from the state
        final_generated_email_dict = final_langgraph_state.get('drafted_email')
        final_generated_email = DraftedEmail.model_validate(final_generated_email_dict) if final_generated_email_dict else None

        # Check if email was actually sent based on workflow
        email_was_sent = final_langgraph_state.get('final_email_sent', False)

        if final_generated_email and email_was_sent:
            print("Agent successfully generated and sent an email.")
            # Evaluate the generated email to get the reward
            reward = await evaluate_email(
                SimulatedEmail( # Convert DraftedEmail to SimulatedEmail for evaluation
                    message_id="generated", # Placeholder ID
                    date=env_state.current_date, # Use simulation date
                    subject=final_generated_email.subject,
                    from_address="simulated_user@example.com", # Assume user's address
                    to_addresses=final_generated_email.to_addresses,
                    cc_addresses=final_generated_email.cc_addresses,
                    bcc_addresses=final_generated_email.bcc_addresses,
                    body=final_generated_email.body,
                    attachments=[]
                ),
                env_state,
                desired_output # Pass the desired output for comparison
            )
            traj.reward = reward
            traj.metrics["email_sent"] = 1.0
            traj.metadata["final_subject"] = final_generated_email.subject
            traj.metadata["final_body_snippet"] = final_generated_email.body[:100] + "..." if len(final_generated_email.body) > 100 else final_generated_email.body

        else:
            print("Agent did not successfully send an email.")
            traj.reward = 0.0 # Zero reward if email wasn't sent or drafted properly
            traj.metrics["email_sent"] = 0.0
            traj.messages_and_choices.append({"role": "assistant", "content": "Simulated Agent Run Failed to Send Email."})


    except Exception as e:
        print(f"Error during agent rollout: {e}")
        traj.reward = 0.0 # Penalize errors
        traj.metrics["email_sent"] = 0.0
        traj.messages_and_choices.append({"role": "assistant", "content": f"Error during rollout: {str(e)}"})


    return traj


async def run_training_loop(model: art.TrainableModel, scenarios: List[Dict[str, Any]], training_config: Dict, langgraph_app):
    """
    Runs the main RL training loop using ART and a LangGraph agent.

    Args:
        model: The ART trainable model.
        scenarios: The list of training scenario dictionaries.
        training_config: Dictionary of training configuration parameters.
        langgraph_app: The compiled LangGraph workflow.
    """
    print("Starting training loop...")

    initial_step = await model.get_step()
    print(f"Initial model step: {initial_step}")

    # Use the custom email scenario iterator
    training_iterator = email_scenario_iterator(
        scenarios,
        groups_per_step=training_config["groups_per_step"],
        num_epochs=training_config["num_epochs"],
        initial_step=initial_step,
    )

    async for batch in training_iterator:
        if batch.step > training_config["max_steps"]:
            print(f"Reached max steps ({training_config['max_steps']}). Stopping training.")
            break

        print(
            f"\n--- Training step {batch.step}, epoch {batch.epoch}, epoch step {batch.epoch_step} ---"
        )
        print(f"Batch contains {len(batch.items)} scenarios.")

        # Create trajectory groups for this batch
        groups = []
        for batch_item in batch.items:
            scenario_data = batch_item.scenario
            environment_state_data = scenario_data["context"]
            desired_output_email_data = scenario_data["desired_output"]

            # Reconstruct Pydantic objects from dictionaries
            environment_state = EmailEnvironmentState.model_validate(environment_state_data)
            desired_output_email = SimulatedEmail.model_validate(desired_output_email_data) if desired_output_email_data else None

            # Create a group with multiple rollouts for this scenario
            # Each rollout runs the LangGraph agent
            groups.append(
                art.TrajectoryGroup(
                    (
                        run_agent_rollout_with_langgraph(
                            model, # Pass the ART model if needed within the rollout (e.g., for inference)
                            environment_state,
                            desired_output,
                            langgraph_app, # Pass the compiled LangGraph app
                            {"step": batch.step, "epoch": batch.epoch, "epoch_step": batch.epoch_step} # Pass batch info for metadata
                        ) for _ in range(training_config["rollouts_per_group"])
                    ),
                    group_id=batch_item.group_id,
                    trajectory_id_prefix=batch_item.trajectory_id_prefix,
                )
            )

        print(f"Gathering {len(groups)} trajectory groups with {training_config['rollouts_per_group']} rollouts each...")
        # Gather all trajectory groups in parallel
        # Using asyncio.gather to run all group futures
        finished_groups = await asyncio.gather(*(group.as_future() for group in groups))
        print("Finished gathering trajectory groups.")

        # Score trajectories
        judged_groups = []
        print("Processing gathered trajectory groups...")
        for group in finished_groups:
            # Since evaluate_email provides a score per rollout, we can use that directly.
            # If relative ranking within the group is desired, ruler_score_group would be called here.
            # For this task, we'll assume evaluate_email's score is sufficient.
            judged_groups.append(group) # Pass the group as is, assuming rewards are set by run_agent_rollout_with_langgraph

        print("Finished processing trajectory groups.")

        # Log metrics from the trajectories (optional)
        all_trajectories_in_batch = [t for group in judged_groups for t in group.trajectories]
        avg_reward = sum(t.reward for t in all_trajectories_in_batch) / max(len(all_trajectories_in_batch), 1)
        avg_email_sent = sum(t.metrics.get("email_sent", 0.0) for t in all_trajectories_in_batch) / max(len(all_trajectories_in_batch), 1)
        print(f"Batch average reward: {avg_reward:.4f}")
        print(f"Batch email sent rate: {avg_email_sent:.2f}")

        # 4. Use the collected training examples to update the agent's policy.
        # The ART train function takes the scored trajectory groups.
        print("Starting model training step...")
        await model.train(
            judged_groups,
            config=art.TrainConfig(learning_rate=training_config["learning_rate"]),
             # Add any necessary internal configs for model training
             _config={"logprob_calculation_chunk_size": 8} # Example from previous notebook
        )
        print(f"Completed training step {batch.step}.")

        # Optionally, save model checkpoints periodically
        # await model.save_checkpoint() # Uncomment to save checkpoints

    print("Training loop finished.")

# To run the training loop, you would await this function call:
# await run_training_loop(model, training_data_for_iterator, training_config, app)

print("Full training loop function `run_training_loop` defined.")

ModuleNotFoundError: No module named 'art'

**Reasoning**:
The `art` module is still not found, despite attempting to re-import it. This suggests an environment issue where the installed packages are not accessible in the current kernel session. A potential solution is to explicitly run the installation command again in a separate cell to ensure the package is available before attempting the import and defining the training loop.



In [10]:
%pip install "openpipe-art[backend,langgraph]==0.4.9" langchain-core langgraph langchain_openai tenacity datasets "gql<4" --prerelease allow --no-cache-dir


Usage:   
  pip3 install [options] <requirement specifier> [package-index-options] ...
  pip3 install [options] -r <requirements file> [package-index-options] ...
  pip3 install [options] [-e] <vcs project url> ...
  pip3 install [options] [-e] <local project path> ...
  pip3 install [options] <archive url/path> ...

no such option: --prerelease


**Reasoning**:
The previous command failed because the `--prerelease` option is not recognized by the `pip install` command. This option was likely intended for `uv pip install` as seen in the initial setup. The standard `pip install` command should be used without this option to install the packages.



In [None]:
%pip install "openpipe-art[backend,langgraph]==0.4.9" langchain-core langgraph langchain_openai tenacity datasets "gql<4" --no-cache-dir

Collecting openpipe-art==0.4.9 (from openpipe-art[backend,langgraph]==0.4.9)
  Downloading openpipe_art-0.4.9-py3-none-any.whl.metadata (14 kB)
Collecting langchain_openai
  Downloading langchain_openai-0.3.31-py3-none-any.whl.metadata (2.4 kB)
Collecting gql<4
  Downloading gql-3.5.3-py2.py3-none-any.whl.metadata (9.4 kB)
Collecting litellm==1.74.1 (from openpipe-art==0.4.9->openpipe-art[backend,langgraph]==0.4.9)
  Downloading litellm-1.74.1-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.3/40.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting openai<=1.99.1,>=1.65.5 (from openpipe-art==0.4.9->openpipe-art[backend,langgraph]==0.4.9)
  Downloading openai-1.99.1-py3-none-any.whl.metadata (29 kB)
Collecting weave>=0.51.51 (from openpipe-art==0.4.9->openpipe-art[backend,langgraph]==0.4.9)
  Downloading weave-0.52.4-py3-none-any.whl.metadata (27 kB)
Collecting accelerate==1.7.0 (from openpipe-art[backend,langgraph]==0.4.9)