From b219a823b994f68b5a54e7ef8a10d82b047b68ef Mon Sep 17 00:00:00 2001 From: Xinye Tao Date: Wed, 21 Sep 2022 17:59:03 +0800 Subject: [PATCH] Revert "tikv_util: introduce future channel (#13407)" (#13507) ref tikv/tikv#13394, ref tikv/tikv#13407 Signed-off-by: tabokie --- components/tikv_util/Cargo.toml | 5 - components/tikv_util/src/mpsc/batch.rs | 509 ++++++++++++++++++ components/tikv_util/src/mpsc/future.rs | 463 ---------------- components/tikv_util/src/mpsc/mod.rs | 2 +- src/server/service/batch.rs | 14 +- src/server/service/kv.rs | 24 +- tests/Cargo.toml | 5 + .../benches/channel/bench_channel.rs | 19 +- .../benches/channel/mod.rs | 0 9 files changed, 547 insertions(+), 494 deletions(-) create mode 100644 components/tikv_util/src/mpsc/batch.rs delete mode 100644 components/tikv_util/src/mpsc/future.rs rename {components/tikv_util => tests}/benches/channel/bench_channel.rs (87%) rename {components/tikv_util => tests}/benches/channel/mod.rs (100%) diff --git a/components/tikv_util/Cargo.toml b/components/tikv_util/Cargo.toml index 5b508a4a4d4..d8964cf0301 100644 --- a/components/tikv_util/Cargo.toml +++ b/components/tikv_util/Cargo.toml @@ -73,8 +73,3 @@ regex = "1.0" tempfile = "3.0" toml = "0.5" utime = "0.2" - -[[bench]] -name = "channel" -path = "benches/channel/mod.rs" -test = true diff --git a/components/tikv_util/src/mpsc/batch.rs b/components/tikv_util/src/mpsc/batch.rs new file mode 100644 index 00000000000..0415f9376af --- /dev/null +++ b/components/tikv_util/src/mpsc/batch.rs @@ -0,0 +1,509 @@ +// Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0. + +use std::{ + pin::Pin, + ptr::null_mut, + sync::{ + atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; + +use crossbeam::channel::{ + self, RecvError, RecvTimeoutError, SendError, TryRecvError, TrySendError, +}; +use futures::{ + stream::Stream, + task::{Context, Poll, Waker}, +}; + +struct State { + // If the receiver can't get any messages temporarily in `poll` context, it will put its + // current task here. + recv_task: AtomicPtr, + notify_size: usize, + // How many messages are sent without notify. + pending: AtomicUsize, + notifier_registered: AtomicBool, +} + +impl State { + fn new(notify_size: usize) -> State { + State { + // Any pointer that is put into `recv_task` must be a valid and owned + // pointer (it must not be dropped). When a pointer is retrieved from + // `recv_task`, the user is responsible for its proper destruction. + recv_task: AtomicPtr::new(null_mut()), + notify_size, + pending: AtomicUsize::new(0), + notifier_registered: AtomicBool::new(false), + } + } + + #[inline] + fn try_notify_post_send(&self) { + let old_pending = self.pending.fetch_add(1, Ordering::AcqRel); + if old_pending >= self.notify_size - 1 { + self.notify(); + } + } + + #[inline] + fn notify(&self) { + let t = self.recv_task.swap(null_mut(), Ordering::AcqRel); + if !t.is_null() { + self.pending.store(0, Ordering::Release); + // Safety: see comment on `recv_task`. + let t = unsafe { Box::from_raw(t) }; + t.wake(); + } + } + + /// When the `Receiver` that holds the `State` is running on an `Executor`, + /// the `Receiver` calls this to yield from the current `poll` context, + /// and puts the current task handle to `recv_task`, so that the `Sender` + /// respectively can notify it after sending some messages into the channel. + #[inline] + fn yield_poll(&self, waker: Waker) -> bool { + let t = Box::into_raw(Box::new(waker)); + let origin = self.recv_task.swap(t, Ordering::AcqRel); + if !origin.is_null() { + // Safety: see comment on `recv_task`. + unsafe { drop(Box::from_raw(origin)) }; + return true; + } + false + } +} + +impl Drop for State { + fn drop(&mut self) { + let t = self.recv_task.swap(null_mut(), Ordering::AcqRel); + if !t.is_null() { + // Safety: see comment on `recv_task`. + unsafe { drop(Box::from_raw(t)) }; + } + } +} + +/// `Notifier` is used to notify receiver whenever you want. +pub struct Notifier(Arc); +impl Notifier { + #[inline] + pub fn notify(self) { + drop(self); + } +} + +impl Drop for Notifier { + #[inline] + fn drop(&mut self) { + let notifier_registered = &self.0.notifier_registered; + if notifier_registered + .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + unreachable!("notifier_registered must be true"); + } + self.0.notify(); + } +} + +pub struct Sender { + sender: Option>, + state: Arc, +} + +impl Clone for Sender { + #[inline] + fn clone(&self) -> Sender { + Sender { + sender: self.sender.clone(), + state: Arc::clone(&self.state), + } + } +} + +impl Drop for Sender { + #[inline] + fn drop(&mut self) { + drop(self.sender.take()); + self.state.notify(); + } +} + +pub struct Receiver { + receiver: channel::Receiver, + state: Arc, +} + +impl Sender { + pub fn is_empty(&self) -> bool { + // When there is no sender references, it can't be known whether + // it's empty or not. + self.sender.as_ref().map_or(false, |s| s.is_empty()) + } + + #[inline] + pub fn send(&self, t: T) -> Result<(), SendError> { + self.sender.as_ref().unwrap().send(t)?; + self.state.try_notify_post_send(); + Ok(()) + } + + #[inline] + pub fn send_and_notify(&self, t: T) -> Result<(), SendError> { + self.sender.as_ref().unwrap().send(t)?; + self.state.notify(); + Ok(()) + } + + #[inline] + pub fn try_send(&self, t: T) -> Result<(), TrySendError> { + self.sender.as_ref().unwrap().try_send(t)?; + self.state.try_notify_post_send(); + Ok(()) + } + + #[inline] + pub fn get_notifier(&self) -> Option { + let notifier_registered = &self.state.notifier_registered; + if notifier_registered + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + return Some(Notifier(Arc::clone(&self.state))); + } + None + } +} + +impl Receiver { + #[inline] + pub fn recv(&self) -> Result { + self.receiver.recv() + } + + #[inline] + pub fn try_recv(&self) -> Result { + self.receiver.try_recv() + } + + #[inline] + pub fn recv_timeout(&self, timeout: Duration) -> Result { + self.receiver.recv_timeout(timeout) + } +} + +/// Creates a unbounded channel with a given `notify_size`, which means if there +/// are more pending messages in the channel than `notify_size`, the `Sender` +/// will auto notify the `Receiver`. +/// +/// # Panics +/// if `notify_size` equals to 0. +#[inline] +pub fn unbounded(notify_size: usize) -> (Sender, Receiver) { + assert!(notify_size > 0); + let state = Arc::new(State::new(notify_size)); + let (sender, receiver) = channel::unbounded(); + ( + Sender { + sender: Some(sender), + state: state.clone(), + }, + Receiver { receiver, state }, + ) +} + +/// Creates a bounded channel with a given `notify_size`, which means if there +/// are more pending messages in the channel than `notify_size`, the `Sender` +/// will auto notify the `Receiver`. +/// +/// # Panics +/// if `notify_size` equals to 0. +#[inline] +pub fn bounded(cap: usize, notify_size: usize) -> (Sender, Receiver) { + assert!(notify_size > 0); + let state = Arc::new(State::new(notify_size)); + let (sender, receiver) = channel::bounded(cap); + ( + Sender { + sender: Some(sender), + state: state.clone(), + }, + Receiver { receiver, state }, + ) +} + +impl Stream for Receiver { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.try_recv() { + Ok(m) => Poll::Ready(Some(m)), + Err(TryRecvError::Empty) => { + if self.state.yield_poll(cx.waker().clone()) { + Poll::Pending + } else { + // For the case that all senders are dropped before the current task is saved. + self.poll_next(cx) + } + } + Err(TryRecvError::Disconnected) => Poll::Ready(None), + } + } +} + +/// A Collector Used in `BatchReceiver`. +pub trait BatchCollector { + /// If `elem` is collected into `collection` successfully, return `None`. + /// Otherwise return `elem` back, and `collection` should be spilled out. + fn collect(&mut self, collection: &mut Collection, elem: Elem) -> Option; +} + +pub struct VecCollector; + +impl BatchCollector, E> for VecCollector { + fn collect(&mut self, v: &mut Vec, e: E) -> Option { + v.push(e); + None + } +} + +/// `BatchReceiver` is a `futures::Stream`, which returns a batched type. +pub struct BatchReceiver { + rx: Receiver, + max_batch_size: usize, + elem: Option, + initializer: I, + collector: C, +} + +impl BatchReceiver +where + T: Unpin, + E: Unpin, + I: Fn() -> E + Unpin, + C: BatchCollector + Unpin, +{ + /// Creates a new `BatchReceiver` with given `initializer` and `collector`. + /// `initializer` is used to generate a initial value, and `collector` + /// will collect every (at most `max_batch_size`) raw items into the + /// batched value. + pub fn new(rx: Receiver, max_batch_size: usize, initializer: I, collector: C) -> Self { + BatchReceiver { + rx, + max_batch_size, + elem: None, + initializer, + collector, + } + } +} + +impl Stream for BatchReceiver +where + T: Unpin, + E: Unpin, + I: Fn() -> E + Unpin, + C: BatchCollector + Unpin, +{ + type Item = E; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let ctx = self.get_mut(); + let (mut count, mut received) = (0, None); + let finished = loop { + match ctx.rx.try_recv() { + Ok(m) => { + let collection = ctx.elem.get_or_insert_with(&ctx.initializer); + if let Some(m) = ctx.collector.collect(collection, m) { + received = Some(m); + break false; + } + count += 1; + if count >= ctx.max_batch_size { + break false; + } + } + Err(TryRecvError::Disconnected) => break true, + Err(TryRecvError::Empty) => { + if ctx.rx.state.yield_poll(cx.waker().clone()) { + break false; + } + } + } + }; + + if ctx.elem.is_none() && finished { + return Poll::Ready(None); + } else if ctx.elem.is_none() { + return Poll::Pending; + } + let elem = ctx.elem.take(); + if let Some(m) = received { + let collection = ctx.elem.get_or_insert_with(&ctx.initializer); + let _received = ctx.collector.collect(collection, m); + debug_assert!(_received.is_none()); + } + Poll::Ready(elem) + } +} + +#[cfg(test)] +mod tests { + use std::{ + sync::{mpsc, Mutex}, + thread, time, + }; + + use futures::{ + future::{self, BoxFuture, FutureExt}, + stream::{self, StreamExt}, + task::{self, ArcWake, Poll}, + }; + use tokio::runtime::Builder; + + use super::*; + + #[test] + fn test_receiver() { + let (tx, rx) = unbounded::(4); + + let msg_counter = Arc::new(AtomicUsize::new(0)); + let msg_counter1 = Arc::clone(&msg_counter); + let pool = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + let _res = pool.spawn(rx.for_each(move |_| { + msg_counter1.fetch_add(1, Ordering::AcqRel); + future::ready(()) + })); + + // Wait until the receiver is suspended. + loop { + thread::sleep(time::Duration::from_millis(10)); + if !tx.state.recv_task.load(Ordering::SeqCst).is_null() { + break; + } + } + + // Send without notify, the receiver can't get batched messages. + tx.send(0).unwrap(); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 0); + + // Send with notify. + let notifier = tx.get_notifier().unwrap(); + assert!(tx.get_notifier().is_none()); + notifier.notify(); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 1); + + // Auto notify with more sendings. + for _ in 0..4 { + tx.send(0).unwrap(); + } + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 5); + } + + #[test] + fn test_batch_receiver() { + let (tx, rx) = unbounded::(4); + + let rx = BatchReceiver::new(rx, 8, || Vec::with_capacity(4), VecCollector); + let msg_counter = Arc::new(AtomicUsize::new(0)); + let msg_counter_spawned = Arc::clone(&msg_counter); + let (nty, polled) = mpsc::sync_channel(1); + let pool = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + let _res = pool.spawn( + stream::select( + rx, + stream::poll_fn(move |_| -> Poll>> { + nty.send(()).unwrap(); + Poll::Ready(None) + }), + ) + .for_each(move |v| { + let len = v.len(); + assert!(len <= 8); + msg_counter_spawned.fetch_add(len, Ordering::AcqRel); + future::ready(()) + }), + ); + + // Wait until the receiver has been polled in the spawned thread. + polled.recv().unwrap(); + + // Send without notify, the receiver can't get batched messages. + tx.send(0).unwrap(); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 0); + + // Send with notify. + let notifier = tx.get_notifier().unwrap(); + assert!(tx.get_notifier().is_none()); + notifier.notify(); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 1); + + // Auto notify with more sendings. + for _ in 0..16 { + tx.send(0).unwrap(); + } + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 17); + } + + #[test] + fn test_switch_between_sender_and_receiver() { + let (tx, mut rx) = unbounded::(4); + let future = async move { rx.next().await }; + let task = Task { + future: Arc::new(Mutex::new(Some(future.boxed()))), + }; + // Receiver has not received any messages, so the future is not be finished + // in this tick. + task.tick(); + assert!(task.future.lock().unwrap().is_some()); + // After sender is dropped, the task will be waked and then it tick self + // again to advance the progress. + drop(tx); + assert!(task.future.lock().unwrap().is_none()); + } + + #[derive(Clone)] + struct Task { + future: Arc>>>>, + } + + impl Task { + fn tick(&self) { + let task = Arc::new(self.clone()); + let mut future_slot = self.future.lock().unwrap(); + if let Some(mut future) = future_slot.take() { + let waker = task::waker_ref(&task); + let cx = &mut Context::from_waker(&waker); + match future.as_mut().poll(cx) { + Poll::Pending => { + *future_slot = Some(future); + } + Poll::Ready(None) => {} + _ => unimplemented!(), + } + } + } + } + + impl ArcWake for Task { + fn wake_by_ref(arc_self: &Arc) { + arc_self.tick(); + } + } +} diff --git a/components/tikv_util/src/mpsc/future.rs b/components/tikv_util/src/mpsc/future.rs deleted file mode 100644 index f5bd3071c65..00000000000 --- a/components/tikv_util/src/mpsc/future.rs +++ /dev/null @@ -1,463 +0,0 @@ -// Copyright 2022 TiKV Project Authors. Licensed under Apache-2.0. - -//! A module provides the implementation of receiver that supports async/await. - -use std::{ - pin::Pin, - ptr, - sync::atomic::{self, AtomicPtr, AtomicUsize, Ordering}, - task::{Context, Poll, Waker}, -}; - -use crossbeam::{ - channel::{SendError, TryRecvError}, - queue::SegQueue, -}; -use futures::{Stream, StreamExt}; - -#[derive(Clone, Copy)] -pub enum WakePolicy { - Immediately, - TillReach(usize), -} - -struct Queue { - queue: SegQueue, - waker: AtomicPtr, - liveness: AtomicUsize, - policy: WakePolicy, -} - -impl Queue { - #[inline] - fn wake(&self, policy: WakePolicy) { - if let WakePolicy::TillReach(n) = policy { - if self.queue.len() < n { - return; - } - } - let ptr = self.waker.swap(ptr::null_mut(), Ordering::AcqRel); - unsafe { - if !ptr.is_null() { - Box::from_raw(ptr).wake(); - } - } - } - - // If there is already a waker, true is returned. - fn register_waker(&self, waker: &Waker) -> bool { - let w = Box::new(waker.clone()); - let ptr = self.waker.swap(Box::into_raw(w), Ordering::AcqRel); - unsafe { - if ptr.is_null() { - false - } else { - drop(Box::from_raw(ptr)); - true - } - } - } -} - -impl Drop for Queue { - #[inline] - fn drop(&mut self) { - let ptr = self.waker.swap(ptr::null_mut(), Ordering::SeqCst); - unsafe { - if !ptr.is_null() { - drop(Box::from_raw(ptr)); - } - } - } -} - -const SENDER_COUNT_BASE: usize = 1 << 1; -const RECEIVER_COUNT_BASE: usize = 1; - -pub struct Sender { - queue: *mut Queue, -} - -impl Sender { - /// Sends the message with predefined wake policy. - #[inline] - pub fn send(&self, t: T) -> Result<(), SendError> { - let policy = unsafe { (*self.queue).policy }; - self.send_with(t, policy) - } - - /// Sends the message with the specified wake policy. - #[inline] - pub fn send_with(&self, t: T, policy: WakePolicy) -> Result<(), SendError> { - let queue = unsafe { &*self.queue }; - if queue.liveness.load(Ordering::Acquire) & RECEIVER_COUNT_BASE != 0 { - queue.queue.push(t); - queue.wake(policy); - return Ok(()); - } - Err(SendError(t)) - } -} - -impl Clone for Sender { - fn clone(&self) -> Self { - let queue = unsafe { &*self.queue }; - queue - .liveness - .fetch_add(SENDER_COUNT_BASE, Ordering::Relaxed); - Self { queue: self.queue } - } -} - -impl Drop for Sender { - #[inline] - fn drop(&mut self) { - let queue = unsafe { &*self.queue }; - let previous = queue - .liveness - .fetch_sub(SENDER_COUNT_BASE, Ordering::Release); - if previous == SENDER_COUNT_BASE | RECEIVER_COUNT_BASE { - // The last sender is dropped, we need to wake up the receiver. - queue.wake(WakePolicy::Immediately); - } else if previous == SENDER_COUNT_BASE { - atomic::fence(Ordering::Acquire); - drop(unsafe { Box::from_raw(self.queue) }); - } - } -} - -unsafe impl Send for Sender {} -unsafe impl Sync for Sender {} - -pub struct Receiver { - queue: *mut Queue, -} - -impl Stream for Receiver { - type Item = T; - - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let queue = unsafe { &*self.queue }; - if let Some(t) = queue.queue.pop() { - return Poll::Ready(Some(t)); - } - // If there is no previous waker, we still need to poll again in case some - // task is pushed before registering current waker. - if !queue.register_waker(cx.waker()) { - // In case the message is pushed right before registering waker. - if let Some(t) = queue.queue.pop() { - return Poll::Ready(Some(t)); - } - } - if queue.liveness.load(Ordering::Acquire) & !RECEIVER_COUNT_BASE != 0 { - return Poll::Pending; - } - Poll::Ready(None) - } -} - -impl Receiver { - #[inline] - pub fn try_recv(&mut self) -> Result { - let queue = unsafe { &*self.queue }; - if let Some(t) = queue.queue.pop() { - return Ok(t); - } - if queue.liveness.load(Ordering::Acquire) & !RECEIVER_COUNT_BASE != 0 { - return Err(TryRecvError::Empty); - } - Err(TryRecvError::Disconnected) - } -} - -impl Drop for Receiver { - #[inline] - fn drop(&mut self) { - let queue = unsafe { &*self.queue }; - if RECEIVER_COUNT_BASE - == queue - .liveness - .fetch_sub(RECEIVER_COUNT_BASE, Ordering::Release) - { - atomic::fence(Ordering::Acquire); - drop(unsafe { Box::from_raw(self.queue) }); - } - } -} - -unsafe impl Send for Receiver {} - -pub fn unbounded(policy: WakePolicy) -> (Sender, Receiver) { - let queue = Box::into_raw(Box::new(Queue { - queue: SegQueue::new(), - waker: AtomicPtr::default(), - liveness: AtomicUsize::new(SENDER_COUNT_BASE | RECEIVER_COUNT_BASE), - policy, - })); - (Sender { queue }, Receiver { queue }) -} - -/// `BatchReceiver` is a `futures::Stream`, which returns a batched type. -pub struct BatchReceiver { - rx: Receiver, - max_batch_size: usize, - initializer: I, - collector: C, -} - -impl BatchReceiver { - /// Creates a new `BatchReceiver` with given `initializer` and `collector`. - /// `initializer` is used to generate a initial value, and `collector` - /// will collect every (at most `max_batch_size`) raw items into the - /// batched value. - pub fn new(rx: Receiver, max_batch_size: usize, initializer: I, collector: C) -> Self { - BatchReceiver { - rx, - max_batch_size, - initializer, - collector, - } - } -} - -impl Stream for BatchReceiver -where - T: Send + Unpin, - E: Unpin, - I: Fn() -> E + Unpin, - C: FnMut(&mut E, T) + Unpin, -{ - type Item = E; - - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let ctx = self.get_mut(); - let mut collector = match ctx.rx.poll_next_unpin(cx) { - Poll::Ready(Some(m)) => { - let mut c = (ctx.initializer)(); - (ctx.collector)(&mut c, m); - c - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - }; - for _ in 1..ctx.max_batch_size { - if let Poll::Ready(Some(m)) = ctx.rx.poll_next_unpin(cx) { - (ctx.collector)(&mut collector, m); - } - } - Poll::Ready(Some(collector)) - } -} - -#[cfg(test)] -mod tests { - use std::{ - sync::{ - atomic::{AtomicBool, AtomicUsize}, - mpsc, Arc, Mutex, - }, - thread, time, - }; - - use futures::{ - future::{self, BoxFuture, FutureExt}, - stream::{self, StreamExt}, - task::{self, ArcWake, Poll}, - }; - use tokio::runtime::{Builder, Runtime}; - - use super::*; - - fn spawn_and_wait( - rx_builder: impl FnOnce() -> S, - ) -> (Runtime, Arc) { - let msg_counter = Arc::new(AtomicUsize::new(0)); - let msg_counter1 = msg_counter.clone(); - let pool = Builder::new_multi_thread() - .worker_threads(1) - .build() - .unwrap(); - let (nty, polled) = mpsc::sync_channel(1); - _ = pool.spawn( - stream::select( - rx_builder(), - stream::poll_fn(move |_| -> Poll> { - nty.send(()).unwrap(); - Poll::Ready(None) - }), - ) - .for_each(move |_| { - msg_counter1.fetch_add(1, Ordering::AcqRel); - future::ready(()) - }), - ); - - // Wait until the receiver has been polled in the spawned thread. - polled.recv().unwrap(); - (pool, msg_counter) - } - - #[test] - fn test_till_reach_wake() { - let (tx, rx) = unbounded::(WakePolicy::TillReach(4)); - - let (_pool, msg_counter) = spawn_and_wait(move || rx); - - // Receiver should not be woken up until its length reach specified value. - for _ in 0..3 { - tx.send(0).unwrap(); - } - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 0); - - tx.send(0).unwrap(); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 4); - - // Should start new batch. - tx.send(0).unwrap(); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 4); - - let tx1 = tx.clone(); - drop(tx); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 4); - // If all senders are dropped, receiver should be woken up. - drop(tx1); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 5); - } - - #[test] - fn test_immediately_wake() { - let (tx, rx) = unbounded::(WakePolicy::Immediately); - - let (_pool, msg_counter) = spawn_and_wait(move || rx); - - // Receiver should be woken up immediately. - for _ in 0..3 { - tx.send(0).unwrap(); - } - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 3); - - tx.send(0).unwrap(); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 4); - } - - #[test] - fn test_batch_receiver() { - let (tx, rx) = unbounded::(WakePolicy::TillReach(4)); - - let len = Arc::new(AtomicUsize::new(0)); - let l = len.clone(); - let rx = BatchReceiver::new(rx, 8, || Vec::with_capacity(4), Vec::push); - let (_pool, msg_counter) = spawn_and_wait(move || { - stream::unfold((rx, l), |(mut rx, l)| async move { - rx.next().await.map(|i| { - l.fetch_add(i.len(), Ordering::SeqCst); - (i, (rx, l)) - }) - }) - }); - - tx.send(0).unwrap(); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::SeqCst), 0); - - // Auto notify with more messages. - for _ in 0..16 { - tx.send(0).unwrap(); - } - thread::sleep(time::Duration::from_millis(10)); - let batch_count = msg_counter.load(Ordering::SeqCst); - assert!(batch_count < 17, "{}", batch_count); - assert_eq!(len.load(Ordering::SeqCst), 17); - } - - #[test] - fn test_switch_between_sender_and_receiver() { - let (tx, mut rx) = unbounded::(WakePolicy::TillReach(4)); - let future = async move { rx.next().await }; - let task = Task { - future: Arc::new(Mutex::new(Some(future.boxed()))), - }; - // Receiver has not received any messages, so the future is not be finished - // in this tick. - task.tick(); - assert!(task.future.lock().unwrap().is_some()); - // After sender is dropped, the task will be waked and then it tick self - // again to advance the progress. - drop(tx); - assert!(task.future.lock().unwrap().is_none()); - } - - #[derive(Clone)] - struct Task { - future: Arc>>>>, - } - - impl Task { - fn tick(&self) { - let task = Arc::new(self.clone()); - let mut future_slot = self.future.lock().unwrap(); - if let Some(mut future) = future_slot.take() { - let waker = task::waker_ref(&task); - let cx = &mut Context::from_waker(&waker); - match future.as_mut().poll(cx) { - Poll::Pending => { - *future_slot = Some(future); - } - Poll::Ready(None) => {} - _ => unimplemented!(), - } - } - } - } - - impl ArcWake for Task { - fn wake_by_ref(arc_self: &Arc) { - arc_self.tick(); - } - } - - #[derive(Default)] - struct SetOnDrop(Arc); - - impl Drop for SetOnDrop { - fn drop(&mut self) { - self.0.store(true, Ordering::Release); - } - } - - #[test] - fn test_drop() { - let dropped = Arc::new(AtomicBool::new(false)); - let (tx, rx) = super::unbounded(WakePolicy::Immediately); - tx.send(SetOnDrop(dropped.clone())).unwrap(); - drop(tx); - assert!(!dropped.load(Ordering::SeqCst)); - - drop(rx); - assert!(dropped.load(Ordering::SeqCst)); - - let dropped = Arc::new(AtomicBool::new(false)); - let (tx, rx) = super::unbounded(WakePolicy::Immediately); - tx.send(SetOnDrop(dropped.clone())).unwrap(); - drop(rx); - assert!(!dropped.load(Ordering::SeqCst)); - - tx.send(SetOnDrop::default()).unwrap_err(); - let tx1 = tx.clone(); - drop(tx); - assert!(!dropped.load(Ordering::SeqCst)); - - tx1.send(SetOnDrop::default()).unwrap_err(); - drop(tx1); - assert!(dropped.load(Ordering::SeqCst)); - } -} diff --git a/components/tikv_util/src/mpsc/mod.rs b/components/tikv_util/src/mpsc/mod.rs index 45249fed9bc..ccec5448d0b 100644 --- a/components/tikv_util/src/mpsc/mod.rs +++ b/components/tikv_util/src/mpsc/mod.rs @@ -3,7 +3,7 @@ //! This module provides an implementation of mpsc channel based on //! crossbeam_channel. Comparing to the crossbeam_channel, this implementation //! supports closed detection and try operations. -pub mod future; +pub mod batch; use std::{ cell::Cell, diff --git a/src/server/service/batch.rs b/src/server/service/batch.rs index ba377bed4d2..15a755c3468 100644 --- a/src/server/service/batch.rs +++ b/src/server/service/batch.rs @@ -3,11 +3,7 @@ // #[PerformanceCriticalPath] use api_version::KvFormat; use kvproto::kvrpcpb::*; -use tikv_util::{ - future::poll_future_notify, - mpsc::future::{Sender, WakePolicy}, - time::Instant, -}; +use tikv_util::{future::poll_future_notify, mpsc::batch::Sender, time::Instant}; use tracker::{with_tls_tracker, RequestInfo, RequestType, Tracker, TrackerToken, GLOBAL_TRACKERS}; use crate::{ @@ -188,7 +184,7 @@ impl ResponseBatchConsumer<(Option>, Statistics)> for GetCommandResponse let mesure = GrpcRequestDuration::new(begin, GrpcTypeKind::kv_batch_get_command, request_source); let task = MeasuredSingleResponse::new(id, res, mesure); - if self.tx.send_with(task, WakePolicy::Immediately).is_err() { + if self.tx.send_and_notify(task).is_err() { error!("KvService response batch commands fail"); } } @@ -219,7 +215,7 @@ impl ResponseBatchConsumer>> for GetCommandResponseConsumer { let mesure = GrpcRequestDuration::new(begin, GrpcTypeKind::raw_batch_get_command, request_source); let task = MeasuredSingleResponse::new(id, res, mesure); - if self.tx.send_with(task, WakePolicy::Immediately).is_err() { + if self.tx.send_and_notify(task).is_err() { error!("KvService response batch commands fail"); } } @@ -268,7 +264,7 @@ fn future_batch_get_command( source, ); let task = MeasuredSingleResponse::new(id, res, measure); - if tx.send_with(task, WakePolicy::Immediately).is_err() { + if tx.send_and_notify(task).is_err() { error!("KvService response batch commands fail"); } } @@ -314,7 +310,7 @@ fn future_batch_raw_get_command( source, ); let task = MeasuredSingleResponse::new(id, res, measure); - if tx.send_with(task, WakePolicy::Immediately).is_err() { + if tx.send_and_notify(task).is_err() { error!("KvService response batch commands fail"); } } diff --git a/src/server/service/kv.rs b/src/server/service/kv.rs index 35deb7e4107..ab2fc41c47c 100644 --- a/src/server/service/kv.rs +++ b/src/server/service/kv.rs @@ -39,7 +39,7 @@ use raftstore::{ use tikv_alloc::trace::MemoryTraceGuard; use tikv_util::{ future::{paired_future_callback, poll_future_notify}, - mpsc::future::{unbounded, BatchReceiver, Sender, WakePolicy}, + mpsc::batch::{unbounded, BatchCollector, BatchReceiver, Sender}, sys::memory_usage_reaches_high_water, time::{duration_to_ms, duration_to_sec, Instant}, worker::Scheduler, @@ -1049,7 +1049,7 @@ impl + 'static, E: Engine, L: LockManager, F: KvFor mut sink: DuplexSink, ) { forward_duplex!(self.proxy, batch_commands, ctx, stream, sink); - let (tx, rx) = unbounded(WakePolicy::TillReach(GRPC_MSG_NOTIFY_SIZE)); + let (tx, rx) = unbounded(GRPC_MSG_NOTIFY_SIZE); let ctx = Arc::new(ctx); let peer = ctx.peer(); @@ -1093,7 +1093,7 @@ impl + 'static, E: Engine, L: LockManager, F: KvFor rx, GRPC_MSG_MAX_BATCH_SIZE, MeasuredBatchResponse::default, - collect_batch_resp, + BatchRespCollector, ); let mut response_retriever = response_retriever.map(move |mut item| { @@ -1268,7 +1268,7 @@ fn response_batch_commands_request( source, }; let task = MeasuredSingleResponse::new(id, resp, measure); - if let Err(e) = tx.send_with(task, WakePolicy::Immediately) { + if let Err(e) = tx.send_and_notify(task) { error!("KvService response batch commands fail"; "err" => ?e); } } @@ -2354,10 +2354,18 @@ impl Default for MeasuredBatchResponse { } } -fn collect_batch_resp(v: &mut MeasuredBatchResponse, mut e: MeasuredSingleResponse) { - v.batch_resp.mut_request_ids().push(e.id); - v.batch_resp.mut_responses().push(e.resp.consume()); - v.measures.push(e.measure); +struct BatchRespCollector; +impl BatchCollector for BatchRespCollector { + fn collect( + &mut self, + v: &mut MeasuredBatchResponse, + mut e: MeasuredSingleResponse, + ) -> Option { + v.batch_resp.mut_request_ids().push(e.id); + v.batch_resp.mut_responses().push(e.resp.consume()); + v.measures.push(e.measure); + None + } } fn raftstore_error_to_region_error(e: RaftStoreError, region_id: u64) -> RegionError { diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 5c573b6e809..b155ae4ab87 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -39,6 +39,11 @@ name = "deadlock_detector" harness = false path = "benches/deadlock_detector/mod.rs" +[[bench]] +name = "channel" +path = "benches/channel/mod.rs" +test = true + [features] default = ["failpoints", "testexport", "test-engine-kv-rocksdb", "test-engine-raft-raft-engine", "cloud-aws", "cloud-gcp", "cloud-azure"] failpoints = ["fail/failpoints", "tikv/failpoints"] diff --git a/components/tikv_util/benches/channel/bench_channel.rs b/tests/benches/channel/bench_channel.rs similarity index 87% rename from components/tikv_util/benches/channel/bench_channel.rs rename to tests/benches/channel/bench_channel.rs index 6867aab0f56..eb69412046d 100644 --- a/components/tikv_util/benches/channel/bench_channel.rs +++ b/tests/benches/channel/bench_channel.rs @@ -113,8 +113,8 @@ fn bench_crossbeam_channel(b: &mut Bencher) { } #[bench] -fn bench_receiver_stream_unbounded_batch(b: &mut Bencher) { - let (tx, rx) = mpsc::future::unbounded::(mpsc::future::WakePolicy::TillReach(8)); +fn bench_receiver_stream_batch(b: &mut Bencher) { + let (tx, rx) = mpsc::batch::bounded::(128, 8); for _ in 0..1 { let tx1 = tx.clone(); thread::spawn(move || { @@ -124,9 +124,12 @@ fn bench_receiver_stream_unbounded_batch(b: &mut Bencher) { }); } - let rx = mpsc::future::BatchReceiver::new(rx, 32, Vec::new, Vec::push); - - let mut rx = Some(block_on(rx.into_future()).1); + let mut rx = Some(mpsc::batch::BatchReceiver::new( + rx, + 32, + Vec::new, + mpsc::batch::VecCollector, + )); b.iter(|| { let mut count = 0; @@ -147,8 +150,8 @@ fn bench_receiver_stream_unbounded_batch(b: &mut Bencher) { } #[bench] -fn bench_receiver_stream_unbounded_nobatch(b: &mut Bencher) { - let (tx, rx) = mpsc::future::unbounded::(mpsc::future::WakePolicy::Immediately); +fn bench_receiver_stream(b: &mut Bencher) { + let (tx, rx) = mpsc::batch::bounded::(128, 1); for _ in 0..1 { let tx1 = tx.clone(); thread::spawn(move || { @@ -158,7 +161,7 @@ fn bench_receiver_stream_unbounded_nobatch(b: &mut Bencher) { }); } - let mut rx = Some(block_on(rx.into_future()).1); + let mut rx = Some(rx); b.iter(|| { let mut count = 0; let mut rx1 = rx.take().unwrap(); diff --git a/components/tikv_util/benches/channel/mod.rs b/tests/benches/channel/mod.rs similarity index 100% rename from components/tikv_util/benches/channel/mod.rs rename to tests/benches/channel/mod.rs