diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 640ee1f33..798acee83 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -460,11 +460,6 @@ async def train( # type: ignore[override] """ groups_list = list(trajectory_groups) - # Record provenance in W&B - wandb_run = model._get_wandb_run() - if wandb_run is not None: - record_provenance(wandb_run, "local-rl") - # Build config objects from explicit kwargs config = TrainConfig(learning_rate=learning_rate, beta=beta) dev_config: dev.TrainConfig = { @@ -521,6 +516,11 @@ async def train( # type: ignore[override] if not os.path.exists(checkpoint_path): checkpoint_path = None + # Record provenance on the latest W&B artifact + wandb_run = model._get_wandb_run() + if wandb_run is not None: + record_provenance(wandb_run, "local-rl") + return LocalTrainResult( step=step, metrics=avg_metrics, diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 82f76e5ad..cabd064c6 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -210,11 +210,6 @@ async def train( # type: ignore[override] """ groups_list = list(trajectory_groups) - # Record provenance in W&B - wandb_run = model._get_wandb_run() - if wandb_run is not None: - record_provenance(wandb_run, "serverless-rl") - # Build config objects from explicit kwargs config = TrainConfig(learning_rate=learning_rate, beta=beta) dev_config: dev.TrainConfig = { @@ -260,6 +255,11 @@ async def train( # type: ignore[override] if model.entity is not None: artifact_name = f"{model.entity}/{model.project}/{model.name}:step{step}" + # Record provenance on the latest W&B artifact + wandb_run = model._get_wandb_run() + if wandb_run is not None: + record_provenance(wandb_run, "serverless-rl") + return ServerlessTrainResult( step=step, metrics=avg_metrics, @@ -645,6 +645,20 @@ async def _experimental_fork_checkpoint( run.log_artifact(dest_artifact, aliases=aliases) run.finish() + # Copy provenance from the source model's W&B run to the destination model + api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute] + try: + source_run = api.run(f"{model.entity}/{from_project}/{from_model}") + source_provenance = source_run.config.get("wandb.provenance") + if source_provenance is not None: + dest_run = model._get_wandb_run() + if dest_run is not None: + dest_run.config.update( + {"wandb.provenance": list(source_provenance)} + ) + except Exception: + pass # Source run may not exist (e.g., S3-only models) + if verbose: print( f"Successfully forked checkpoint from {from_model} " diff --git a/src/art/utils/deployment/common.py b/src/art/utils/deployment/common.py index b1bf34f46..e104bcf11 100644 --- a/src/art/utils/deployment/common.py +++ b/src/art/utils/deployment/common.py @@ -90,6 +90,7 @@ async def deploy_model( model=model, checkpoint_path=checkpoint_path, step=step, + config=config, verbose=verbose, ) return DeploymentResult(inference_model_name=inference_name) diff --git a/src/art/utils/deployment/wandb.py b/src/art/utils/deployment/wandb.py index 08a15fb53..9ddf778e8 100644 --- a/src/art/utils/deployment/wandb.py +++ b/src/art/utils/deployment/wandb.py @@ -21,7 +21,8 @@ class WandbDeploymentConfig(DeploymentConfig): - Qwen/Qwen2.5-14B-Instruct """ - pass + provenance: list[str] + """The training provenance history for this model (e.g. ["local-rl", "serverless-rl"]).""" WANDB_SUPPORTED_BASE_MODELS = [ @@ -36,6 +37,7 @@ def deploy_wandb( model: "TrainableModel", checkpoint_path: str, step: int, + config: "WandbDeploymentConfig | None" = None, verbose: bool = False, ) -> str: """Deploy a model to W&B by uploading a LoRA artifact. @@ -44,6 +46,7 @@ def deploy_wandb( model: The TrainableModel to deploy. checkpoint_path: Local path to the checkpoint directory. step: The step number of the checkpoint. + config: Optional WandbDeploymentConfig with provenance metadata. verbose: Whether to print verbose output. Returns: @@ -74,10 +77,13 @@ def deploy_wandb( settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]), ) try: + metadata: dict[str, object] = {"wandb.base_model": model.base_model} + if config is not None: + metadata["wandb.provenance"] = config.provenance artifact = wandb.Artifact( model.name, type="lora", - metadata={"wandb.base_model": model.base_model}, + metadata=metadata, storage_region="coreweave-us", ) artifact.add_dir(checkpoint_path) diff --git a/src/art/utils/record_provenance.py b/src/art/utils/record_provenance.py index 8f3faed1e..84a8bce7f 100644 --- a/src/art/utils/record_provenance.py +++ b/src/art/utils/record_provenance.py @@ -7,11 +7,22 @@ def record_provenance(run: wandb.Run, provenance: str) -> None: - """Record provenance in run metadata, ensuring it's the last value in the array.""" - if "provenance" in run.config: - existing = list(run.config["provenance"]) + """Record provenance on the latest artifact version's metadata.""" + import wandb as wandb_module + + api = wandb_module.Api() + artifact_path = f"{run.entity}/{run.project}/{run.name}:latest" + try: + artifact = api.artifact(artifact_path, type="lora") + except wandb_module.errors.CommError: + return # No artifact exists yet + + existing = artifact.metadata.get("wandb.provenance") + if existing is not None: + existing = list(existing) if existing[-1] != provenance: existing.append(provenance) - run.config.update({"provenance": existing}) + artifact.metadata["wandb.provenance"] = existing else: - run.config.update({"provenance": [provenance]}) + artifact.metadata["wandb.provenance"] = [provenance] + artifact.save() diff --git a/tests/integration/test_provenance.py b/tests/integration/test_provenance.py index 0f1f0c900..187fcad88 100644 --- a/tests/integration/test_provenance.py +++ b/tests/integration/test_provenance.py @@ -1,9 +1,10 @@ -"""Integration test: verify provenance tracking in W&B run config via ServerlessBackend.""" +"""Integration test: verify provenance tracking on W&B artifact metadata via ServerlessBackend.""" import asyncio from datetime import datetime from dotenv import load_dotenv +import wandb import art from art.serverless.backend import ServerlessBackend @@ -36,8 +37,13 @@ async def simple_rollout(model: art.TrainableModel) -> art.Trajectory: return traj -async def make_group(model: art.TrainableModel) -> art.TrajectoryGroup: - return art.TrajectoryGroup(simple_rollout(model) for _ in range(4)) +def get_latest_artifact_provenance( + entity: str, project: str, name: str +) -> list[str] | None: + """Fetch provenance from the latest W&B artifact's metadata.""" + api = wandb.Api() + artifact = api.artifact(f"{entity}/{project}/{name}:latest", type="lora") + return artifact.metadata.get("wandb.provenance") async def main() -> None: @@ -49,25 +55,35 @@ async def main() -> None: base_model="OpenPipe/Qwen3-14B-Instruct", ) await model.register(backend) + assert model.entity is not None - # --- Step 1: first training call --- - groups = await art.gather_trajectory_groups(make_group(model) for _ in range(1)) - result = await backend.train(model, groups) - await model.log(groups, metrics=result.metrics, step=result.step, split="train") - - # Check provenance after first train call - run = model._get_wandb_run() - assert run is not None, "W&B run should exist" - provenance = run.config.get("provenance") + # --- Step 1: first training call (retry on transient server errors) --- + for attempt in range(3): + groups = await art.gather_trajectory_groups( + [art.TrajectoryGroup(simple_rollout(model) for _ in range(4))] # ty: ignore[invalid-argument-type] + ) + try: + result = await backend.train(model, groups) + await model.log( + groups, metrics=result.metrics, step=result.step, split="train" + ) + break + except RuntimeError as e: + print(f"Step 1 attempt {attempt + 1} failed: {e}") + if attempt == 2: + raise + + # Check provenance on the latest artifact after first train call + provenance = get_latest_artifact_provenance(model.entity, model.project, model.name) print(f"After step 1: provenance = {provenance}") assert provenance == ["serverless-rl"], ( f"Expected ['serverless-rl'], got {provenance}" ) # --- Step 2: second training call (same technique, should NOT duplicate) --- - # Provenance is recorded at the start of train(), before the remote call, - # so we can verify deduplication even if the server-side training fails. - groups2 = await art.gather_trajectory_groups(make_group(model) for _ in range(1)) + groups2 = await art.gather_trajectory_groups( + [art.TrajectoryGroup(simple_rollout(model) for _ in range(4))] # ty: ignore[invalid-argument-type] + ) try: result2 = await backend.train(model, groups2) await model.log( @@ -76,7 +92,7 @@ async def main() -> None: except RuntimeError as e: print(f"Step 2 training failed (transient server error, OK for this test): {e}") - provenance = run.config.get("provenance") + provenance = get_latest_artifact_provenance(model.entity, model.project, model.name) print(f"After step 2: provenance = {provenance}") assert provenance == ["serverless-rl"], ( f"Expected ['serverless-rl'] (no duplicate), got {provenance}"