diff --git a/Cargo.toml b/Cargo.toml index 6209854..1328ac5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,4 +21,4 @@ redis = { version = "0.27", features = ["script"] } openssl-sys = { version = "0.9", features = ["vendored"] } pyo3 = { version = "0.22", features = ["multiple-pymethods"] } async-trait = "0.1" -dagron-core = { git = "https://github.com/ByteVeda/dagron.git" } +dagron-core = { git = "https://github.com/ByteVeda/dagron.git", rev = "d1b61aaf2ed2d516b9a239f089a55b143cb05f65" } diff --git a/crates/taskito-core/src/storage/redis_backend/jobs.rs b/crates/taskito-core/src/storage/redis_backend/jobs.rs index 76f8816..a2834fb 100644 --- a/crates/taskito-core/src/storage/redis_backend/jobs.rs +++ b/crates/taskito-core/src/storage/redis_backend/jobs.rs @@ -722,7 +722,7 @@ impl RedisStorage { } // Sort by created_at desc - jobs.sort_by(|a, b| b.created_at.cmp(&a.created_at)); + jobs.sort_by_key(|j| std::cmp::Reverse(j.created_at)); // Apply pagination let start = (offset.max(0) as usize).min(jobs.len()); @@ -900,7 +900,7 @@ impl RedisStorage { } } - jobs.sort_by(|a, b| b.created_at.cmp(&a.created_at)); + jobs.sort_by_key(|j| std::cmp::Reverse(j.created_at)); let start = (offset.max(0) as usize).min(jobs.len()); let end = start.saturating_add(limit.max(0) as usize).min(jobs.len()); diff --git a/crates/taskito-core/src/storage/redis_backend/locks.rs b/crates/taskito-core/src/storage/redis_backend/locks.rs index 577b43f..57a1913 100644 --- a/crates/taskito-core/src/storage/redis_backend/locks.rs +++ b/crates/taskito-core/src/storage/redis_backend/locks.rs @@ -162,10 +162,11 @@ impl RedisStorage { let mut conn = self.conn()?; let now = now_millis(); let ckey = self.key(&["exec_claim", job_id]); + let index_key = self.key(&["exec_claims", "by_time"]); // NX: set only if not exists. PX: auto-expire after 24 hours so // orphaned claims from dead workers don't block re-execution forever. - let result: bool = redis::cmd("SET") + let acquired: bool = redis::cmd("SET") .arg(&ckey) .arg(format!("{worker_id}:{now}")) .arg("NX") @@ -174,22 +175,51 @@ impl RedisStorage { .query(&mut conn) .map_err(map_err)?; - Ok(result) + if acquired { + // Mirror the claim into a time-indexed sorted set so the + // scheduler's maintenance loop can purge stale claims with an + // O(log n) range query. + conn.zadd::<_, _, _, ()>(&index_key, job_id, now as f64) + .map_err(map_err)?; + } + + Ok(acquired) } pub fn complete_execution(&self, job_id: &str) -> Result<()> { let mut conn = self.conn()?; let ckey = self.key(&["exec_claim", job_id]); + let index_key = self.key(&["exec_claims", "by_time"]); - conn.del::<_, ()>(&ckey).map_err(map_err)?; + let pipe = &mut redis::pipe(); + pipe.del(&ckey); + pipe.zrem(&index_key, job_id); + pipe.query::<()>(&mut conn).map_err(map_err)?; Ok(()) } - pub fn purge_execution_claims(&self, _older_than_ms: i64) -> Result { - // Redis doesn't have efficient timestamp-based scanning for simple keys. - // For production use, execution claims should use TTL on the key itself. - // For now, this is a no-op — claims are cleaned up on complete_execution. - Ok(0) + pub fn purge_execution_claims(&self, older_than_ms: i64) -> Result { + let mut conn = self.conn()?; + let index_key = self.key(&["exec_claims", "by_time"]); + + // Find all claims with `claimed_at <= older_than_ms`. + let expired_ids: Vec = conn + .zrangebyscore(&index_key, "-inf", older_than_ms as f64) + .map_err(map_err)?; + + if expired_ids.is_empty() { + return Ok(0); + } + + let pipe = &mut redis::pipe(); + for id in &expired_ids { + let ckey = self.key(&["exec_claim", id]); + pipe.del(&ckey); + pipe.zrem(&index_key, id); + } + pipe.query::<()>(&mut conn).map_err(map_err)?; + + Ok(expired_ids.len() as u64) } } diff --git a/crates/taskito-core/src/storage/redis_backend/logs.rs b/crates/taskito-core/src/storage/redis_backend/logs.rs index 722bd9b..4446483 100644 --- a/crates/taskito-core/src/storage/redis_backend/logs.rs +++ b/crates/taskito-core/src/storage/redis_backend/logs.rs @@ -131,7 +131,7 @@ impl RedisStorage { } // Sort by logged_at desc - rows.sort_by(|a, b| b.logged_at.cmp(&a.logged_at)); + rows.sort_by_key(|r| std::cmp::Reverse(r.logged_at)); if limit >= 0 { rows.truncate(limit as usize); } diff --git a/crates/taskito-core/src/storage/redis_backend/metrics.rs b/crates/taskito-core/src/storage/redis_backend/metrics.rs index 78dc335..66e4591 100644 --- a/crates/taskito-core/src/storage/redis_backend/metrics.rs +++ b/crates/taskito-core/src/storage/redis_backend/metrics.rs @@ -121,7 +121,7 @@ impl RedisStorage { } // Sort by recorded_at desc - rows.sort_by(|a, b| b.recorded_at.cmp(&a.recorded_at)); + rows.sort_by_key(|r| std::cmp::Reverse(r.recorded_at)); Ok(rows) } @@ -210,7 +210,7 @@ impl RedisStorage { } // Sort by replayed_at desc - rows.sort_by(|a, b| b.replayed_at.cmp(&a.replayed_at)); + rows.sort_by_key(|r| std::cmp::Reverse(r.replayed_at)); Ok(rows) } } diff --git a/crates/taskito-core/src/storage/sqlite/dead_letter.rs b/crates/taskito-core/src/storage/sqlite/dead_letter.rs index 4189b4d..3cabf03 100644 --- a/crates/taskito-core/src/storage/sqlite/dead_letter.rs +++ b/crates/taskito-core/src/storage/sqlite/dead_letter.rs @@ -49,17 +49,15 @@ impl SqliteStorage { Ok::<(), diesel::result::Error>(()) })?; - // Drop connection before cascade (needed for single-connection pools) + // Drop connection before cascade (needed for single-connection pools). drop(conn); - // Cascade cancel dependents — log warning on failure since the DLQ - // transaction already committed and we can't roll it back. - if let Err(e) = self.cascade_cancel(&job_id, "dependency failed") { - log::warn!( - "[taskito] cascade_cancel failed for job {}: {}. Dependent jobs may be left pending.", - job_id, e - ); - } + // Cascade cancel dependents. Errors propagate so callers can decide how + // to react; parity with the Postgres and Redis backends. Note: the DLQ + // row has already been committed, so a failure here leaves a partial + // state (DLQ entry present, dependents possibly uncancelled) — callers + // should log and alert, not silently retry `move_to_dlq`. + self.cascade_cancel(&job_id, "dependency failed")?; Ok(()) } diff --git a/crates/taskito-core/src/storage/traits.rs b/crates/taskito-core/src/storage/traits.rs index ded6ffb..aed2535 100644 --- a/crates/taskito-core/src/storage/traits.rs +++ b/crates/taskito-core/src/storage/traits.rs @@ -5,8 +5,9 @@ use crate::storage::{DeadJob, QueueStats}; /// Trait abstracting the storage backend for the task queue. /// -/// Implementations include `SqliteStorage` and `PostgresStorage`. This trait -/// enables alternative backends and simplifies testing with mock storage. +/// Implementations: `SqliteStorage` (default), `PostgresStorage` (feature +/// `postgres`), and `RedisStorage` (feature `redis`). The trait enables +/// alternative backends and simplifies testing with mock storage. pub trait Storage: Send + Sync + Clone { // ── Job operations ────────────────────────────────────────────── diff --git a/crates/taskito-core/tests/rust/storage_tests.rs b/crates/taskito-core/tests/rust/storage_tests.rs index d8e9ce1..60796e4 100644 --- a/crates/taskito-core/tests/rust/storage_tests.rs +++ b/crates/taskito-core/tests/rust/storage_tests.rs @@ -211,6 +211,37 @@ fn test_pause_resume_queue(s: &impl Storage) { assert!(!paused.contains(&q.to_string())); } +fn test_execution_claims_purge(s: &impl Storage) { + // Regression: Redis `purge_execution_claims` was a silent no-op. The + // scheduler's maintenance loop relies on this method to reap stale claims, + // so all backends must honor the `older_than_ms` cutoff. + let worker = "w-purge"; + let old_job = "old-claim-job-id"; + let fresh_job = "fresh-claim-job-id"; + + assert!(s.claim_execution(old_job, worker).unwrap()); + // Advance past the old claim so the cutoff below can catch it but miss + // the fresh claim (claimed after the cutoff below is computed). + std::thread::sleep(std::time::Duration::from_millis(20)); + let cutoff = now_millis(); + std::thread::sleep(std::time::Duration::from_millis(20)); + assert!(s.claim_execution(fresh_job, worker).unwrap()); + + let purged = s.purge_execution_claims(cutoff).unwrap(); + assert!( + purged >= 1, + "purge must delete at least the one claim older than the cutoff" + ); + + // The old claim is gone — a fresh claim_execution for the same job succeeds. + assert!(s.claim_execution(old_job, worker).unwrap()); + // The fresh claim must still be held. + assert!(!s.claim_execution(fresh_job, worker).unwrap()); + + s.complete_execution(old_job).unwrap(); + s.complete_execution(fresh_job).unwrap(); +} + fn test_circuit_breakers(s: &impl Storage) { let task = "cb-test-task"; let cb = s.get_circuit_breaker(task).unwrap(); @@ -256,6 +287,7 @@ fn run_storage_tests(s: &impl Storage) { test_workers(s); test_pause_resume_queue(s); test_circuit_breakers(s); + test_execution_claims_purge(s); } // ── Backend-specific wiring ────────────────────────────────────────── diff --git a/crates/taskito-python/src/py_queue/mod.rs b/crates/taskito-python/src/py_queue/mod.rs index 22ed33f..4026705 100644 --- a/crates/taskito-python/src/py_queue/mod.rs +++ b/crates/taskito-python/src/py_queue/mod.rs @@ -39,6 +39,11 @@ pub struct PyQueue { pub(crate) scheduler_reap_interval: u32, pub(crate) scheduler_cleanup_interval: u32, pub(crate) namespace: Option, + /// Cached workflow storage handle. Lazily initialized on first workflow API + /// call; migrations run exactly once per `PyQueue` instance instead of + /// per-call. + #[cfg(feature = "workflows")] + pub(crate) workflow_storage: std::sync::OnceLock, } #[pymethods] @@ -132,6 +137,8 @@ impl PyQueue { scheduler_reap_interval, scheduler_cleanup_interval, namespace, + #[cfg(feature = "workflows")] + workflow_storage: std::sync::OnceLock::new(), }) } diff --git a/crates/taskito-python/src/py_queue/workflow_ops.rs b/crates/taskito-python/src/py_queue/workflow_ops.rs index 31e070a..74c3e65 100644 --- a/crates/taskito-python/src/py_queue/workflow_ops.rs +++ b/crates/taskito-python/src/py_queue/workflow_ops.rs @@ -4,11 +4,12 @@ //! workflow-specific methods to `PyQueue` via a separate `#[pymethods]` //! impl block (enabled by pyo3's `multiple-pymethods` feature). -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; +use taskito_core::error::Result as CoreResult; use taskito_core::job::{now_millis, NewJob}; use taskito_core::storage::{Storage, StorageBackend}; use taskito_workflows::{ @@ -19,15 +20,24 @@ use taskito_workflows::{ use crate::py_queue::PyQueue; use crate::py_workflow::{PyWorkflowHandle, PyWorkflowRunStatus}; -/// Build a `WorkflowSqliteStorage` from a `PyQueue`'s backend. +/// Return the queue's cached workflow storage, initializing it on first use. /// -/// Currently only SQLite is supported for workflows. Migrations run on -/// construction; repeated calls are cheap because the migrations use -/// `CREATE TABLE IF NOT EXISTS`. +/// Migrations run on first construction only; subsequent calls are a cheap +/// `OnceLock::get()`. Callers receive a cloned handle (the underlying +/// `SqliteStorage` is a pool handle so clones share the same connection pool). fn workflow_storage(queue: &PyQueue) -> PyResult { + if let Some(wf) = queue.workflow_storage.get() { + return Ok(wf.clone()); + } match &queue.storage { - StorageBackend::Sqlite(s) => WorkflowSqliteStorage::new(s.clone()) - .map_err(|e| PyRuntimeError::new_err(e.to_string())), + StorageBackend::Sqlite(s) => { + let wf = WorkflowSqliteStorage::new(s.clone()) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + // If another thread raced us to initialize, our value is ignored — + // either handle is equivalent because the underlying pool is shared. + let _ = queue.workflow_storage.set(wf.clone()); + Ok(wf) + } #[cfg(feature = "postgres")] StorageBackend::Postgres(_) => Err(PyRuntimeError::new_err( "workflows are currently only supported on the SQLite backend", @@ -44,18 +54,65 @@ fn parse_step_metadata(json: &str) -> PyResult> { .map_err(|e| PyValueError::new_err(format!("invalid step_metadata JSON: {e}"))) } +/// Build a job-metadata JSON blob that carries workflow routing info. +/// +/// Uses `serde_json` to guarantee proper escaping of node names containing +/// backslashes, control characters, or Unicode — hand-rolled escaping previously +/// produced invalid JSON for such inputs. fn build_metadata_json(run_id: &str, node_name: &str) -> String { - format!( - r#"{{"workflow_run_id":"{}","workflow_node_name":"{}"}}"#, - run_id.replace('"', "\\\""), - node_name.replace('"', "\\\""), - ) + serde_json::json!({ + "workflow_run_id": run_id, + "workflow_node_name": node_name, + }) + .to_string() } fn status_to_py(status: WorkflowState) -> String { status.as_str().to_string() } +/// Mark every pending/ready node in a run as skipped and cancel its job. +/// +/// Best-effort: per-node failures are logged but do not abort the sweep. +fn cascade_skip_pending_nodes( + storage: &StorageBackend, + wf_storage: &WorkflowSqliteStorage, + run_id: &str, + nodes: &[WorkflowNode], +) -> CoreResult<()> { + for node in nodes { + if !matches!( + node.status, + WorkflowNodeStatus::Pending | WorkflowNodeStatus::Ready + ) { + continue; + } + if let Some(job_id) = &node.job_id { + if let Err(e) = storage.cancel_job(job_id) { + log::warn!( + "[taskito] cancel_job({}) failed during cascade skip for run {}: {}", + job_id, + run_id, + e + ); + } + } + if let Err(e) = wf_storage.update_workflow_node_status( + run_id, + &node.node_name, + WorkflowNodeStatus::Skipped, + ) { + log::warn!( + "[taskito] skip node '{}' failed for run {}: {}", + node.node_name, + run_id, + e + ); + } + } + Ok(()) +} + #[pymethods] impl PyQueue { /// Submit a workflow for execution. @@ -262,83 +319,95 @@ impl PyQueue { } /// Fetch a snapshot of a workflow run's state and per-node status. - pub fn get_workflow_run_status(&self, run_id: &str) -> PyResult { + pub fn get_workflow_run_status( + &self, + py: Python<'_>, + run_id: &str, + ) -> PyResult { let wf_storage = workflow_storage(self)?; - let run = wf_storage - .get_workflow_run(run_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))? - .ok_or_else(|| PyValueError::new_err(format!("workflow run '{run_id}' not found")))?; - - let nodes = wf_storage - .get_workflow_nodes(run_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - - let node_rows = nodes - .into_iter() - .map(|n| { - ( - n.node_name, - n.status.as_str().to_string(), - n.job_id, - n.error, - ) - }) - .collect(); + let run_id_owned = run_id.to_string(); - Ok(PyWorkflowRunStatus { - run_id: run.id, - state: status_to_py(run.state), - started_at: run.started_at, - completed_at: run.completed_at, - error: run.error, - nodes: node_rows, - }) + let result: CoreResult> = py.allow_threads(|| { + let run = match wf_storage.get_workflow_run(&run_id_owned)? { + Some(r) => r, + None => return Ok(None), + }; + let nodes = wf_storage.get_workflow_nodes(&run_id_owned)?; + let node_rows = nodes + .into_iter() + .map(|n| { + ( + n.node_name, + n.status.as_str().to_string(), + n.job_id, + n.error, + ) + }) + .collect(); + Ok(Some(PyWorkflowRunStatus { + run_id: run.id, + state: status_to_py(run.state), + started_at: run.started_at, + completed_at: run.completed_at, + error: run.error, + nodes: node_rows, + })) + }); + + result + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + .ok_or_else(|| PyValueError::new_err(format!("workflow run '{run_id}' not found"))) } - /// Cancel a workflow run. + /// Cancel a workflow run and all of its sub-workflow descendants. /// - /// Marks the run `Cancelled`, skips any pending nodes, and cancels - /// their underlying jobs. Nodes already running are left alone + /// Marks each visited run `Cancelled`, skips pending/ready nodes, and + /// cancels their underlying jobs. Traversal is iterative with a visited + /// set so that any accidental cycle in `parent_run_id` links terminates + /// safely instead of recursing. Nodes already running are left alone /// (consistent with taskito's existing cancel semantics). - pub fn cancel_workflow_run(&self, run_id: &str) -> PyResult<()> { + pub fn cancel_workflow_run(&self, py: Python<'_>, run_id: &str) -> PyResult<()> { let wf_storage = workflow_storage(self)?; - let nodes = wf_storage - .get_workflow_nodes(run_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let root = run_id.to_string(); - for node in &nodes { - if matches!( - node.status, - WorkflowNodeStatus::Pending | WorkflowNodeStatus::Ready - ) { - if let Some(job_id) = &node.job_id { - let _ = self.storage.cancel_job(job_id); + let result: CoreResult<()> = py.allow_threads(|| { + let mut visited: HashSet = HashSet::new(); + let mut stack: Vec = vec![root]; + let now = now_millis(); + + while let Some(rid) = stack.pop() { + if !visited.insert(rid.clone()) { + continue; } - let _ = wf_storage.update_workflow_node_status( - run_id, - &node.node_name, - WorkflowNodeStatus::Skipped, - ); - } - } - wf_storage - .update_workflow_run_state(run_id, WorkflowState::Cancelled, None) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - wf_storage - .set_workflow_run_completed(run_id, now_millis()) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let nodes = wf_storage.get_workflow_nodes(&rid)?; + cascade_skip_pending_nodes(&self.storage, &wf_storage, &rid, &nodes)?; - // Cascade cancellation to child workflow runs (sub-workflows). - if let Ok(children) = wf_storage.get_child_workflow_runs(run_id) { - for child in children { - if !child.state.is_terminal() { - let _ = self.cancel_workflow_run(&child.id); + wf_storage.update_workflow_run_state(&rid, WorkflowState::Cancelled, None)?; + wf_storage.set_workflow_run_completed(&rid, now)?; + + match wf_storage.get_child_workflow_runs(&rid) { + Ok(children) => { + for child in children { + if !child.state.is_terminal() && !visited.contains(&child.id) { + stack.push(child.id); + } + } + } + Err(e) => { + log::warn!( + "[taskito] get_child_workflow_runs({}) failed during cancel: {}", + rid, + e + ); + } } } - } - Ok(()) + Ok(()) + }); + + result.map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Record the terminal outcome of a workflow node's job. @@ -358,6 +427,7 @@ impl PyQueue { #[pyo3(signature = (job_id, succeeded, error=None, skip_cascade=false, result_hash=None))] pub fn mark_workflow_node_result( &self, + py: Python<'_>, job_id: &str, succeeded: bool, error: Option, @@ -365,86 +435,82 @@ impl PyQueue { result_hash: Option, ) -> PyResult)>> { let wf_storage = workflow_storage(self)?; - let job = self - .storage - .get_job(job_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))? - .ok_or_else(|| PyValueError::new_err(format!("job '{job_id}' not found")))?; + let job_id_owned = job_id.to_string(); + + enum Outcome { + NotFound, + NoWorkflowMetadata, + Settled { + run_id: String, + node_name: String, + final_state: Option, + }, + } - let metadata_json = match &job.metadata { - Some(m) => m, - None => return Ok(None), - }; - let parsed: serde_json::Value = match serde_json::from_str(metadata_json) { - Ok(v) => v, - Err(_) => return Ok(None), - }; - let run_id = match parsed.get("workflow_run_id").and_then(|v| v.as_str()) { - Some(id) => id.to_string(), - None => return Ok(None), - }; - let node_name = match parsed.get("workflow_node_name").and_then(|v| v.as_str()) { - Some(n) => n.to_string(), - None => return Ok(None), - }; + let outcome: CoreResult = py.allow_threads(|| { + let job = match self.storage.get_job(&job_id_owned)? { + Some(j) => j, + None => return Ok(Outcome::NotFound), + }; - let now = now_millis(); - if succeeded { - wf_storage - .set_workflow_node_completed(&run_id, &node_name, now, result_hash.as_deref()) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - } else { - let err_msg = error.clone().unwrap_or_else(|| "failed".to_string()); - wf_storage - .set_workflow_node_error(&run_id, &node_name, &err_msg) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - } + let metadata_json = match &job.metadata { + Some(m) => m, + None => return Ok(Outcome::NoWorkflowMetadata), + }; + let parsed: serde_json::Value = match serde_json::from_str(metadata_json) { + Ok(v) => v, + Err(_) => return Ok(Outcome::NoWorkflowMetadata), + }; + let run_id = match parsed.get("workflow_run_id").and_then(|v| v.as_str()) { + Some(id) => id.to_string(), + None => return Ok(Outcome::NoWorkflowMetadata), + }; + let node_name = match parsed.get("workflow_node_name").and_then(|v| v.as_str()) { + Some(n) => n.to_string(), + None => return Ok(Outcome::NoWorkflowMetadata), + }; - // Fail-fast: cascade failure to pending/ready nodes. - // Skipped when the Python tracker manages cascade (conditions / continue mode). - if !succeeded && !skip_cascade { - let nodes = wf_storage - .get_workflow_nodes(&run_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - for n in &nodes { - if matches!( - n.status, - WorkflowNodeStatus::Pending | WorkflowNodeStatus::Ready - ) { - if let Some(j) = &n.job_id { - let _ = self.storage.cancel_job(j); - } - let _ = wf_storage.update_workflow_node_status( - &run_id, - &n.node_name, - WorkflowNodeStatus::Skipped, - ); - } + let now = now_millis(); + if succeeded { + wf_storage.set_workflow_node_completed( + &run_id, + &node_name, + now, + result_hash.as_deref(), + )?; + } else { + let err_msg = error.clone().unwrap_or_else(|| "failed".to_string()); + wf_storage.set_workflow_node_error(&run_id, &node_name, &err_msg)?; } - } - // Note: fan-out parent status is NOT updated here. The Python - // tracker calls `check_fan_out_completion` which atomically marks - // the parent and triggers fan-in. Doing it here would race. + // Fail-fast: cascade failure to pending/ready nodes. Skipped when the + // Python tracker manages cascade (conditions / continue mode). + if !succeeded && !skip_cascade { + let nodes = wf_storage.get_workflow_nodes(&run_id)?; + cascade_skip_pending_nodes(&self.storage, &wf_storage, &run_id, &nodes)?; + } - // Check if the entire run is terminal. - let nodes = wf_storage - .get_workflow_nodes(&run_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - let all_terminal = nodes.iter().all(|n| n.status.is_terminal()); - if !all_terminal { - return Ok(Some((run_id, node_name, None))); - } + // Note: fan-out parent status is NOT updated here. The tracker calls + // `check_fan_out_completion` which uses a CAS to finalize exactly once. - let any_failed = nodes.iter().any(|n| n.status == WorkflowNodeStatus::Failed); - let final_state = if any_failed || !succeeded { - WorkflowState::Failed - } else { - WorkflowState::Completed - }; + let nodes = wf_storage.get_workflow_nodes(&run_id)?; + let all_terminal = nodes.iter().all(|n| n.status.is_terminal()); + if !all_terminal { + return Ok(Outcome::Settled { + run_id, + node_name, + final_state: None, + }); + } - wf_storage - .update_workflow_run_state( + let any_failed = nodes.iter().any(|n| n.status == WorkflowNodeStatus::Failed); + let final_state = if any_failed || !succeeded { + WorkflowState::Failed + } else { + WorkflowState::Completed + }; + + wf_storage.update_workflow_run_state( &run_id, final_state, if final_state == WorkflowState::Failed { @@ -452,17 +518,25 @@ impl PyQueue { } else { None }, - ) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - wf_storage - .set_workflow_run_completed(&run_id, now) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + )?; + wf_storage.set_workflow_run_completed(&run_id, now)?; - Ok(Some(( - run_id, - node_name, - Some(final_state.as_str().to_string()), - ))) + Ok(Outcome::Settled { + run_id, + node_name, + final_state: Some(final_state.as_str().to_string()), + }) + }); + + match outcome.map_err(|e| PyRuntimeError::new_err(e.to_string()))? { + Outcome::NotFound => Err(PyValueError::new_err(format!("job '{job_id}' not found"))), + Outcome::NoWorkflowMetadata => Ok(None), + Outcome::Settled { + run_id, + node_name, + final_state, + } => Ok(Some((run_id, node_name, final_state))), + } } // ── Fan-out / Fan-in helpers ──────────────────────────────── @@ -480,6 +554,7 @@ impl PyQueue { #[allow(clippy::too_many_arguments)] pub fn expand_fan_out( &self, + py: Python<'_>, run_id: &str, parent_node_name: &str, child_names: Vec, @@ -497,68 +572,68 @@ impl PyQueue { } let wf_storage = workflow_storage(self)?; - let now = now_millis(); - let count = child_names.len() as i32; - - // Empty fan-out: mark parent completed immediately. - if count == 0 { - wf_storage - .set_workflow_node_fan_out_count(run_id, parent_node_name, 0) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - wf_storage - .set_workflow_node_completed(run_id, parent_node_name, now, None) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - return Ok(vec![]); - } - - let mut child_job_ids = Vec::with_capacity(child_names.len()); - - for (child_name, payload) in child_names.iter().zip(child_payloads.into_iter()) { - let new_job = NewJob { - queue: queue.to_string(), - task_name: task_name.to_string(), - payload, - priority, - scheduled_at: now, - max_retries, - timeout_ms, - unique_key: None, - metadata: Some(build_metadata_json(run_id, child_name)), - depends_on: vec![], - expires_at: None, - result_ttl_ms: self.result_ttl_ms, - namespace: self.namespace.clone(), - }; + let run_id_owned = run_id.to_string(); + let parent_name_owned = parent_node_name.to_string(); + let task_name_owned = task_name.to_string(); + let queue_owned = queue.to_string(); + + let result: CoreResult> = py.allow_threads(|| { + let now = now_millis(); + let count = child_names.len() as i32; + + // Empty fan-out: mark parent completed immediately. + if count == 0 { + wf_storage.set_workflow_node_fan_out_count(&run_id_owned, &parent_name_owned, 0)?; + wf_storage.set_workflow_node_completed( + &run_id_owned, + &parent_name_owned, + now, + None, + )?; + return Ok(Vec::new()); + } - let job = self - .storage - .enqueue(new_job) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - child_job_ids.push(job.id.clone()); + let mut child_job_ids = Vec::with_capacity(child_names.len()); + for (child_name, payload) in child_names.iter().zip(child_payloads.into_iter()) { + let new_job = NewJob { + queue: queue_owned.clone(), + task_name: task_name_owned.clone(), + payload, + priority, + scheduled_at: now, + max_retries, + timeout_ms, + unique_key: None, + metadata: Some(build_metadata_json(&run_id_owned, child_name)), + depends_on: vec![], + expires_at: None, + result_ttl_ms: self.result_ttl_ms, + namespace: self.namespace.clone(), + }; + let job = self.storage.enqueue(new_job)?; + child_job_ids.push(job.id.clone()); - let wf_node = WorkflowNode { - id: uuid::Uuid::now_v7().to_string(), - run_id: run_id.to_string(), - node_name: child_name.clone(), - job_id: Some(job.id), - status: WorkflowNodeStatus::Pending, - result_hash: None, - fan_out_count: None, - fan_in_data: None, - started_at: None, - completed_at: None, - error: None, - }; - wf_storage - .create_workflow_node(&wf_node) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - } + let wf_node = WorkflowNode { + id: uuid::Uuid::now_v7().to_string(), + run_id: run_id_owned.clone(), + node_name: child_name.clone(), + job_id: Some(job.id), + status: WorkflowNodeStatus::Pending, + result_hash: None, + fan_out_count: None, + fan_in_data: None, + started_at: None, + completed_at: None, + error: None, + }; + wf_storage.create_workflow_node(&wf_node)?; + } - wf_storage - .set_workflow_node_fan_out_count(run_id, parent_node_name, count) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + wf_storage.set_workflow_node_fan_out_count(&run_id_owned, &parent_name_owned, count)?; + Ok(child_job_ids) + }); - Ok(child_job_ids) + result.map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Create a job for a deferred workflow node. @@ -569,6 +644,7 @@ impl PyQueue { #[allow(clippy::too_many_arguments)] pub fn create_deferred_job( &self, + py: Python<'_>, run_id: &str, node_name: &str, payload: Vec, @@ -579,88 +655,86 @@ impl PyQueue { priority: i32, ) -> PyResult { let wf_storage = workflow_storage(self)?; - let now = now_millis(); - - let new_job = NewJob { - queue: queue.to_string(), - task_name: task_name.to_string(), - payload, - priority, - scheduled_at: now, - max_retries, - timeout_ms, - unique_key: None, - metadata: Some(build_metadata_json(run_id, node_name)), - depends_on: vec![], - expires_at: None, - result_ttl_ms: self.result_ttl_ms, - namespace: self.namespace.clone(), - }; + let run_id_owned = run_id.to_string(); + let node_name_owned = node_name.to_string(); + let task_name_owned = task_name.to_string(); + let queue_owned = queue.to_string(); - let job = self - .storage - .enqueue(new_job) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - - wf_storage - .set_workflow_node_job(run_id, node_name, &job.id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let result: CoreResult = py.allow_threads(|| { + let now = now_millis(); + let new_job = NewJob { + queue: queue_owned, + task_name: task_name_owned, + payload, + priority, + scheduled_at: now, + max_retries, + timeout_ms, + unique_key: None, + metadata: Some(build_metadata_json(&run_id_owned, &node_name_owned)), + depends_on: vec![], + expires_at: None, + result_ttl_ms: self.result_ttl_ms, + namespace: self.namespace.clone(), + }; + let job = self.storage.enqueue(new_job)?; + wf_storage.set_workflow_node_job(&run_id_owned, &node_name_owned, &job.id)?; + Ok(job.id) + }); - Ok(job.id) + result.map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Check whether all fan-out children of a parent node are terminal. /// - /// If all children are terminal, atomically marks the parent node as - /// `Completed` (all succeeded) or `Failed` (any failed) and returns - /// `Some((all_succeeded, child_job_ids))`. Returns `None` if not all - /// children are done yet or if the parent was already finalized by a - /// concurrent call. + /// When all children are terminal, performs an atomic compare-and-swap on + /// the parent node's status to finalize it exactly once, even across + /// concurrent callers. Returns `Some((all_succeeded, child_job_ids))` if + /// this caller performed the transition, `None` otherwise (either not all + /// children are done yet, or another concurrent caller already finalized). pub fn check_fan_out_completion( &self, + py: Python<'_>, run_id: &str, parent_node_name: &str, ) -> PyResult)>> { let wf_storage = workflow_storage(self)?; + let run_id_owned = run_id.to_string(); + let parent_name_owned = parent_node_name.to_string(); - // Guard: if the parent is already terminal, another call beat us. - let parent = wf_storage - .get_workflow_node(run_id, parent_node_name) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))? - .ok_or_else(|| { - PyValueError::new_err(format!( - "workflow node '{parent_node_name}' not found in run '{run_id}'" - )) - })?; - if parent.status.is_terminal() { - return Ok(None); - } - - let children = wf_storage - .get_workflow_nodes_by_prefix(run_id, &format!("{parent_node_name}[")) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let result: CoreResult)>> = py.allow_threads(|| { + let prefix = format!("{parent_name_owned}["); + let children = wf_storage.get_workflow_nodes_by_prefix(&run_id_owned, &prefix)?; - if !children.iter().all(|n| n.status.is_terminal()) { - return Ok(None); - } + if children.is_empty() || !children.iter().all(|n| n.status.is_terminal()) { + return Ok(None); + } - let any_failed = children - .iter() - .any(|n| n.status == WorkflowNodeStatus::Failed); - let child_job_ids: Vec = children.iter().filter_map(|n| n.job_id.clone()).collect(); + let any_failed = children + .iter() + .any(|n| n.status == WorkflowNodeStatus::Failed); + let child_job_ids: Vec = + children.iter().filter_map(|n| n.job_id.clone()).collect(); + + let transitioned = wf_storage.finalize_fan_out_parent( + &run_id_owned, + &parent_name_owned, + !any_failed, + if any_failed { + Some("fan-out child failed") + } else { + None + }, + now_millis(), + )?; + if !transitioned { + return Ok(None); + } - let now = now_millis(); - if any_failed { - wf_storage - .set_workflow_node_error(run_id, parent_node_name, "fan-out child failed") - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - } else { - wf_storage - .set_workflow_node_completed(run_id, parent_node_name, now, None) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - } + Ok(Some((!any_failed, child_job_ids))) + }); - Ok(Some((!any_failed, child_job_ids))) + result.map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Check whether all workflow nodes are terminal and finalize the run. @@ -669,57 +743,87 @@ impl PyQueue { /// (e.g., after a failed fan-out). If all nodes are terminal, transitions /// the run to `Completed` or `Failed` and returns the final state string. /// Returns `None` if not all nodes are terminal yet. - pub fn finalize_run_if_terminal(&self, run_id: &str) -> PyResult> { + pub fn finalize_run_if_terminal( + &self, + py: Python<'_>, + run_id: &str, + ) -> PyResult> { let wf_storage = workflow_storage(self)?; - let nodes = wf_storage - .get_workflow_nodes(run_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let run_id_owned = run_id.to_string(); - if !nodes.iter().all(|n| n.status.is_terminal()) { - return Ok(None); - } + let result: CoreResult> = py.allow_threads(|| { + let nodes = wf_storage.get_workflow_nodes(&run_id_owned)?; + if !nodes.iter().all(|n| n.status.is_terminal()) { + return Ok(None); + } - let any_failed = nodes.iter().any(|n| n.status == WorkflowNodeStatus::Failed); - let final_state = if any_failed { - WorkflowState::Failed - } else { - WorkflowState::Completed - }; + let any_failed = nodes.iter().any(|n| n.status == WorkflowNodeStatus::Failed); + let final_state = if any_failed { + WorkflowState::Failed + } else { + WorkflowState::Completed + }; - let now = now_millis(); - wf_storage - .update_workflow_run_state( - run_id, + let now = now_millis(); + wf_storage.update_workflow_run_state( + &run_id_owned, final_state, if final_state == WorkflowState::Failed { Some("fan-out child failed") } else { None }, - ) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - wf_storage - .set_workflow_run_completed(run_id, now) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + )?; + wf_storage.set_workflow_run_completed(&run_id_owned, now)?; + Ok(Some(final_state.as_str().to_string())) + }); - Ok(Some(final_state.as_str().to_string())) + result.map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Transition a workflow node to `WaitingApproval` status. /// /// Used by the Python tracker when a gate node becomes evaluable. - /// Sets `started_at` without overriding the status (unlike - /// `set_workflow_node_started` which forces `running`). pub fn set_workflow_node_waiting_approval( &self, + py: Python<'_>, run_id: &str, node_name: &str, ) -> PyResult<()> { let wf_storage = workflow_storage(self)?; - wf_storage - .update_workflow_node_status(run_id, node_name, WorkflowNodeStatus::WaitingApproval) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(()) + let run_id_owned = run_id.to_string(); + let node_name_owned = node_name.to_string(); + + py.allow_threads(|| { + wf_storage.update_workflow_node_status( + &run_id_owned, + &node_name_owned, + WorkflowNodeStatus::WaitingApproval, + ) + }) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + + /// Transition a workflow node to `Running` with a `started_at` timestamp. + /// + /// Used by the Python tracker to promote sub-workflow parent nodes after + /// the child workflow has been successfully compiled and submitted. This + /// is the clean counterpart to the old "waiting-approval → skip → running" + /// dance that could leave nodes permanently skipped on compile failure. + pub fn set_workflow_node_running( + &self, + py: Python<'_>, + run_id: &str, + node_name: &str, + ) -> PyResult<()> { + let wf_storage = workflow_storage(self)?; + let run_id_owned = run_id.to_string(); + let node_name_owned = node_name.to_string(); + + py.allow_threads(|| { + wf_storage.set_workflow_node_running(&run_id_owned, &node_name_owned, now_millis()) + }) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Fetch node data from a prior run for incremental caching. @@ -727,50 +831,72 @@ impl PyQueue { /// Returns a list of ``(node_name, status, result_hash)`` tuples. pub fn get_base_run_node_data( &self, + py: Python<'_>, base_run_id: &str, ) -> PyResult)>> { let wf_storage = workflow_storage(self)?; - let nodes = wf_storage - .get_workflow_nodes(base_run_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(nodes - .into_iter() - .map(|n| (n.node_name, n.status.as_str().to_string(), n.result_hash)) - .collect()) + let base_run_id_owned = base_run_id.to_string(); + + let result: CoreResult)>> = py.allow_threads(|| { + let nodes = wf_storage.get_workflow_nodes(&base_run_id_owned)?; + Ok(nodes + .into_iter() + .map(|n| (n.node_name, n.status.as_str().to_string(), n.result_hash)) + .collect()) + }); + + result.map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Return the DAG JSON bytes for a workflow run's definition. /// /// Used by the Python visualization layer to render diagrams. - pub fn get_workflow_definition_dag(&self, run_id: &str) -> PyResult> { + pub fn get_workflow_definition_dag(&self, py: Python<'_>, run_id: &str) -> PyResult> { let wf_storage = workflow_storage(self)?; - let run = wf_storage - .get_workflow_run(run_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))? - .ok_or_else(|| PyValueError::new_err(format!("run '{run_id}' not found")))?; - let def = wf_storage - .get_workflow_definition_by_id(&run.definition_id) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))? - .ok_or_else(|| { - PyRuntimeError::new_err(format!("definition '{}' not found", run.definition_id)) - })?; - Ok(def.dag_data) + let run_id_owned = run_id.to_string(); + + enum Outcome { + RunMissing, + DefinitionMissing(String), + Found(Vec), + } + + let outcome: CoreResult = py.allow_threads(|| { + let run = match wf_storage.get_workflow_run(&run_id_owned)? { + Some(r) => r, + None => return Ok(Outcome::RunMissing), + }; + match wf_storage.get_workflow_definition_by_id(&run.definition_id)? { + Some(def) => Ok(Outcome::Found(def.dag_data)), + None => Ok(Outcome::DefinitionMissing(run.definition_id)), + } + }); + + match outcome.map_err(|e| PyRuntimeError::new_err(e.to_string()))? { + Outcome::Found(data) => Ok(data), + Outcome::RunMissing => Err(PyValueError::new_err(format!("run '{run_id}' not found"))), + Outcome::DefinitionMissing(def_id) => Err(PyRuntimeError::new_err(format!( + "definition '{def_id}' not found" + ))), + } } /// Set a node's fan_out_count and transition to Running. - /// - /// Also used by the tracker to mark sub-workflow parent nodes as Running. pub fn set_workflow_node_fan_out_count( &self, + py: Python<'_>, run_id: &str, node_name: &str, count: i32, ) -> PyResult<()> { let wf_storage = workflow_storage(self)?; - wf_storage - .set_workflow_node_fan_out_count(run_id, node_name, count) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(()) + let run_id_owned = run_id.to_string(); + let node_name_owned = node_name.to_string(); + + py.allow_threads(|| { + wf_storage.set_workflow_node_fan_out_count(&run_id_owned, &node_name_owned, count) + }) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Approve or reject an approval gate node. @@ -779,42 +905,96 @@ impl PyQueue { #[pyo3(signature = (run_id, node_name, approved, error=None))] pub fn resolve_workflow_gate( &self, + py: Python<'_>, run_id: &str, node_name: &str, approved: bool, error: Option, ) -> PyResult<()> { let wf_storage = workflow_storage(self)?; - let now = now_millis(); - if approved { - wf_storage - .set_workflow_node_completed(run_id, node_name, now, None) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - } else { - let err_msg = error.unwrap_or_else(|| "rejected".to_string()); - wf_storage - .set_workflow_node_error(run_id, node_name, &err_msg) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - } - Ok(()) + let run_id_owned = run_id.to_string(); + let node_name_owned = node_name.to_string(); + + let result: CoreResult<()> = py.allow_threads(|| { + let now = now_millis(); + if approved { + wf_storage.set_workflow_node_completed( + &run_id_owned, + &node_name_owned, + now, + None, + )?; + } else { + let err_msg = error.unwrap_or_else(|| "rejected".to_string()); + wf_storage.set_workflow_node_error(&run_id_owned, &node_name_owned, &err_msg)?; + } + Ok(()) + }); + + result.map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + + /// Mark a workflow node as `Failed` with an error message. + /// + /// Used by the Python tracker when sub-workflow compilation or submission + /// fails — the parent node needs a terminal state so the outer run can + /// finalize instead of hanging. + pub fn fail_workflow_node( + &self, + py: Python<'_>, + run_id: &str, + node_name: &str, + error: &str, + ) -> PyResult<()> { + let wf_storage = workflow_storage(self)?; + let run_id_owned = run_id.to_string(); + let node_name_owned = node_name.to_string(); + let error_owned = error.to_string(); + + py.allow_threads(|| { + wf_storage.set_workflow_node_error(&run_id_owned, &node_name_owned, &error_owned) + }) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) } /// Mark a single workflow node as `Skipped` and cancel its job. /// /// Used by the Python tracker for condition-based skip propagation. - pub fn skip_workflow_node(&self, run_id: &str, node_name: &str) -> PyResult<()> { + /// Cancel-job failures are logged but do not abort the skip — the node's + /// terminal status is more important than best-effort job cancellation. + pub fn skip_workflow_node( + &self, + py: Python<'_>, + run_id: &str, + node_name: &str, + ) -> PyResult<()> { let wf_storage = workflow_storage(self)?; - let node = wf_storage - .get_workflow_node(run_id, node_name) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - if let Some(node) = node { - if let Some(job_id) = &node.job_id { - let _ = self.storage.cancel_job(job_id); + let run_id_owned = run_id.to_string(); + let node_name_owned = node_name.to_string(); + + let result: CoreResult<()> = py.allow_threads(|| { + let node = wf_storage.get_workflow_node(&run_id_owned, &node_name_owned)?; + if let Some(node) = node { + if let Some(job_id) = &node.job_id { + if let Err(e) = self.storage.cancel_job(job_id) { + log::warn!( + "[taskito] cancel_job({}) failed while skipping node '{}' in run {}: {}", + job_id, + node_name_owned, + run_id_owned, + e + ); + } + } + wf_storage.update_workflow_node_status( + &run_id_owned, + &node_name_owned, + WorkflowNodeStatus::Skipped, + )?; } - wf_storage - .update_workflow_node_status(run_id, node_name, WorkflowNodeStatus::Skipped) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - } - Ok(()) + Ok(()) + }); + + result.map_err(|e| PyRuntimeError::new_err(e.to_string())) } } diff --git a/crates/taskito-workflows/src/sqlite_store.rs b/crates/taskito-workflows/src/sqlite_store.rs index 3a38b03..b47ff1d 100644 --- a/crates/taskito-workflows/src/sqlite_store.rs +++ b/crates/taskito-workflows/src/sqlite_store.rs @@ -600,6 +600,61 @@ impl WorkflowStorage for WorkflowSqliteStorage { Ok(()) } + fn set_workflow_node_running( + &self, + run_id: &str, + node_name: &str, + started_at: i64, + ) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "UPDATE workflow_nodes SET status = 'running', started_at = ? + WHERE run_id = ? AND node_name = ?", + ) + .bind::(started_at) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)?; + Ok(()) + } + + fn finalize_fan_out_parent( + &self, + run_id: &str, + node_name: &str, + succeeded: bool, + error: Option<&str>, + completed_at: i64, + ) -> Result { + let mut conn = self.inner.conn()?; + // Compare-and-swap: only update if not already terminal. Exactly one + // concurrent caller will affect >0 rows; losers see 0 rows affected. + let affected = if succeeded { + diesel::sql_query( + "UPDATE workflow_nodes + SET status = 'completed', completed_at = ? + WHERE run_id = ? AND node_name = ? + AND status NOT IN ('completed', 'failed', 'skipped', 'cache_hit')", + ) + .bind::(completed_at) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)? + } else { + diesel::sql_query( + "UPDATE workflow_nodes + SET status = 'failed', error = ? + WHERE run_id = ? AND node_name = ? + AND status NOT IN ('completed', 'failed', 'skipped', 'cache_hit')", + ) + .bind::(error.unwrap_or("fan-out child failed")) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)? + }; + Ok(affected > 0) + } + fn get_workflow_nodes_by_prefix( &self, run_id: &str, diff --git a/crates/taskito-workflows/src/storage.rs b/crates/taskito-workflows/src/storage.rs index c88a649..fa7e188 100644 --- a/crates/taskito-workflows/src/storage.rs +++ b/crates/taskito-workflows/src/storage.rs @@ -82,6 +82,36 @@ pub trait WorkflowStorage: Send + Sync { count: i32, ) -> Result<()>; + /// Transition a node to `Running` and set its `started_at` timestamp. + /// + /// Used by the tracker to promote gate/sub-workflow nodes to running + /// without going through fan-out bookkeeping. + fn set_workflow_node_running( + &self, + run_id: &str, + node_name: &str, + started_at: i64, + ) -> Result<()>; + + /// Atomically finalize a fan-out parent node if it is not already terminal. + /// + /// Issues a single conditional `UPDATE` that transitions the node to + /// `Completed` (with `completed_at`) when `succeeded` is true, or `Failed` + /// (with `error`) when false — but only when the current status is + /// non-terminal. Returns `true` if this call performed the transition, + /// `false` if another concurrent caller already finalized it. This is the + /// compare-and-swap that makes fan-in expansion exactly-once even when + /// multiple children complete on different worker threads at the same + /// instant. + fn finalize_fan_out_parent( + &self, + run_id: &str, + node_name: &str, + succeeded: bool, + error: Option<&str>, + completed_at: i64, + ) -> Result; + /// Return all nodes whose `node_name` starts with `prefix`. /// /// Used to find fan-out children (e.g., prefix `"process["` returns diff --git a/crates/taskito-workflows/src/tests.rs b/crates/taskito-workflows/src/tests.rs index 677954d..7eb3690 100644 --- a/crates/taskito-workflows/src/tests.rs +++ b/crates/taskito-workflows/src/tests.rs @@ -701,3 +701,140 @@ fn test_get_nodes_by_prefix() { .unwrap(); assert!(empty.is_empty()); } + +// ── Regression tests for correctness fixes (2026-04-24) ────────── + +/// Explicit `set_workflow_node_running` transitions without overloading +/// fan_out_count — used by the sub-workflow tracker after the child +/// successfully compiles and submits. +#[test] +fn test_set_workflow_node_running() { + let storage = test_storage(); + let def = make_definition("sub_wf_parent"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + storage + .create_workflow_node(&make_node(&run_id, "parent")) + .unwrap(); + + let before = now_millis(); + storage + .set_workflow_node_running(&run_id, "parent", before) + .unwrap(); + + let fetched = storage + .get_workflow_node(&run_id, "parent") + .unwrap() + .unwrap(); + assert_eq!(fetched.status, WorkflowNodeStatus::Running); + assert_eq!(fetched.started_at, Some(before)); + assert_eq!( + fetched.fan_out_count, None, + "set_workflow_node_running must not set fan_out_count" + ); +} + +/// The CAS-based finalize must transition exactly once even when called +/// twice in a row (simulating concurrent callers that both see all children +/// terminal). This is the storage-layer guarantee behind the P0 fan-in race fix. +#[test] +fn test_finalize_fan_out_parent_is_idempotent() { + let storage = test_storage(); + let def = make_definition("fan_in_race"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + storage + .create_workflow_node(&make_node(&run_id, "process")) + .unwrap(); + + // First call: not terminal → transition runs. + let now = now_millis(); + let first = storage + .finalize_fan_out_parent(&run_id, "process", true, None, now) + .unwrap(); + assert!(first, "first caller must perform the transition"); + + let after_first = storage + .get_workflow_node(&run_id, "process") + .unwrap() + .unwrap(); + assert_eq!(after_first.status, WorkflowNodeStatus::Completed); + + // Second call: already terminal → CAS returns false, state unchanged. + let second = storage + .finalize_fan_out_parent(&run_id, "process", true, None, now + 1000) + .unwrap(); + assert!(!second, "second caller must be rejected by the CAS"); + + let after_second = storage + .get_workflow_node(&run_id, "process") + .unwrap() + .unwrap(); + assert_eq!(after_second.status, WorkflowNodeStatus::Completed); + assert_eq!( + after_second.completed_at, after_first.completed_at, + "second caller must not overwrite completed_at" + ); +} + +/// Rejected-path of the CAS: any child failure routes through the +/// failure branch and sets the node error. +#[test] +fn test_finalize_fan_out_parent_failure_branch() { + let storage = test_storage(); + let def = make_definition("fan_in_fail"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + storage + .create_workflow_node(&make_node(&run_id, "process")) + .unwrap(); + + let now = now_millis(); + let transitioned = storage + .finalize_fan_out_parent(&run_id, "process", false, Some("boom"), now) + .unwrap(); + assert!(transitioned); + + let node = storage + .get_workflow_node(&run_id, "process") + .unwrap() + .unwrap(); + assert_eq!(node.status, WorkflowNodeStatus::Failed); + assert_eq!(node.error.as_deref(), Some("boom")); +} + +/// Nodes that are already in a terminal state must not be re-transitioned +/// by the CAS — this guards against a late-arriving event from a +/// cascade-skipped child. +#[test] +fn test_finalize_fan_out_parent_no_op_on_terminal_node() { + let storage = test_storage(); + let def = make_definition("skipped_parent"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + storage + .create_workflow_node(&make_node(&run_id, "process")) + .unwrap(); + storage + .update_workflow_node_status(&run_id, "process", WorkflowNodeStatus::Skipped) + .unwrap(); + + let transitioned = storage + .finalize_fan_out_parent(&run_id, "process", true, None, now_millis()) + .unwrap(); + assert!(!transitioned, "already-skipped nodes must be left alone"); + + let node = storage + .get_workflow_node(&run_id, "process") + .unwrap() + .unwrap(); + assert_eq!(node.status, WorkflowNodeStatus::Skipped); +} diff --git a/docs/changelog.md b/docs/changelog.md index 374ae74..7197a9c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,40 @@ All notable changes to taskito are documented here. +## Unreleased + +### Fixed + +- **Workflow fan-in race** -- concurrent child completions on different worker threads could both expand fan-in. The `check_fan_out_completion` Rust call now delegates to a new `WorkflowStorage::finalize_fan_out_parent` compare-and-swap, so the parent transitions at most once regardless of how many children complete simultaneously. +- **Sub-workflow compile failure** -- if a child workflow's factory or compile step raised, the parent node was left permanently `SKIPPED`, hanging the outer run. The parent is now promoted to `RUNNING` only after the child's compile + submit succeed, and is marked `FAILED` on error so the run finalizes. +- **Redis `purge_execution_claims`** -- previously a silent no-op. Execution claims are now mirrored into a time-indexed sorted set (`taskito:exec_claims:by_time`) so the scheduler's maintenance loop can reap stale claims in O(log n). Legacy keys still expire via the 24 h `PX` TTL. +- **SQLite `move_to_dlq` cascade** -- cascade-cancel errors on the dependent sweep are now propagated (parity with Postgres and Redis) instead of being swallowed as a warning. Callers see the failure and can decide whether to retry or alert. + +### Performance + +- **Workflow ops release the GIL during SQLite I/O** -- every method in `workflow_ops.rs` now wraps DB round-trips in `py.allow_threads(...)`. Event-bus callbacks that fire from worker threads no longer serialize the rest of the Python runtime on each fan-in / mark-result / cancel call. +- **`WorkflowSqliteStorage` cached per queue** -- migrations run once on first workflow API call via `OnceLock`, instead of re-running `CREATE TABLE IF NOT EXISTS` on every single call. + +### Safety + +- **`cancel_workflow_run` iterative** -- replaced recursive sub-workflow cascade with an iterative BFS plus `visited` set. No recursion deadlock, no connection-pool exhaustion on deep sub-workflow trees, and any accidental cycle in `parent_run_id` terminates safely. +- **Tracker state lock** -- `WorkflowTracker._state_lock` (RLock) now guards every access to `_run_configs`, `_job_to_run`, `_child_to_parent`, and `_gate_timers`, which are touched from worker threads, gate-timeout timers, and user threads. +- **Gate timer cleanup** -- `_cleanup_run` cancels any pending gate timers for the finishing run and drops stale child→parent mappings. Timers no longer fire on already-terminal runs. +- **Workflow metadata JSON escaping** -- `build_metadata_json` uses `serde_json::json!`; node names containing backslashes, control characters, or Unicode are now escaped correctly. Previously they produced malformed JSON that silently dropped the workflow event. +- **Narrower exception handling in the tracker** -- broad `except Exception:` clauses narrowed to `(RuntimeError, ValueError)` on Rust FFI call sites; the remaining broad catches are restricted to user callables and event emission with an explanatory `# noqa`. Silent `let _ = storage.cancel_job(...)` replaced with `log::warn!` via a shared helper. + +### Added + +- **`PrometheusMiddleware(task_filter=...)`** -- parity with `OTelMiddleware` and `SentryMiddleware`. A predicate `(task_name: str) -> bool` toggles metric export per task. + +### Changed + +- **`dagron-core` git dependency pinned** -- `Cargo.toml` now pins `dagron-core` to a specific commit SHA. Upstream pushes no longer cause silent build breakage. +- **`Storage` trait doc comment** -- now lists all three backends (SQLite, Postgres, Redis) instead of just the two Diesel ones. +- **`AsyncQueueMixin.metrics_timeseries` stub** -- parameter name corrected from `interval` to `bucket` to match the real sync signature. Call sites typed via the stub were silently wrong at runtime. + +--- + ## 0.11.0 ### Features diff --git a/docs/integrations/prometheus.md b/docs/integrations/prometheus.md index 625ae70..94a9e63 100644 --- a/docs/integrations/prometheus.md +++ b/docs/integrations/prometheus.md @@ -28,6 +28,7 @@ PrometheusMiddleware( namespace="myapp", extra_labels_fn=lambda ctx: {"env": "prod", "region": "us-east-1"}, disabled_metrics={"resource", "proxy"}, + task_filter=lambda name: not name.startswith("internal."), ) ``` @@ -36,6 +37,7 @@ PrometheusMiddleware( | `namespace` | `str` | `"taskito"` | Prefix for all metric names. | | `extra_labels_fn` | `Callable[[JobContext], dict[str, str]] | None` | `None` | Returns extra labels to add to job metrics. Receives `JobContext`. | | `disabled_metrics` | `set[str] | None` | `None` | Metric groups or individual names to skip. Groups: `"jobs"`, `"queue"`, `"resource"`, `"proxy"`, `"intercept"`. | +| `task_filter` | `Callable[[str], bool] | None` | `None` | Predicate that receives a task name. Return `True` to export metrics for the task, `False` to skip it. `None` exports all tasks. | ### Metrics Tracked diff --git a/docs/workflows/composition.md b/docs/workflows/composition.md index d8b4d6d..4aafe4b 100644 --- a/docs/workflows/composition.md +++ b/docs/workflows/composition.md @@ -61,7 +61,9 @@ run.cancel() # Cancels parent + all child sub-workflows ### Failure -If a child workflow fails, the parent node is marked `FAILED`. Downstream steps follow the parent's `on_failure` strategy. +If a child workflow fails at runtime, the parent node is marked `FAILED`. Downstream steps follow the parent's `on_failure` strategy. + +The same holds for failures at *submission* time -- if the child's factory raises or the DAG fails to compile when the parent node becomes evaluable, the parent node is marked `FAILED` immediately (rather than leaving the outer run hanging), and the parent run finalizes normally. ## Cron-scheduled workflows diff --git a/py_src/taskito/_taskito.pyi b/py_src/taskito/_taskito.pyi index f3590c8..564240a 100644 --- a/py_src/taskito/_taskito.pyi +++ b/py_src/taskito/_taskito.pyi @@ -290,6 +290,8 @@ class PyQueue: ) -> tuple[bool, list[str]] | None: ... def finalize_run_if_terminal(self, run_id: str) -> str | None: ... def set_workflow_node_waiting_approval(self, run_id: str, node_name: str) -> None: ... + def set_workflow_node_running(self, run_id: str, node_name: str) -> None: ... + def fail_workflow_node(self, run_id: str, node_name: str, error: str) -> None: ... def resolve_workflow_gate( self, run_id: str, diff --git a/py_src/taskito/async_support/mixins.py b/py_src/taskito/async_support/mixins.py index 7f6579d..9c14b8e 100644 --- a/py_src/taskito/async_support/mixins.py +++ b/py_src/taskito/async_support/mixins.py @@ -64,7 +64,7 @@ def metrics_timeseries( self, task_name: str | None = ..., since: int = ..., - interval: int = ..., + bucket: int = ..., ) -> list[dict]: ... def job_dag(self, job_id: str) -> dict[str, Any]: ... def job_errors(self, job_id: str) -> list[dict]: ... diff --git a/py_src/taskito/contrib/prometheus.py b/py_src/taskito/contrib/prometheus.py index a095b3b..4444a64 100644 --- a/py_src/taskito/contrib/prometheus.py +++ b/py_src/taskito/contrib/prometheus.py @@ -174,6 +174,10 @@ class PrometheusMiddleware(TaskMiddleware): disabled_metrics: Metric groups or individual metric names to skip. Groups: ``"jobs"``, ``"queue"``, ``"resource"``, ``"proxy"``, ``"intercept"``. + task_filter: Predicate that receives a task name and returns ``True`` + to export metrics for the task, ``False`` to ignore it. Defaults + to exporting all tasks. Parity with ``OTelMiddleware`` and + ``SentryMiddleware``. """ def __init__( @@ -182,6 +186,7 @@ def __init__( namespace: str = "taskito", extra_labels_fn: Callable[[JobContext], dict[str, str]] | None = None, disabled_metrics: set[str] | None = None, + task_filter: Callable[[str], bool] | None = None, ) -> None: if Counter is None: raise ImportError( @@ -190,10 +195,16 @@ def __init__( ) self._metrics = _get_or_create_metrics(namespace, disabled_metrics) self._extra_labels_fn = extra_labels_fn + self._task_filter = task_filter self._start_times: dict[str, float] = {} self._lock = threading.Lock() + def _should_track(self, task_name: str) -> bool: + return self._task_filter is None or self._task_filter(task_name) + def before(self, ctx: JobContext) -> None: + if not self._should_track(ctx.task_name): + return with self._lock: self._start_times[ctx.id] = time.monotonic() m = self._metrics["active_workers"] @@ -201,6 +212,8 @@ def before(self, ctx: JobContext) -> None: m.inc() def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: + if not self._should_track(ctx.task_name): + return m = self._metrics["active_workers"] if m is not None: m.dec() @@ -219,6 +232,8 @@ def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: m.labels(task=ctx.task_name).observe(duration) def on_retry(self, ctx: JobContext, error: Exception, retry_count: int) -> None: + if not self._should_track(ctx.task_name): + return m = self._metrics["retries_total"] if m is not None: m.labels(task=ctx.task_name).inc() diff --git a/py_src/taskito/workflows/tracker.py b/py_src/taskito/workflows/tracker.py index 4bcc718..7950a17 100644 --- a/py_src/taskito/workflows/tracker.py +++ b/py_src/taskito/workflows/tracker.py @@ -57,6 +57,12 @@ def __init__(self, queue: Queue): self._waiters_lock = threading.Lock() self._waiters: dict[str, list[threading.Event]] = {} self._event_bus = queue._event_bus + # `_state_lock` guards every read and write of `_run_configs`, + # `_job_to_run`, `_child_to_parent`, and `_gate_timers`. These dicts + # are accessed from worker threads (event bus), timer threads + # (gate timeouts), and user threads (approve_gate/reject_gate). + # A single lock is simple and adequate; tracker operations are short. + self._state_lock = threading.RLock() self._run_configs: dict[str, _RunConfig] = {} self._job_to_run: dict[str, str] = {} self._gate_timers: dict[tuple[str, str], threading.Timer] = {} @@ -128,17 +134,20 @@ def register_run( gate_configs=gate_configs or {}, sub_workflow_refs=sub_workflow_refs or {}, ) - self._run_configs[run_id] = config + with self._state_lock: + self._run_configs[run_id] = config # Populate job→run mapping for initial nodes. try: raw = self._queue._inner.get_workflow_run_status(run_id) - for _name, info in raw.node_statuses().items(): - jid = info.get("job_id") - if jid: - self._job_to_run[jid] = run_id - except Exception: # pragma: no cover - logger.exception("failed to populate job→run mapping for %s", run_id) + except (RuntimeError, ValueError): + logger.exception("failed to read workflow run status for %s", run_id) + else: + with self._state_lock: + for _name, info in raw.node_statuses().items(): + jid = info.get("job_id") + if jid: + self._job_to_run[jid] = run_id # Evaluate root deferred nodes (those with no predecessors). self._evaluate_root_deferred(run_id, config) @@ -182,8 +191,9 @@ def _handle( return # Determine if this job belongs to a managed run. - run_id = self._job_to_run.get(job_id) - config = self._run_configs.get(run_id) if run_id else None + with self._state_lock: + run_id = self._job_to_run.get(job_id) + config = self._run_configs.get(run_id) if run_id else None skip_cascade = config is not None # Compute result hash for successful completions. @@ -195,8 +205,12 @@ def _handle( result = self._queue._inner.mark_workflow_node_result( job_id, succeeded, error, skip_cascade, rh ) - except Exception: # pragma: no cover - defensive + except (RuntimeError, ValueError) as exc: logger.exception("mark_workflow_node_result failed for job %s", job_id) + # Notify any waiters so they don't block forever on a silent failure. + if run_id is not None: + self._emit_terminal(run_id, "failed", str(exc)) + self._cleanup_run(run_id) return if result is None: @@ -211,7 +225,8 @@ def _handle( # Re-fetch config now that we have the definitive run_id. if config is None: - config = self._run_configs.get(run_id) + with self._state_lock: + config = self._run_configs.get(run_id) if config is None: return # Static workflow — Rust cascade handled everything. @@ -240,13 +255,27 @@ def _emit_terminal(self, run_id: str, terminal_state: str, error: str | None) -> workflow_event, {"run_id": run_id, "state": terminal_state, "error": error}, ) - except Exception: # pragma: no cover - defensive + except Exception: logger.exception("failed to emit %s", workflow_event) self._release_waiters(run_id) def _cleanup_run(self, run_id: str) -> None: - self._run_configs.pop(run_id, None) - self._job_to_run = {jid: rid for jid, rid in self._job_to_run.items() if rid != run_id} + """Drop all tracker state tied to `run_id` and cancel any live timers.""" + with self._state_lock: + self._run_configs.pop(run_id, None) + self._job_to_run = {jid: rid for jid, rid in self._job_to_run.items() if rid != run_id} + # Cancel and remove any gate timers still scheduled for this run. + stale_timer_keys = [k for k in self._gate_timers if k[0] == run_id] + for key in stale_timer_keys: + timer = self._gate_timers.pop(key, None) + if timer is not None: + timer.cancel() + # Drop any child→parent mappings whose parent run is finishing. + stale_child_ids = [ + cid for cid, (prid, _) in self._child_to_parent.items() if prid == run_id + ] + for cid in stale_child_ids: + self._child_to_parent.pop(cid, None) # ── Condition evaluation ─────────────────────────────────────── @@ -312,7 +341,7 @@ def _skip_and_propagate(self, run_id: str, node_name: str, config: _RunConfig) - """Mark a node as SKIPPED and recursively evaluate its successors.""" try: self._queue._inner.skip_workflow_node(run_id, node_name) - except Exception: + except (RuntimeError, ValueError): logger.exception("skip_workflow_node failed for %s", node_name) return config.deferred_nodes.discard(node_name) @@ -325,7 +354,7 @@ def _enter_gate(self, run_id: str, node_name: str, config: _RunConfig) -> None: """Transition a gate node to WAITING_APPROVAL and start timeout.""" try: self._queue._inner.set_workflow_node_waiting_approval(run_id, node_name) - except Exception: + except (RuntimeError, ValueError): logger.exception("set_workflow_node_waiting_approval failed for %s", node_name) return config.deferred_nodes.discard(node_name) @@ -340,7 +369,7 @@ def _enter_gate(self, run_id: str, node_name: str, config: _RunConfig) -> None: "message": gate.message if isinstance(gate.message, str) else None, }, ) - except Exception: # pragma: no cover + except Exception: logger.exception("failed to emit WORKFLOW_GATE_REACHED") if gate.timeout is not None and gate.timeout > 0: @@ -350,8 +379,9 @@ def _enter_gate(self, run_id: str, node_name: str, config: _RunConfig) -> None: args=(run_id, node_name, gate.on_timeout), ) timer.daemon = True + with self._state_lock: + self._gate_timers[(run_id, node_name)] = timer timer.start() - self._gate_timers[(run_id, node_name)] = timer def resolve_gate( self, @@ -363,15 +393,15 @@ def resolve_gate( ) -> None: """Approve or reject a gate, resuming the workflow.""" # Cancel any pending timeout timer. - timer = self._gate_timers.pop((run_id, node_name), None) + with self._state_lock: + timer = self._gate_timers.pop((run_id, node_name), None) + config = self._run_configs.get(run_id) if timer is not None: timer.cancel() - config = self._run_configs.get(run_id) - try: self._queue._inner.resolve_workflow_gate(run_id, node_name, approved, error) - except Exception: + except (RuntimeError, ValueError): logger.exception("resolve_workflow_gate failed for %s", node_name) return @@ -382,25 +412,19 @@ def resolve_gate( # ── Sub-workflows ────────────────────────────────────────────── def _submit_sub_workflow(self, run_id: str, node_name: str, config: _RunConfig) -> None: - """Submit a child workflow for a sub-workflow node.""" + """Submit a child workflow and transition the parent node to Running. + + The parent node is only promoted to `Running` after the child has + successfully compiled *and* been submitted. On any failure during + compile/submit, the parent is marked Failed so the run can finalize + instead of hanging in an indeterminate state. + """ ref = config.sub_workflow_refs.get(node_name) if ref is None: # pragma: no cover return - try: - child_wf = ref.proxy.build(**ref.params) - # Mark parent node as RUNNING. - self._queue._inner.set_workflow_node_waiting_approval(run_id, node_name) - # Override status to RUNNING (waiting_approval was just to set started_at). - self._queue._inner.skip_workflow_node(run_id, node_name) - # Actually, let me use a cleaner approach: just mark running via the - # node status update. Use the Rust set_workflow_node_fan_out_count - # trick (sets RUNNING). Or add a direct call. - except Exception: - logger.exception("failed to build sub-workflow for %s", node_name) - return try: - # Submit child workflow with parent linkage. + child_wf = ref.proxy.build(**ref.params) ( dag_bytes, meta_json, @@ -424,45 +448,65 @@ def _submit_sub_workflow(self, run_id: str, node_name: str, config: _RunConfig) run_id, # parent_run_id node_name, # parent_node_name ) + except Exception as exc: + logger.exception("submit sub-workflow failed for %s", node_name) + # Mark the parent node Failed so the outer run can finalize rather + # than hanging. This is the central fix for the old bug where a + # compile failure left the node permanently Skipped. + try: + self._queue._inner.fail_workflow_node( + run_id, node_name, f"sub-workflow submit failed: {exc}" + ) + except (RuntimeError, ValueError): + logger.exception("failed to mark sub-workflow parent %s as failed", node_name) + config.deferred_nodes.discard(node_name) + # Evaluate successors now that this node is terminal. + self._evaluate_successors(run_id, node_name, config) + return - child_run_id = handle.run_id + # Child compiled and submitted successfully — now promote the parent. + child_run_id = handle.run_id + with self._state_lock: self._child_to_parent[child_run_id] = (run_id, node_name) + try: + self._queue._inner.set_workflow_node_running(run_id, node_name) + except (RuntimeError, ValueError): + logger.exception( + "set_workflow_node_running failed for sub-workflow parent %s", + node_name, + ) - # Mark parent node as RUNNING (use fan_out_count trick). - self._queue._inner.set_workflow_node_fan_out_count(run_id, node_name, 1) - - # Register child with tracker if it has deferred nodes. - needs_child_tracker = ( - bool(deferred) - or bool(callables) - or bool(gates) - or bool(sub_refs) - or on_failure != "fail_fast" + # Register child with tracker if it has deferred nodes. + needs_child_tracker = ( + bool(deferred) + or bool(callables) + or bool(gates) + or bool(sub_refs) + or on_failure != "fail_fast" + ) + if needs_child_tracker: + child_payloads = {n: payloads[n] for n in deferred if n in payloads} + self.register_run( + child_run_id, + meta_json, + dag_bytes, + deferred, + child_payloads, + on_failure=on_failure, + callable_conditions=callables, + gate_configs=gates, + sub_workflow_refs=sub_refs, ) - if needs_child_tracker: - child_payloads = {n: payloads[n] for n in deferred if n in payloads} - self.register_run( - child_run_id, - meta_json, - dag_bytes, - deferred, - child_payloads, - on_failure=on_failure, - callable_conditions=callables, - gate_configs=gates, - sub_workflow_refs=sub_refs, - ) - config.deferred_nodes.discard(node_name) - except Exception: - logger.exception("submit sub-workflow failed for %s", node_name) + config.deferred_nodes.discard(node_name) def _on_child_workflow_terminal(self, _event_type: EventType, payload: dict[str, Any]) -> None: """Handle child workflow completion → update parent node.""" child_run_id = payload.get("run_id") if not child_run_id: return - parent_info = self._child_to_parent.pop(child_run_id, None) + with self._state_lock: + parent_info = self._child_to_parent.pop(child_run_id, None) if parent_info is None: return # Not a sub-workflow child. @@ -477,7 +521,7 @@ def _on_child_workflow_terminal(self, _event_type: EventType, payload: dict[str, succeeded, payload.get("error") if not succeeded else None, ) - except Exception: + except (RuntimeError, ValueError): logger.exception( "failed to update parent node %s for child %s", parent_node_name, @@ -485,14 +529,20 @@ def _on_child_workflow_terminal(self, _event_type: EventType, payload: dict[str, ) return - config = self._run_configs.get(parent_run_id) + with self._state_lock: + config = self._run_configs.get(parent_run_id) if config is not None: self._evaluate_successors(parent_run_id, parent_node_name, config) self._try_finalize(parent_run_id) def _on_gate_timeout(self, run_id: str, node_name: str, action: str) -> None: """Handle gate timeout expiry.""" - self._gate_timers.pop((run_id, node_name), None) + with self._state_lock: + # If the run was cleaned up (e.g., cancelled before timeout fired), + # the timer entry was already removed by `_cleanup_run` — stop. + if (run_id, node_name) not in self._gate_timers: + return + self._gate_timers.pop((run_id, node_name), None) approved = action == "approve" error = None if approved else "gate timeout" self.resolve_gate(run_id, node_name, approved=approved, error=error) @@ -603,16 +653,18 @@ def _create_deferred_job_for_node( timeout_ms, priority, ) - self._job_to_run[job_id] = run_id - config.deferred_nodes.discard(node_name) - except Exception: + except (RuntimeError, ValueError): logger.exception("create_deferred_job failed for %s", node_name) + return + with self._state_lock: + self._job_to_run[job_id] = run_id + config.deferred_nodes.discard(node_name) def _try_finalize(self, run_id: str) -> None: """If all nodes are terminal, finalize the run and emit the event.""" try: terminal_state = self._queue._inner.finalize_run_if_terminal(run_id) - except Exception: + except (RuntimeError, ValueError): logger.exception("finalize_run_if_terminal failed for %s", run_id) return if terminal_state is not None: @@ -683,11 +735,12 @@ def _expand_fan_out( timeout_ms, priority, ) - for jid in child_job_ids: - self._job_to_run[jid] = run_id - except Exception: + except (RuntimeError, ValueError): logger.exception("expand_fan_out failed for %s in run %s", fan_out_node, run_id) return + with self._state_lock: + for jid in child_job_ids: + self._job_to_run[jid] = run_id # Empty fan-out: parent is immediately COMPLETED with 0 children. if not child_names: @@ -705,7 +758,7 @@ def _handle_fan_out_child(self, run_id: str, child_name: str, config: _RunConfig parent_name = child_name.split("[")[0] try: completion = self._queue._inner.check_fan_out_completion(run_id, parent_name) - except Exception: + except (RuntimeError, ValueError): logger.exception("check_fan_out_completion failed for %s", parent_name) return @@ -736,7 +789,7 @@ def _handle_fan_out_child_failure( parent_name = child_name.split("[")[0] try: completion = self._queue._inner.check_fan_out_completion(run_id, parent_name) - except Exception: + except (RuntimeError, ValueError): logger.exception("check_fan_out_completion failed for %s", parent_name) return @@ -780,9 +833,11 @@ def _create_fan_in_job( timeout_ms, priority, ) - self._job_to_run[job_id] = run_id - except Exception: + except (RuntimeError, ValueError): logger.exception("create_deferred_job failed for fan-in %s", fan_in_node) + return + with self._state_lock: + self._job_to_run[job_id] = run_id # ── Helpers ──────────────────────────────────────────────────────── diff --git a/tests/python/test_contrib.py b/tests/python/test_contrib.py index 95e25e7..5df8773 100644 --- a/tests/python/test_contrib.py +++ b/tests/python/test_contrib.py @@ -244,6 +244,7 @@ def test_before_increments_active_workers(self) -> None: mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) mw._metrics = metrics mw._extra_labels_fn = None + mw._task_filter = None mw._start_times = {} mw._lock = threading.Lock() @@ -263,6 +264,7 @@ def test_after_tracks_counter_and_histogram(self) -> None: mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) mw._metrics = metrics mw._extra_labels_fn = None + mw._task_filter = None mw._start_times = {"job-1": 0.0} mw._lock = threading.Lock() @@ -284,6 +286,7 @@ def test_after_tracks_failure(self) -> None: mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) mw._metrics = metrics mw._extra_labels_fn = None + mw._task_filter = None mw._start_times = {"job-1": 0.0} mw._lock = threading.Lock() diff --git a/tests/python/test_workflows_subworkflow.py b/tests/python/test_workflows_subworkflow.py index 0a69631..4d6b2fc 100644 --- a/tests/python/test_workflows_subworkflow.py +++ b/tests/python/test_workflows_subworkflow.py @@ -97,6 +97,42 @@ def failing_sub() -> Workflow: assert final.nodes["after"].status == NodeStatus.SKIPPED +def test_sub_workflow_compile_failure_marks_parent_failed(queue: Queue) -> None: + """Regression: a factory that raises during `build()` must not leave the + parent node Skipped forever — it must be marked Failed so the outer run + can finalize. Before the fix, the tracker called `skip_workflow_node` + on the parent before attempting compile, and a compile failure left the + node Skipped permanently.""" + + @queue.task() + def downstream() -> str: + return "should not run" + + @queue.workflow("broken_sub") + def broken_sub() -> Workflow: + raise RuntimeError("factory blew up") + + wf = Workflow(name="parent_compile_fail") + wf.step("sub", broken_sub.as_step()) + wf.step("after", downstream, after="sub") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED, ( + f"outer run must finalize as FAILED, got {final.state}" + ) + assert final.nodes["sub"].status == NodeStatus.FAILED, ( + f"sub-workflow parent must be FAILED (was {final.nodes['sub'].status}) — " + "the old bug left it SKIPPED" + ) + assert final.nodes["after"].status == NodeStatus.SKIPPED + + def test_cancel_parent_cascades(queue: Queue) -> None: """Cancelling a parent workflow cancels the child sub-workflow too."""