From 1710088ca1eff5e87f4c43470712b256c760e27c Mon Sep 17 00:00:00 2001 From: dragon-zhang Date: Thu, 16 Jan 2025 10:23:07 +0800 Subject: [PATCH] basic support IOCP Operator --- .github/workflows/ci-preemptive.sh | 10 + .github/workflows/ci.sh | 10 + core/Cargo.toml | 7 + core/src/net/event_loop.rs | 84 ++++- core/src/net/mod.rs | 37 +- core/src/net/operator/mod.rs | 6 + core/src/net/operator/windows/mod.rs | 482 +++++++++++++++++++++++++ core/src/net/operator/windows/tests.rs | 195 ++++++++++ core/src/syscall/windows/WSASend.rs | 12 +- core/src/syscall/windows/mod.rs | 212 +++++++++-- core/src/syscall/windows/shutdown.rs | 12 +- hook/Cargo.toml | 6 + hook/src/syscall/windows.rs | 2 +- open-coroutine/Cargo.toml | 6 + open-coroutine/build.rs | 6 + 15 files changed, 1035 insertions(+), 52 deletions(-) create mode 100644 core/src/net/operator/windows/mod.rs create mode 100644 core/src/net/operator/windows/tests.rs diff --git a/.github/workflows/ci-preemptive.sh b/.github/workflows/ci-preemptive.sh index 4a2598b6..fa2fe6d5 100644 --- a/.github/workflows/ci-preemptive.sh +++ b/.github/workflows/ci-preemptive.sh @@ -34,3 +34,13 @@ if [ "${TARGET}" = "x86_64-unknown-linux-gnu" ]; then "${CARGO}" test --target "${TARGET}" --no-default-features --features io_uring,preemptive,ci "${CARGO}" test --target "${TARGET}" --no-default-features --features io_uring,preemptive,ci --release fi + +# test IOCP +if [ "${OS}" = "windows-latest" ]; then + cd "${PROJECT_DIR}"/core + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,preemptive,ci + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,preemptive,ci --release + cd "${PROJECT_DIR}"/open-coroutine + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,preemptive,ci + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,preemptive,ci --release +fi diff --git a/.github/workflows/ci.sh b/.github/workflows/ci.sh index 5bddad7e..f9e4563e 100644 --- a/.github/workflows/ci.sh +++ b/.github/workflows/ci.sh @@ -34,3 +34,13 @@ if [ "${TARGET}" = "x86_64-unknown-linux-gnu" ]; then "${CARGO}" test --target "${TARGET}" --no-default-features --features io_uring,ci "${CARGO}" test --target "${TARGET}" --no-default-features --features io_uring,ci --release fi + +# test IOCP +if [ "${OS}" = "windows-latest" ]; then + cd "${PROJECT_DIR}"/core + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,ci + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,ci --release + cd "${PROJECT_DIR}"/open-coroutine + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,ci + "${CARGO}" test --target "${TARGET}" --no-default-features --features iocp,ci --release +fi diff --git a/core/Cargo.toml b/core/Cargo.toml index fb221583..11c95328 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -58,6 +58,7 @@ windows-sys = { workspace = true, features = [ "Win32_Networking_WinSock", "Win32_System_SystemInformation", "Win32_System_Diagnostics_Debug", + "Win32_System_WindowsProgramming", ] } polling = { workspace = true, optional = true } @@ -95,5 +96,11 @@ net = ["korosensei", "polling", "mio", "crossbeam-utils", "core_affinity"] # Provide io_uring adaptation, this feature only works in linux. io_uring = ["net", "io-uring"] +# Provide IOCP adaptation, this feature only works in windows. +iocp = ["net"] + +# Provide completion IO adaptation +completion_io = ["io_uring", "iocp"] + # Provide syscall implementation. syscall = ["net"] diff --git a/core/src/net/event_loop.rs b/core/src/net/event_loop.rs index 4e61f3da..41d3f25c 100644 --- a/core/src/net/event_loop.rs +++ b/core/src/net/event_loop.rs @@ -24,16 +24,34 @@ cfg_if::cfg_if! { } } +cfg_if::cfg_if! { + if #[cfg(all(windows, feature = "iocp"))] { + use dashmap::DashMap; + use std::ffi::{c_longlong, c_uint}; + use windows_sys::core::{PCSTR, PSTR}; + use windows_sys::Win32::Networking::WinSock::{ + LPWSAOVERLAPPED_COMPLETION_ROUTINE, SEND_RECV_FLAGS, SOCKADDR, SOCKET, WSABUF, + }; + use windows_sys::Win32::System::IO::OVERLAPPED; + } +} + #[repr(C)] #[derive(Debug)] pub(crate) struct EventLoop<'e> { stop: Arc<(Mutex, Condvar)>, shared_stop: Arc<(Mutex, Condvar)>, cpu: usize, - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") + ))] operator: crate::net::operator::Operator<'e>, #[allow(clippy::type_complexity)] - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") + ))] syscall_wait_table: DashMap>, Condvar)>>, selector: Poller, pool: CoroutinePool<'e>, @@ -87,9 +105,15 @@ impl<'e> EventLoop<'e> { stop: Arc::new((Mutex::new(false), Condvar::new())), shared_stop, cpu, - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") + ))] operator: crate::net::operator::Operator::new(cpu)?, - #[cfg(all(target_os = "linux", feature = "io_uring"))] + #[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") + ))] syscall_wait_table: DashMap::new(), selector: Poller::new()?, pool: CoroutinePool::new(name, stack_size, min_size, max_size, keep_alive_time), @@ -222,6 +246,8 @@ impl<'e> EventLoop<'e> { cfg_if::cfg_if! { if #[cfg(all(target_os = "linux", feature = "io_uring"))] { left_time = self.adapt_io_uring(left_time)?; + } else if #[cfg(all(windows, feature = "iocp"))] { + left_time = self.adapt_iocp(left_time)?; } } @@ -267,6 +293,28 @@ impl<'e> EventLoop<'e> { Ok(left_time) } + #[cfg(all(windows, feature = "iocp"))] + fn adapt_iocp(&self, mut left_time: Option) -> std::io::Result> { + // use IOCP + let (count, mut cq, left) = self.operator.select(left_time, 0)?; + if count > 0 { + for cqe in &mut cq { + let token = cqe.token; + if let Some((_, pair)) = self.syscall_wait_table.remove(&token) { + let (lock, cvar) = &*pair; + let mut pending = lock.lock().expect("lock failed"); + *pending = Some(cqe.result); + cvar.notify_one(); + } + unsafe { self.resume(token) }; + } + } + if left != left_time { + left_time = Some(left.unwrap_or(Duration::ZERO)); + } + Ok(left_time) + } + unsafe fn resume(&self, token: usize) { if COROUTINE_TOKENS.remove(&token).is_none() { return; @@ -446,6 +494,34 @@ impl_io_uring!(mkdirat(dirfd: c_int, pathname: *const c_char, mode: mode_t) -> c impl_io_uring!(renameat(olddirfd: c_int, oldpath: *const c_char, newdirfd: c_int, newpath: *const c_char) -> c_int); impl_io_uring!(renameat2(olddirfd: c_int, oldpath: *const c_char, newdirfd: c_int, newpath: *const c_char, flags: c_uint) -> c_int); +macro_rules! impl_iocp { + ( $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + #[cfg(all(windows, feature = "iocp"))] + impl EventLoop<'_> { + #[allow(non_snake_case, clippy::too_many_arguments)] + pub(super) fn $syscall( + &self, + $($arg: $arg_type),* + ) -> std::io::Result>, Condvar)>> { + let token = EventLoop::token(SyscallName::$syscall); + self.operator.$syscall(token, $($arg, )*)?; + let arc = Arc::new((Mutex::new(None), Condvar::new())); + assert!( + self.syscall_wait_table.insert(token, arc.clone()).is_none(), + "The previous token was not retrieved in a timely manner" + ); + Ok(arc) + } + } + } +} + +impl_iocp!(accept(fd: SOCKET, addr: *mut SOCKADDR, len: *mut c_int) -> c_int); +impl_iocp!(recv(fd: SOCKET, buf: PSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int); +impl_iocp!(WSARecv(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, lpflags : *mut c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine : LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); +impl_iocp!(send(fd: SOCKET, buf: PCSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int); +impl_iocp!(WSASend(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, dwflags : c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine : LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); + #[cfg(all(test, not(all(unix, feature = "preemptive"))))] mod tests { use crate::net::event_loop::EventLoop; diff --git a/core/src/net/mod.rs b/core/src/net/mod.rs index bcb2bf37..9a3d59dd 100644 --- a/core/src/net/mod.rs +++ b/core/src/net/mod.rs @@ -18,13 +18,27 @@ cfg_if::cfg_if! { } } +cfg_if::cfg_if! { + if #[cfg(all(windows, feature = "iocp"))] { + use std::ffi::c_uint; + use windows_sys::core::{PCSTR, PSTR}; + use windows_sys::Win32::Networking::WinSock::{ + LPWSAOVERLAPPED_COMPLETION_ROUTINE, SEND_RECV_FLAGS, SOCKADDR, SOCKET, WSABUF, + }; + use windows_sys::Win32::System::IO::OVERLAPPED; + } +} + /// 做C兼容时会用到 pub type UserFunc = extern "C" fn(usize) -> usize; mod selector; #[allow(clippy::too_many_arguments)] -#[cfg(all(target_os = "linux", feature = "io_uring"))] +#[cfg(any( + all(target_os = "linux", feature = "io_uring"), + all(windows, feature = "iocp") +))] mod operator; #[allow(missing_docs)] @@ -280,3 +294,24 @@ impl_io_uring!(fsync(fd: c_int) -> c_int); impl_io_uring!(mkdirat(dirfd: c_int, pathname: *const c_char, mode: mode_t) -> c_int); impl_io_uring!(renameat(olddirfd: c_int, oldpath: *const c_char, newdirfd: c_int, newpath: *const c_char) -> c_int); impl_io_uring!(renameat2(olddirfd: c_int, oldpath: *const c_char, newdirfd: c_int, newpath: *const c_char, flags: c_uint) -> c_int); + +macro_rules! impl_iocp { + ( $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + #[allow(non_snake_case)] + #[cfg(all(windows, feature = "iocp"))] + impl EventLoops { + #[allow(missing_docs)] + pub fn $syscall( + $($arg: $arg_type),* + ) -> std::io::Result>, Condvar)>> { + Self::event_loop().$syscall($($arg, )*) + } + } + } +} + +impl_iocp!(accept(fd: SOCKET, addr: *mut SOCKADDR, len: *mut c_int) -> c_int); +impl_iocp!(recv(fd: SOCKET, buf: PSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int); +impl_iocp!(WSARecv(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, lpflags : *mut c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine : LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); +impl_iocp!(send(fd: SOCKET, buf: PCSTR, len: c_int, flags: SEND_RECV_FLAGS) -> c_int); +impl_iocp!(WSASend(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, dwflags : c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine : LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); diff --git a/core/src/net/operator/mod.rs b/core/src/net/operator/mod.rs index 6a821a4d..1b246c8b 100644 --- a/core/src/net/operator/mod.rs +++ b/core/src/net/operator/mod.rs @@ -2,3 +2,9 @@ mod linux; #[cfg(all(target_os = "linux", feature = "io_uring"))] pub(crate) use linux::*; + +#[allow(non_snake_case)] +#[cfg(all(windows, feature = "iocp"))] +mod windows; +#[cfg(all(windows, feature = "iocp"))] +pub(crate) use windows::*; diff --git a/core/src/net/operator/windows/mod.rs b/core/src/net/operator/windows/mod.rs new file mode 100644 index 00000000..89ce0e89 --- /dev/null +++ b/core/src/net/operator/windows/mod.rs @@ -0,0 +1,482 @@ +use crate::common::constants::SyscallName; +use crate::common::{get_timeout_time, now}; +use crate::impl_display_by_debug; +use std::ffi::{c_int, c_longlong, c_uint}; +use std::io::{Error, ErrorKind}; +use std::marker::PhantomData; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{Duration, Instant}; +use windows_sys::core::{PCSTR, PSTR}; +use windows_sys::Win32::Foundation::{ + ERROR_INVALID_PARAMETER, FALSE, HANDLE, INVALID_HANDLE_VALUE, +}; +use windows_sys::Win32::Networking::WinSock::{ + getsockopt, setsockopt, AcceptEx, WSAGetLastError, WSARecv, WSASend, WSASocketW, + INVALID_SOCKET, LPCONDITIONPROC, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SEND_RECV_FLAGS, SOCKADDR, + SOCKADDR_IN, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_PROTOCOL_INFO, SO_UPDATE_ACCEPT_CONTEXT, + WSABUF, WSAPROTOCOL_INFOW, WSA_FLAG_OVERLAPPED, WSA_IO_PENDING, +}; +use windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes; +use windows_sys::Win32::System::WindowsProgramming::FILE_SKIP_SET_EVENT_ON_HANDLE; +use windows_sys::Win32::System::IO::{ + CreateIoCompletionPort, GetQueuedCompletionStatusEx, OVERLAPPED, OVERLAPPED_ENTRY, +}; + +#[cfg(test)] +mod tests; + +/// The overlapped struct we actually used for IOCP. +#[repr(C)] +#[derive(educe::Educe)] +#[educe(Debug)] +pub(crate) struct Overlapped { + /// The base [`OVERLAPPED`]. + #[educe(Debug(ignore))] + base: OVERLAPPED, + from_fd: SOCKET, + pub token: usize, + syscall_name: SyscallName, + socket: SOCKET, + pub result: c_longlong, +} + +impl Default for Overlapped { + fn default() -> Self { + unsafe { std::mem::zeroed() } + } +} + +impl_display_by_debug!(Overlapped); + +#[repr(C)] +#[derive(Debug)] +pub(crate) struct Operator<'o> { + cpu: usize, + iocp: HANDLE, + entering: AtomicBool, + phantom_data: PhantomData<&'o Overlapped>, +} + +impl<'o> Operator<'o> { + pub(crate) fn new(cpu: usize) -> std::io::Result { + let iocp = + unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, std::ptr::null_mut(), 0, 0) }; + if iocp.is_null() { + return Err(Error::last_os_error()); + } + Ok(Self { + cpu, + iocp, + entering: AtomicBool::new(false), + phantom_data: PhantomData, + }) + } + + /// Associates a new `HANDLE` to this I/O completion port. + /// + /// This function will associate the given handle to this port with the + /// given `token` to be returned in status messages whenever it receives a + /// notification. + /// + /// Any object which is convertible to a `HANDLE` via the `AsRawHandle` + /// trait can be provided to this function, such as `std::fs::File` and + /// friends. + fn add_handle(&self, handle: HANDLE) -> std::io::Result<()> { + unsafe { + let ret = CreateIoCompletionPort(handle, self.iocp, self.cpu, 0); + if ret.is_null() + && ERROR_INVALID_PARAMETER == WSAGetLastError().try_into().expect("overflow") + { + // duplicate bind + return Ok(()); + } + debug_assert_eq!(ret, self.iocp); + if SetFileCompletionNotificationModes( + handle, + u8::try_from(FILE_SKIP_SET_EVENT_ON_HANDLE).expect("overflow"), + ) == 0 + { + return Err(Error::last_os_error()); + } + } + Ok(()) + } + + pub(crate) fn select( + &self, + timeout: Option, + want: usize, + ) -> std::io::Result<(usize, Vec, Option)> { + if self + .entering + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { + return Ok((0, Vec::new(), timeout)); + } + let result = self.do_select(timeout, want); + self.entering.store(false, Ordering::Release); + result + } + + #[allow(clippy::cast_ptr_alignment)] + fn do_select( + &self, + timeout: Option, + want: usize, + ) -> std::io::Result<(usize, Vec, Option)> { + let start_time = Instant::now(); + let timeout_time = timeout.map_or(u64::MAX, get_timeout_time); + let mut cq = Vec::new(); + loop { + let left_ns = timeout_time.saturating_sub(now()); + if left_ns == 0 { + break; + } + let mut entries: Vec = Vec::with_capacity(1024); + let uninit = entries.spare_capacity_mut(); + let mut recv_count = 0; + unsafe { + let ret = GetQueuedCompletionStatusEx( + self.iocp, + uninit.as_mut_ptr().cast(), + uninit.len().try_into().expect("overflow"), + &mut recv_count, + left_ns + .saturating_div(1_000_000) + .try_into() + .unwrap_or(u32::MAX), + 0, + ); + if FALSE == ret { + let e = Error::last_os_error(); + if ErrorKind::TimedOut == e.kind() { + continue; + } + return Err(e); + } + entries.set_len(recv_count as _); + for entry in entries { + let mut cqe = *Box::from_raw(entry.lpOverlapped.cast::()); + // resolve completed read/write tasks + // todo refactor IOCP impl + cqe.result = match cqe.syscall_name { + SyscallName::accept => { + if setsockopt( + cqe.socket, + SOL_SOCKET, + SO_UPDATE_ACCEPT_CONTEXT, + std::ptr::from_ref(&cqe.from_fd).cast(), + c_int::try_from(size_of::()).expect("overflow"), + ) == 0 + { + cqe.socket.try_into().expect("result overflow") + } else { + -c_longlong::from(windows_sys::Win32::Foundation::GetLastError()) + } + } + SyscallName::recv + | SyscallName::WSARecv + | SyscallName::send + | SyscallName::WSASend => { + let r = entry.dwNumberOfBytesTransferred.into(); + if r > 0 { + r + } else { + -c_longlong::from(windows_sys::Win32::Foundation::GetLastError()) + } + } + _ => panic!("unsupported"), + }; + eprintln!("IOCP got:{cqe}"); + cq.push(cqe); + } + } + if cq.len() >= want { + break; + } + } + let cost = Instant::now().saturating_duration_since(start_time); + Ok((cq.len(), cq, timeout.map(|t| t.saturating_sub(cost)))) + } + + pub(crate) fn accept( + &self, + user_data: usize, + fd: SOCKET, + _address: *mut SOCKADDR, + _address_len: *mut c_int, + ) -> std::io::Result<()> { + self.acceptex(user_data, fd, SyscallName::accept) + } + + pub(crate) fn WSAAccept( + &self, + user_data: usize, + fd: SOCKET, + _address: *mut SOCKADDR, + _address_len: *mut c_int, + lpfncondition: LPCONDITIONPROC, + _dwcallbackdata: usize, + ) -> std::io::Result<()> { + if lpfncondition.is_some() { + return Err(Error::new( + ErrorKind::InvalidInput, + "the WSAAccept in Operator should be called without lpfncondition!", + )); + } + self.acceptex(user_data, fd, SyscallName::WSAAccept) + } + + fn acceptex( + &self, + user_data: usize, + fd: SOCKET, + syscall_name: SyscallName, + ) -> std::io::Result<()> { + unsafe { + let mut sock_info: WSAPROTOCOL_INFOW = std::mem::zeroed(); + let mut sock_info_len = size_of::() + .try_into() + .expect("protocol_len overflow"); + if getsockopt( + fd, + SOL_SOCKET, + SO_PROTOCOL_INFO, + std::ptr::from_mut(&mut sock_info).cast(), + &mut sock_info_len, + ) != 0 + { + return Err(Error::new(ErrorKind::Other, "get socket info failed")); + } + self.add_handle(fd as HANDLE)?; + let socket = WSASocketW( + sock_info.iAddressFamily, + sock_info.iSocketType, + sock_info.iProtocol, + &sock_info, + 0, + WSA_FLAG_OVERLAPPED, + ); + if INVALID_SOCKET == socket { + return Err(Error::new( + ErrorKind::Other, + format!("add {syscall_name} operation failed"), + )); + } + let size = size_of::() + .saturating_add(16) + .try_into() + .expect("size overflow"); + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd; + overlapped.token = user_data; + overlapped.syscall_name = syscall_name; + overlapped.socket = socket; + let mut buf: Vec = Vec::with_capacity(size as usize * 2); + while AcceptEx( + fd, + socket, + buf.as_mut_ptr().cast(), + 0, + size, + size, + std::ptr::null_mut(), + std::ptr::from_mut(overlapped).cast(), + ) == FALSE + { + if WSA_IO_PENDING == WSAGetLastError() { + break; + } + } + eprintln!("add {syscall_name} operation:{overlapped}"); + } + Ok(()) + } + + pub(crate) fn recv( + &self, + user_data: usize, + fd: SOCKET, + buf: PSTR, + len: c_int, + flags: SEND_RECV_FLAGS, + ) -> std::io::Result<()> { + let buf = [WSABUF { + len: len.try_into().expect("len overflow"), + buf: buf.cast(), + }]; + self.wsarecv( + user_data, + fd, + buf.as_ptr(), + buf.len().try_into().expect("len overflow"), + std::ptr::null_mut(), + &mut c_uint::try_from(flags).expect("overflow"), + None, + SyscallName::recv, + ) + } + + pub(crate) fn WSARecv( + &self, + user_data: usize, + fd: SOCKET, + buf: *const WSABUF, + dwbuffercount: c_uint, + lpnumberofbytesrecvd: *mut c_uint, + lpflags: *mut c_uint, + lpoverlapped: *mut OVERLAPPED, + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> std::io::Result<()> { + if !lpoverlapped.is_null() { + return Err(Error::new( + ErrorKind::InvalidInput, + "the WSARecv in Operator should be called without lpoverlapped!", + )); + } + self.wsarecv( + user_data, + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + lpflags, + lpcompletionroutine, + SyscallName::WSARecv, + ) + } + + fn wsarecv( + &self, + user_data: usize, + fd: SOCKET, + buf: *const WSABUF, + dwbuffercount: c_uint, + lpnumberofbytesrecvd: *mut c_uint, + lpflags: *mut c_uint, + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + syscall_name: SyscallName, + ) -> std::io::Result<()> { + self.add_handle(fd as HANDLE)?; + unsafe { + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd; + overlapped.token = user_data; + overlapped.syscall_name = syscall_name; + if WSARecv( + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + lpflags, + std::ptr::from_mut(overlapped).cast(), + lpcompletionroutine, + ) == SOCKET_ERROR + { + let errno = WSAGetLastError(); + if WSA_IO_PENDING != errno { + return Err(Error::new( + ErrorKind::Other, + format!("add {syscall_name} operation failed with {errno}"), + )); + } + } + eprintln!("add {syscall_name} operation:{overlapped}"); + } + Ok(()) + } + + pub(crate) fn send( + &self, + user_data: usize, + fd: SOCKET, + buf: PCSTR, + len: c_int, + flags: SEND_RECV_FLAGS, + ) -> std::io::Result<()> { + let buf = [WSABUF { + len: len.try_into().expect("len overflow"), + buf: buf.cast_mut(), + }]; + self.wsasend( + user_data, + fd, + buf.as_ptr(), + buf.len().try_into().expect("len overflow"), + std::ptr::null_mut(), + c_uint::try_from(flags).expect("overflow"), + None, + SyscallName::send, + ) + } + + pub(crate) fn WSASend( + &self, + user_data: usize, + fd: SOCKET, + buf: *const WSABUF, + dwbuffercount: c_uint, + lpnumberofbytesrecvd: *mut c_uint, + dwflags: c_uint, + lpoverlapped: *mut OVERLAPPED, + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + ) -> std::io::Result<()> { + if !lpoverlapped.is_null() { + return Err(Error::new( + ErrorKind::InvalidInput, + "the WSASend in Operator should be called without lpoverlapped!", + )); + } + self.wsasend( + user_data, + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + dwflags, + lpcompletionroutine, + SyscallName::WSASend, + ) + } + + fn wsasend( + &self, + user_data: usize, + fd: SOCKET, + buf: *const WSABUF, + dwbuffercount: c_uint, + lpnumberofbytesrecvd: *mut c_uint, + dwflags: c_uint, + lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, + syscall_name: SyscallName, + ) -> std::io::Result<()> { + self.add_handle(fd as HANDLE)?; + unsafe { + let overlapped: &'o mut Overlapped = Box::leak(Box::default()); + overlapped.from_fd = fd; + overlapped.token = user_data; + overlapped.syscall_name = syscall_name; + if WSASend( + fd, + buf, + dwbuffercount, + lpnumberofbytesrecvd, + dwflags, + std::ptr::from_mut(overlapped).cast(), + lpcompletionroutine, + ) == SOCKET_ERROR + { + let errno = WSAGetLastError(); + if WSA_IO_PENDING != errno { + return Err(Error::new( + ErrorKind::Other, + format!("add {syscall_name} operation failed with {errno}"), + )); + } + } + eprintln!("add {syscall_name} operation:{overlapped}"); + } + Ok(()) + } +} diff --git a/core/src/net/operator/windows/tests.rs b/core/src/net/operator/windows/tests.rs new file mode 100644 index 00000000..292d00ab --- /dev/null +++ b/core/src/net/operator/windows/tests.rs @@ -0,0 +1,195 @@ +use crate::net::operator::Operator; +use slab::Slab; +use std::io::{BufRead, BufReader, Write}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream}; +use std::os::windows::io::AsRawSocket; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use windows_sys::Win32::Networking::WinSock::{closesocket, recv, send, SOCKET}; + +#[derive(Clone, Debug)] +enum Token { + Accept, + Read { + fd: SOCKET, + buf_index: usize, + }, + Write { + fd: SOCKET, + buf_index: usize, + offset: usize, + len: usize, + }, +} + +fn crate_client(port: u16, server_started: Arc) { + //等服务端起来 + while !server_started.load(Ordering::Acquire) {} + let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port); + let mut stream = TcpStream::connect_timeout(&socket, Duration::from_secs(3)) + .unwrap_or_else(|_| panic!("connect to 127.0.0.1:{port} failed !")); + let mut data: [u8; 512] = [b'1'; 512]; + data[511] = b'\n'; + let mut buffer: Vec = Vec::with_capacity(512); + for _ in 0..3 { + //写入stream流,如果写入失败,提示"写入失败" + assert_eq!(512, stream.write(&data).expect("Failed to write!")); + print!("Client Send: {}", String::from_utf8_lossy(&data[..])); + + let mut reader = BufReader::new(&stream); + //一直读到换行为止(b'\n'中的b表示字节),读到buffer里面 + assert_eq!( + 512, + reader + .read_until(b'\n', &mut buffer) + .expect("Failed to read into buffer") + ); + print!("Client Received: {}", String::from_utf8_lossy(&buffer[..])); + assert_eq!(&data, &buffer as &[u8]); + buffer.clear(); + } + //发送终止符 + assert_eq!(1, stream.write(&[b'e']).expect("Failed to write!")); + println!("client closed"); +} + +fn crate_server2(port: u16, server_started: Arc) -> anyhow::Result<()> { + let operator = Operator::new(0)?; + let listener = TcpListener::bind(("127.0.0.1", port))?; + + let mut bufpool = Vec::with_capacity(64); + let mut buf_alloc = Slab::with_capacity(64); + let mut token_alloc = Slab::with_capacity(64); + + println!("listen {}", listener.local_addr()?); + server_started.store(true, Ordering::Release); + + operator.accept( + token_alloc.insert(Token::Accept), + listener.as_raw_socket() as _, + std::ptr::null_mut(), + std::ptr::null_mut(), + )?; + + let mut first = true; + loop { + let (_, mut cq, _) = operator.select(None, 1)?; + for cqe in &mut cq { + let token_index = cqe.token; + let token = &mut token_alloc[token_index]; + match token.clone() { + Token::Accept => { + println!("server accepted"); + let fd = cqe.socket; + let (buf_index, buf) = match bufpool.pop() { + Some(buf_index) => (buf_index, &mut buf_alloc[buf_index]), + None => { + let buf = vec![0u8; 2048].into_boxed_slice(); + let buf_entry = buf_alloc.vacant_entry(); + let buf_index = buf_entry.key(); + (buf_index, buf_entry.insert(buf)) + } + }; + *token = Token::Read { fd, buf_index }; + if first { + unsafe { + let len = recv(fd, buf.as_mut_ptr() as _, buf.len() as _, 0); + assert_ne!(0, len); + assert_ne!(0, send(fd, buf.as_ptr() as _, len, 0)); + } + first = false; + } + operator.recv(token_index, fd, buf.as_mut_ptr() as _, buf.len() as _, 0)?; + } + Token::Read { fd, buf_index } => { + println!("server received"); + let ret = cqe.result as _; + if ret == 0 { + bufpool.push(buf_index); + _ = token_alloc.remove(token_index); + println!("shutdown connection1"); + _ = unsafe { closesocket(fd) }; + println!("Server closed1"); + return Ok(()); + } else { + let len = ret; + let buf = &buf_alloc[buf_index]; + *token = Token::Write { + fd, + buf_index, + len, + offset: 0, + }; + operator.send(token_index, fd, buf.as_ptr() as _, len as _, 0)?; + } + } + Token::Write { + fd, + buf_index, + offset, + len, + } => { + println!("server sent"); + let write_len = cqe.result as usize; + if offset + write_len >= len { + bufpool.push(buf_index); + let (buf_index, buf) = match bufpool.pop() { + Some(buf_index) => (buf_index, &mut buf_alloc[buf_index]), + None => { + let buf = vec![0u8; 2048].into_boxed_slice(); + let buf_entry = buf_alloc.vacant_entry(); + let buf_index = buf_entry.key(); + (buf_index, buf_entry.insert(buf)) + } + }; + *token = Token::Read { fd, buf_index }; + if operator + .recv(token_index, fd, buf.as_mut_ptr() as _, buf.len() as _, 0) + .is_err() + { + bufpool.push(buf_index); + _ = token_alloc.remove(token_index); + println!("shutdown connection2"); + _ = unsafe { closesocket(fd) }; + println!("Server closed2"); + return Ok(()); + } + } else { + let offset = offset + write_len; + let len = len - offset; + let buf = &buf_alloc[buf_index][offset..]; + *token = Token::Write { + fd, + buf_index, + offset, + len, + }; + operator.send(token_index, fd, buf.as_ptr() as _, len as _, 0)?; + }; + } + } + } + } +} + +#[test] +fn framework() -> anyhow::Result<()> { + #[cfg(feature = "log")] + let _ = tracing_subscriber::fmt() + .with_thread_names(true) + .with_line_number(true) + .with_timer(tracing_subscriber::fmt::time::OffsetTime::new( + time::UtcOffset::from_hms(8, 0, 0).expect("create UtcOffset failed !"), + time::format_description::well_known::Rfc2822, + )) + .try_init(); + let port = 7061; + let server_started = Arc::new(AtomicBool::new(false)); + let clone = server_started.clone(); + let handle = std::thread::spawn(move || crate_server2(port, clone)); + std::thread::spawn(move || crate_client(port, server_started)) + .join() + .expect("client has error"); + handle.join().expect("server has error") +} diff --git a/core/src/syscall/windows/WSASend.rs b/core/src/syscall/windows/WSASend.rs index a5ef469a..34d05de3 100644 --- a/core/src/syscall/windows/WSASend.rs +++ b/core/src/syscall/windows/WSASend.rs @@ -19,7 +19,7 @@ pub extern "system" fn WSASend( fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, - lpnumberofbytesrecvd: *mut c_uint, + lpnumberofbytessent: *mut c_uint, dwflags: c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, @@ -31,7 +31,7 @@ pub extern "system" fn WSASend( fd, buf, dwbuffercount, - lpnumberofbytesrecvd, + lpnumberofbytessent, dwflags, lpoverlapped, lpcompletionroutine, @@ -55,7 +55,7 @@ trait WSARecvSyscall { fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, - lpnumberofbytesrecvd: *mut c_uint, + lpnumberofbytessent: *mut c_uint, dwflags: c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, @@ -67,7 +67,7 @@ impl_facade!(WSARecvSyscallFacade, WSARecvSyscall, fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, - lpnumberofbytesrecvd: *mut c_uint, + lpnumberofbytessent: *mut c_uint, dwflags: c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE @@ -79,7 +79,7 @@ impl_nio_write_iovec!(NioWSARecvSyscall, WSARecvSyscall, fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, - lpnumberofbytesrecvd: *mut c_uint, + lpnumberofbytessent: *mut c_uint, dwflags: c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE @@ -91,7 +91,7 @@ impl_raw!(RawWSARecvSyscall, WSARecvSyscall, windows_sys::Win32::Networking::Win fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, - lpnumberofbytesrecvd: *mut c_uint, + lpnumberofbytessent: *mut c_uint, dwflags: c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE diff --git a/core/src/syscall/windows/mod.rs b/core/src/syscall/windows/mod.rs index d142634d..450ff167 100644 --- a/core/src/syscall/windows/mod.rs +++ b/core/src/syscall/windows/mod.rs @@ -7,7 +7,10 @@ use windows_sys::Win32::Networking::WinSock::{ }; macro_rules! impl_facade { - ( $struct_name:ident, $trait_name: ident, $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + ( + $struct_name:ident, $trait_name: ident, + $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty + ) => { #[repr(C)] #[derive(Debug, Default)] struct $struct_name { @@ -43,8 +46,88 @@ macro_rules! impl_facade { } } +#[allow(unused_macros)] +macro_rules! impl_iocp { + ( + $struct_name:ident, $trait_name: ident, + $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty + ) => { + #[repr(C)] + #[derive(Debug, Default)] + #[cfg(all(windows, feature = "iocp"))] + struct $struct_name { + inner: I, + } + + #[cfg(all(windows, feature = "iocp"))] + impl $trait_name for $struct_name { + extern "system" fn $syscall( + &self, + fn_ptr: Option<&extern "system" fn($($arg_type),*) -> $result>, + $($arg: $arg_type),* + ) -> $result { + use $crate::common::constants::{CoroutineState, SyscallName, SyscallState}; + use $crate::scheduler::{SchedulableCoroutine, SchedulableSuspender}; + + if let Ok(arc) = $crate::net::EventLoops::$syscall($($arg, )*) { + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::Syscall((), syscall, SyscallState::Executing) = co.state() + { + let new_state = SyscallState::Suspend(u64::MAX); + if co.syscall((), syscall, new_state).is_err() { + $crate::error!( + "{} change to syscall {} {} failed !", + co.name(), + syscall, + new_state + ); + } + } + } + if let Some(suspender) = SchedulableSuspender::current() { + suspender.suspend(); + //回来的时候,系统调用已经执行完了 + } + if let Some(co) = SchedulableCoroutine::current() { + if let CoroutineState::Syscall((), syscall, SyscallState::Callback) = co.state() + { + let new_state = SyscallState::Executing; + if co.syscall((), syscall, new_state).is_err() { + $crate::error!( + "{} change to syscall {} {} failed !", + co.name(), syscall, new_state + ); + } + } + } + let (lock, cvar) = &*arc; + let mut syscall_result = cvar + .wait_while(lock.lock().expect("lock failed"), + |&mut result| result.is_none() + ) + .expect("lock failed") + .expect("no syscall result"); + if syscall_result < 0 { + $crate::syscall::set_errno((-syscall_result).try_into().expect("errno overflow")); + if SyscallName::accept == SyscallName::$syscall { + syscall_result = 0; + } else { + syscall_result = -1; + } + } + return <$result>::try_from(syscall_result).expect("overflow"); + } + self.inner.$syscall(fn_ptr, $($arg, )*) + } + } + } +} + macro_rules! impl_nio_read { - ( $struct_name:ident, $trait_name: ident, $syscall: ident($fd: ident : $fd_type: ty, $($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + ( + $struct_name:ident, $trait_name: ident, + $syscall: ident($fd: ident : $fd_type: ty, $($arg: ident : $arg_type: ty),*) -> $result: ty + ) => { #[repr(C)] #[derive(Debug, Default)] struct $struct_name { @@ -99,8 +182,15 @@ macro_rules! impl_nio_read { } macro_rules! impl_nio_read_buf { - ( $struct_name:ident, $trait_name: ident, $syscall: ident($fd: ident : $fd_type: ty, - $buf: ident : $buf_type: ty, $len: ident : $len_type: ty, $($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + ( + $struct_name:ident, $trait_name: ident, + $syscall: ident( + $fd: ident : $fd_type: ty, + $buf: ident : $buf_type: ty, + $len: ident : $len_type: ty + $(, $($arg: ident : $arg_type: ty),*)? + ) -> $result: ty + ) => { #[repr(C)] #[derive(Debug, Default)] struct $struct_name { @@ -110,11 +200,18 @@ macro_rules! impl_nio_read_buf { impl $trait_name for $struct_name { extern "system" fn $syscall( &self, - fn_ptr: Option<&extern "system" fn($fd_type, $buf_type, $len_type, $($arg_type),*) -> $result>, + fn_ptr: Option< + &extern "system" fn( + $fd_type, + $buf_type, + $len_type + $(, $($arg_type),*)? + ) -> $result + >, $fd: $fd_type, $buf: $buf_type, - $len: $len_type, - $($arg: $arg_type),* + $len: $len_type + $(, $($arg: $arg_type),*)? ) -> $result { let blocking = $crate::syscall::is_blocking($fd); if blocking { @@ -130,7 +227,7 @@ macro_rules! impl_nio_read_buf { $fd, ($buf as usize + usize::try_from(received).expect("overflow")) as windows_sys::core::PSTR, $len - received, - $($arg, )* + $($($arg, )*)? ); if r != -1 { $crate::syscall::reset_errno(); @@ -169,8 +266,16 @@ macro_rules! impl_nio_read_buf { } macro_rules! impl_nio_read_iovec { - ( $struct_name:ident, $trait_name: ident, $syscall: ident($fd: ident : $fd_type: ty, - $iov: ident : $iov_type: ty, $iovcnt: ident : $iovcnt_type: ty, $($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + ( + $struct_name:ident, $trait_name: ident, + $syscall: ident( + $fd: ident : $fd_type: ty, + $iov: ident : $iov_type: ty, + $iovcnt: ident : $iovcnt_type: ty, + $recvd: ident : $recvd_type: ty, + $($arg: ident : $arg_type: ty),* + ) -> $result: ty + ) => { #[repr(C)] #[derive(Debug, Default)] struct $struct_name { @@ -180,10 +285,19 @@ macro_rules! impl_nio_read_iovec { impl $trait_name for $struct_name { extern "system" fn $syscall( &self, - fn_ptr: Option<&extern "system" fn($fd_type, $iov_type, $iovcnt_type, $($arg_type),*) -> $result>, + fn_ptr: Option< + &extern "system" fn( + $fd_type, + $iov_type, + $iovcnt_type, + $recvd_type, + $($arg_type),* + ) -> $result + >, $fd: $fd_type, $iov: $iov_type, $iovcnt: $iovcnt_type, + $recvd: $recvd_type, $($arg: $arg_type),* ) -> $result { let blocking = $crate::syscall::is_blocking($fd); @@ -228,20 +342,15 @@ macro_rules! impl_nio_read_iovec { std::ffi::c_uint::try_from(arg.len()).unwrap_or_else(|_| { panic!("{} iovcnt overflow", $crate::common::constants::SyscallName::$syscall) }), + $recvd, $($arg, )* ); - if r == 0 { - r = received.try_into().expect("overflow"); - std::mem::forget(vec); - if blocking { - $crate::syscall::set_blocking($fd); - } - return r; - } else if r != -1 { + if r != -1 { $crate::syscall::reset_errno(); received += usize::try_from(r).expect("overflow"); if received >= length { - r = received.try_into().expect("overflow"); + r = 0; + unsafe{ $recvd.write(received.try_into().expect("overflow")) }; break; } offset = received.saturating_sub(length); @@ -258,7 +367,8 @@ macro_rules! impl_nio_read_iovec { $fd.try_into().expect("overflow"), Some(wait_time) ).is_err() { - r = received.try_into().expect("overflow"); + r = 0; + unsafe{ $recvd.write(received.try_into().expect("overflow")) }; std::mem::forget(vec); if blocking { $crate::syscall::set_blocking($fd); @@ -288,8 +398,15 @@ macro_rules! impl_nio_read_iovec { } macro_rules! impl_nio_write_buf { - ( $struct_name:ident, $trait_name: ident, $syscall: ident($fd: ident : $fd_type: ty, - $buf: ident : $buf_type: ty, $len: ident : $len_type: ty, $($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + ( + $struct_name:ident, $trait_name: ident, + $syscall: ident( + $fd: ident : $fd_type: ty, + $buf: ident : $buf_type: ty, + $len: ident : $len_type: ty + $(, $($arg: ident : $arg_type: ty),*)? + ) -> $result: ty + ) => { #[repr(C)] #[derive(Debug, Default)] struct $struct_name { @@ -299,11 +416,18 @@ macro_rules! impl_nio_write_buf { impl $trait_name for $struct_name { extern "system" fn $syscall( &self, - fn_ptr: Option<&extern "system" fn($fd_type, $buf_type, $len_type, $($arg_type),*) -> $result>, + fn_ptr: Option< + &extern "system" fn( + $fd_type, + $buf_type, + $len_type + $(, $($arg_type),*)? + ) -> $result + >, $fd: $fd_type, $buf: $buf_type, - $len: $len_type, - $($arg: $arg_type),* + $len: $len_type + $(, $($arg: $arg_type),*)? ) -> $result { let blocking = $crate::syscall::is_blocking($fd); if blocking { @@ -319,7 +443,7 @@ macro_rules! impl_nio_write_buf { $fd, ($buf as usize + usize::try_from(sent).expect("overflow")) as windows_sys::core::PSTR, $len - sent, - $($arg, )* + $($($arg, )*)? ); if r != -1 { $crate::syscall::reset_errno(); @@ -358,8 +482,15 @@ macro_rules! impl_nio_write_buf { } macro_rules! impl_nio_write_iovec { - ( $struct_name:ident, $trait_name: ident, $syscall: ident($fd: ident : $fd_type: ty, - $iov: ident : $iov_type: ty, $iovcnt: ident : $iovcnt_type: ty, $($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + ( + $struct_name:ident, $trait_name: ident, + $syscall: ident( + $fd: ident : $fd_type: ty, + $iov: ident : $iov_type: ty, + $iovcnt: ident : $iovcnt_type: ty, + $sent: ident : $sent_type: ty, + $($arg: ident : $arg_type: ty),* + ) -> $result: ty ) => { #[repr(C)] #[derive(Debug, Default)] struct $struct_name { @@ -369,10 +500,19 @@ macro_rules! impl_nio_write_iovec { impl $trait_name for $struct_name { extern "system" fn $syscall( &self, - fn_ptr: Option<&extern "system" fn($fd_type, $iov_type, $iovcnt_type, $($arg_type),*) -> $result>, + fn_ptr: Option< + &extern "system" fn( + $fd_type, + $iov_type, + $iovcnt_type, + $sent_type, + $($arg_type),* + ) -> $result + >, $fd: $fd_type, $iov: $iov_type, $iovcnt: $iovcnt_type, + $sent: $sent_type, $($arg: $arg_type),* ) -> $result { let blocking = $crate::syscall::is_blocking($fd); @@ -417,13 +557,15 @@ macro_rules! impl_nio_write_iovec { std::ffi::c_uint::try_from(arg.len()).unwrap_or_else(|_| { panic!("{} iovcnt overflow", $crate::common::constants::SyscallName::$syscall) }), + $sent, $($arg, )* ); if r != -1 { $crate::syscall::reset_errno(); sent += usize::try_from(r).expect("overflow"); if sent >= length { - r = sent.try_into().expect("overflow"); + r = 0; + unsafe{ $sent.write(sent.try_into().expect("overflow")) }; break; } offset = sent.saturating_sub(length); @@ -440,7 +582,8 @@ macro_rules! impl_nio_write_iovec { $fd.try_into().expect("overflow"), Some(wait_time) ).is_err() { - r = sent.try_into().expect("overflow"); + r = 0; + unsafe{ $sent.write(sent.try_into().expect("overflow")) }; std::mem::forget(vec); if blocking { $crate::syscall::set_blocking($fd); @@ -470,7 +613,10 @@ macro_rules! impl_nio_write_iovec { } macro_rules! impl_raw { - ( $struct_name: ident, $trait_name: ident, $($mod_name: ident)::*, $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty ) => { + ( + $struct_name: ident, $trait_name: ident, $($mod_name: ident)::*, + $syscall: ident($($arg: ident : $arg_type: ty),*) -> $result: ty + ) => { #[repr(C)] #[derive(Debug, Copy, Clone, Default)] struct $struct_name {} diff --git a/core/src/syscall/windows/shutdown.rs b/core/src/syscall/windows/shutdown.rs index 0b624ead..53f2f8e9 100644 --- a/core/src/syscall/windows/shutdown.rs +++ b/core/src/syscall/windows/shutdown.rs @@ -2,7 +2,7 @@ use crate::net::EventLoops; use crate::syscall::set_errno; use once_cell::sync::Lazy; use std::ffi::c_int; -use windows_sys::Win32::Networking::WinSock::{SOCKET, WINSOCK_SHUTDOWN_HOW}; +use windows_sys::Win32::Networking::WinSock::{SOCKET, WINSOCK_SHUTDOWN_HOW, SD_RECEIVE, SD_SEND, SD_BOTH, WSAEINVAL}; #[must_use] pub extern "system" fn shutdown( @@ -42,13 +42,11 @@ impl ShutdownSyscall for NioShutdownSyscall { { let fd = fd.try_into().expect("overflow"); _ = match how { - windows_sys::Win32::Networking::WinSock::SD_RECEIVE => { - EventLoops::del_read_event(fd) - } - windows_sys::Win32::Networking::WinSock::SD_SEND => EventLoops::del_write_event(fd), - windows_sys::Win32::Networking::WinSock::SD_BOTH => EventLoops::del_event(fd), + SD_RECEIVE => EventLoops::del_read_event(fd), + SD_SEND => EventLoops::del_write_event(fd), + SD_BOTH => EventLoops::del_event(fd), _ => { - set_errno(windows_sys::Win32::Networking::WinSock::WSAEINVAL.try_into().expect("overflow")); + set_errno(WSAEINVAL.try_into().expect("overflow")); return -1; } }; diff --git a/hook/Cargo.toml b/hook/Cargo.toml index e1c324cc..7dfca4d3 100644 --- a/hook/Cargo.toml +++ b/hook/Cargo.toml @@ -49,6 +49,12 @@ net = ["open-coroutine-core/net"] # Provide io_uring adaptation, this feature only works in linux. io_uring = ["open-coroutine-core/io_uring"] +# Provide IOCP adaptation, this feature only works in windows. +iocp = ["open-coroutine-core/iocp"] + +# Provide completion IO adaptation +completion_io = ["open-coroutine-core/completion_io"] + # Provide syscall implementation. syscall = ["open-coroutine-core/syscall"] diff --git a/hook/src/syscall/windows.rs b/hook/src/syscall/windows.rs index 1cfb2970..7f154c7c 100644 --- a/hook/src/syscall/windows.rs +++ b/hook/src/syscall/windows.rs @@ -81,7 +81,7 @@ unsafe fn attach() -> std::io::Result<()> { impl_hook!("ws2_32.dll", SOCKET, socket(domain: c_int, ty: WINSOCK_SOCKET_TYPE, protocol: IPPROTO) -> SOCKET); impl_hook!("ws2_32.dll", SETSOCKOPT, setsockopt(socket: SOCKET, level: c_int, name: c_int, value: PSTR, option_len: c_int) -> c_int); impl_hook!("ws2_32.dll", WSARECV, WSARecv(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, lpflags: *mut c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); - impl_hook!("ws2_32.dll", WSASEND, WSASend(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytesrecvd: *mut c_uint, dwflags: c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); + impl_hook!("ws2_32.dll", WSASEND, WSASend(fd: SOCKET, buf: *const WSABUF, dwbuffercount: c_uint, lpnumberofbytessent: *mut c_uint, dwflags: c_uint, lpoverlapped: *mut OVERLAPPED, lpcompletionroutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE) -> c_int); impl_hook!("ws2_32.dll", WSASOCKETW, WSASocketW(domain: c_int, ty: WINSOCK_SOCKET_TYPE, protocol: IPPROTO, lpprotocolinfo: *const WSAPROTOCOL_INFOW, g: c_uint, dw_flags: c_uint) -> SOCKET); impl_hook!("ws2_32.dll", SELECT, select(nfds: c_int, readfds: *mut FD_SET, writefds: *mut FD_SET, errorfds: *mut FD_SET, timeout: *mut TIMEVAL) -> c_int); impl_hook!("ws2_32.dll", WSAPOLL, WSAPoll(fds: *mut WSAPOLLFD, nfds: c_uint, timeout: c_int) -> c_int); diff --git a/open-coroutine/Cargo.toml b/open-coroutine/Cargo.toml index dc948442..9615f8bc 100644 --- a/open-coroutine/Cargo.toml +++ b/open-coroutine/Cargo.toml @@ -60,5 +60,11 @@ net = ["open-coroutine-hook/net", "open-coroutine-core/net"] # This feature only works in linux. io_uring = ["open-coroutine-hook/io_uring", "open-coroutine-core/io_uring"] +# Provide IOCP adaptation, this feature only works in windows. +iocp = ["open-coroutine-hook/iocp", "open-coroutine-core/iocp"] + +# Provide completion IO adaptation +completion_io = ["open-coroutine-hook/completion_io", "open-coroutine-core/completion_io"] + # Provide syscall implementation. syscall = ["open-coroutine-hook/syscall", "open-coroutine-core/syscall"] diff --git a/open-coroutine/build.rs b/open-coroutine/build.rs index d2b167a6..72dd42ea 100644 --- a/open-coroutine/build.rs +++ b/open-coroutine/build.rs @@ -155,6 +155,12 @@ fn main() { if cfg!(feature = "io_uring") { features.push("io_uring"); } + if cfg!(feature = "iocp") { + features.push("iocp"); + } + if cfg!(feature = "completion_io") { + features.push("completion_io"); + } if cfg!(feature = "syscall") { features.push("syscall"); }