From 86df6db3af7ef337b5fa30f00d10a774637d54de Mon Sep 17 00:00:00 2001 From: salimlaimeche Date: Fri, 29 May 2026 04:38:26 +0200 Subject: [PATCH] chore(ci): add training regression smoke --- scripts/ci-smoke.mjs | 213 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) diff --git a/scripts/ci-smoke.mjs b/scripts/ci-smoke.mjs index fcad0c3..99abc2e 100644 --- a/scripts/ci-smoke.mjs +++ b/scripts/ci-smoke.mjs @@ -16,6 +16,7 @@ const results = [ runGridWorldSmoke(), await runDqnSmoke(), runDroneTargetSmoke(), + runTrainingRegressionSmoke(), ]; console.log( @@ -832,6 +833,142 @@ async function runDqnSmoke() { }); } +function runTrainingRegressionSmoke() { + return withProjectDir("ignitionrl-training-regression-smoke-", (projectDir) => { + const gridProjectDir = join(projectDir, "grid-world.ignitionrl"); + const droneProjectDir = join(projectDir, "drone-target.ignitionrl"); + + runCli(["init", "grid-world", gridProjectDir, "--json"]); + const gridTrain = runCli([ + "train", + gridProjectDir, + "GridWorld-v0", + "--learner", + "tabular-q", + "--episodes", + "12", + "--max-steps", + "20", + "--seed", + "ci-training-regression-grid", + "--run-id", + "grid-world-training-regression", + "--checkpoint-id", + "final", + "--json", + ]); + const gridMetrics = runCli([ + "metrics", + gridProjectDir, + "totalReward", + "--run-id", + "grid-world-training-regression", + "--json", + ]); + const gridMetricPoints = metricPointCount( + gridMetrics, + "grid-world-training-regression", + "totalReward", + ); + + assertTrainingRun(gridTrain, { + environmentId: "GridWorld-v0", + learner: "tabular-q", + runId: "grid-world-training-regression", + }); + assertTrainingMetricListed(gridTrain, "learner.transitions"); + assertTrainingMetricAtLeast( + gridTrain, + "summary.successRate", + gridTrain.summary.successRate, + 0.8, + ); + assertTrainingMetricAtLeast( + gridTrain, + "summary.totalReward", + gridTrain.summary.totalReward, + 100, + ); + assertTrainingMetricAtLeast(gridTrain, "metricPoints.totalReward", gridMetricPoints, 1); + + runCli(["init", "drone-target", droneProjectDir, "--json"]); + const droneTrain = runCli([ + "train", + droneProjectDir, + "DroneTarget-v0", + "--learner", + "linear-policy-search", + "--episodes", + "2", + "--max-steps", + "12", + "--seed", + "ci-training-regression-drone", + "--run-id", + "drone-target-training-regression", + "--checkpoint-id", + "final", + "--json", + ]); + const droneMetrics = runCli([ + "metrics", + droneProjectDir, + "learner.bestReward", + "--run-id", + "drone-target-training-regression", + "--json", + ]); + const droneMetricPoints = metricPointCount( + droneMetrics, + "drone-target-training-regression", + "learner.bestReward", + ); + + assertTrainingRun(droneTrain, { + environmentId: "DroneTarget-v0", + learner: "linear-policy-search", + runId: "drone-target-training-regression", + }); + assertTrainingMetricListed(droneTrain, "learner.bestReward"); + assertTrainingMetricAtLeast( + droneTrain, + "summary.successRate", + droneTrain.summary.successRate, + 0.5, + ); + assertTrainingMetricAtLeast( + droneTrain, + "summary.bestReward", + droneTrain.summary.bestReward, + 20, + ); + assertTrainingMetricAtLeast( + droneTrain, + "metricPoints.learner.bestReward", + droneMetricPoints, + 1, + ); + + return { + name: "TrainingRegression", + runs: [gridTrain.runId, droneTrain.runId], + doctorChecks: 0, + evalChecks: 5, + environments: 2, + environmentRuns: 2, + historyRows: 0, + studioSelectedRun: droneTrain.runId, + runDetailArtifacts: 0, + artifacts: 0, + episodes: gridTrain.summary.episodes + droneTrain.summary.episodes, + metricPoints: gridMetricPoints + droneMetricPoints, + rewardTerms: 0, + checkpoints: 2, + checkpointSource: `${gridTrain.runId}, ${droneTrain.runId}`, + }; + }); +} + function withProjectDir(prefix, fn) { const projectDir = mkdtempSync(join(tmpdir(), prefix)); @@ -879,3 +1016,79 @@ function assert(condition, message) { throw new Error(message); } } + +function assertTrainingRun(result, expected) { + assert( + result.command === "train", + "training regression command returned an unexpected payload", + ); + assert( + result.environmentId === expected.environmentId, + `training regression selected wrong environment: run=${result.runId} expected=${expected.environmentId} actual=${result.environmentId}`, + ); + assert( + result.learner === expected.learner, + `training regression selected wrong learner: run=${result.runId} expected=${expected.learner} actual=${result.learner}`, + ); + assert( + result.runId === expected.runId, + `training regression selected wrong run: expected=${expected.runId} actual=${result.runId}`, + ); + assert( + result.checkpoint?.id === "final", + `training regression did not persist final checkpoint: run=${result.runId}`, + ); +} + +function assertTrainingMetricListed(result, metric) { + if (!result.metricNames.includes(metric)) { + throw new Error( + [ + "Training regression failed:", + `run=${result.runId}`, + `environment=${result.environmentId}`, + `learner=${result.learner}`, + `metric=${metric}`, + "reason=missing_metric_name", + ].join(" "), + ); + } +} + +function assertTrainingMetricAtLeast(result, metric, actual, expected) { + if (typeof actual !== "number" || !Number.isFinite(actual) || actual < expected) { + throw new Error( + [ + "Training regression failed:", + `run=${result.runId}`, + `environment=${result.environmentId}`, + `learner=${result.learner}`, + `metric=${metric}`, + `actual=${String(actual)}`, + `expected>=${String(expected)}`, + ].join(" "), + ); + } +} + +function metricPointCount(result, runId, metric) { + assert( + result.command === "metrics", + "training regression metrics command returned an unexpected payload", + ); + + const run = result.runs.find((entry) => entry.id === runId); + + if (run === undefined) { + throw new Error( + [ + "Training regression failed:", + `run=${runId}`, + `metric=${metric}`, + "reason=missing_metric_run", + ].join(" "), + ); + } + + return run.points.length; +}