Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,6 @@ async def train( # type: ignore[override]
"""
groups_list = list(trajectory_groups)

# Record provenance in W&B
wandb_run = model._get_wandb_run()
if wandb_run is not None:
record_provenance(wandb_run, "local-rl")

# Build config objects from explicit kwargs
config = TrainConfig(learning_rate=learning_rate, beta=beta)
dev_config: dev.TrainConfig = {
Expand Down Expand Up @@ -521,6 +516,11 @@ async def train( # type: ignore[override]
if not os.path.exists(checkpoint_path):
checkpoint_path = None

# Record provenance on the latest W&B artifact
wandb_run = model._get_wandb_run()
if wandb_run is not None:
record_provenance(wandb_run, "local-rl")

return LocalTrainResult(
step=step,
metrics=avg_metrics,
Expand Down
24 changes: 19 additions & 5 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,6 @@ async def train( # type: ignore[override]
"""
groups_list = list(trajectory_groups)

# Record provenance in W&B
wandb_run = model._get_wandb_run()
if wandb_run is not None:
record_provenance(wandb_run, "serverless-rl")

# Build config objects from explicit kwargs
config = TrainConfig(learning_rate=learning_rate, beta=beta)
dev_config: dev.TrainConfig = {
Expand Down Expand Up @@ -260,6 +255,11 @@ async def train( # type: ignore[override]
if model.entity is not None:
artifact_name = f"{model.entity}/{model.project}/{model.name}:step{step}"

# Record provenance on the latest W&B artifact
wandb_run = model._get_wandb_run()
if wandb_run is not None:
record_provenance(wandb_run, "serverless-rl")

return ServerlessTrainResult(
step=step,
metrics=avg_metrics,
Expand Down Expand Up @@ -645,6 +645,20 @@ async def _experimental_fork_checkpoint(
run.log_artifact(dest_artifact, aliases=aliases)
run.finish()

# Copy provenance from the source model's W&B run to the destination model
api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute]
try:
source_run = api.run(f"{model.entity}/{from_project}/{from_model}")
source_provenance = source_run.config.get("wandb.provenance")
if source_provenance is not None:
dest_run = model._get_wandb_run()
if dest_run is not None:
dest_run.config.update(
{"wandb.provenance": list(source_provenance)}
)
except Exception:
pass # Source run may not exist (e.g., S3-only models)

if verbose:
print(
f"Successfully forked checkpoint from {from_model} "
Expand Down
1 change: 1 addition & 0 deletions src/art/utils/deployment/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def deploy_model(
model=model,
checkpoint_path=checkpoint_path,
step=step,
config=config,
verbose=verbose,
)
return DeploymentResult(inference_model_name=inference_name)
Expand Down
10 changes: 8 additions & 2 deletions src/art/utils/deployment/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class WandbDeploymentConfig(DeploymentConfig):
- Qwen/Qwen2.5-14B-Instruct
"""

pass
provenance: list[str]
"""The training provenance history for this model (e.g. ["local-rl", "serverless-rl"])."""


WANDB_SUPPORTED_BASE_MODELS = [
Expand All @@ -36,6 +37,7 @@ def deploy_wandb(
model: "TrainableModel",
checkpoint_path: str,
step: int,
config: "WandbDeploymentConfig | None" = None,
verbose: bool = False,
) -> str:
"""Deploy a model to W&B by uploading a LoRA artifact.
Expand All @@ -44,6 +46,7 @@ def deploy_wandb(
model: The TrainableModel to deploy.
checkpoint_path: Local path to the checkpoint directory.
step: The step number of the checkpoint.
config: Optional WandbDeploymentConfig with provenance metadata.
verbose: Whether to print verbose output.

Returns:
Expand Down Expand Up @@ -74,10 +77,13 @@ def deploy_wandb(
settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]),
)
try:
metadata: dict[str, object] = {"wandb.base_model": model.base_model}
if config is not None:
metadata["wandb.provenance"] = config.provenance
artifact = wandb.Artifact(
model.name,
type="lora",
metadata={"wandb.base_model": model.base_model},
metadata=metadata,
storage_region="coreweave-us",
)
artifact.add_dir(checkpoint_path)
Expand Down
21 changes: 16 additions & 5 deletions src/art/utils/record_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,22 @@


def record_provenance(run: wandb.Run, provenance: str) -> None:
"""Record provenance in run metadata, ensuring it's the last value in the array."""
if "provenance" in run.config:
existing = list(run.config["provenance"])
"""Record provenance on the latest artifact version's metadata."""
import wandb as wandb_module

api = wandb_module.Api()
artifact_path = f"{run.entity}/{run.project}/{run.name}:latest"
try:
artifact = api.artifact(artifact_path, type="lora")
except wandb_module.errors.CommError:
return # No artifact exists yet

existing = artifact.metadata.get("wandb.provenance")
if existing is not None:
existing = list(existing)
if existing[-1] != provenance:
existing.append(provenance)
run.config.update({"provenance": existing})
artifact.metadata["wandb.provenance"] = existing
else:
run.config.update({"provenance": [provenance]})
artifact.metadata["wandb.provenance"] = [provenance]
artifact.save()
48 changes: 32 additions & 16 deletions tests/integration/test_provenance.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Integration test: verify provenance tracking in W&B run config via ServerlessBackend."""
"""Integration test: verify provenance tracking on W&B artifact metadata via ServerlessBackend."""

import asyncio
from datetime import datetime

from dotenv import load_dotenv
import wandb

import art
from art.serverless.backend import ServerlessBackend
Expand Down Expand Up @@ -36,8 +37,13 @@ async def simple_rollout(model: art.TrainableModel) -> art.Trajectory:
return traj


async def make_group(model: art.TrainableModel) -> art.TrajectoryGroup:
return art.TrajectoryGroup(simple_rollout(model) for _ in range(4))
def get_latest_artifact_provenance(
entity: str, project: str, name: str
) -> list[str] | None:
"""Fetch provenance from the latest W&B artifact's metadata."""
api = wandb.Api()
artifact = api.artifact(f"{entity}/{project}/{name}:latest", type="lora")
return artifact.metadata.get("wandb.provenance")


async def main() -> None:
Expand All @@ -49,25 +55,35 @@ async def main() -> None:
base_model="OpenPipe/Qwen3-14B-Instruct",
)
await model.register(backend)
assert model.entity is not None

# --- Step 1: first training call ---
groups = await art.gather_trajectory_groups(make_group(model) for _ in range(1))
result = await backend.train(model, groups)
await model.log(groups, metrics=result.metrics, step=result.step, split="train")

# Check provenance after first train call
run = model._get_wandb_run()
assert run is not None, "W&B run should exist"
provenance = run.config.get("provenance")
# --- Step 1: first training call (retry on transient server errors) ---
for attempt in range(3):
groups = await art.gather_trajectory_groups(
[art.TrajectoryGroup(simple_rollout(model) for _ in range(4))] # ty: ignore[invalid-argument-type]
)
try:
result = await backend.train(model, groups)
await model.log(
groups, metrics=result.metrics, step=result.step, split="train"
)
break
except RuntimeError as e:
print(f"Step 1 attempt {attempt + 1} failed: {e}")
if attempt == 2:
raise

# Check provenance on the latest artifact after first train call
provenance = get_latest_artifact_provenance(model.entity, model.project, model.name)
print(f"After step 1: provenance = {provenance}")
assert provenance == ["serverless-rl"], (
f"Expected ['serverless-rl'], got {provenance}"
)

# --- Step 2: second training call (same technique, should NOT duplicate) ---
# Provenance is recorded at the start of train(), before the remote call,
# so we can verify deduplication even if the server-side training fails.
groups2 = await art.gather_trajectory_groups(make_group(model) for _ in range(1))
groups2 = await art.gather_trajectory_groups(
[art.TrajectoryGroup(simple_rollout(model) for _ in range(4))] # ty: ignore[invalid-argument-type]
)
try:
result2 = await backend.train(model, groups2)
await model.log(
Expand All @@ -76,7 +92,7 @@ async def main() -> None:
except RuntimeError as e:
print(f"Step 2 training failed (transient server error, OK for this test): {e}")

provenance = run.config.get("provenance")
provenance = get_latest_artifact_provenance(model.entity, model.project, model.name)
print(f"After step 2: provenance = {provenance}")
assert provenance == ["serverless-rl"], (
f"Expected ['serverless-rl'] (no duplicate), got {provenance}"
Expand Down