Skip to content

Commit

Permalink
Auto merge of rust-lang#2638 - DrMeepster:windows-condvars, r=RalfJung
Browse files Browse the repository at this point in the history
Implement condvars for Windows

Adds 3 shims for Windows: `SleepConditionVariableSRW`, `WakeConditionVariable`, `WakeAllConditionVariable` to add support for condvars (which fixes rust-lang#2628).

Salvaged from what was removed from rust-lang#2231
  • Loading branch information
bors committed Nov 6, 2022
2 parents 4a3ed29 + 958ca31 commit bf0f47b
Show file tree
Hide file tree
Showing 9 changed files with 504 additions and 30 deletions.
24 changes: 18 additions & 6 deletions src/concurrency/sync.rs
Expand Up @@ -116,13 +116,25 @@ struct RwLock {

declare_id!(CondvarId);

#[derive(Debug, Copy, Clone)]
pub enum RwLockMode {
Read,
Write,
}

#[derive(Debug)]
pub enum CondvarLock {
Mutex(MutexId),
RwLock { id: RwLockId, mode: RwLockMode },
}

/// A thread waiting on a conditional variable.
#[derive(Debug)]
struct CondvarWaiter {
/// The thread that is waiting on this variable.
thread: ThreadId,
/// The mutex on which the thread is waiting.
mutex: MutexId,
/// The mutex or rwlock on which the thread is waiting.
lock: CondvarLock,
}

/// The conditional variable state.
Expand Down Expand Up @@ -569,16 +581,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}

/// Mark that the thread is waiting on the conditional variable.
fn condvar_wait(&mut self, id: CondvarId, thread: ThreadId, mutex: MutexId) {
fn condvar_wait(&mut self, id: CondvarId, thread: ThreadId, lock: CondvarLock) {
let this = self.eval_context_mut();
let waiters = &mut this.machine.threads.sync.condvars[id].waiters;
assert!(waiters.iter().all(|waiter| waiter.thread != thread), "thread is already waiting");
waiters.push_back(CondvarWaiter { thread, mutex });
waiters.push_back(CondvarWaiter { thread, lock });
}

/// Wake up some thread (if there is any) sleeping on the conditional
/// variable.
fn condvar_signal(&mut self, id: CondvarId) -> Option<(ThreadId, MutexId)> {
fn condvar_signal(&mut self, id: CondvarId) -> Option<(ThreadId, CondvarLock)> {
let this = self.eval_context_mut();
let current_thread = this.get_active_thread();
let condvar = &mut this.machine.threads.sync.condvars[id];
Expand All @@ -592,7 +604,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
if let Some(data_race) = data_race {
data_race.validate_lock_acquire(&condvar.data_race, waiter.thread);
}
(waiter.thread, waiter.mutex)
(waiter.thread, waiter.lock)
})
}

Expand Down
21 changes: 15 additions & 6 deletions src/shims/unix/sync.rs
Expand Up @@ -3,6 +3,7 @@ use std::time::SystemTime;
use rustc_hir::LangItem;
use rustc_middle::ty::{layout::TyAndLayout, query::TyCtxtAt, Ty};

use crate::concurrency::sync::CondvarLock;
use crate::concurrency::thread::{MachineCallback, Time};
use crate::*;

Expand Down Expand Up @@ -696,8 +697,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn pthread_cond_signal(&mut self, cond_op: &OpTy<'tcx, Provenance>) -> InterpResult<'tcx, i32> {
let this = self.eval_context_mut();
let id = this.condvar_get_or_create_id(cond_op, CONDVAR_ID_OFFSET)?;
if let Some((thread, mutex)) = this.condvar_signal(id) {
post_cond_signal(this, thread, mutex)?;
if let Some((thread, lock)) = this.condvar_signal(id) {
if let CondvarLock::Mutex(mutex) = lock {
post_cond_signal(this, thread, mutex)?;
} else {
panic!("condvar should not have an rwlock on unix");
}
}

Ok(0)
Expand All @@ -710,8 +715,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let this = self.eval_context_mut();
let id = this.condvar_get_or_create_id(cond_op, CONDVAR_ID_OFFSET)?;

while let Some((thread, mutex)) = this.condvar_signal(id) {
post_cond_signal(this, thread, mutex)?;
while let Some((thread, lock)) = this.condvar_signal(id) {
if let CondvarLock::Mutex(mutex) = lock {
post_cond_signal(this, thread, mutex)?;
} else {
panic!("condvar should not have an rwlock on unix");
}
}

Ok(0)
Expand All @@ -729,7 +738,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let active_thread = this.get_active_thread();

release_cond_mutex_and_block(this, active_thread, mutex_id)?;
this.condvar_wait(id, active_thread, mutex_id);
this.condvar_wait(id, active_thread, CondvarLock::Mutex(mutex_id));

Ok(0)
}
Expand Down Expand Up @@ -768,7 +777,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
};

release_cond_mutex_and_block(this, active_thread, mutex_id)?;
this.condvar_wait(id, active_thread, mutex_id);
this.condvar_wait(id, active_thread, CondvarLock::Mutex(mutex_id));

// We return success for now and override it in the timeout callback.
this.write_scalar(Scalar::from_i32(0), dest)?;
Expand Down
19 changes: 19 additions & 0 deletions src/shims/windows/foreign_items.rs
Expand Up @@ -273,6 +273,25 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let result = this.InitOnceComplete(ptr, flags, context)?;
this.write_scalar(result, dest)?;
}
"SleepConditionVariableSRW" => {
let [condvar, lock, timeout, flags] =
this.check_shim(abi, Abi::System { unwind: false }, link_name, args)?;

let result = this.SleepConditionVariableSRW(condvar, lock, timeout, flags, dest)?;
this.write_scalar(result, dest)?;
}
"WakeConditionVariable" => {
let [condvar] =
this.check_shim(abi, Abi::System { unwind: false }, link_name, args)?;

this.WakeConditionVariable(condvar)?;
}
"WakeAllConditionVariable" => {
let [condvar] =
this.check_shim(abi, Abi::System { unwind: false }, link_name, args)?;

this.WakeAllConditionVariable(condvar)?;
}

// Dynamic symbol loading
"GetProcAddress" => {
Expand Down
161 changes: 161 additions & 0 deletions src/shims/windows/sync.rs
Expand Up @@ -3,11 +3,45 @@ use std::time::Duration;
use rustc_target::abi::Size;

use crate::concurrency::init_once::InitOnceStatus;
use crate::concurrency::sync::{CondvarLock, RwLockMode};
use crate::concurrency::thread::MachineCallback;
use crate::*;

const SRWLOCK_ID_OFFSET: u64 = 0;
const INIT_ONCE_ID_OFFSET: u64 = 0;
const CONDVAR_ID_OFFSET: u64 = 0;

impl<'mir, 'tcx> EvalContextExtPriv<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
/// Try to reacquire the lock associated with the condition variable after we
/// were signaled.
fn reacquire_cond_lock(
&mut self,
thread: ThreadId,
lock: RwLockId,
mode: RwLockMode,
) -> InterpResult<'tcx> {
let this = self.eval_context_mut();
this.unblock_thread(thread);

match mode {
RwLockMode::Read =>
if this.rwlock_is_write_locked(lock) {
this.rwlock_enqueue_and_block_reader(lock, thread);
} else {
this.rwlock_reader_lock(lock, thread);
},
RwLockMode::Write =>
if this.rwlock_is_locked(lock) {
this.rwlock_enqueue_and_block_writer(lock, thread);
} else {
this.rwlock_writer_lock(lock, thread);
},
}

Ok(())
}
}

impl<'mir, 'tcx> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
#[allow(non_snake_case)]
Expand Down Expand Up @@ -327,4 +361,131 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {

Ok(())
}

fn SleepConditionVariableSRW(
&mut self,
condvar_op: &OpTy<'tcx, Provenance>,
lock_op: &OpTy<'tcx, Provenance>,
timeout_op: &OpTy<'tcx, Provenance>,
flags_op: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
let this = self.eval_context_mut();

let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?;
let lock_id = this.rwlock_get_or_create_id(lock_op, SRWLOCK_ID_OFFSET)?;
let timeout_ms = this.read_scalar(timeout_op)?.to_u32()?;
let flags = this.read_scalar(flags_op)?.to_u32()?;

let timeout_time = if timeout_ms == this.eval_windows("c", "INFINITE")?.to_u32()? {
None
} else {
let duration = Duration::from_millis(timeout_ms.into());
Some(this.machine.clock.now().checked_add(duration).unwrap())
};

let shared_mode = 0x1; // CONDITION_VARIABLE_LOCKMODE_SHARED is not in std
let mode = if flags == 0 {
RwLockMode::Write
} else if flags == shared_mode {
RwLockMode::Read
} else {
throw_unsup_format!("unsupported `Flags` {flags} in `SleepConditionVariableSRW`");
};

let active_thread = this.get_active_thread();

let was_locked = match mode {
RwLockMode::Read => this.rwlock_reader_unlock(lock_id, active_thread),
RwLockMode::Write => this.rwlock_writer_unlock(lock_id, active_thread),
};

if !was_locked {
throw_ub_format!(
"calling SleepConditionVariableSRW with an SRWLock that is not locked by the current thread"
);
}

this.block_thread(active_thread);
this.condvar_wait(condvar_id, active_thread, CondvarLock::RwLock { id: lock_id, mode });

if let Some(timeout_time) = timeout_time {
struct Callback<'tcx> {
thread: ThreadId,
condvar_id: CondvarId,
lock_id: RwLockId,
mode: RwLockMode,
dest: PlaceTy<'tcx, Provenance>,
}

impl<'tcx> VisitTags for Callback<'tcx> {
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
let Callback { thread: _, condvar_id: _, lock_id: _, mode: _, dest } = self;
dest.visit_tags(visit);
}
}

impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for Callback<'tcx> {
fn call(&self, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
this.reacquire_cond_lock(self.thread, self.lock_id, self.mode)?;

this.condvar_remove_waiter(self.condvar_id, self.thread);

let error_timeout = this.eval_windows("c", "ERROR_TIMEOUT")?;
this.set_last_error(error_timeout)?;
this.write_scalar(this.eval_windows("c", "FALSE")?, &self.dest)?;
Ok(())
}
}

this.register_timeout_callback(
active_thread,
Time::Monotonic(timeout_time),
Box::new(Callback {
thread: active_thread,
condvar_id,
lock_id,
mode,
dest: dest.clone(),
}),
);
}

this.eval_windows("c", "TRUE")
}

fn WakeConditionVariable(&mut self, condvar_op: &OpTy<'tcx, Provenance>) -> InterpResult<'tcx> {
let this = self.eval_context_mut();
let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?;

if let Some((thread, lock)) = this.condvar_signal(condvar_id) {
if let CondvarLock::RwLock { id, mode } = lock {
this.reacquire_cond_lock(thread, id, mode)?;
this.unregister_timeout_callback_if_exists(thread);
} else {
panic!("mutexes should not exist on windows");
}
}

Ok(())
}

fn WakeAllConditionVariable(
&mut self,
condvar_op: &OpTy<'tcx, Provenance>,
) -> InterpResult<'tcx> {
let this = self.eval_context_mut();
let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?;

while let Some((thread, lock)) = this.condvar_signal(condvar_id) {
if let CondvarLock::RwLock { id, mode } = lock {
this.reacquire_cond_lock(thread, id, mode)?;
this.unregister_timeout_callback_if_exists(thread);
} else {
panic!("mutexes should not exist on windows");
}
}

Ok(())
}
}
20 changes: 4 additions & 16 deletions tests/pass/concurrency/sync.rs
Expand Up @@ -230,20 +230,8 @@ fn main() {
check_once();
park_timeout();
park_unpark();

if !cfg!(windows) {
// ignore-target-windows: Condvars on Windows are not supported yet
check_barriers();
check_conditional_variables_notify_one();
check_conditional_variables_timed_wait_timeout();
check_conditional_variables_timed_wait_notimeout();
} else {
// We need to fake the same output...
for _ in 0..10 {
println!("before wait");
}
for _ in 0..10 {
println!("after wait");
}
}
check_barriers();
check_conditional_variables_notify_one();
check_conditional_variables_timed_wait_timeout();
check_conditional_variables_timed_wait_notimeout();
}
1 change: 0 additions & 1 deletion tests/pass/concurrency/sync_nopreempt.rs
@@ -1,4 +1,3 @@
//@ignore-target-windows: Condvars on Windows are not supported yet.
// We are making scheduler assumptions here.
//@compile-flags: -Zmiri-strict-provenance -Zmiri-preemption-rate=0

Expand Down

0 comments on commit bf0f47b

Please sign in to comment.