Skip to content

Commit

Permalink
Remove const generic arguments on Channel
Browse files Browse the repository at this point in the history
  • Loading branch information
AldaronLau committed Aug 6, 2022
1 parent 6987e6e commit 7518cf8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 105 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://github.com/AldaronLau/

### Removed
- `Message`
- Const generic arguments on `Channel`

### Fixed
- Bug with wakers when using MPMC functionality that could possibly trigger UB

## [0.4.1] - 2022-07-23
### Fixed
Expand Down
182 changes: 77 additions & 105 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,13 @@

extern crate alloc;

use alloc::sync::{self, Arc};
use alloc::{
sync::{self, Arc},
vec::Vec,
};
use core::{
cell::UnsafeCell,
future::Future,
mem::MaybeUninit,
pin::Pin,
sync::atomic::{
self, AtomicBool,
Expand All @@ -109,43 +111,53 @@ use core::{
};

#[allow(unsafe_code)]
mod list {
mod wake {
use super::*;

/// A list
#[derive(Debug)]
pub(super) struct List<T, const N: usize> {
data: [MaybeUninit<T>; N],
size: usize,
/// Type for waking on send or receive
#[derive(Debug, Default)]
pub(super) struct Wake {
/// Channel unique identifier (the arc pointer casted to usize)
chan: usize,
/// Channel waker
wake: Option<Waker>,
/// Heap wakers
list: Vec<(usize, Waker)>,
}

impl<T, const N: usize> List<T, N> {
#[inline]
pub(super) fn new() -> Self {
let data = unsafe {
MaybeUninit::<[MaybeUninit<T>; N]>::uninit().assume_init()
};
let size = 0;

Self { data, size }
}

impl Wake {
/// Register a waker for a channel
#[inline(always)]
pub(super) fn push(&mut self, item: T, idx: usize) -> usize {
let idx = if idx == usize::MAX { self.size } else { idx };
self.data[idx] = MaybeUninit::new(item);
self.size += 1;
idx
pub(super) fn register(&mut self, chan: usize, waker: Waker) {
if self.list.is_empty() {
if let Some(wake) = self.wake.take() {
if self.chan == chan {
(self.chan, self.wake) = (chan, Some(waker));
} else {
self.list.extend([(self.chan, wake), (chan, waker)]);
}
} else {
(self.chan, self.wake) = (chan, Some(waker));
}
} else {
if let Some(wake) = self.list.iter_mut().find(|w| w.0 == chan) {
wake.1 = waker;
} else {
self.list.push((chan, waker));
}
}
}

/// Wake all channels and de-register all wakers
#[inline(always)]
pub(super) fn drain(&mut self) -> impl Iterator<Item = T> + '_ {
let mut size = 0;
(size, self.size) = (self.size, size);
self.data
.iter()
.take(size)
.map(|t| unsafe { t.assume_init_read() })
pub(super) fn wake(&mut self) {
if let Some(waker) = self.wake.take() {
waker.wake();
return;
}
for waker in self.list.drain(..) {
waker.1.wake();
}
}
}
}
Expand Down Expand Up @@ -183,87 +195,70 @@ mod spin {
}

#[derive(Debug)]
struct Locked<T: Send, const S: usize, const R: usize> {
struct Locked<T: Send> {
/// Receive wakers
recv: list::List<Waker, R>,
recv: wake::Wake,
/// Send wakers
send: list::List<Waker, S>,
send: wake::Wake,
/// Data in transit
data: Option<T>,
}

impl<T: Send, const S: usize, const R: usize> Default for Locked<T, S, R> {
impl<T: Send> Default for Locked<T> {
#[inline]
fn default() -> Self {
let data = None;
let send = list::List::new();
let recv = list::List::new();
let send = wake::Wake::default();
let recv = wake::Wake::default();

Self { data, send, recv }
}
}

#[derive(Debug, Default)]
struct Shared<T: Send, const S: usize, const R: usize> {
spin: spin::Spin<Locked<T, S, R>>,
struct Shared<T: Send> {
spin: spin::Spin<Locked<T>>,
}

/// A `Channel` notifies when another `Channel` sends a message.
///
/// Implemented as a multi-producer/multi-consumer queue of size 1.
///
/// Const generic `S` is the upper bound on the number of channels that can be
/// sending at once (doesn't include inactive channels).
///
/// Const generic `R` is the upper bound on the number of channels that can be
/// receiving at once (doesn't include inactive channels).
///
/// Enable the **`futures-core`** feature for `Channel` to implement
/// [`Stream`](futures_core::Stream) (generic `T` must be `Option<Item>`).
///
/// Enable the **`pasts`** feature for `Channel` to implement
/// [`Notifier`](pasts::Notifier).
#[derive(Debug)]
pub struct Channel<T: Send + Unpin, const S: usize = 1, const R: usize = 1>(
Arc<Shared<T, S, R>>,
usize,
);

impl<T, const S: usize, const R: usize> Clone for Channel<T, S, R>
where
T: Send + Unpin,
{
pub struct Channel<T: Send + Unpin>(Arc<Shared<T>>);

impl<T: Send + Unpin> Clone for Channel<T> {
#[inline]
fn clone(&self) -> Self {
Channel(Arc::clone(&self.0), usize::MAX)
Self(Arc::clone(&self.0))
}
}

impl<T, const S: usize, const R: usize> Default for Channel<T, S, R>
where
T: Send + Unpin,
{
impl<T: Send + Unpin> Default for Channel<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}

impl<T: Send + Unpin, const S: usize, const R: usize> Channel<T, S, R> {
impl<T: Send + Unpin> Channel<T> {
/// Create a new channel.
#[inline]
pub fn new() -> Self {
let spin = spin::Spin::default();

Self(Arc::new(Shared { spin }), usize::MAX)
Self(Arc::new(Shared { spin }))
}

/// Send a message on this channel.
#[inline(always)]
pub fn send(&self, message: T) -> impl Future<Output = ()> + Send + Unpin {
let mut chan = (*self).clone();
chan.1 = usize::MAX;
Message(chan, Some(message))
Message((*self).clone(), Some(message))
}

/// Receive a message from this channel.
Expand All @@ -276,40 +271,33 @@ impl<T: Send + Unpin, const S: usize, const R: usize> Channel<T, S, R> {

/// Create a new corresponding [`Weak`] channel.
#[inline]
pub fn downgrade(&self) -> Weak<T, S, R> {
pub fn downgrade(&self) -> Weak<T> {
Weak(Arc::downgrade(&self.0))
}
}

impl<T, const S: usize, const R: usize> Future for Channel<T, S, R>
where
T: Send + Unpin,
{
impl<T: Send + Unpin> Future for Channel<T> {
type Output = T;

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let waker = cx.waker();
let uid = Arc::as_ptr(&this.0) as usize;
this.0.spin.with(|shared| {
if let Some(output) = shared.data.take() {
for waker in shared.send.drain() {
waker.wake();
}
shared.send.wake();
Ready(output)
} else {
this.1 = shared.recv.push(waker.clone(), this.1);
shared.recv.register(uid, waker.clone());
Pending
}
})
}
}

#[cfg(feature = "pasts")]
impl<T, const S: usize, const R: usize> pasts::Notifier for Channel<T, S, R>
where
T: Send + Unpin,
{
impl<T: Send + Unpin> pasts::Notifier for Channel<T> {
type Event = T;

#[inline(always)]
Expand All @@ -319,11 +307,7 @@ where
}

#[cfg(feature = "futures-core")]
impl<T, const S: usize, const R: usize> futures_core::Stream
for Channel<Option<T>, S, R>
where
T: Send + Unpin,
{
impl<T: Send + Unpin> futures_core::Stream for Channel<Option<T>> {
type Item = T;

#[inline(always)]
Expand All @@ -337,11 +321,9 @@ where

/// A weak version of a `Channel`.
#[derive(Debug, Default)]
pub struct Weak<T: Send + Unpin, const S: usize = 1, const R: usize = 1>(
sync::Weak<Shared<T, S, R>>,
);
pub struct Weak<T: Send + Unpin>(sync::Weak<Shared<T>>);

impl<T: Send + Unpin, const S: usize, const R: usize> Weak<T, S, R> {
impl<T: Send + Unpin> Weak<T> {
/// Calling `upgrade()` will always return `None`.
#[inline]
pub fn new() -> Self {
Expand All @@ -350,47 +332,37 @@ impl<T: Send + Unpin, const S: usize, const R: usize> Weak<T, S, R> {

/// Attempt to upgrade the Weak channel to a [`Channel`].
#[inline]
pub fn upgrade(&self) -> Option<Channel<T, S, R>> {
Some(Channel(self.0.upgrade()?, usize::MAX))
pub fn upgrade(&self) -> Option<Channel<T>> {
Some(Channel(self.0.upgrade()?))
}
}

/// A message in the process of being sent over a [`Channel`].
#[derive(Debug)]
struct Message<T: Send + Unpin, const S: usize, const R: usize>(
Channel<T, S, R>,
Option<T>,
);

impl<T, const S: usize, const R: usize> Future for Message<T, S, R>
where
T: Send + Unpin,
{
struct Message<T: Send + Unpin>(Channel<T>, Option<T>);

impl<T: Send + Unpin> Future for Message<T> {
type Output = ();

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let waker = cx.waker();
let uid = Arc::as_ptr(&this.0 .0) as usize;
this.0 .0.spin.with(|shared| {
if shared.data.is_none() {
shared.data = this.1.take();
for waker in shared.recv.drain() {
waker.wake();
}
shared.recv.wake();
Ready(())
} else {
this.0 .1 = shared.send.push(waker.clone(), this.0 .1);
shared.send.register(uid, waker.clone());
Pending
}
})
}
}

impl<T, const S: usize, const R: usize> Drop for Message<T, S, R>
where
T: Send + Unpin,
{
impl<T: Send + Unpin> Drop for Message<T> {
fn drop(&mut self) {
if self.1.is_some() {
panic!("Message dropped without sending");
Expand Down

0 comments on commit 7518cf8

Please sign in to comment.