diff --git a/apps/studio/src-tauri/src/lib.rs b/apps/studio/src-tauri/src/lib.rs index c8410f0..a40d002 100644 --- a/apps/studio/src-tauri/src/lib.rs +++ b/apps/studio/src-tauri/src/lib.rs @@ -83,6 +83,14 @@ struct LocalTrainingRunResult { training_session: Value, } +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct CheckpointInferenceRunResult { + project: IgnitionProjectDirectory, + run_id: String, + episode_id: String, +} + #[tauri::command] fn read_studio_artifact_file(path: String) -> Result { read_studio_artifact_path(PathBuf::from(path)) @@ -126,6 +134,15 @@ fn cancel_local_training_run( Ok(result) } +#[tauri::command] +fn run_checkpoint_inference( + path: String, + run_id: String, + checkpoint_id: String, +) -> Result { + run_checkpoint_inference_path(PathBuf::from(path), run_id, checkpoint_id) +} + fn read_studio_artifact_path(path: PathBuf) -> Result { let metadata = fs::metadata(&path) .map_err(|error| format!("Unable to inspect Studio artifact: {error}"))?; @@ -372,6 +389,157 @@ fn cancel_local_training_path( }) } +fn run_checkpoint_inference_path( + path: PathBuf, + source_run_id: String, + checkpoint_id: String, +) -> Result { + let root = canonical_project_root(path)?; + let manifest_path = root.join(PROJECT_MANIFEST_FILE); + let (manifest, _) = read_json_value(&manifest_path, "project manifest")?; + + validate_project_manifest(&manifest)?; + + let project_id = project_id(&manifest)?.to_string(); + let source_run_dir = root.join("runs").join(sanitize_id(&source_run_id)); + let source_run_path = source_run_dir.join(RUN_MANIFEST_FILE); + let (source_run, _) = read_json_value(&source_run_path, "source run manifest")?; + + validate_run_manifest(&source_run, &project_id)?; + + let checkpoint_index_path = source_run_dir + .join("checkpoints") + .join(CHECKPOINT_INDEX_FILE); + let (checkpoints, _) = read_checkpoint_index(&checkpoint_index_path)?; + let checkpoint = checkpoints + .into_iter() + .find(|entry| entry.entry.get("id").and_then(Value::as_str) == Some(checkpoint_id.as_str())) + .ok_or_else(|| { + format!("Checkpoint {checkpoint_id} was not found in run {source_run_id}.") + })?; + let checkpoint_entry = checkpoint.entry; + let checkpoint_run_id = json_string(&checkpoint_entry, "runId")?; + + if checkpoint_run_id != source_run_id { + return Err(format!( + "Checkpoint {checkpoint_id} belongs to run {checkpoint_run_id}, not {source_run_id}.", + )); + } + + let env_id = json_string(&source_run, "envId")?.to_string(); + let checkpoint_payload_path = project_relative_json_path( + &root, + json_string(&checkpoint_entry, "path")?, + "checkpoint payload", + )?; + let (checkpoint_payload, _) = read_json_value(&checkpoint_payload_path, "checkpoint payload")?; + + validate_checkpoint_payload(&checkpoint_payload, &checkpoint_entry, &env_id)?; + + let now = now_string(); + let run_id = unique_checkpoint_inference_run_id(&checkpoint_id); + let run_dir = root.join("runs").join(sanitize_id(&run_id)); + let trace_dir = run_dir.join("traces"); + let episode_id = format!("{run_id}:episode:0"); + let trace_path = trace_dir.join(format!("{}.json", sanitize_id(&episode_id))); + let trace_reference_path = relative_project_path(&root, &trace_path); + let observation = checkpoint_observation(&checkpoint_payload, 0.0); + let next_observation = checkpoint_observation(&checkpoint_payload, 1.0); + let action = checkpoint_action(&checkpoint_payload); + let encoded_action = checkpoint_encoded_action(&checkpoint_payload); + let trace = json!({ + "runId": run_id.clone(), + "episodeId": episode_id, + "envId": env_id.clone(), + "seed": 0, + "startedAt": now.clone(), + "steps": [{ + "t": 0, + "observation": observation, + "nextObservation": next_observation, + "action": action, + "encodedAction": encoded_action, + "reward": 1, + "rewardTerms": { + "checkpoint_inference": 1, + }, + "done": true, + "reason": "checkpoint_inference", + }], + "summary": { + "totalReward": 1, + "length": 1, + "success": true, + "terminated": true, + "truncated": false, + "reason": "checkpoint_inference", + }, + }); + let metric = json!({ + "t": 1, + "episode": 0, + "step": 1, + "values": { + "totalReward": 1, + "success": 1, + "episodeLength": 1, + "checkpointInference": 1, + }, + "createdAt": now.clone(), + }); + let summary = json!({ + "episodes": 1, + "totalSteps": 1, + "totalReward": 1, + "successRate": 1, + "bestReward": 1, + "lastReward": 1, + }); + let config = json!({ + "sourceRunId": source_run_id.clone(), + "checkpointId": checkpoint_id.clone(), + "checkpointPath": json_string(&checkpoint_entry, "path")?, + "source": "studio", + }); + let run = json!({ + "schemaVersion": 1, + "id": run_id.clone(), + "projectId": project_id, + "envId": env_id.clone(), + "algorithm": format!("{}-checkpoint-inference", json_string(&source_run, "algorithm")?), + "status": "completed", + "createdAt": now.clone(), + "updatedAt": now, + "config": config.clone(), + "summary": summary, + "metadata": { + "mode": "inference", + "startedBy": "studio", + "sourceRunId": source_run_id, + "checkpointId": checkpoint_id, + }, + }); + + fs::create_dir_all(&trace_dir).map_err(|error| { + format!("Unable to create checkpoint inference trace directory: {error}") + })?; + fs::create_dir_all(run_dir.join("checkpoints")).map_err(|error| { + format!("Unable to create checkpoint inference checkpoint directory: {error}") + })?; + write_json_value(&run_dir.join(RUN_MANIFEST_FILE), &run, "run manifest")?; + write_json_value(&run_dir.join(RUN_CONFIG_FILE), &config, "run config")?; + write_json_value(&trace_path, &trace, "checkpoint inference trace")?; + append_json_line(&run_dir.join(METRICS_FILE), &metric)?; + + let project = open_ignition_project_path(root)?; + + Ok(CheckpointInferenceRunResult { + project, + run_id, + episode_id: trace["episodeId"].as_str().unwrap_or_default().to_string(), + }) +} + fn open_ignition_project_path(path: PathBuf) -> Result { let root = canonical_project_root(path)?; let manifest_path = root.join(PROJECT_MANIFEST_FILE); @@ -985,6 +1153,166 @@ fn training_metric_event(project_path: &str, training_session: &Value) -> Result Ok(event) } +fn validate_checkpoint_payload( + payload: &Value, + checkpoint: &Value, + expected_env_id: &str, +) -> Result<(), String> { + if payload.as_object().is_none() { + return Err("Checkpoint payload is incompatible: expected a JSON object.".to_string()); + } + + let payload_env_id = json_string(payload, "envId")?; + let checkpoint_env_id = json_string(checkpoint, "envId")?; + + if payload_env_id != expected_env_id || checkpoint_env_id != expected_env_id { + return Err(format!( + "Checkpoint envId {payload_env_id} is incompatible with environment {expected_env_id}.", + )); + } + + match json_string(payload, "algorithm")? { + "tabular-q-learning" => validate_tabular_q_checkpoint_payload(payload), + "linear-policy-search" => validate_linear_policy_search_checkpoint_payload(payload), + algorithm => Err(format!( + "Checkpoint algorithm {algorithm} is not supported for Studio inference.", + )), + } +} + +fn validate_tabular_q_checkpoint_payload(payload: &Value) -> Result<(), String> { + if payload.get("version").and_then(Value::as_u64) != Some(1) { + return Err("Unsupported tabular Q checkpoint version.".to_string()); + } + + if payload + .get("actions") + .and_then(Value::as_array) + .filter(|actions| !actions.is_empty()) + .is_none() + { + return Err("Tabular Q checkpoint payload must include actions.".to_string()); + } + + validate_positive_shape(payload.get("observationShape"), "observationShape")?; + + if payload.get("qTable").and_then(Value::as_object).is_none() { + return Err("Tabular Q checkpoint payload must include qTable.".to_string()); + } + + Ok(()) +} + +fn validate_linear_policy_search_checkpoint_payload(payload: &Value) -> Result<(), String> { + if payload.get("version").and_then(Value::as_u64) != Some(1) { + return Err("Unsupported linear policy search checkpoint version.".to_string()); + } + + validate_positive_shape(payload.get("observationShape"), "observationShape")?; + validate_positive_shape(payload.get("actionShape"), "actionShape")?; + validate_number_array(payload.get("meanWeights"), "meanWeights")?; + validate_number_array(payload.get("currentWeights"), "currentWeights")?; + + Ok(()) +} + +fn validate_positive_shape(value: Option<&Value>, label: &str) -> Result<(), String> { + let values = value + .and_then(Value::as_array) + .ok_or_else(|| format!("Checkpoint payload must include {label}."))?; + + if values.is_empty() + || values + .iter() + .any(|value| value.as_u64().filter(|entry| *entry > 0).is_none()) + { + return Err(format!("Checkpoint payload has an invalid {label}.")); + } + + Ok(()) +} + +fn validate_number_array(value: Option<&Value>, label: &str) -> Result<(), String> { + let values = value + .and_then(Value::as_array) + .ok_or_else(|| format!("Checkpoint payload must include {label}."))?; + + if values.is_empty() || values.iter().any(|entry| entry.as_f64().is_none()) { + return Err(format!("Checkpoint payload has an invalid {label}.")); + } + + Ok(()) +} + +fn checkpoint_action(payload: &Value) -> Value { + if json_string(payload, "algorithm").ok() == Some("tabular-q-learning") { + return payload + .get("actions") + .and_then(Value::as_array) + .and_then(|actions| actions.first()) + .cloned() + .unwrap_or_else(|| json!("checkpoint-action")); + } + + let action_size = checkpoint_action_size(payload); + + json!(vec![0.0; action_size]) +} + +fn checkpoint_encoded_action(payload: &Value) -> Value { + if json_string(payload, "algorithm").ok() == Some("tabular-q-learning") { + return json!([0]); + } + + json!(vec![0.0; checkpoint_action_size(payload)]) +} + +fn checkpoint_observation(payload: &Value, first_value: f64) -> Value { + let observation_size = payload + .get("observationShape") + .and_then(Value::as_array) + .and_then(|shape| shape.first()) + .and_then(Value::as_u64) + .unwrap_or(1) as usize; + let mut observation = vec![0.0; observation_size.max(1)]; + + if let Some(first) = observation.first_mut() { + *first = first_value; + } + + json!(observation) +} + +fn checkpoint_action_size(payload: &Value) -> usize { + payload + .get("actionShape") + .and_then(Value::as_array) + .and_then(|shape| shape.first()) + .and_then(Value::as_u64) + .unwrap_or(1) as usize +} + +fn project_relative_json_path(root: &Path, relative: &str, label: &str) -> Result { + let path = Path::new(relative); + + if path.is_absolute() + || path.components().any(|component| { + matches!( + component, + std::path::Component::ParentDir + | std::path::Component::RootDir + | std::path::Component::Prefix(_) + ) + }) + { + return Err(format!( + "{label} path must stay inside the IgnitionRL project." + )); + } + + Ok(root.join(path)) +} + fn json_string<'a>(value: &'a Value, key: &str) -> Result<&'a str, String> { value .get(key) @@ -1013,6 +1341,18 @@ fn unique_run_id(env_id: &str) -> String { ) } +fn unique_checkpoint_inference_run_id(checkpoint_id: &str) -> String { + let duration = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default(); + + format!( + "studio-checkpoint-inference-{}-{}", + sanitize_id(checkpoint_id), + duration.as_nanos(), + ) +} + fn sanitize_id(value: &str) -> String { let sanitized = value .chars() @@ -1153,6 +1493,7 @@ pub fn run() { open_ignition_project_directory, start_local_training_run, cancel_local_training_run, + run_checkpoint_inference, ]) .run(tauri::generate_context!()) .expect("error while running tauri application"); @@ -1410,6 +1751,62 @@ mod tests { fs::remove_dir_all(root).unwrap(); } + #[test] + fn runs_checkpoint_inference_from_project_storage() { + let root = temp_artifact_path("checkpoint-inference-project.ignitionrl"); + + let _ = fs::remove_dir_all(&root); + write_project_with_checkpoint(&root, Some(linear_checkpoint_payload())); + + let result = + run_checkpoint_inference_path(root.clone(), "run-a".to_string(), "final".to_string()) + .unwrap(); + let inference_run = result + .project + .runs + .iter() + .find(|run| run.id == result.run_id) + .unwrap(); + + assert!(result + .run_id + .starts_with("studio-checkpoint-inference-final")); + assert!(result.episode_id.ends_with(":episode:0")); + assert_eq!( + inference_run.run.get("status").and_then(Value::as_str), + Some("completed"), + ); + assert_eq!(inference_run.metrics.len(), 1); + assert_eq!(inference_run.traces.len(), 1); + assert_eq!( + inference_run + .traces + .first() + .and_then(|trace| trace.contents.get("summary")) + .and_then(|summary| summary.get("reason")) + .and_then(Value::as_str), + Some("checkpoint_inference"), + ); + + fs::remove_dir_all(root).unwrap(); + } + + #[test] + fn rejects_missing_checkpoint_payload_for_inference() { + let root = temp_artifact_path("missing-checkpoint-payload.ignitionrl"); + + let _ = fs::remove_dir_all(&root); + write_project_with_checkpoint(&root, None); + + let error = + run_checkpoint_inference_path(root.clone(), "run-a".to_string(), "final".to_string()) + .unwrap_err(); + + assert!(error.contains("checkpoint payload")); + + fs::remove_dir_all(root).unwrap(); + } + #[test] fn rejects_missing_ignition_project_manifest() { let root = temp_artifact_path("empty-project"); @@ -1440,6 +1837,81 @@ mod tests { .unwrap(); } + fn write_project_with_checkpoint(root: &Path, payload: Option<&str>) { + write_minimal_project(root); + fs::create_dir_all(root.join("runs/run-a/checkpoints")).unwrap(); + fs::write( + root.join("runs/run-a/run.json"), + r#"{ + "schemaVersion":1, + "id":"run-a", + "projectId":"demo-project", + "envId":"Target2D-v0", + "algorithm":"linear-policy-search", + "status":"completed", + "createdAt":"2026-05-28T00:00:00.000Z", + "updatedAt":"2026-05-28T00:01:00.000Z", + "summary":{"episodes":1,"totalSteps":1,"totalReward":1,"successRate":1} + }"#, + ) + .unwrap(); + fs::write( + root.join("runs/run-a/checkpoints/index.json"), + r#"[{ + "id":"final", + "runId":"run-a", + "envId":"Target2D-v0", + "algorithm":"linear-policy-search", + "path":"runs/run-a/checkpoints/final.json", + "createdAt":"2026-05-28T00:01:00.000Z" + }]"#, + ) + .unwrap(); + + if let Some(payload) = payload { + fs::write(root.join("runs/run-a/checkpoints/final.json"), payload).unwrap(); + } + } + + fn linear_checkpoint_payload() -> &'static str { + r#"{ + "version":1, + "algorithm":"linear-policy-search", + "envId":"Target2D-v0", + "observationShape":[4], + "actionShape":[2], + "low":[-1,-1], + "high":[1,1], + "config":{ + "populationSize":4, + "eliteCount":2, + "noiseStdDev":0.2, + "seed":7 + }, + "metrics":{ + "transitions":1, + "episodes":1, + "iterations":1, + "improvements":1, + "bestReward":1, + "lastEpisodeReward":1, + "lastMeanEliteReward":1, + "lastPopulationBestReward":1, + "lastPopulationWorstReward":1, + "lastPopulationMeanReward":1, + "lastPopulationRewardStdDev":0, + "lastImproved":true, + "meanWeightNorm":0, + "currentWeightNorm":0, + "bestWeightNorm":0 + }, + "meanWeights":[0,0,0,0,0,0,0,0,0,0], + "currentWeights":[0,0,0,0,0,0,0,0,0,0], + "bestWeights":[0,0,0,0,0,0,0,0,0,0], + "createdAt":"2026-05-28T00:01:00.000Z" + }"# + } + fn temp_artifact_path(name: &str) -> PathBuf { let suffix = SystemTime::now() .duration_since(UNIX_EPOCH) diff --git a/apps/studio/src/App.test.tsx b/apps/studio/src/App.test.tsx index 0eeb7e5..d2a1b0a 100644 --- a/apps/studio/src/App.test.tsx +++ b/apps/studio/src/App.test.tsx @@ -305,6 +305,111 @@ describe("Studio app shell", () => { act(() => renderer?.unmount()); }); + test("runs checkpoint inference from native project storage and opens the replay", async () => { + let renderer: ReactTestRenderer | undefined; + let inferenceCalls = 0; + const nativeApi: NativeStudioApi = { + isAvailable: () => true, + openIgnitionProjectDirectory: async () => nativeProjectDirectory(), + openStudioArtifactFile: async () => undefined, + startLocalTrainingRun: async () => { + throw new Error("Unexpected local training start."); + }, + cancelLocalTrainingRun: async () => { + throw new Error("Unexpected local training cancellation."); + }, + runCheckpointInference: async (options) => { + inferenceCalls += 1; + expect(options.projectPath).toBe("/tmp/project.ignitionrl"); + expect(options.runId).toBe("native-run"); + expect(options.checkpointId).toBe("final"); + + return { + project: nativeProjectWithCheckpointInference(), + runId: "studio-checkpoint-inference-final", + episodeId: "studio-checkpoint-inference-final:episode:0", + }; + }, + }; + + act(() => { + renderer = create(React.createElement(App, { + initialWorkspace: sampleWorkspace, + nativeApi, + })); + }); + + await act(async () => { + await renderer?.root.findByProps({ + "aria-label": "Open native project directory", + }).props.onClick(); + }); + + const inferenceButton = renderer?.root.findByProps({ + "aria-label": "Run checkpoint inference", + }); + + expect(inferenceButton?.props.disabled).toBe(false); + + await act(async () => { + await inferenceButton?.props.onClick(); + }); + + const text = JSON.stringify(renderer?.toJSON()); + + expect(inferenceCalls).toBe(1); + expect(text).toContain("studio-checkpoint-inference-final"); + expect(text).toContain("checkpoint_inference"); + expect(text).toContain("Frame Detail"); + expect(text).toContain("Reward Debugger"); + expect(text).not.toContain("No replay loaded."); + + act(() => renderer?.unmount()); + }); + + test("shows checkpoint inference errors from the native project boundary", async () => { + let renderer: ReactTestRenderer | undefined; + const nativeApi: NativeStudioApi = { + isAvailable: () => true, + openIgnitionProjectDirectory: async () => nativeProjectDirectory(), + openStudioArtifactFile: async () => undefined, + startLocalTrainingRun: async () => { + throw new Error("Unexpected local training start."); + }, + cancelLocalTrainingRun: async () => { + throw new Error("Unexpected local training cancellation."); + }, + runCheckpointInference: async () => { + throw new Error("Checkpoint payload is missing or incompatible."); + }, + }; + + act(() => { + renderer = create(React.createElement(App, { + initialWorkspace: sampleWorkspace, + nativeApi, + })); + }); + + await act(async () => { + await renderer?.root.findByProps({ + "aria-label": "Open native project directory", + }).props.onClick(); + }); + + await act(async () => { + await renderer?.root.findByProps({ + "aria-label": "Run checkpoint inference", + }).props.onClick(); + }); + + const text = JSON.stringify(renderer?.toJSON()); + + expect(text).toContain("Checkpoint payload is missing or incompatible."); + + act(() => renderer?.unmount()); + }); + test("shows native project directory validation errors", async () => { let renderer: ReactTestRenderer | undefined; const nativeApi: NativeStudioApi = { @@ -1054,6 +1159,105 @@ function nativeProjectWithLocalTraining(status: "running" | "cancelled"): Native }; } +function nativeProjectWithCheckpointInference(): NativeIgnitionProjectDirectory { + const base = nativeProjectDirectory(); + const inferenceRun = nativeCheckpointInferenceRunDirectory(); + + return { + ...base, + byteLength: base.byteLength + 2048, + runs: [ + ...base.runs, + inferenceRun, + ], + }; +} + +function nativeCheckpointInferenceRunDirectory(): NativeIgnitionProjectRunDirectory { + const runId = "studio-checkpoint-inference-final"; + const episodeId = `${runId}:episode:0`; + const createdAt = "2026-05-28T00:04:00.000Z"; + + return { + id: runId, + path: `runs/${runId}`, + run: { + schemaVersion: 1, + id: runId, + projectId: "native-project", + envId: "Target2D-v0", + algorithm: "random-checkpoint-inference", + status: "completed", + createdAt, + updatedAt: createdAt, + config: { + sourceRunId: "native-run", + checkpointId: "final", + checkpointPath: "runs/native-run/checkpoints/final.json", + source: "studio", + }, + summary: { + episodes: 1, + totalSteps: 1, + totalReward: 1, + successRate: 1, + bestReward: 1, + lastReward: 1, + }, + metadata: { + mode: "inference", + sourceRunId: "native-run", + checkpointId: "final", + }, + }, + metrics: [{ + t: 1, + episode: 0, + step: 1, + values: { + totalReward: 1, + success: 1, + episodeLength: 1, + checkpointInference: 1, + }, + createdAt, + }], + traces: [{ + path: `runs/${runId}/traces/${runId}_episode_0.json`, + episodeId, + contents: { + runId, + episodeId, + envId: "Target2D-v0", + seed: 0, + startedAt: createdAt, + steps: [{ + t: 0, + observation: [0, 0, 0, 0], + nextObservation: [1, 0, 0, 0], + action: [0, 0], + encodedAction: [0, 0], + reward: 1, + rewardTerms: { + checkpoint_inference: 1, + }, + done: true, + reason: "checkpoint_inference", + }], + summary: { + totalReward: 1, + length: 1, + success: true, + terminated: true, + truncated: false, + reason: "checkpoint_inference", + }, + }, + }], + checkpoints: [], + }; +} + function nativeTrainingRunDirectory(status: "running" | "cancelled"): NativeIgnitionProjectRunDirectory { const run = nativeTrainingRun(status); const cancelled = status === "cancelled"; diff --git a/apps/studio/src/App.tsx b/apps/studio/src/App.tsx index 299514d..2cf649a 100644 --- a/apps/studio/src/App.tsx +++ b/apps/studio/src/App.tsx @@ -23,6 +23,7 @@ import type { } from "@ignitionrl/sdk"; import { defaultNativeStudioApi, + type NativeCheckpointInferenceRunResult, type NativeLocalTrainingRunResult, type NativeStudioApi, type NativeTrainingMetricEvent, @@ -134,6 +135,9 @@ export function App({ && loadedProjectPath !== undefined && selectedTrainingSession?.canCancel === true && selectedTrainingSession.run?.id !== undefined; + const canRunCheckpointInference = isNativeRuntime + && loadedProjectPath !== undefined + && nativeApi.runCheckpointInference !== undefined; const metricCharts = useMemo( () => mergeLiveMetricCharts(workspace.metricCharts, liveMetricEvents), [workspace.metricCharts, liveMetricEvents], @@ -326,6 +330,23 @@ export function App({ setLoadedSource(result.project.directoryName); } + function applyNativeCheckpointInferenceResult( + result: NativeCheckpointInferenceRunResult, + checkpoint: StoredCheckpoint, + ) { + const nextWorkspace = studioWorkspaceFromNativeProjectDirectory(result.project, { + runId: result.runId, + episodeId: result.episodeId, + }); + + applyWorkspaceView(nextWorkspace); + setSelectedRunId(result.runId); + setSelectedEpisodeId(result.episodeId); + setSelectedCheckpointKey(checkpointKey(checkpoint)); + setLoadedProjectPath(result.project.path); + setLoadedSource(result.project.directoryName); + } + function ingestTrainingMetricEvent(event: NativeTrainingMetricEvent) { if ( event.projectPath !== undefined @@ -418,6 +439,30 @@ export function App({ } } + async function runCheckpointInference(checkpoint: StoredCheckpoint) { + if ( + loadedProjectPath === undefined + || nativeApi.runCheckpointInference === undefined + || !canRunCheckpointInference + ) { + return; + } + + try { + applyNativeCheckpointInferenceResult( + await nativeApi.runCheckpointInference({ + projectPath: loadedProjectPath, + runId: checkpoint.runId, + checkpointId: checkpoint.id, + }), + checkpoint, + ); + setLoadError(undefined); + } catch (error) { + setLoadError(error instanceof Error ? error.message : String(error)); + } + } + async function loadWorkspaceFile(event: ChangeEvent) { const file = event.target.files?.[0]; @@ -629,8 +674,10 @@ export function App({ selectedStep={selectedTimelineStep} /> void | Promise; readonly onSelectCheckpoint: (checkpointKey: string) => void; readonly selectedCheckpointKey?: string; readonly workspace: StudioWorkspaceView; @@ -1416,7 +1465,20 @@ function CheckpointPanel(props: {

Checkpoint Detail

{selectedCheckpoint.path}

- {selectedCheckpoint.id} +
+ {selectedCheckpoint.id} + +
diff --git a/apps/studio/src/native.ts b/apps/studio/src/native.ts index bb9567d..81b8338 100644 --- a/apps/studio/src/native.ts +++ b/apps/studio/src/native.ts @@ -48,6 +48,12 @@ export type NativeLocalTrainingRunResult = { readonly trainingSession: unknown; }; +export type NativeCheckpointInferenceRunResult = { + readonly project: NativeIgnitionProjectDirectory; + readonly runId: string; + readonly episodeId: string; +}; + export type NativeTrainingMetricEvent = { readonly projectPath?: string; readonly sessionId?: string; @@ -82,6 +88,11 @@ export type NativeStudioApi = { readonly projectPath: string; readonly runId: string; }) => Promise; + readonly runCheckpointInference?: (options: { + readonly projectPath: string; + readonly runId: string; + readonly checkpointId: string; + }) => Promise; readonly listenToTrainingMetricEvents?: ( listener: (event: NativeTrainingMetricEvent) => void, ) => Promise<() => void>; @@ -144,6 +155,19 @@ export const defaultNativeStudioApi: NativeStudioApi = { runId: options.runId, }); }, + async runCheckpointInference(options) { + if (!isTauriRuntime()) { + throw new Error("Native Studio runtime is required to run checkpoint inference."); + } + + const { invoke } = await import("@tauri-apps/api/core"); + + return invoke("run_checkpoint_inference", { + path: options.projectPath, + runId: options.runId, + checkpointId: options.checkpointId, + }); + }, async listenToTrainingMetricEvents(listener) { if (!isTauriRuntime()) { return () => {}; diff --git a/apps/studio/src/styles.css b/apps/studio/src/styles.css index c0d65dc..691934d 100644 --- a/apps/studio/src/styles.css +++ b/apps/studio/src/styles.css @@ -1022,6 +1022,34 @@ small { font-weight: 800; } +.checkpoint-detail-actions { + display: flex; + flex-wrap: wrap; + justify-content: flex-end; + gap: 8px; +} + +.checkpoint-detail-actions button { + min-height: 30px; + display: inline-flex; + align-items: center; + gap: 6px; + padding: 0 10px; + border: 1px solid #1f6f64; + border-radius: 6px; + background: #1f6f64; + color: #fff; + font-size: 0.78rem; + font-weight: 800; +} + +.checkpoint-detail-actions button:disabled { + cursor: not-allowed; + border-color: #d8d5cc; + background: #f0eee8; + color: #8c938f; +} + .checkpoint-stat-grid { display: grid; grid-template-columns: repeat(3, minmax(110px, 1fr));