From 54eeac1c042fb3eddb97773d67206447d52cf9d0 Mon Sep 17 00:00:00 2001 From: rkuklik Date: Tue, 18 Jun 2024 22:42:25 +0200 Subject: [PATCH 1/8] fix: resolver should be set to "2" in top level manifest --- Cargo.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index a547d77..d16be66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,6 @@ [workspace] +# cargo complains that this defaults to one in virtual package manifests (for some reason) +resolver = "2" members = ["client", "daemon", "common"] default-members = ["client", "daemon"] From 6d49d25ef6a04cfa3b17919263ac257b5ea20478 Mon Sep 17 00:00:00 2001 From: rkuklik Date: Tue, 18 Jun 2024 22:43:23 +0200 Subject: [PATCH 2/8] refactor(ipc): get rid of unused `Answer::Err` --- client/src/main.rs | 2 -- common/src/ipc/mod.rs | 19 ------------------- 2 files changed, 21 deletions(-) diff --git a/client/src/main.rs b/client/src/main.rs index f49f86a..5d68af5 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -47,7 +47,6 @@ fn process_swww_args(args: &Swww) -> Result<(), String> { let bytes = read_socket(&socket)?; drop(socket); match Answer::receive(bytes) { - Answer::Err(msg) => return Err(msg.to_string()), Answer::Info(info) => info.iter().for_each(|i| println!("{}", i)), Answer::Ok => { if let Swww::Kill = args { @@ -249,7 +248,6 @@ fn get_format_dims_and_outputs( Ok((format, dims, outputs)) } } - Answer::Err(e) => Err(format!("daemon error when sending query: {e}")), _ => unreachable!(), } } diff --git a/common/src/ipc/mod.rs b/common/src/ipc/mod.rs index e8511c5..7c8d936 100644 --- a/common/src/ipc/mod.rs +++ b/common/src/ipc/mod.rs @@ -239,7 +239,6 @@ pub enum Answer { Ok, Ping(bool), Info(Box<[BgInfo]>), - Err(String), } impl Answer { @@ -250,7 +249,6 @@ impl Answer { Self::Ping(true) => 1u64.to_ne_bytes(), Self::Ping(false) => 2u64.to_ne_bytes(), Self::Info(_) => 3u64.to_ne_bytes(), - Self::Err(_) => 4u64.to_ne_bytes(), }); let mmap = match self { @@ -268,14 +266,6 @@ impl Answer { Some(mmap) } - Self::Err(s) => { - let len = 4 + s.len(); - let mut mmap = Mmap::create(len); - let bytes = mmap.slice_mut(); - bytes[0..4].copy_from_slice(&(s.as_bytes().len() as u32).to_ne_bytes()); - bytes[4..len].copy_from_slice(s.as_bytes()); - Some(mmap) - } _ => None, }; @@ -308,15 +298,6 @@ impl Answer { Self::Info(bg_infos.into()) } - 4 => { - let mmap = socket_msg.shm.unwrap(); - let bytes = mmap.slice(); - let size = u32::from_ne_bytes(bytes[0..4].try_into().unwrap()) as usize; - let s = std::str::from_utf8(&bytes[4..4 + size]) - .expect("received a non utf8 string from socket") - .to_string(); - Self::Err(s) - } _ => panic!("Received malformed answer from daemon"), } } From 089ab9ef63047f877f85eb54f26b0942406fd180 Mon Sep 17 00:00:00 2001 From: rkuklik Date: Tue, 18 Jun 2024 22:48:44 +0200 Subject: [PATCH 3/8] refactor(socket): introduce new `IpcSocket` type --- common/src/ipc/error.rs | 78 +++++++++++++++ common/src/ipc/mod.rs | 2 + common/src/ipc/socket.rs | 203 ++++++++++++++++++++++++++------------- 3 files changed, 216 insertions(+), 67 deletions(-) create mode 100644 common/src/ipc/error.rs diff --git a/common/src/ipc/error.rs b/common/src/ipc/error.rs new file mode 100644 index 0000000..602671a --- /dev/null +++ b/common/src/ipc/error.rs @@ -0,0 +1,78 @@ +use std::error::Error; +use std::fmt; + +use rustix::io::Errno; + +/// Failiures if IPC with added context +#[derive(Debug)] +pub struct IpcError { + err: Errno, + kind: IpcErrorKind, +} + +impl IpcError { + pub(crate) fn new(kind: IpcErrorKind, err: Errno) -> Self { + Self { err, kind } + } +} + +#[derive(Debug)] +pub enum IpcErrorKind { + /// Failed to create file descriptor + Socket, + /// Failed to connect to socket + Connect, + /// Binding on socket failed + Bind, + /// Listening on socket failed + Listen, + /// Socket file wasn't found + NoSocketFile, + /// Socket timeout couldn't be set + SetTimeout, +} + +impl IpcErrorKind { + fn description(&self) -> &'static str { + match self { + Self::Socket => "failed to create socket file descriptor", + Self::Connect => "failed to connect to socket", + Self::Bind => "failed to bind to socket", + Self::Listen => "failed to listen on socket", + Self::NoSocketFile => "Socket file not found. Are you sure swww-daemon is running?", + Self::SetTimeout => "failed to set read timeout for socket", + } + } +} + +impl fmt::Display for IpcError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.kind.description()) + } +} + +impl Error for IpcError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.err) + } +} + +/// Simplify generating [`IpcError`]s from [`Errno`] +pub(crate) trait ErrnoExt { + type Output; + fn context(self, kind: IpcErrorKind) -> Self::Output; +} + +impl ErrnoExt for Errno { + type Output = IpcError; + fn context(self, kind: IpcErrorKind) -> Self::Output { + IpcError::new(kind, self) + } +} + +impl ErrnoExt for Result { + type Output = Result; + fn context(self, kind: IpcErrorKind) -> Self::Output { + self.map_err(|error| error.context(kind)) + } +} diff --git a/common/src/ipc/mod.rs b/common/src/ipc/mod.rs index 7c8d936..d45174d 100644 --- a/common/src/ipc/mod.rs +++ b/common/src/ipc/mod.rs @@ -2,11 +2,13 @@ use std::path::PathBuf; use rustix::fd::OwnedFd; +mod error; mod mmap; mod socket; mod types; use crate::cache; +pub use error::*; pub use mmap::*; pub use socket::*; pub use types::*; diff --git a/common/src/ipc/socket.rs b/common/src/ipc/socket.rs index 7935218..6b2e699 100644 --- a/common/src/ipc/socket.rs +++ b/common/src/ipc/socket.rs @@ -1,10 +1,17 @@ -use std::{path::PathBuf, time::Duration}; - -use rustix::{ - fd::OwnedFd, - net::{self, RecvFlags}, -}; - +use std::env; +use std::marker::PhantomData; +use std::path::PathBuf; +use std::sync::OnceLock; +use std::time::Duration; + +use rustix::fd::OwnedFd; +use rustix::io::Errno; +use rustix::net; +use rustix::net::RecvFlags; + +use super::ErrnoExt; +use super::IpcError; +use super::IpcErrorKind; use super::Mmap; pub struct SocketMsg { @@ -12,6 +19,123 @@ pub struct SocketMsg { pub(super) shm: Option, } +/// Represents client in IPC communication, via typestate pattern in [`IpcSocket`] +pub struct Client; +/// Represents server in IPC communication, via typestate pattern in [`IpcSocket`] +pub struct Server; + +/// Typesafe handle for socket facilitating communication between [`Client`] and [`Server`] +pub struct IpcSocket { + fd: OwnedFd, + phantom: PhantomData, +} + +impl IpcSocket { + /// Creates new [`IpcSocket`] from provided [`OwnedFd`] + /// + /// TODO: remove external ability to construct [`Self`] from random file descriptors + pub fn new(fd: OwnedFd) -> Self { + Self { + fd, + phantom: PhantomData, + } + } + + fn socket_file() -> String { + let runtime = env::var("XDG_RUNTIME_DIR"); + let display = env::var("WAYLAND_DISPLAY"); + + let runtime = runtime.as_deref().unwrap_or("/tmp/swww"); + let display = display.as_deref().unwrap_or("wayland-0"); + + format!("{runtime}/swww-{display}.socket") + } + + /// Retreives path to socket file + /// + /// To treat this as filesystem path, wrap it in [`Path`]. + /// If you get errors with missing generics, you can shove any type as `T`, but + /// [`Client`] or [`Server`] are recommended. + /// + /// [`Path`]: std::path::Path + #[must_use] + pub fn path() -> &'static str { + static PATH: OnceLock = OnceLock::new(); + PATH.get_or_init(Self::socket_file) + } + + #[must_use] + pub fn as_fd(&self) -> &OwnedFd { + &self.fd + } +} + +impl IpcSocket { + /// Connects to already running `Daemon`, if there is one. + pub fn connect() -> Result { + // these were hardcoded everywhere, no point in passing them around + let tries = 5; + let interval = 100; + + let socket = net::socket_with( + net::AddressFamily::UNIX, + net::SocketType::STREAM, + net::SocketFlags::CLOEXEC, + None, + ) + .context(IpcErrorKind::Socket)?; + + let addr = net::SocketAddrUnix::new(Self::path()).expect("addr is correct"); + + // this will be overwriten, Rust just doesn't know it + let mut error = Errno::INVAL; + for _ in 0..tries { + match net::connect_unix(&socket, &addr) { + Ok(()) => { + #[cfg(debug_assertions)] + let timeout = Duration::from_secs(30); //Some operations take a while to respond in debug mode + #[cfg(not(debug_assertions))] + let timeout = Duration::from_secs(5); + return net::sockopt::set_socket_timeout( + &socket, + net::sockopt::Timeout::Recv, + Some(timeout), + ) + .context(IpcErrorKind::SetTimeout) + .map(|()| Self::new(socket)); + } + Err(e) => error = e, + } + std::thread::sleep(Duration::from_millis(interval)); + } + + let kind = if error.kind() == std::io::ErrorKind::NotFound { + IpcErrorKind::NoSocketFile + } else { + IpcErrorKind::Connect + }; + + Err(error.context(kind)) + } +} + +impl IpcSocket { + /// Creates [`IpcSocket`] for use in server (i.e `Daemon`) + pub fn server() -> Result { + let addr = net::SocketAddrUnix::new(Self::path()).expect("addr is correct"); + let socket = net::socket_with( + net::AddressFamily::UNIX, + net::SocketType::STREAM, + net::SocketFlags::CLOEXEC.union(rustix::net::SocketFlags::NONBLOCK), + None, + ) + .context(IpcErrorKind::Socket)?; + net::bind_unix(&socket, &addr).context(IpcErrorKind::Bind)?; + net::listen(&socket, 0).context(IpcErrorKind::Listen)?; + Ok(Self::new(socket)) + } +} + pub fn read_socket(stream: &OwnedFd) -> Result { let mut buf = [0u8; 16]; let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; @@ -72,70 +196,15 @@ pub(super) fn send_socket_msg( #[must_use] pub fn get_socket_path() -> PathBuf { - let runtime_dir = if let Ok(dir) = std::env::var("XDG_RUNTIME_DIR") { - dir - } else { - "/tmp/swww".to_string() - }; - - let mut socket_path = PathBuf::new(); - socket_path.push(runtime_dir); - - let mut socket_name = String::new(); - socket_name.push_str("swww-"); - if let Ok(socket) = std::env::var("WAYLAND_DISPLAY") { - socket_name.push_str(socket.as_str()); - } else { - socket_name.push_str("wayland-0") - } - socket_name.push_str(".socket"); - - socket_path.push(socket_name); - - socket_path + IpcSocket::::path().into() } /// We make sure the Stream is always set to blocking mode /// /// * `tries` - how many times to attempt the connection /// * `interval` - how long to wait between attempts, in milliseconds -pub fn connect_to_socket(addr: &PathBuf, tries: u8, interval: u64) -> Result { - let socket = rustix::net::socket_with( - rustix::net::AddressFamily::UNIX, - rustix::net::SocketType::STREAM, - rustix::net::SocketFlags::CLOEXEC, - None, - ) - .expect("failed to create socket file descriptor"); - let addr = net::SocketAddrUnix::new(addr).unwrap(); - //Make sure we try at least once - let tries = if tries == 0 { 1 } else { tries }; - let mut error = None; - for _ in 0..tries { - match net::connect_unix(&socket, &addr) { - Ok(()) => { - #[cfg(debug_assertions)] - let timeout = Duration::from_secs(30); //Some operations take a while to respond in debug mode - #[cfg(not(debug_assertions))] - let timeout = Duration::from_secs(5); - if let Err(e) = net::sockopt::set_socket_timeout( - &socket, - net::sockopt::Timeout::Recv, - Some(timeout), - ) { - return Err(format!("failed to set read timeout for socket: {e}")); - } - - return Ok(socket); - } - Err(e) => error = Some(e), - } - std::thread::sleep(Duration::from_millis(interval)); - } - let error = error.unwrap(); - if error.kind() == std::io::ErrorKind::NotFound { - return Err("Socket file not found. Are you sure swww-daemon is running?".to_string()); - } - - Err(format!("Failed to connect to socket: {error}")) +pub fn connect_to_socket(_: &PathBuf, _: u8, _: u64) -> Result { + IpcSocket::connect() + .map(|socket| socket.fd) + .map_err(|err| err.to_string()) } From 64bd080a2578fe9472939b65bcf588eabdf7c754 Mon Sep 17 00:00:00 2001 From: rkuklik Date: Tue, 18 Jun 2024 23:00:37 +0200 Subject: [PATCH 4/8] refactor(ipc): replace old methods with new --- client/src/main.rs | 33 ++++++++++++------------ common/src/ipc/socket.rs | 20 +++------------ daemon/src/main.rs | 55 +++++++++++++++------------------------- 3 files changed, 41 insertions(+), 67 deletions(-) diff --git a/client/src/main.rs b/client/src/main.rs index 5d68af5..633c5be 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -1,9 +1,9 @@ -use clap::Parser; -use std::time::Duration; +use std::{path::Path, time::Duration}; +use clap::Parser; use common::{ cache, - ipc::{self, connect_to_socket, get_socket_path, read_socket, Answer, RequestSend}, + ipc::{self, read_socket, Answer, Client, IpcSocket, RequestSend}, }; mod imgproc; @@ -19,10 +19,10 @@ fn main() -> Result<(), String> { return cache::clean().map_err(|e| format!("failed to clean the cache: {e}")); } + let socket = IpcSocket::connect().map_err(|err| err.to_string())?; loop { - let socket = connect_to_socket(&get_socket_path(), 5, 100)?; - RequestSend::Ping.send(&socket)?; - let bytes = read_socket(&socket)?; + RequestSend::Ping.send(socket.as_fd())?; + let bytes = read_socket(socket.as_fd())?; let answer = Answer::receive(bytes); if let Answer::Ping(configured) = answer { if configured { @@ -42,9 +42,9 @@ fn process_swww_args(args: &Swww) -> Result<(), String> { Some(request) => request, None => return Ok(()), }; - let socket = connect_to_socket(&get_socket_path(), 5, 100)?; - request.send(&socket)?; - let bytes = read_socket(&socket)?; + let socket = IpcSocket::connect().map_err(|err| err.to_string())?; + request.send(socket.as_fd())?; + let bytes = read_socket(socket.as_fd())?; drop(socket); match Answer::receive(bytes) { Answer::Info(info) => info.iter().for_each(|i| println!("{}", i)), @@ -54,16 +54,15 @@ fn process_swww_args(args: &Swww) -> Result<(), String> { let tries = 20; #[cfg(not(debug_assertions))] let tries = 10; - let socket_path = get_socket_path(); + let path = IpcSocket::::path(); + let path = Path::new(path); for _ in 0..tries { - if !socket_path.exists() { + if !path.exists() { return Ok(()); } std::thread::sleep(Duration::from_millis(100)); } - return Err(format!( - "Could not confirm socket deletion at: {socket_path:?}" - )); + return Err(format!("Could not confirm socket deletion at: {path:?}")); } } Answer::Ping(_) => { @@ -213,9 +212,9 @@ fn get_format_dims_and_outputs( let mut dims: Vec<(u32, u32)> = Vec::new(); let mut imgs: Vec = Vec::new(); - let socket = connect_to_socket(&get_socket_path(), 5, 100)?; - RequestSend::Query.send(&socket)?; - let bytes = read_socket(&socket)?; + let socket = IpcSocket::connect().map_err(|err| err.to_string())?; + RequestSend::Query.send(socket.as_fd())?; + let bytes = read_socket(socket.as_fd())?; drop(socket); let answer = Answer::receive(bytes); match answer { diff --git a/common/src/ipc/socket.rs b/common/src/ipc/socket.rs index 6b2e699..74271ef 100644 --- a/common/src/ipc/socket.rs +++ b/common/src/ipc/socket.rs @@ -1,6 +1,5 @@ use std::env; use std::marker::PhantomData; -use std::path::PathBuf; use std::sync::OnceLock; use std::time::Duration; @@ -41,6 +40,10 @@ impl IpcSocket { } } + pub fn to_fd(self) -> OwnedFd { + self.fd + } + fn socket_file() -> String { let runtime = env::var("XDG_RUNTIME_DIR"); let display = env::var("WAYLAND_DISPLAY"); @@ -193,18 +196,3 @@ pub(super) fn send_socket_msg( net::sendmsg(stream, &[iov], &mut ancillary, net::SendFlags::empty()) .map(|written| written == socket_msg.len()) } - -#[must_use] -pub fn get_socket_path() -> PathBuf { - IpcSocket::::path().into() -} - -/// We make sure the Stream is always set to blocking mode -/// -/// * `tries` - how many times to attempt the connection -/// * `interval` - how long to wait between attempts, in milliseconds -pub fn connect_to_socket(_: &PathBuf, _: u8, _: u64) -> Result { - IpcSocket::connect() - .map(|socket| socket.fd) - .map_err(|err| err.to_string()) -} diff --git a/daemon/src/main.rs b/daemon/src/main.rs index 47f1a25..ccf4dd9 100644 --- a/daemon/src/main.rs +++ b/daemon/src/main.rs @@ -23,7 +23,7 @@ use std::{ fs, io::{IsTerminal, Write}, num::{NonZeroI32, NonZeroU32}, - path::PathBuf, + path::Path, sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -31,8 +31,8 @@ use std::{ }; use common::ipc::{ - connect_to_socket, get_socket_path, read_socket, Answer, BgInfo, ImageReq, MmappedStr, - RequestRecv, RequestSend, Scale, + read_socket, Answer, BgInfo, ImageReq, IpcSocket, MmappedStr, RequestRecv, RequestSend, Scale, + Server, }; use animations::Animator; @@ -525,25 +525,26 @@ fn setup_signals() { struct SocketWrapper(OwnedFd); impl SocketWrapper { fn new() -> Result { - let socket_addr = get_socket_path(); + let addr = IpcSocket::::path(); + let addr = Path::new(addr); - if socket_addr.exists() { - if is_daemon_running(&socket_addr)? { + if addr.exists() { + if is_daemon_running()? { return Err( "There is an swww-daemon instance already running on this socket!".to_string(), ); } else { warn!( "socket file {} was not deleted when the previous daemon exited", - socket_addr.to_string_lossy() + addr.to_string_lossy() ); - if let Err(e) = std::fs::remove_file(&socket_addr) { + if let Err(e) = std::fs::remove_file(addr) { return Err(format!("failed to delete previous socket: {e}")); } } } - let runtime_dir = match socket_addr.parent() { + let runtime_dir = match addr.parent() { Some(path) => path, None => return Err("couldn't find a valid runtime directory".to_owned()), }; @@ -555,34 +556,20 @@ impl SocketWrapper { } } - let socket = rustix::net::socket_with( - rustix::net::AddressFamily::UNIX, - rustix::net::SocketType::STREAM, - rustix::net::SocketFlags::CLOEXEC.union(rustix::net::SocketFlags::NONBLOCK), - None, - ) - .expect("failed to create socket file descriptor"); - - rustix::net::bind_unix( - &socket, - &rustix::net::SocketAddrUnix::new(&socket_addr).unwrap(), - ) - .unwrap(); - - rustix::net::listen(&socket, 0).unwrap(); + let socket = IpcSocket::server().map_err(|err| err.to_string())?; - debug!("Created socket in {:?}", socket_addr); - Ok(Self(socket)) + debug!("Created socket in {:?}", addr); + Ok(Self(socket.to_fd())) } } impl Drop for SocketWrapper { fn drop(&mut self) { - let socket_addr = get_socket_path(); - if let Err(e) = fs::remove_file(&socket_addr) { - error!("Failed to remove socket at {socket_addr:?}: {e}"); + let addr = IpcSocket::::path(); + if let Err(e) = fs::remove_file(Path::new(addr)) { + error!("Failed to remove socket at {addr}: {e}"); } - info!("Removed socket at {:?}", socket_addr); + info!("Removed socket at {addr}"); } } @@ -648,16 +635,16 @@ fn make_logger(quiet: bool) { .unwrap(); } -pub fn is_daemon_running(addr: &PathBuf) -> Result { - let sock = match connect_to_socket(addr, 5, 100) { +pub fn is_daemon_running() -> Result { + let sock = match IpcSocket::connect() { Ok(s) => s, // likely a connection refused; either way, this is a reliable signal there's no surviving // daemon. Err(_) => return Ok(false), }; - RequestSend::Ping.send(&sock)?; - let answer = Answer::receive(read_socket(&sock)?); + RequestSend::Ping.send(sock.as_fd())?; + let answer = Answer::receive(read_socket(sock.as_fd())?); match answer { Answer::Ping(_) => Ok(true), _ => Err("Daemon did not return Answer::Ping, as expected".to_string()), From b28f3bf7ee262aaefaeccd1b92503cca0cedd4d6 Mon Sep 17 00:00:00 2001 From: rkuklik Date: Wed, 19 Jun 2024 07:28:23 +0200 Subject: [PATCH 5/8] refactor(ipc): rework messaging format v1 --- client/src/main.rs | 14 +- common/src/ipc/error.rs | 9 ++ common/src/ipc/mod.rs | 155 ++------------------ common/src/ipc/socket.rs | 65 --------- common/src/ipc/transmit.rs | 292 +++++++++++++++++++++++++++++++++++++ daemon/src/main.rs | 14 +- 6 files changed, 328 insertions(+), 221 deletions(-) create mode 100644 common/src/ipc/transmit.rs diff --git a/client/src/main.rs b/client/src/main.rs index 633c5be..fdb96f7 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -3,7 +3,7 @@ use std::{path::Path, time::Duration}; use clap::Parser; use common::{ cache, - ipc::{self, read_socket, Answer, Client, IpcSocket, RequestSend}, + ipc::{self, Answer, Client, IpcSocket, RequestSend}, }; mod imgproc; @@ -21,8 +21,8 @@ fn main() -> Result<(), String> { let socket = IpcSocket::connect().map_err(|err| err.to_string())?; loop { - RequestSend::Ping.send(socket.as_fd())?; - let bytes = read_socket(socket.as_fd())?; + RequestSend::Ping.send(&socket)?; + let bytes = socket.recv().map_err(|err| err.to_string())?; let answer = Answer::receive(bytes); if let Answer::Ping(configured) = answer { if configured { @@ -43,8 +43,8 @@ fn process_swww_args(args: &Swww) -> Result<(), String> { None => return Ok(()), }; let socket = IpcSocket::connect().map_err(|err| err.to_string())?; - request.send(socket.as_fd())?; - let bytes = read_socket(socket.as_fd())?; + request.send(&socket)?; + let bytes = socket.recv().map_err(|err| err.to_string())?; drop(socket); match Answer::receive(bytes) { Answer::Info(info) => info.iter().for_each(|i| println!("{}", i)), @@ -213,8 +213,8 @@ fn get_format_dims_and_outputs( let mut imgs: Vec = Vec::new(); let socket = IpcSocket::connect().map_err(|err| err.to_string())?; - RequestSend::Query.send(socket.as_fd())?; - let bytes = read_socket(socket.as_fd())?; + RequestSend::Query.send(&socket)?; + let bytes = socket.recv().map_err(|err| err.to_string())?; drop(socket); let answer = Answer::receive(bytes); match answer { diff --git a/common/src/ipc/error.rs b/common/src/ipc/error.rs index 602671a..247be93 100644 --- a/common/src/ipc/error.rs +++ b/common/src/ipc/error.rs @@ -30,6 +30,12 @@ pub enum IpcErrorKind { NoSocketFile, /// Socket timeout couldn't be set SetTimeout, + /// IPC contained invalid identification code + BadCode, + /// IPC payload was broken + MalformedMsg, + /// Reading socket failed + Read, } impl IpcErrorKind { @@ -41,6 +47,9 @@ impl IpcErrorKind { Self::Listen => "failed to listen on socket", Self::NoSocketFile => "Socket file not found. Are you sure swww-daemon is running?", Self::SetTimeout => "failed to set read timeout for socket", + Self::BadCode => "invalid message code", + Self::MalformedMsg => "malformed ancillary message", + Self::Read => "failed to receive message", } } } diff --git a/common/src/ipc/mod.rs b/common/src/ipc/mod.rs index d45174d..2401ef5 100644 --- a/common/src/ipc/mod.rs +++ b/common/src/ipc/mod.rs @@ -1,10 +1,11 @@ use std::path::PathBuf; -use rustix::fd::OwnedFd; +use transmit::RawMsg; mod error; mod mmap; mod socket; +mod transmit; mod types; use crate::cache; @@ -136,104 +137,20 @@ pub enum RequestRecv { } impl RequestSend { - pub fn send(&self, stream: &OwnedFd) -> Result<(), String> { - let mut socket_msg = [0u8; 16]; - socket_msg[0..8].copy_from_slice(&match self { - Self::Ping => 0u64.to_ne_bytes(), - Self::Query => 1u64.to_ne_bytes(), - Self::Clear(_) => 2u64.to_ne_bytes(), - Self::Img(_) => 3u64.to_ne_bytes(), - Self::Kill => 4u64.to_ne_bytes(), - }); - - let mmap = match self { - Self::Clear(clear) => Some(clear), - Self::Img(img) => Some(img), - _ => None, - }; - - match send_socket_msg(stream, &mut socket_msg, mmap) { - Ok(true) => (), - Ok(false) => return Err("failed to send full length of message in socket!".to_string()), - Err(e) => return Err(format!("failed to write serialized request: {e}")), + pub fn send(self, stream: &IpcSocket) -> Result<(), String> { + match stream.send(self.into()) { + Ok(true) => Ok(()), + Ok(false) => Err("failed to send full length of message in socket!".to_string()), + Err(e) => Err(format!("failed to write serialized request: {e}")), } - - Ok(()) } } impl RequestRecv { #[must_use] #[inline] - pub fn receive(socket_msg: SocketMsg) -> Self { - let ret = match socket_msg.code { - 0 => Self::Ping, - 1 => Self::Query, - 2 => { - let mmap = socket_msg.shm.unwrap(); - let bytes = mmap.slice(); - let len = bytes[0] as usize; - let mut outputs = Vec::with_capacity(len); - let mut i = 1; - for _ in 0..len { - let output = MmappedStr::new(&mmap, &bytes[i..]); - i += 4 + output.str().len(); - outputs.push(output); - } - let color = [bytes[i], bytes[i + 1], bytes[i + 2]]; - Self::Clear(ClearReq { - color, - outputs: outputs.into(), - }) - } - 3 => { - let mmap = socket_msg.shm.unwrap(); - let bytes = mmap.slice(); - let transition = Transition::deserialize(&bytes[0..]); - let len = bytes[51] as usize; - - let mut imgs = Vec::with_capacity(len); - let mut outputs = Vec::with_capacity(len); - let mut animations = Vec::with_capacity(len); - - let mut i = 52; - for _ in 0..len { - let (img, offset) = ImgReq::deserialize(&mmap, &bytes[i..]); - i += offset; - imgs.push(img); - - let n_outputs = bytes[i] as usize; - i += 1; - let mut out = Vec::with_capacity(n_outputs); - for _ in 0..n_outputs { - let output = MmappedStr::new(&mmap, &bytes[i..]); - i += 4 + output.str().len(); - out.push(output); - } - outputs.push(out.into()); - - if bytes[i] == 1 { - let (animation, offset) = Animation::deserialize(&mmap, &bytes[i + 1..]); - i += offset; - animations.push(animation); - } - i += 1; - } - - Self::Img(ImageReq { - transition, - imgs: imgs.into(), - outputs: outputs.into(), - animations: if animations.is_empty() { - None - } else { - Some(animations.into()) - }, - }) - } - _ => Self::Kill, - }; - ret + pub fn receive(msg: RawMsg) -> Self { + msg.into() } } @@ -244,34 +161,8 @@ pub enum Answer { } impl Answer { - pub fn send(&self, stream: &OwnedFd) -> Result<(), String> { - let mut socket_msg = [0u8; 16]; - socket_msg[0..8].copy_from_slice(&match self { - Self::Ok => 0u64.to_ne_bytes(), - Self::Ping(true) => 1u64.to_ne_bytes(), - Self::Ping(false) => 2u64.to_ne_bytes(), - Self::Info(_) => 3u64.to_ne_bytes(), - }); - - let mmap = match self { - Self::Info(infos) => { - let len = 1 + infos.iter().map(|i| i.serialized_size()).sum::(); - let mut mmap = Mmap::create(len); - let bytes = mmap.slice_mut(); - - bytes[0] = infos.len() as u8; - let mut i = 1; - - for info in infos.iter() { - i += info.serialize(&mut bytes[i..]); - } - - Some(mmap) - } - _ => None, - }; - - match send_socket_msg(stream, &mut socket_msg, mmap.as_ref()) { + pub fn send(self, stream: &IpcSocket) -> Result<(), String> { + match stream.send(self.into()) { Ok(true) => Ok(()), Ok(false) => Err("failed to send full length of message in socket!".to_string()), Err(e) => Err(format!("failed to write serialized request: {e}")), @@ -280,27 +171,7 @@ impl Answer { #[must_use] #[inline] - pub fn receive(socket_msg: SocketMsg) -> Self { - match socket_msg.code { - 0 => Self::Ok, - 1 => Self::Ping(true), - 2 => Self::Ping(false), - 3 => { - let mmap = socket_msg.shm.unwrap(); - let bytes = mmap.slice(); - let len = bytes[0] as usize; - let mut bg_infos = Vec::with_capacity(len); - - let mut i = 1; - for _ in 0..len { - let (info, offset) = BgInfo::deserialize(&bytes[i..]); - i += offset; - bg_infos.push(info); - } - - Self::Info(bg_infos.into()) - } - _ => panic!("Received malformed answer from daemon"), - } + pub fn receive(msg: RawMsg) -> Self { + msg.into() } } diff --git a/common/src/ipc/socket.rs b/common/src/ipc/socket.rs index 74271ef..0de2959 100644 --- a/common/src/ipc/socket.rs +++ b/common/src/ipc/socket.rs @@ -6,17 +6,10 @@ use std::time::Duration; use rustix::fd::OwnedFd; use rustix::io::Errno; use rustix::net; -use rustix::net::RecvFlags; use super::ErrnoExt; use super::IpcError; use super::IpcErrorKind; -use super::Mmap; - -pub struct SocketMsg { - pub(super) code: u8, - pub(super) shm: Option, -} /// Represents client in IPC communication, via typestate pattern in [`IpcSocket`] pub struct Client; @@ -138,61 +131,3 @@ impl IpcSocket { Ok(Self::new(socket)) } } - -pub fn read_socket(stream: &OwnedFd) -> Result { - let mut buf = [0u8; 16]; - let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; - - let mut control = net::RecvAncillaryBuffer::new(&mut ancillary_buf); - - let mut tries = 0; - loop { - let iov = rustix::io::IoSliceMut::new(&mut buf); - match net::recvmsg(stream, &mut [iov], &mut control, RecvFlags::WAITALL) { - Ok(_) => break, - Err(e) => { - if e.kind() == std::io::ErrorKind::WouldBlock && tries < 5 { - std::thread::sleep(Duration::from_millis(1)); - } else { - return Err(format!("failed to read serialized length: {e}")); - } - } - } - tries += 1; - } - - let code = u64::from_ne_bytes(buf[0..8].try_into().unwrap()) as u8; - let len = u64::from_ne_bytes(buf[8..16].try_into().unwrap()) as usize; - - let shm = if len == 0 { - None - } else { - let shm_file = match control.drain().next().unwrap() { - net::RecvAncillaryMessage::ScmRights(mut iter) => iter.next().unwrap(), - _ => panic!("malformed ancillary message"), - }; - Some(Mmap::from_fd(shm_file, len)) - }; - Ok(SocketMsg { code, shm }) -} - -pub(super) fn send_socket_msg( - stream: &OwnedFd, - socket_msg: &mut [u8; 16], - mmap: Option<&Mmap>, -) -> rustix::io::Result { - let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; - let mut ancillary = net::SendAncillaryBuffer::new(&mut ancillary_buf); - - let msg_buf; - if let Some(mmap) = mmap.as_ref() { - socket_msg[8..].copy_from_slice(&(mmap.len() as u64).to_ne_bytes()); - msg_buf = [mmap.fd()]; - let msg = net::SendAncillaryMessage::ScmRights(&msg_buf); - ancillary.push(msg); - } - - let iov = rustix::io::IoSlice::new(&socket_msg[..]); - net::sendmsg(stream, &[iov], &mut ancillary, net::SendFlags::empty()) - .map(|written| written == socket_msg.len()) -} diff --git a/common/src/ipc/transmit.rs b/common/src/ipc/transmit.rs new file mode 100644 index 0000000..e8b03c2 --- /dev/null +++ b/common/src/ipc/transmit.rs @@ -0,0 +1,292 @@ +use std::thread; +use std::time::Duration; + +use rustix::io; +use rustix::io::Errno; +use rustix::net; +use rustix::net::RecvFlags; + +use super::Animation; +use super::Answer; +use super::BgInfo; +use super::ClearReq; +use super::ErrnoExt; +use super::ImageReq; +use super::ImgReq; +use super::IpcError; +use super::IpcErrorKind; +use super::IpcSocket; +use super::Mmap; +use super::MmappedStr; +use super::RequestRecv; +use super::RequestSend; +use super::Transition; + +// could be enum +pub struct RawMsg { + code: Code, + shm: Option, +} + +impl From for RawMsg { + fn from(value: RequestSend) -> Self { + let code = match value { + RequestSend::Ping => Code::ReqPing, + RequestSend::Query => Code::ReqQuery, + RequestSend::Clear(_) => Code::ReqClear, + RequestSend::Img(_) => Code::ReqImg, + RequestSend::Kill => Code::ReqKill, + }; + + let shm = match value { + RequestSend::Clear(mem) | RequestSend::Img(mem) => Some(mem), + _ => None, + }; + + Self { code, shm } + } +} + +impl From for RawMsg { + fn from(value: Answer) -> Self { + let code = match value { + Answer::Ok => Code::ResOk, + Answer::Ping(true) => Code::ResPingTrue, + Answer::Ping(false) => Code::ResPingFalse, + Answer::Info(_) => Code::ResInfo, + }; + + let shm = if let Answer::Info(infos) = value { + let len = 1 + infos + .iter() + .map(|info| info.serialized_size()) + .sum::(); + let mut mmap = Mmap::create(len); + let bytes = mmap.slice_mut(); + + bytes[0] = infos.len() as u8; + let mut i = 1; + + for info in infos.iter() { + i += info.serialize(&mut bytes[i..]); + } + + Some(mmap) + } else { + None + }; + + Self { code, shm } + } +} + +// TODO: remove this ugly mess +impl From for RequestRecv { + fn from(value: RawMsg) -> Self { + match value.code { + Code::ReqPing => Self::Ping, + Code::ReqQuery => Self::Query, + Code::ReqClear => { + let mmap = value.shm.unwrap(); + let bytes = mmap.slice(); + let len = bytes[0] as usize; + let mut outputs = Vec::with_capacity(len); + let mut i = 1; + for _ in 0..len { + let output = MmappedStr::new(&mmap, &bytes[i..]); + i += 4 + output.str().len(); + outputs.push(output); + } + let color = [bytes[i], bytes[i + 1], bytes[i + 2]]; + Self::Clear(ClearReq { + color, + outputs: outputs.into(), + }) + } + Code::ReqImg => { + let mmap = value.shm.unwrap(); + let bytes = mmap.slice(); + let transition = Transition::deserialize(&bytes[0..]); + let len = bytes[51] as usize; + + let mut imgs = Vec::with_capacity(len); + let mut outputs = Vec::with_capacity(len); + let mut animations = Vec::with_capacity(len); + + let mut i = 52; + for _ in 0..len { + let (img, offset) = ImgReq::deserialize(&mmap, &bytes[i..]); + i += offset; + imgs.push(img); + + let n_outputs = bytes[i] as usize; + i += 1; + let mut out = Vec::with_capacity(n_outputs); + for _ in 0..n_outputs { + let output = MmappedStr::new(&mmap, &bytes[i..]); + i += 4 + output.str().len(); + out.push(output); + } + outputs.push(out.into()); + + if bytes[i] == 1 { + let (animation, offset) = Animation::deserialize(&mmap, &bytes[i + 1..]); + i += offset; + animations.push(animation); + } + i += 1; + } + + Self::Img(ImageReq { + transition, + imgs: imgs.into(), + outputs: outputs.into(), + animations: if animations.is_empty() { + None + } else { + Some(animations.into()) + }, + }) + } + Code::ReqKill => Self::Kill, + _ => Self::Kill, + } + } +} + +impl From for Answer { + fn from(value: RawMsg) -> Self { + match value.code { + Code::ResOk => Self::Ok, + Code::ResPingTrue => Self::Ping(true), + Code::ResPingFalse => Self::Ping(false), + Code::ResInfo => { + let mmap = value.shm.unwrap(); + let bytes = mmap.slice(); + let len = bytes[0] as usize; + let mut bg_infos = Vec::with_capacity(len); + + let mut i = 1; + for _ in 0..len { + let (info, offset) = BgInfo::deserialize(&bytes[i..]); + i += offset; + bg_infos.push(info); + } + + Self::Info(bg_infos.into()) + } + _ => panic!("Received malformed answer from daemon"), + } + } +} +// TODO: end remove ugly mess block + +macro_rules! code { + ($($name:ident $num:literal),* $(,)?) => { + pub enum Code { + $($name,)* + } + + impl Code { + const fn into(self) -> u64 { + match self { + $(Self::$name => $num,)* + } + } + + const fn from(num: u64) -> Option { + match num { + $($num => Some(Self::$name),)* + _ => None + } + } + } + + }; +} + +code! { + ReqPing 0, + ReqQuery 1, + ReqClear 2, + ReqImg 3, + ReqKill 4, + ResOk 5, + ResPingTrue 6, + ResPingFalse 7, + ResInfo 8, +} + +impl TryFrom for Code { + type Error = IpcError; + fn try_from(value: u64) -> Result { + Self::from(value).ok_or(IpcError::new(IpcErrorKind::BadCode, Errno::DOM)) + } +} + +// TODO: this along with `RawMsg` should be implementation detail +impl IpcSocket { + pub fn send(&self, msg: RawMsg) -> io::Result { + let mut payload = [0u8; 16]; + payload[0..8].copy_from_slice(&msg.code.into().to_ne_bytes()); + + let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; + let mut ancillary = net::SendAncillaryBuffer::new(&mut ancillary_buf); + + let fd; + if let Some(ref mmap) = msg.shm { + payload[8..].copy_from_slice(&(mmap.len() as u64).to_ne_bytes()); + fd = [mmap.fd()]; + let msg = net::SendAncillaryMessage::ScmRights(&fd); + ancillary.push(msg); + } + + let iov = io::IoSlice::new(&payload[..]); + net::sendmsg( + self.as_fd(), + &[iov], + &mut ancillary, + net::SendFlags::empty(), + ) + .map(|written| written == payload.len()) + } + + pub fn recv(&self) -> Result { + let mut buf = [0u8; 16]; + let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; + + let mut control = net::RecvAncillaryBuffer::new(&mut ancillary_buf); + + for _ in 0..5 { + let iov = io::IoSliceMut::new(&mut buf); + match net::recvmsg(self.as_fd(), &mut [iov], &mut control, RecvFlags::WAITALL) { + Ok(_) => break, + Err(Errno::WOULDBLOCK | Errno::INTR) => thread::sleep(Duration::from_millis(1)), + Err(err) => return Err(err).context(IpcErrorKind::Read), + } + } + + let code = u64::from_ne_bytes(buf[0..8].try_into().unwrap()).try_into()?; + let len = u64::from_ne_bytes(buf[8..16].try_into().unwrap()) as usize; + + let shm = if len == 0 { + debug_assert!(matches!( + code, + Code::ReqClear | Code::ReqImg | Code::ResInfo + )); + None + } else { + let file = control + .drain() + .next() + .and_then(|msg| match msg { + net::RecvAncillaryMessage::ScmRights(mut iter) => iter.next(), + _ => None, + }) + .ok_or(Errno::BADMSG) + .context(IpcErrorKind::MalformedMsg)?; + Some(Mmap::from_fd(file, len)) + }; + Ok(RawMsg { code, shm }) + } +} diff --git a/daemon/src/main.rs b/daemon/src/main.rs index ccf4dd9..e551a88 100644 --- a/daemon/src/main.rs +++ b/daemon/src/main.rs @@ -31,8 +31,7 @@ use std::{ }; use common::ipc::{ - read_socket, Answer, BgInfo, ImageReq, IpcSocket, MmappedStr, RequestRecv, RequestSend, Scale, - Server, + Answer, BgInfo, ImageReq, IpcSocket, MmappedStr, RequestRecv, RequestSend, Scale, Server, }; use animations::Animator; @@ -124,8 +123,8 @@ impl Daemon { ))); } - fn recv_socket_msg(&mut self, stream: OwnedFd) { - let bytes = match common::ipc::read_socket(&stream) { + fn recv_socket_msg(&mut self, stream: IpcSocket) { + let bytes = match stream.recv() { Ok(bytes) => bytes, Err(e) => { error!("FATAL: cannot read socket: {e}. Exiting..."); @@ -470,7 +469,8 @@ fn main() -> Result<(), String> { if !fds[1].revents().is_empty() { match rustix::net::accept(&listener.0) { - Ok(stream) => daemon.recv_socket_msg(stream), + // TODO: abstract away explicit socket creation + Ok(stream) => daemon.recv_socket_msg(IpcSocket::new(stream)), Err(rustix::io::Errno::INTR | rustix::io::Errno::WOULDBLOCK) => continue, Err(e) => return Err(format!("failed to accept incoming connection: {e}")), } @@ -643,8 +643,8 @@ pub fn is_daemon_running() -> Result { Err(_) => return Ok(false), }; - RequestSend::Ping.send(sock.as_fd())?; - let answer = Answer::receive(read_socket(sock.as_fd())?); + RequestSend::Ping.send(&sock)?; + let answer = Answer::receive(sock.recv().map_err(|err| err.to_string())?); match answer { Answer::Ping(_) => Ok(true), _ => Err("Daemon did not return Answer::Ping, as expected".to_string()), From e8ddf48b06fd5b7ba9824f8e0b8cf0dec31fe03b Mon Sep 17 00:00:00 2001 From: rkuklik Date: Wed, 19 Jun 2024 09:43:08 +0200 Subject: [PATCH 6/8] refactor(mmap): extract into top level module --- client/src/main.rs | 9 ++++----- common/src/cache.rs | 4 +++- common/src/compression/mod.rs | 6 +++++- common/src/ipc/mod.rs | 3 +-- common/src/ipc/transmit.rs | 4 ++-- common/src/ipc/types.rs | 5 ++++- common/src/lib.rs | 1 + common/src/{ipc => }/mmap.rs | 16 ++++++++++------ daemon/src/main.rs | 5 ++--- daemon/src/wayland/bump_pool.rs | 2 +- 10 files changed, 33 insertions(+), 22 deletions(-) rename common/src/{ipc => }/mmap.rs (97%) diff --git a/client/src/main.rs b/client/src/main.rs index fdb96f7..4a525a0 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -1,10 +1,9 @@ use std::{path::Path, time::Duration}; use clap::Parser; -use common::{ - cache, - ipc::{self, Answer, Client, IpcSocket, RequestSend}, -}; +use common::cache; +use common::ipc::{self, Answer, Client, IpcSocket, RequestSend}; +use common::mmap::Mmap; mod imgproc; use imgproc::*; @@ -111,7 +110,7 @@ fn make_img_request( dims: &[(u32, u32)], pixel_format: ipc::PixelFormat, outputs: &[Vec], -) -> Result { +) -> Result { let transition = make_transition(img); let mut img_req_builder = ipc::ImageRequestBuilder::new(transition); diff --git a/common/src/cache.rs b/common/src/cache.rs index 493e0de..e384f35 100644 --- a/common/src/cache.rs +++ b/common/src/cache.rs @@ -10,7 +10,9 @@ use std::{ path::{Path, PathBuf}, }; -use crate::ipc::{Animation, Mmap, PixelFormat}; +use crate::ipc::Animation; +use crate::ipc::PixelFormat; +use crate::mmap::Mmap; pub(crate) fn store(output_name: &str, img_path: &str) -> io::Result<()> { let mut filepath = cache_dir()?; diff --git a/common/src/compression/mod.rs b/common/src/compression/mod.rs index 369b759..eb27346 100644 --- a/common/src/compression/mod.rs +++ b/common/src/compression/mod.rs @@ -6,7 +6,11 @@ use comp::pack_bytes; use decomp::{unpack_bytes_3channels, unpack_bytes_4channels}; use std::ffi::{c_char, c_int}; -use crate::ipc::{ImageRequestBuilder, Mmap, MmappedBytes, PixelFormat}; +use crate::ipc::ImageRequestBuilder; +use crate::ipc::PixelFormat; +use crate::mmap::Mmap; +use crate::mmap::MmappedBytes; + mod comp; mod cpu; mod decomp; diff --git a/common/src/ipc/mod.rs b/common/src/ipc/mod.rs index 2401ef5..6565261 100644 --- a/common/src/ipc/mod.rs +++ b/common/src/ipc/mod.rs @@ -3,14 +3,13 @@ use std::path::PathBuf; use transmit::RawMsg; mod error; -mod mmap; mod socket; mod transmit; mod types; use crate::cache; +use crate::mmap::Mmap; pub use error::*; -pub use mmap::*; pub use socket::*; pub use types::*; diff --git a/common/src/ipc/transmit.rs b/common/src/ipc/transmit.rs index e8b03c2..8e40bc3 100644 --- a/common/src/ipc/transmit.rs +++ b/common/src/ipc/transmit.rs @@ -16,11 +16,11 @@ use super::ImgReq; use super::IpcError; use super::IpcErrorKind; use super::IpcSocket; -use super::Mmap; -use super::MmappedStr; use super::RequestRecv; use super::RequestSend; use super::Transition; +use crate::mmap::Mmap; +use crate::mmap::MmappedStr; // could be enum pub struct RawMsg { diff --git a/common/src/ipc/types.rs b/common/src/ipc/types.rs index 9de514a..5bc60c5 100644 --- a/common/src/ipc/types.rs +++ b/common/src/ipc/types.rs @@ -5,8 +5,11 @@ use std::{ }; use crate::compression::BitPack; +use crate::mmap::Mmap; +use crate::mmap::MmappedBytes; +use crate::mmap::MmappedStr; -use super::{ImageRequestBuilder, Mmap, MmappedBytes, MmappedStr}; +use super::ImageRequestBuilder; #[derive(Clone, PartialEq)] pub enum Coord { diff --git a/common/src/lib.rs b/common/src/lib.rs index ddc7a86..a7c8271 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,3 +1,4 @@ pub mod cache; pub mod compression; pub mod ipc; +pub mod mmap; diff --git a/common/src/ipc/mmap.rs b/common/src/mmap.rs similarity index 97% rename from common/src/ipc/mmap.rs rename to common/src/mmap.rs index ad4083c..9d2f579 100644 --- a/common/src/ipc/mmap.rs +++ b/common/src/mmap.rs @@ -1,11 +1,15 @@ use std::ptr::NonNull; -use rustix::{ - fd::{AsFd, BorrowedFd, OwnedFd}, - io::Errno, - mm::{mmap, munmap, MapFlags, ProtFlags}, - shm::{Mode, ShmOFlags}, -}; +use rustix::fd::AsFd; +use rustix::fd::BorrowedFd; +use rustix::fd::OwnedFd; +use rustix::io::Errno; +use rustix::mm::mmap; +use rustix::mm::munmap; +use rustix::mm::MapFlags; +use rustix::mm::ProtFlags; +use rustix::shm::Mode; +use rustix::shm::ShmOFlags; #[derive(Debug)] pub struct Mmap { diff --git a/daemon/src/main.rs b/daemon/src/main.rs index e551a88..91575b7 100644 --- a/daemon/src/main.rs +++ b/daemon/src/main.rs @@ -30,9 +30,8 @@ use std::{ }, }; -use common::ipc::{ - Answer, BgInfo, ImageReq, IpcSocket, MmappedStr, RequestRecv, RequestSend, Scale, Server, -}; +use common::ipc::{Answer, BgInfo, ImageReq, IpcSocket, RequestRecv, RequestSend, Scale, Server}; +use common::mmap::MmappedStr; use animations::Animator; diff --git a/daemon/src/wayland/bump_pool.rs b/daemon/src/wayland/bump_pool.rs index 14a2956..a38f99a 100644 --- a/daemon/src/wayland/bump_pool.rs +++ b/daemon/src/wayland/bump_pool.rs @@ -1,6 +1,6 @@ use std::sync::atomic::{AtomicBool, Ordering}; -use common::ipc::Mmap; +use common::mmap::Mmap; use super::{globals, ObjectId}; From f666f3799857e5fcd3d780ddc3867410eed13ee1 Mon Sep 17 00:00:00 2001 From: rkuklik Date: Wed, 19 Jun 2024 11:57:30 +0200 Subject: [PATCH 7/8] refactor(mmap): unify `Mapped'thing'` with generic bool --- common/src/mmap.rs | 97 ++++++++-------------------------------------- 1 file changed, 17 insertions(+), 80 deletions(-) diff --git a/common/src/mmap.rs b/common/src/mmap.rs index 9d2f579..b145be3 100644 --- a/common/src/mmap.rs +++ b/common/src/mmap.rs @@ -251,40 +251,23 @@ fn create_memfd() -> rustix::io::Result { } } -pub struct MmappedBytes { +pub struct Mmapped { base_ptr: NonNull, ptr: NonNull, len: usize, } -impl MmappedBytes { +pub type MmappedBytes = Mmapped; +pub type MmappedStr = Mmapped; + +impl Mmapped { const PROT: ProtFlags = ProtFlags::READ; const FLAGS: MapFlags = MapFlags::SHARED; pub(crate) fn new(map: &Mmap, bytes: &[u8]) -> Self { let len = u32::from_ne_bytes(bytes[0..4].try_into().unwrap()) as usize; - let offset = 4 + bytes.as_ptr() as usize - map.ptr.as_ptr() as usize; - let page_size = rustix::param::page_size(); - let page_offset = offset - offset % page_size; - - let base_ptr = unsafe { - let ptr = mmap( - std::ptr::null_mut(), - len + (offset - page_offset), - Self::PROT, - Self::FLAGS, - &map.fd, - page_offset as u64, - ) - .unwrap(); - // SAFETY: the function above will never return a null pointer if it succeeds - // POSIX says that the implementation will never select an address at 0 - NonNull::new_unchecked(ptr) - }; - let ptr = - unsafe { NonNull::new_unchecked(base_ptr.as_ptr().byte_add(offset - page_offset)) }; - - Self { base_ptr, ptr, len } + let bytes = &bytes[4..]; + Self::new_with_len(map, bytes, len) } pub(crate) fn new_with_len(map: &Mmap, bytes: &[u8], len: usize) -> Self { @@ -309,6 +292,12 @@ impl MmappedBytes { let ptr = unsafe { NonNull::new_unchecked(base_ptr.as_ptr().byte_add(offset - page_offset)) }; + if UTF8 { + // try to parse, panicking if we fail + let s = unsafe { std::slice::from_raw_parts(ptr.as_ptr().cast(), len) }; + let _s = std::str::from_utf8(s).expect("received a non utf8 string from socket"); + } + Self { base_ptr, ptr, len } } @@ -319,59 +308,7 @@ impl MmappedBytes { } } -impl Drop for MmappedBytes { - #[inline] - fn drop(&mut self) { - let len = self.len + self.ptr.as_ptr() as usize - self.base_ptr.as_ptr() as usize; - if let Err(e) = unsafe { munmap(self.base_ptr.as_ptr(), len) } { - eprintln!("ERROR WHEN UNMAPPING MEMORY: {e}"); - } - } -} - -unsafe impl Send for MmappedBytes {} -unsafe impl Sync for MmappedBytes {} - -pub struct MmappedStr { - base_ptr: NonNull, - ptr: NonNull, - len: usize, -} - impl MmappedStr { - const PROT: ProtFlags = ProtFlags::READ; - const FLAGS: MapFlags = MapFlags::SHARED; - - pub(crate) fn new(map: &Mmap, bytes: &[u8]) -> Self { - let len = u32::from_ne_bytes(bytes[0..4].try_into().unwrap()) as usize; - let offset = 4 + bytes.as_ptr() as usize - map.ptr.as_ptr() as usize; - let page_size = rustix::param::page_size(); - let page_offset = offset - offset % page_size; - - let base_ptr = unsafe { - let ptr = mmap( - std::ptr::null_mut(), - len + (offset - page_offset), - Self::PROT, - Self::FLAGS, - &map.fd, - page_offset as u64, - ) - .unwrap(); - // SAFETY: the function above will never return a null pointer if it succeeds - // POSIX says that the implementation will never select an address at 0 - NonNull::new_unchecked(ptr) - }; - let ptr = - unsafe { NonNull::new_unchecked(base_ptr.as_ptr().byte_add(offset - page_offset)) }; - - // try to parse, panicking if we fail - let s = unsafe { std::slice::from_raw_parts(ptr.as_ptr().cast(), len) }; - let _s = std::str::from_utf8(s).expect("received a non utf8 string from socket"); - - Self { base_ptr, ptr, len } - } - #[inline] #[must_use] pub fn str(&self) -> &str { @@ -380,10 +317,7 @@ impl MmappedStr { } } -unsafe impl Send for MmappedStr {} -unsafe impl Sync for MmappedStr {} - -impl Drop for MmappedStr { +impl Drop for Mmapped { #[inline] fn drop(&mut self) { let len = self.len + self.ptr.as_ptr() as usize - self.base_ptr.as_ptr() as usize; @@ -392,3 +326,6 @@ impl Drop for MmappedStr { } } } + +unsafe impl Send for Mmapped {} +unsafe impl Sync for Mmapped {} From 74f5d59d4c8213dd9ee8144b4c692f833f134da7 Mon Sep 17 00:00:00 2001 From: rkuklik Date: Wed, 19 Jun 2024 14:58:29 +0200 Subject: [PATCH 8/8] refactor(mmap): minor fonction movement --- common/src/ipc/transmit.rs | 13 ++-- common/src/mmap.rs | 153 +++++++++++++++++-------------------- 2 files changed, 77 insertions(+), 89 deletions(-) diff --git a/common/src/ipc/transmit.rs b/common/src/ipc/transmit.rs index 8e40bc3..47ecb8a 100644 --- a/common/src/ipc/transmit.rs +++ b/common/src/ipc/transmit.rs @@ -51,8 +51,8 @@ impl From for RawMsg { fn from(value: Answer) -> Self { let code = match value { Answer::Ok => Code::ResOk, - Answer::Ping(true) => Code::ResPingTrue, - Answer::Ping(false) => Code::ResPingFalse, + Answer::Ping(true) => Code::ResConfigured, + Answer::Ping(false) => Code::ResAwait, Answer::Info(_) => Code::ResInfo, }; @@ -158,8 +158,8 @@ impl From for Answer { fn from(value: RawMsg) -> Self { match value.code { Code::ResOk => Self::Ok, - Code::ResPingTrue => Self::Ping(true), - Code::ResPingFalse => Self::Ping(false), + Code::ResConfigured => Self::Ping(true), + Code::ResAwait => Self::Ping(false), Code::ResInfo => { let mmap = value.shm.unwrap(); let bytes = mmap.slice(); @@ -211,9 +211,10 @@ code! { ReqClear 2, ReqImg 3, ReqKill 4, + ResOk 5, - ResPingTrue 6, - ResPingFalse 7, + ResConfigured 6, + ResAwait 7, ResInfo 8, } diff --git a/common/src/mmap.rs b/common/src/mmap.rs index b145be3..87696fa 100644 --- a/common/src/mmap.rs +++ b/common/src/mmap.rs @@ -1,13 +1,19 @@ +use std::iter::repeat_with; use std::ptr::NonNull; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; use rustix::fd::AsFd; use rustix::fd::BorrowedFd; use rustix::fd::OwnedFd; +use rustix::fs; +use rustix::io; use rustix::io::Errno; use rustix::mm::mmap; use rustix::mm::munmap; use rustix::mm::MapFlags; use rustix::mm::ProtFlags; +use rustix::shm; use rustix::shm::Mode; use rustix::shm::ShmOFlags; @@ -26,7 +32,7 @@ impl Mmap { #[inline] #[must_use] pub fn create(len: usize) -> Self { - let fd = create_shm_fd().unwrap(); + let fd = Self::mmap_fd().unwrap(); rustix::io::retry_on_intr(|| rustix::fs::ftruncate(&fd, len as u64)).unwrap(); let ptr = unsafe { @@ -43,6 +49,61 @@ impl Mmap { } } + #[cfg(target_os = "linux")] + fn mmap_fd() -> io::Result { + match Self::memfd() { + Ok(fd) => Ok(fd), + // Not supported, use fallback. + Err(Errno::NOSYS) => Self::shm(), + Err(err) => Err(err), + } + } + + #[cfg(not(target_os = "linux"))] + fn mmap_fd() -> io::Result { + Self::shm() + } + + fn shm() -> io::Result { + let mut filenames = repeat_with(SystemTime::now) + .map(|time| time.duration_since(UNIX_EPOCH).unwrap().subsec_nanos()) + .map(|stamp| format!("/swww-ipc-{stamp}",)); + + let flags = ShmOFlags::CREATE | ShmOFlags::EXCL | ShmOFlags::RDWR; + let mode = Mode::RUSR | Mode::WUSR; + + loop { + let filename = filenames.next().expect("infinite generator"); + match shm::shm_open(filename.as_str(), flags, mode) { + Ok(fd) => return shm::shm_unlink(filename.as_str()).map(|()| fd), + Err(Errno::EXIST | Errno::INTR) => continue, + Err(err) => return Err(err), + } + } + } + + #[cfg(target_os = "linux")] + fn memfd() -> io::Result { + use rustix::fs::MemfdFlags; + use rustix::fs::SealFlags; + use std::ffi::CStr; + + let name = CStr::from_bytes_with_nul(b"swww-ipc\0").unwrap(); + let flags = MemfdFlags::ALLOW_SEALING | MemfdFlags::CLOEXEC; + + loop { + match fs::memfd_create(name, flags) { + Ok(fd) => { + // We only need to seal for the purposes of optimization, ignore the errors. + let _ = fs::fcntl_add_seals(&fd, SealFlags::SHRINK | SealFlags::SEAL); + return Ok(fd); + } + Err(Errno::INTR) => continue, + Err(err) => return Err(err), + } + } + } + #[inline] /// Unmaps without destroying the file descriptor /// @@ -80,32 +141,28 @@ impl Mmap { } #[inline] - pub fn remap(&mut self, new_len: usize) { - rustix::io::retry_on_intr(|| rustix::fs::ftruncate(&self.fd, new_len as u64)).unwrap(); + pub fn remap(&mut self, new: usize) { + io::retry_on_intr(|| fs::ftruncate(&self.fd, new as u64)).unwrap(); #[cfg(target_os = "linux")] { - let result = unsafe { - rustix::mm::mremap( - self.ptr.as_ptr(), - self.len, - new_len, - rustix::mm::MremapFlags::MAYMOVE, - ) - }; + use rustix::mm; + + let result = + unsafe { mm::mremap(self.ptr.as_ptr(), self.len, new, mm::MremapFlags::MAYMOVE) }; if let Ok(ptr) = result { // SAFETY: the mremap above will never return a null pointer if it succeeds let ptr = unsafe { NonNull::new_unchecked(ptr) }; self.ptr = ptr; - self.len = new_len; + self.len = new; return; } } self.unmap(); - self.len = new_len; + self.len = new; self.ptr = unsafe { let ptr = mmap( std::ptr::null_mut(), @@ -181,76 +238,6 @@ impl Drop for Mmap { } } -fn create_shm_fd() -> std::io::Result { - #[cfg(target_os = "linux")] - { - match create_memfd() { - Ok(fd) => return Ok(fd), - // Not supported, use fallback. - Err(Errno::NOSYS) => (), - Err(err) => return Err(err.into()), - }; - } - - let time = std::time::SystemTime::now(); - let mut mem_file_handle = format!( - "/swww-ipc-{}", - time.duration_since(std::time::UNIX_EPOCH) - .unwrap() - .subsec_nanos() - ); - - let flags = ShmOFlags::CREATE | ShmOFlags::EXCL | ShmOFlags::RDWR; - let mode = Mode::RUSR | Mode::WUSR; - loop { - match rustix::shm::shm_open(mem_file_handle.as_str(), flags, mode) { - Ok(fd) => match rustix::shm::shm_unlink(mem_file_handle.as_str()) { - Ok(_) => return Ok(fd), - - Err(errno) => { - return Err(errno.into()); - } - }, - Err(Errno::EXIST) => { - // Change the handle if we happen to be duplicate. - let time = std::time::SystemTime::now(); - - mem_file_handle = format!( - "/swww-ipc-{}", - time.duration_since(std::time::UNIX_EPOCH) - .unwrap() - .subsec_nanos() - ); - - continue; - } - Err(Errno::INTR) => continue, - Err(err) => return Err(err.into()), - } - } -} - -#[cfg(target_os = "linux")] -fn create_memfd() -> rustix::io::Result { - use rustix::fs::{MemfdFlags, SealFlags}; - use std::ffi::CStr; - - let name = CStr::from_bytes_with_nul(b"swww-ipc\0").unwrap(); - let flags = MemfdFlags::ALLOW_SEALING | MemfdFlags::CLOEXEC; - - loop { - match rustix::fs::memfd_create(name, flags) { - Ok(fd) => { - // We only need to seal for the purposes of optimization, ignore the errors. - let _ = rustix::fs::fcntl_add_seals(&fd, SealFlags::SHRINK | SealFlags::SEAL); - return Ok(fd); - } - Err(Errno::INTR) => continue, - Err(err) => return Err(err), - } - } -} - pub struct Mmapped { base_ptr: NonNull, ptr: NonNull,