From c8526652c0235944a6755f3e9cfd79f06b8dc629 Mon Sep 17 00:00:00 2001 From: dragon-zhang Date: Sun, 16 Mar 2025 11:38:48 +0800 Subject: [PATCH] support cancel coroutine/task --- CHANGELOG.md | 4 +++ README.md | 1 - README_ZH.md | 1 - core/src/co_pool/creator.rs | 2 +- core/src/co_pool/mod.rs | 56 +++++++++++++++++++++++++++++- core/src/co_pool/task.rs | 6 ++++ core/src/common/constants.rs | 4 ++- core/src/coroutine/korosensei.rs | 4 +++ core/src/coroutine/listener.rs | 8 +++++ core/src/coroutine/mod.rs | 31 +++++++++++++++++ core/src/coroutine/state.rs | 19 ++++++++++ core/src/coroutine/suspender.rs | 37 ++++++++++++++++++++ core/src/monitor.rs | 1 + core/src/net/event_loop.rs | 5 +++ core/src/net/mod.rs | 5 +++ core/src/scheduler.rs | 59 ++++++++++++++++++++++++++++---- hook/src/lib.rs | 12 +++++++ open-coroutine/src/lib.rs | 10 ++++++ 18 files changed, 254 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 418595bf..726dad66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +### 0.7.x + +- [x] support cancel coroutine/task + ### 0.6.x - [x] support custom task and coroutine priority. diff --git a/README.md b/README.md index b95b34e2..a7e40f4e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,6 @@ English | [中文](README_ZH.md) - [ ] add performance [benchmark](https://github.com/TechEmpower/FrameworkBenchmarks/wiki/Project-Information-Framework-Tests-Overview); -- [ ] cancel coroutine/task; - [ ] add metrics; - [ ] add synchronization toolkit; - [ ] support and compatibility for AF_XDP socket; diff --git a/README_ZH.md b/README_ZH.md index 8079d6bf..043633e1 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -30,7 +30,6 @@ - [ ] 增加性能[基准测试](https://github.com/TechEmpower/FrameworkBenchmarks/wiki/Project-Information-Framework-Tests-Overview); -- [ ] 取消协程/任务; - [ ] 增加性能指标监控; - [ ] 增加并发工具包; - [ ] 支持AF_XDP套接字; diff --git a/core/src/co_pool/creator.rs b/core/src/co_pool/creator.rs index 9c227e02..44a59b53 100644 --- a/core/src/co_pool/creator.rs +++ b/core/src/co_pool/creator.rs @@ -29,7 +29,7 @@ impl Listener<(), Option> for CoroutineCreator { .store(pool.get_running_size().saturating_sub(1), Ordering::Release); } } - CoroutineState::Error(_) => { + CoroutineState::Cancelled | CoroutineState::Error(_) => { if let Some(pool) = CoroutinePool::current() { //worker协程异常退出,需要先回收再创建 pool.running diff --git a/core/src/co_pool/mod.rs b/core/src/co_pool/mod.rs index f9ba8fca..f4841043 100644 --- a/core/src/co_pool/mod.rs +++ b/core/src/co_pool/mod.rs @@ -6,8 +6,9 @@ use crate::common::ordered_work_steal::{OrderedLocalQueue, OrderedWorkStealQueue use crate::common::{get_timeout_time, now, CondvarBlocker}; use crate::coroutine::suspender::Suspender; use crate::scheduler::{SchedulableCoroutine, Scheduler}; -use crate::{error, impl_current_for, impl_display_by_debug, impl_for_named, trace}; +use crate::{error, impl_current_for, impl_display_by_debug, impl_for_named, trace, warn}; use dashmap::{DashMap, DashSet}; +use once_cell::sync::Lazy; use std::cell::Cell; use std::ffi::c_longlong; use std::io::{Error, ErrorKind}; @@ -25,6 +26,11 @@ mod state; /// Creator for coroutine pool. mod creator; +/// `task_name` -> `co_name` +static RUNNING_TASKS: Lazy> = Lazy::new(DashMap::new); + +static CANCEL_TASKS: Lazy> = Lazy::new(DashSet::new); + /// The coroutine pool impls. #[repr(C)] #[derive(Debug)] @@ -383,7 +389,17 @@ impl<'p> CoroutinePool<'p> { fn try_run(&self) -> Option<()> { self.task_queue.pop().map(|task| { + let tname = task.get_name().to_string().leak(); + if CANCEL_TASKS.contains(tname) { + _ = CANCEL_TASKS.remove(tname); + warn!("Cancel task:{} successfully !", tname); + return; + } + if let Some(co) = SchedulableCoroutine::current() { + _ = RUNNING_TASKS.insert(tname, co.name()); + } let (task_name, result) = task.run(); + _ = RUNNING_TASKS.remove(tname); let n = task_name.clone().leak(); if self.no_waits.contains(n) { _ = self.no_waits.remove(n); @@ -406,6 +422,44 @@ impl<'p> CoroutinePool<'p> { } } + /// Try to cancel a task. + pub fn try_cancel_task(task_name: &str) { + // 检查正在运行的任务是否是要取消的任务 + if let Some(info) = RUNNING_TASKS.get(task_name) { + let co_name = *info; + // todo windows support + #[allow(unused_variables)] + if let Some(pthread) = Scheduler::get_scheduling_thread(co_name) { + // 发送SIGVTALRM信号,在运行时取消任务 + #[cfg(unix)] + if nix::sys::pthread::pthread_kill(pthread, nix::sys::signal::Signal::SIGVTALRM) + .is_ok() + { + warn!( + "Attempt to cancel task:{} running on coroutine:{} by thread:{}, cancelling...", + task_name, co_name, pthread + ); + } else { + error!( + "Attempt to cancel task:{} running on coroutine:{} by thread:{} failed !", + task_name, co_name, pthread + ); + } + } else { + // 添加到待取消队列 + Scheduler::try_cancel_coroutine(co_name); + warn!( + "Attempt to cancel task:{} running on coroutine:{}, cancelling...", + task_name, co_name + ); + } + } else { + // 添加到待取消队列 + _ = CANCEL_TASKS.insert(Box::leak(Box::from(task_name))); + warn!("Attempt to cancel task:{}, cancelling...", task_name); + } + } + /// Schedule the tasks. /// /// Allow multiple threads to concurrently submit task to the pool, diff --git a/core/src/co_pool/task.rs b/core/src/co_pool/task.rs index 12fffa45..ba9a003b 100644 --- a/core/src/co_pool/task.rs +++ b/core/src/co_pool/task.rs @@ -33,6 +33,12 @@ impl<'t> Task<'t> { } } + /// get the task name. + #[must_use] + pub fn get_name(&self) -> &str { + &self.name + } + /// execute the task /// /// # Errors diff --git a/core/src/common/constants.rs b/core/src/common/constants.rs index 17a8332d..bb8a4868 100644 --- a/core/src/common/constants.rs +++ b/core/src/common/constants.rs @@ -195,9 +195,11 @@ pub enum CoroutineState { Suspend(Y, u64), ///The coroutine enters the syscall. Syscall(Y, SyscallName, SyscallState), + /// The coroutine cancelled. + Cancelled, /// The coroutine completed with a return value. Complete(R), - /// The coroutine completed with a error message. + /// The coroutine completed with an error message. Error(&'static str), } diff --git a/core/src/coroutine/korosensei.rs b/core/src/coroutine/korosensei.rs index ac1b5a54..880b65c5 100644 --- a/core/src/coroutine/korosensei.rs +++ b/core/src/coroutine/korosensei.rs @@ -445,6 +445,10 @@ where let current = self.state(); match current { CoroutineState::Running => { + if Suspender::::is_cancel() { + self.cancel()?; + return Ok(CoroutineState::Cancelled); + } let timestamp = Suspender::::timestamp(); self.suspend(y, timestamp)?; Ok(CoroutineState::Suspend(y, timestamp)) diff --git a/core/src/coroutine/listener.rs b/core/src/coroutine/listener.rs index b4f00378..ce94f35b 100644 --- a/core/src/coroutine/listener.rs +++ b/core/src/coroutine/listener.rs @@ -27,6 +27,9 @@ pub trait Listener: Debug { /// callback when the coroutine enters syscall. fn on_syscall(&self, local: &CoroutineLocal, old_state: CoroutineState) {} + /// Callback when the coroutine is cancelled. + fn on_cancel(&self, local: &CoroutineLocal, old_state: CoroutineState) {} + /// Callback when the coroutine is completed. fn on_complete( &self, @@ -91,6 +94,11 @@ where old_state: CoroutineState ), "on_syscall"); + broadcast!(on_cancel( + local: &CoroutineLocal, + old_state: CoroutineState + ), "on_cancel"); + broadcast!(on_complete( local: &CoroutineLocal, old_state: CoroutineState, diff --git a/core/src/coroutine/mod.rs b/core/src/coroutine/mod.rs index cfb8c73c..ec220ebb 100644 --- a/core/src/coroutine/mod.rs +++ b/core/src/coroutine/mod.rs @@ -135,6 +135,35 @@ impl<'c, Param, Yield, Return> Coroutine<'c, Param, Yield, Return> { callback, ) } + + /// handle SIGVTALRM + #[cfg(unix)] + fn setup_sigvtalrm_handler() { + use nix::sys::signal::{sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal}; + use std::sync::atomic::{AtomicBool, Ordering}; + static CANCEL_HANDLER_INITED: AtomicBool = AtomicBool::new(false); + if CANCEL_HANDLER_INITED + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + extern "C" fn sigvtalrm_handler(_: libc::c_int) { + if let Some(suspender) = suspender::Suspender::::current() { + suspender.cancel(); + } + } + // install SIGVTALRM signal handler + let mut set = SigSet::empty(); + set.add(Signal::SIGVTALRM); + let sa = SigAction::new( + SigHandler::Handler(sigvtalrm_handler::), + SaFlags::SA_RESTART, + set, + ); + unsafe { + _ = sigaction(Signal::SIGVTALRM, &sa).expect("install SIGVTALRM handler failed !"); + } + } + } } impl Coroutine<'_, (), Yield, Return> @@ -170,6 +199,8 @@ where } Self::init_current(self); self.running()?; + #[cfg(unix)] + Self::setup_sigvtalrm_handler(); let r = self.raw_resume(arg); Self::clean_current(); r diff --git a/core/src/coroutine/state.rs b/core/src/coroutine/state.rs index 4ff97827..33e0c963 100644 --- a/core/src/coroutine/state.rs +++ b/core/src/coroutine/state.rs @@ -146,6 +146,25 @@ where ))) } + /// running -> cancel + /// + /// # Errors + /// if change state fails. + pub(super) fn cancel(&self) -> std::io::Result<()> { + let current = self.state(); + if CoroutineState::Running == current { + let new_state = CoroutineState::Cancelled; + let old_state = self.change_state(new_state); + self.on_cancel(self, old_state); + return Ok(()); + } + Err(Error::other(format!( + "{} unexpected {current}->{:?}", + self.name(), + CoroutineState::::Cancelled + ))) + } + /// running -> complete /// /// # Errors diff --git a/core/src/coroutine/suspender.rs b/core/src/coroutine/suspender.rs index ea5f2308..94a0f2e5 100644 --- a/core/src/coroutine/suspender.rs +++ b/core/src/coroutine/suspender.rs @@ -6,6 +6,10 @@ thread_local! { #[allow(clippy::missing_const_for_thread_local)] static TIMESTAMP: crossbeam_utils::atomic::AtomicCell> = const { crossbeam_utils::atomic::AtomicCell::new(std::collections::VecDeque::new()) }; + + #[allow(clippy::missing_const_for_thread_local)] + static CANCEL: crossbeam_utils::atomic::AtomicCell> = + const { crossbeam_utils::atomic::AtomicCell::new(std::collections::VecDeque::new()) }; } impl Suspender<'_, Param, Yield> { @@ -30,6 +34,23 @@ impl Suspender<'_, Param, Yield> { self.suspend_with(arg) } + /// Cancel the execution of the coroutine. + pub fn cancel(&self) -> ! { + CANCEL.with(|s| unsafe { + s.as_ptr() + .as_mut() + .unwrap_or_else(|| { + panic!( + "thread:{} init CANCEL current failed", + std::thread::current().name().unwrap_or("unknown") + ) + }) + .push_front(true); + }); + _ = self.suspend_with(unsafe { std::mem::zeroed() }); + unreachable!() + } + pub(crate) fn timestamp() -> u64 { TIMESTAMP .with(|s| unsafe { @@ -45,6 +66,22 @@ impl Suspender<'_, Param, Yield> { }) .unwrap_or(0) } + + pub(crate) fn is_cancel() -> bool { + CANCEL + .with(|s| unsafe { + s.as_ptr() + .as_mut() + .unwrap_or_else(|| { + panic!( + "thread:{} get CANCEL current failed", + std::thread::current().name().unwrap_or("unknown") + ) + }) + .pop_front() + }) + .unwrap_or(false) + } } #[allow(clippy::must_use_candidate)] diff --git a/core/src/monitor.rs b/core/src/monitor.rs index b57623b5..d1f35e47 100644 --- a/core/src/monitor.rs +++ b/core/src/monitor.rs @@ -199,6 +199,7 @@ impl Listener for MonitorListener { } CoroutineState::Suspend(_, _) | CoroutineState::Syscall(_, _, _) + | CoroutineState::Cancelled | CoroutineState::Complete(_) | CoroutineState::Error(_) => { if let Some(node) = local.get(NOTIFY_NODE) { diff --git a/core/src/net/event_loop.rs b/core/src/net/event_loop.rs index c3546a08..126dbdbe 100644 --- a/core/src/net/event_loop.rs +++ b/core/src/net/event_loop.rs @@ -136,6 +136,11 @@ impl<'e> EventLoop<'e> { }) } + /// Try to cancel a task from `CoroutinePool`. + pub(super) fn try_cancel_task(name: &str) { + CoroutinePool::try_cancel_task(name); + } + #[allow(trivial_numeric_casts, clippy::cast_possible_truncation)] fn token(syscall: SyscallName) -> usize { if let Some(co) = SchedulableCoroutine::current() { diff --git a/core/src/net/mod.rs b/core/src/net/mod.rs index 9a3d59dd..f32ff868 100644 --- a/core/src/net/mod.rs +++ b/core/src/net/mod.rs @@ -154,6 +154,11 @@ impl EventLoops { ) } + /// Try to cancel a task from event-loop. + pub fn try_cancel_task(name: &str) { + EventLoop::try_cancel_task(name); + } + /// Submit a new coroutine to event-loop. /// /// Allow multiple threads to concurrently submit coroutine to the pool, diff --git a/core/src/scheduler.rs b/core/src/scheduler.rs index 14ac0347..524e38ba 100644 --- a/core/src/scheduler.rs +++ b/core/src/scheduler.rs @@ -5,8 +5,11 @@ use crate::common::{get_timeout_time, now}; use crate::coroutine::listener::Listener; use crate::coroutine::suspender::Suspender; use crate::coroutine::Coroutine; -use crate::{co, impl_current_for, impl_display_by_debug, impl_for_named}; -use dashmap::DashMap; +use crate::{co, impl_current_for, impl_display_by_debug, impl_for_named, warn}; +use dashmap::{DashMap, DashSet}; +#[cfg(unix)] +use nix::sys::pthread::Pthread; +use once_cell::sync::Lazy; use std::collections::{BinaryHeap, HashMap, VecDeque}; use std::ffi::c_longlong; use std::io::Error; @@ -78,6 +81,13 @@ impl Ord for SyscallSuspendItem<'_> { } } +#[cfg(unix)] +static RUNNING_COROUTINES: Lazy> = Lazy::new(DashMap::new); +#[cfg(windows)] +static RUNNING_COROUTINES: Lazy> = Lazy::new(DashMap::new); + +static CANCEL_COROUTINES: Lazy> = Lazy::new(DashSet::new); + /// The scheduler impls. #[repr(C)] #[derive(Debug)] @@ -280,10 +290,27 @@ impl<'s> Scheduler<'s> { self.check_ready()?; // schedule coroutines if let Some(mut coroutine) = self.ready.pop() { - match coroutine.resume()? { + let co_name = coroutine.name().to_string().leak(); + if CANCEL_COROUTINES.contains(co_name) { + _ = CANCEL_COROUTINES.remove(co_name); + warn!("Cancel coroutine:{} successfully !", co_name); + continue; + } + cfg_if::cfg_if! { + if #[cfg(windows)] { + let current_thread = unsafe { + windows_sys::Win32::System::Threading::GetCurrentThread() + } as usize; + } else { + let current_thread = nix::sys::pthread::pthread_self(); + } + } + _ = RUNNING_COROUTINES.insert(co_name, current_thread); + match coroutine.resume().inspect(|_| { + _ = RUNNING_COROUTINES.remove(co_name); + })? { CoroutineState::Syscall((), _, state) => { //挂起协程到系统调用表 - let co_name = Box::leak(Box::from(coroutine.name())); //如果已包含,说明当前系统调用还有上层父系统调用,因此直接忽略插入结果 _ = self.syscall.insert(co_name, coroutine); if let SyscallState::Suspend(timestamp) = state { @@ -303,15 +330,14 @@ impl<'s> Scheduler<'s> { self.ready.push(coroutine); } } + CoroutineState::Cancelled => {} CoroutineState::Complete(result) => { - let co_name = Box::leak(Box::from(coroutine.name())); assert!( results.insert(co_name, Ok(result)).is_none(), "not consume result" ); } CoroutineState::Error(message) => { - let co_name = Box::leak(Box::from(coroutine.name())); assert!( results.insert(co_name, Err(message)).is_none(), "not consume result" @@ -359,6 +385,27 @@ impl<'s> Scheduler<'s> { } Ok(()) } + + /// Cancel the coroutine by name. + pub fn try_cancel_coroutine(co_name: &str) { + _ = CANCEL_COROUTINES.insert(Box::leak(Box::from(co_name))); + } + + /// Get the scheduling thread of the coroutine. + #[cfg(unix)] + pub fn get_scheduling_thread(co_name: &str) -> Option { + let co_name: &str = Box::leak(Box::from(co_name)); + RUNNING_COROUTINES.get(co_name).map(|r| *r) + } + + /// Get the scheduling thread of the coroutine. + #[cfg(windows)] + pub fn get_scheduling_thread(co_name: &str) -> Option { + let co_name: &str = Box::leak(Box::from(co_name)); + RUNNING_COROUTINES + .get(co_name) + .map(|r| *r as windows_sys::Win32::Foundation::HANDLE) + } } #[cfg(test)] diff --git a/hook/src/lib.rs b/hook/src/lib.rs index 8c885009..718f555c 100644 --- a/hook/src/lib.rs +++ b/hook/src/lib.rs @@ -103,6 +103,18 @@ pub extern "C" fn task_crate(f: UserTaskFunc, param: usize, priority: c_longlong ) } +///尝试异步取消任务 +#[no_mangle] +pub extern "C" fn task_cancel(handle: &JoinHandle) -> c_longlong { + match handle.get_name() { + Ok(name) => { + EventLoops::try_cancel_task(name); + 0 + } + Err(_) => -1, + } +} + ///等待任务完成 #[no_mangle] pub extern "C" fn task_join(handle: &JoinHandle) -> c_longlong { diff --git a/open-coroutine/src/lib.rs b/open-coroutine/src/lib.rs index cd97324e..de067257 100644 --- a/open-coroutine/src/lib.rs +++ b/open-coroutine/src/lib.rs @@ -87,6 +87,8 @@ extern "C" { fn task_join(handle: &open_coroutine_core::net::join::JoinHandle) -> c_longlong; + fn task_cancel(handle: &open_coroutine_core::net::join::JoinHandle) -> c_longlong; + fn task_timeout_join( handle: &open_coroutine_core::net::join::JoinHandle, ns_time: u64, @@ -203,6 +205,14 @@ impl JoinHandle { pub fn any_join>(iter: I) -> std::io::Result> { Self::any_timeout_join(Duration::MAX, Vec::from_iter(iter).as_slice()) } + + pub fn try_cancel(self) -> std::io::Result<()> { + let r = unsafe { task_cancel(&self) }; + match r.cmp(&0) { + Ordering::Equal => Ok(()), + _ => Err(Error::other("cancel failed")), + } + } } impl From for JoinHandle {