Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/prime/src/prime_cli/api/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class RLRun(BaseModel):
max_steps: int = Field(..., alias="maxSteps")
max_tokens: Optional[int] = Field(None, alias="maxTokens")
batch_size: int = Field(..., alias="batchSize")
loss: Optional[str] = "rl"
teacher_config: Optional[Dict[str, Any]] = Field(None, alias="teacherConfig")
base_model: str = Field(..., alias="baseModel")
environments: List[Dict[str, Any]] = Field(default_factory=list)
run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig")
Expand Down Expand Up @@ -205,6 +207,8 @@ def create_run(
enable_thinking: Optional[bool] = None,
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
run_config: Optional[Dict[str, Any]] = None,
loss: str = "rl",
teacher: Optional[Dict[str, Any]] = None,
) -> RLRun:
"""Create a new RL training run."""
try:
Expand All @@ -222,6 +226,12 @@ def create_run(
"secrets": secrets_list,
}

if loss != "rl":
payload["loss"] = loss

if teacher:
payload["teacher"] = teacher

if name:
payload["name"] = name

Expand Down
43 changes: 43 additions & 0 deletions packages/prime/src/prime_cli/commands/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def generate_rl_config_template(environment: str | None = None) -> str:

return f'''\
model = "Qwen/Qwen3.5-0.8B"
loss = "rl" # "rl" | "sft"; OPD is not yet supported on hosted runtimes
max_steps = 100

# env_files = ["secrets.env"] # optional file(s) for secrets
Expand All @@ -231,6 +232,15 @@ def generate_rl_config_template(environment: str | None = None) -> str:
# enable_thinking = false # supported models: Qwen3.5, Nemotron
# reasoning_effort = "high" # supported models: GPT-OSS ("low" | "medium" | "high")

# Optional: SFT distillation teacher
# loss = "sft"
# [teacher]
# model = "openai/gpt-oss-120b"
#
# [teacher.sampling]
# max_tokens = 2048
# reasoning_effort = "medium"

[[env]]
id = "{env_value}"

Expand Down Expand Up @@ -374,6 +384,19 @@ def _reasoning_controls_mutually_exclusive(self) -> "SamplingConfig":
return self


class TeacherConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

model: str
sampling: SamplingConfig | None = None

def to_api_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {"model": {"name": self.model}}
if self.sampling is not None:
result["sampling"] = self.sampling.model_dump(exclude_none=True)
return result


class EvalConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

Expand Down Expand Up @@ -578,6 +601,8 @@ class RLConfig(BaseModel):

name: str | None = None
model: str
loss: Literal["rl", "sft", "opd"] = "rl"
teacher: TeacherConfig | None = None
max_steps: int = 100
batch_size: int = 128
rollouts_per_example: int = 8
Expand All @@ -601,6 +626,19 @@ class RLConfig(BaseModel):
env_file: List[str] = Field(default_factory=list) # deprecated, use env_files
env_files: List[str] = Field(default_factory=list)

@model_validator(mode="after")
def _validate_loss_teacher(self) -> "RLConfig":
if self.loss == "rl" and self.teacher is not None:
raise ValueError("teacher can only be set when loss is 'sft' or 'opd'")
if self.loss == "sft" and self.teacher is None:
raise ValueError("teacher is required when loss is 'sft'")
if self.loss == "opd":
raise ValueError(
"loss='opd' is not supported for hosted runs yet; OPD requires "
"teacher logprob scoring support in the hosted runtime"
)
return self


def _format_validation_errors(errors: list[Any]) -> list[str]:
"""Format Pydantic validation errors into user-friendly messages."""
Expand Down Expand Up @@ -871,6 +909,9 @@ def _fetch_pricing() -> None:
# Model & Environment
console.print("[cyan]Model & Environment[/cyan]")
console.print(f" Model: {cfg.model}")
console.print(f" Loss: {cfg.loss}")
if cfg.teacher is not None:
console.print(f" Teacher: {cfg.teacher.model}")
console.print(f" Environments: {', '.join(e.id for e in cfg.env)}")
if app_config.team_id:
console.print(f" Team: {app_config.team_id}")
Expand Down Expand Up @@ -1087,6 +1128,8 @@ def _format(list_p: Any, eff_p: Any) -> str:
enable_thinking=cfg.sampling.enable_thinking,
reasoning_effort=cfg.sampling.reasoning_effort,
run_config=cfg.run_config if cfg.run_config else None,
loss=cfg.loss,
teacher=cfg.teacher.to_api_dict() if cfg.teacher else None,
)

if output == "json":
Expand Down
49 changes: 49 additions & 0 deletions packages/prime/tests/test_rl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,55 @@ def test_load_config_accepts_sampling_enable_thinking(tmp_path: Path) -> None:
assert cfg.sampling.reasoning_effort is None


def test_load_config_accepts_sft_teacher(tmp_path: Path) -> None:
config_path = tmp_path / "sft.toml"
config_path.write_text(
'model = "openai/gpt-oss-20b"\n'
'loss = "sft"\n'
"[teacher]\n"
'model = "openai/gpt-oss-120b"\n'
"[teacher.sampling]\n"
"max_tokens = 2048\n"
'reasoning_effort = "medium"\n'
)

cfg = load_config(str(config_path))

assert cfg.loss == "sft"
assert cfg.teacher is not None
assert cfg.teacher.model == "openai/gpt-oss-120b"
assert cfg.teacher.sampling is not None
assert cfg.teacher.sampling.max_tokens == 2048
assert cfg.teacher.to_api_dict() == {
"model": {"name": "openai/gpt-oss-120b"},
"sampling": {
"max_tokens": 2048,
"reasoning_effort": "medium",
},
}


def test_load_config_rejects_sft_without_teacher(tmp_path: Path) -> None:
config_path = tmp_path / "sft.toml"
config_path.write_text('model = "openai/gpt-oss-20b"\nloss = "sft"\n')

with pytest.raises(typer.Exit):
load_config(str(config_path))


def test_load_config_rejects_opd_until_hosted_scoring_exists(tmp_path: Path) -> None:
config_path = tmp_path / "opd.toml"
config_path.write_text(
'model = "openai/gpt-oss-20b"\n'
'loss = "opd"\n'
"[teacher]\n"
'model = "openai/gpt-oss-120b"\n'
)

with pytest.raises(typer.Exit):
load_config(str(config_path))


def test_load_config_rejects_both_reasoning_controls(tmp_path: Path) -> None:
config_path = tmp_path / "rl.toml"
config_path.write_text(
Expand Down