Skip to content

Commit

Permalink
Fix data race in send/recv (continue)
Browse files Browse the repository at this point in the history
The previous commit f31c466 was not complete, as the tests failed.

I had to rewrite the way we perform `send` to make it safe again.

The issue this time was mainly caused by inconsistent update of the `lap` variable.

the writer, will update the `lap`, and then it will start writing the node, the reader can slip in the middle of these 2 events, and get an incorrect lap value.
  • Loading branch information
Amjad50 committed Feb 22, 2024
1 parent f31c466 commit df01906
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 66 deletions.
152 changes: 87 additions & 65 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ use core::{
cell::{Cell, UnsafeCell},
cmp,
mem::MaybeUninit,
sync::atomic::AtomicU64,
};

#[cfg(loom)]
Expand All @@ -114,6 +115,21 @@ const STATE_WRITING: usize = 2;
const STATE_READING: usize = 4;
const READING_MASK: usize = usize::MAX & !(STATE_READING - 1);

// extracts the lap and index
// top 32bits are the lap
// bottom 32bits are the index
const fn unpack_data_index(index: u64) -> (usize, usize) {
let lap = (index >> 32) as usize;
let index = (index & 0xFFFFFFFF) as usize;
(lap, index)
}

const fn pack_data_index(lap: usize, index: usize) -> u64 {
debug_assert!(lap < (1 << 32));
debug_assert!(index < (1 << 32));
((lap as u64) << 32) | ((index & 0xFFFFFFFF) as u64)
}

#[inline]
fn is_reading(state: usize) -> bool {
state & READING_MASK != 0
Expand Down Expand Up @@ -148,13 +164,9 @@ struct ReaderData {

struct InnerChannel<T, const N: usize> {
buffer: Box<[Node<T>]>,
head: AtomicUsize,
producer_lap: AtomicUsize,
head: AtomicU64,
}

unsafe impl<T: Clone + Send, const N: usize> Send for InnerChannel<T, N> {}
unsafe impl<T: Clone + Send, const N: usize> Sync for InnerChannel<T, N> {}

impl<T: Clone, const N: usize> InnerChannel<T, N> {
fn new() -> Self {
let mut buffer = Vec::with_capacity(N);
Expand All @@ -164,70 +176,77 @@ impl<T: Clone, const N: usize> InnerChannel<T, N> {
let buffer = buffer.into_boxed_slice();
Self {
buffer,
head: AtomicUsize::new(0),
producer_lap: AtomicUsize::new(0),
head: AtomicU64::new(0),
}
}

fn push(&self, value: T) {
let mut current_head;
let mut next_head;
let mut node;
let mut should_drop;

let mut current_head = self.head.load(Ordering::Acquire);

// reserves a slot for writing not shared with other writers
loop {
current_head = self.head.load(Ordering::Acquire);
next_head = (current_head + 1) % self.buffer.len();

if self
.head
.compare_exchange_weak(
current_head,
next_head,
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
let (producer_lap, producer_index) = unpack_data_index(current_head);

let current_lap = self.producer_lap.load(Ordering::Relaxed);
if next_head == 0 {
self.producer_lap.fetch_add(1, Ordering::Release);
}
node = &self.buffer[producer_index % self.buffer.len()];

let node = &self.buffer[current_head % self.buffer.len()];
// acquire the node
let mut state;
loop {
state = node.state.load(Ordering::Acquire);

let should_drop;
loop {
let mut state = node.state.load(Ordering::Acquire);
while is_reading(state) {
core::hint::spin_loop();
// wait until the reader is done
state = node.state.load(Ordering::Acquire);
}

while is_reading(state) {
core::hint::spin_loop();
// wait until the reader is done
state = node.state.load(Ordering::Acquire);
match state {
STATE_EMPTY | STATE_AVAILABLE => {
if node
.state
.compare_exchange_weak(
state,
STATE_WRITING,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
should_drop = state == STATE_AVAILABLE;
break;
}
}
STATE_WRITING => unreachable!("There should be no writer writing"),
s => unreachable!("Invalid state: {}", s),
}
}

match state {
STATE_EMPTY | STATE_AVAILABLE => {
if node
.state
.compare_exchange_weak(
state,
STATE_WRITING,
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
{
should_drop = state == STATE_AVAILABLE;
break;
}
let next_index = (producer_index + 1) % self.buffer.len();
let next_lap = if next_index == 0 {
producer_lap + 1
} else {
producer_lap
};
let next_head = pack_data_index(next_lap, next_index);

match self.head.compare_exchange(
current_head,
next_head,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => {
node.lap.set(producer_lap);
break;
}
Err(x) => {
current_head = x;
}
STATE_WRITING => unreachable!("There should be no writer writing"),
s => unreachable!("Invalid state: {}", s),
}
// rollback and try again
node.state.store(state, Ordering::Release);
}

if should_drop {
Expand All @@ -241,53 +260,50 @@ impl<T: Clone, const N: usize> InnerChannel<T, N> {
unsafe {
node.data.get().write(MaybeUninit::new(value));
}
node.lap.set(current_lap);

// publish the value
// release the node
node.state.store(STATE_AVAILABLE, Ordering::Release);
}

fn pop(&self, reader: &mut ReaderData) -> Option<T> {
let current_head = self.head.load(Ordering::Acquire);
let (producer_lap, producer_index) = unpack_data_index(self.head.load(Ordering::Acquire));
let mut reader_index = reader.index;
let reader_lap = reader.lap;

let producer_lap = self.producer_lap.load(Ordering::Relaxed);
match reader_lap.cmp(&producer_lap) {
// the reader is before the writer
// so there must be something to read
cmp::Ordering::Less => {
let lap_diff = producer_lap - reader_lap;
let head_diff = current_head as isize - reader_index as isize;
let head_diff = producer_index as isize - reader_index as isize;

if (lap_diff > 0 && head_diff > 0) || lap_diff > 1 {
// there is an overflow
// we need to update the reader index
// we will take the latest readable value, the furthest from the writer
// and this is the value at [head, producer_lap - 1]
let new_index = current_head % self.buffer.len();
let new_index = producer_index % self.buffer.len();
reader_index = new_index;

reader.lap = producer_lap - 1;
}
}
cmp::Ordering::Equal => {
if reader_index >= current_head {
if reader_index >= producer_index {
return None;
}
}
cmp::Ordering::Greater => {
unreachable!("The reader is after the writer");
}
}

let mut node = &self.buffer[reader_index % self.buffer.len()];

// acquire the node
loop {
let state = node.state.load(Ordering::Acquire);

if is_readable(state) {
let old = node.state.fetch_add(STATE_READING, Ordering::Release);
let old = node.state.fetch_add(STATE_READING, Ordering::AcqRel);

if is_readable(old) {
break;
Expand Down Expand Up @@ -355,6 +371,9 @@ pub struct Sender<T, const N: usize> {
queue: Arc<InnerChannel<T, N>>,
}

unsafe impl<T: Clone + Send, const N: usize> Send for Sender<T, N> {}
unsafe impl<T: Clone + Send, const N: usize> Sync for Sender<T, N> {}

impl<T: Clone, const N: usize> Sender<T, N> {
/// Sends a message to the channel.
/// If the channel is full, the oldest message will be overwritten.
Expand Down Expand Up @@ -406,6 +425,9 @@ pub struct Receiver<T, const N: usize> {
reader: ReaderData,
}

unsafe impl<T: Clone + Send, const N: usize> Send for Receiver<T, N> {}
unsafe impl<T: Clone + Send, const N: usize> Sync for Receiver<T, N> {}

impl<T: Clone, const N: usize> Receiver<T, N> {
/// Receives a message from the channel.
///
Expand Down
2 changes: 1 addition & 1 deletion src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ fn test_sender_receiver_conflict() {

let barrier = Arc::new(std::sync::Barrier::new(2));

for _ in 0..10000 {
for _ in 0..10 {
// setup
// fill the channel
for i in 0..4 {
Expand Down

0 comments on commit df01906

Please sign in to comment.