diff --git a/Cargo.toml b/Cargo.toml index a547d77c..d16be66b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,6 @@ [workspace] +# cargo complains that this defaults to one in virtual package manifests (for some reason) +resolver = "2" members = ["client", "daemon", "common"] default-members = ["client", "daemon"] diff --git a/client/src/main.rs b/client/src/main.rs index f49f86a2..4a525a02 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -1,10 +1,9 @@ -use clap::Parser; -use std::time::Duration; +use std::{path::Path, time::Duration}; -use common::{ - cache, - ipc::{self, connect_to_socket, get_socket_path, read_socket, Answer, RequestSend}, -}; +use clap::Parser; +use common::cache; +use common::ipc::{self, Answer, Client, IpcSocket, RequestSend}; +use common::mmap::Mmap; mod imgproc; use imgproc::*; @@ -19,10 +18,10 @@ fn main() -> Result<(), String> { return cache::clean().map_err(|e| format!("failed to clean the cache: {e}")); } + let socket = IpcSocket::connect().map_err(|err| err.to_string())?; loop { - let socket = connect_to_socket(&get_socket_path(), 5, 100)?; RequestSend::Ping.send(&socket)?; - let bytes = read_socket(&socket)?; + let bytes = socket.recv().map_err(|err| err.to_string())?; let answer = Answer::receive(bytes); if let Answer::Ping(configured) = answer { if configured { @@ -42,12 +41,11 @@ fn process_swww_args(args: &Swww) -> Result<(), String> { Some(request) => request, None => return Ok(()), }; - let socket = connect_to_socket(&get_socket_path(), 5, 100)?; + let socket = IpcSocket::connect().map_err(|err| err.to_string())?; request.send(&socket)?; - let bytes = read_socket(&socket)?; + let bytes = socket.recv().map_err(|err| err.to_string())?; drop(socket); match Answer::receive(bytes) { - Answer::Err(msg) => return Err(msg.to_string()), Answer::Info(info) => info.iter().for_each(|i| println!("{}", i)), Answer::Ok => { if let Swww::Kill = args { @@ -55,16 +53,15 @@ fn process_swww_args(args: &Swww) -> Result<(), String> { let tries = 20; #[cfg(not(debug_assertions))] let tries = 10; - let socket_path = get_socket_path(); + let path = IpcSocket::::path(); + let path = Path::new(path); for _ in 0..tries { - if !socket_path.exists() { + if !path.exists() { return Ok(()); } std::thread::sleep(Duration::from_millis(100)); } - return Err(format!( - "Could not confirm socket deletion at: {socket_path:?}" - )); + return Err(format!("Could not confirm socket deletion at: {path:?}")); } } Answer::Ping(_) => { @@ -113,7 +110,7 @@ fn make_img_request( dims: &[(u32, u32)], pixel_format: ipc::PixelFormat, outputs: &[Vec], -) -> Result { +) -> Result { let transition = make_transition(img); let mut img_req_builder = ipc::ImageRequestBuilder::new(transition); @@ -214,9 +211,9 @@ fn get_format_dims_and_outputs( let mut dims: Vec<(u32, u32)> = Vec::new(); let mut imgs: Vec = Vec::new(); - let socket = connect_to_socket(&get_socket_path(), 5, 100)?; + let socket = IpcSocket::connect().map_err(|err| err.to_string())?; RequestSend::Query.send(&socket)?; - let bytes = read_socket(&socket)?; + let bytes = socket.recv().map_err(|err| err.to_string())?; drop(socket); let answer = Answer::receive(bytes); match answer { @@ -249,7 +246,6 @@ fn get_format_dims_and_outputs( Ok((format, dims, outputs)) } } - Answer::Err(e) => Err(format!("daemon error when sending query: {e}")), _ => unreachable!(), } } diff --git a/common/src/cache.rs b/common/src/cache.rs index 493e0de7..e384f357 100644 --- a/common/src/cache.rs +++ b/common/src/cache.rs @@ -10,7 +10,9 @@ use std::{ path::{Path, PathBuf}, }; -use crate::ipc::{Animation, Mmap, PixelFormat}; +use crate::ipc::Animation; +use crate::ipc::PixelFormat; +use crate::mmap::Mmap; pub(crate) fn store(output_name: &str, img_path: &str) -> io::Result<()> { let mut filepath = cache_dir()?; diff --git a/common/src/compression/mod.rs b/common/src/compression/mod.rs index 369b7599..eb273465 100644 --- a/common/src/compression/mod.rs +++ b/common/src/compression/mod.rs @@ -6,7 +6,11 @@ use comp::pack_bytes; use decomp::{unpack_bytes_3channels, unpack_bytes_4channels}; use std::ffi::{c_char, c_int}; -use crate::ipc::{ImageRequestBuilder, Mmap, MmappedBytes, PixelFormat}; +use crate::ipc::ImageRequestBuilder; +use crate::ipc::PixelFormat; +use crate::mmap::Mmap; +use crate::mmap::MmappedBytes; + mod comp; mod cpu; mod decomp; diff --git a/common/src/ipc/error.rs b/common/src/ipc/error.rs new file mode 100644 index 00000000..247be93e --- /dev/null +++ b/common/src/ipc/error.rs @@ -0,0 +1,87 @@ +use std::error::Error; +use std::fmt; + +use rustix::io::Errno; + +/// Failiures if IPC with added context +#[derive(Debug)] +pub struct IpcError { + err: Errno, + kind: IpcErrorKind, +} + +impl IpcError { + pub(crate) fn new(kind: IpcErrorKind, err: Errno) -> Self { + Self { err, kind } + } +} + +#[derive(Debug)] +pub enum IpcErrorKind { + /// Failed to create file descriptor + Socket, + /// Failed to connect to socket + Connect, + /// Binding on socket failed + Bind, + /// Listening on socket failed + Listen, + /// Socket file wasn't found + NoSocketFile, + /// Socket timeout couldn't be set + SetTimeout, + /// IPC contained invalid identification code + BadCode, + /// IPC payload was broken + MalformedMsg, + /// Reading socket failed + Read, +} + +impl IpcErrorKind { + fn description(&self) -> &'static str { + match self { + Self::Socket => "failed to create socket file descriptor", + Self::Connect => "failed to connect to socket", + Self::Bind => "failed to bind to socket", + Self::Listen => "failed to listen on socket", + Self::NoSocketFile => "Socket file not found. Are you sure swww-daemon is running?", + Self::SetTimeout => "failed to set read timeout for socket", + Self::BadCode => "invalid message code", + Self::MalformedMsg => "malformed ancillary message", + Self::Read => "failed to receive message", + } + } +} + +impl fmt::Display for IpcError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.kind.description()) + } +} + +impl Error for IpcError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.err) + } +} + +/// Simplify generating [`IpcError`]s from [`Errno`] +pub(crate) trait ErrnoExt { + type Output; + fn context(self, kind: IpcErrorKind) -> Self::Output; +} + +impl ErrnoExt for Errno { + type Output = IpcError; + fn context(self, kind: IpcErrorKind) -> Self::Output { + IpcError::new(kind, self) + } +} + +impl ErrnoExt for Result { + type Output = Result; + fn context(self, kind: IpcErrorKind) -> Self::Output { + self.map_err(|error| error.context(kind)) + } +} diff --git a/common/src/ipc/mod.rs b/common/src/ipc/mod.rs index e8511c5e..65652619 100644 --- a/common/src/ipc/mod.rs +++ b/common/src/ipc/mod.rs @@ -1,13 +1,15 @@ use std::path::PathBuf; -use rustix::fd::OwnedFd; +use transmit::RawMsg; -mod mmap; +mod error; mod socket; +mod transmit; mod types; use crate::cache; -pub use mmap::*; +use crate::mmap::Mmap; +pub use error::*; pub use socket::*; pub use types::*; @@ -134,104 +136,20 @@ pub enum RequestRecv { } impl RequestSend { - pub fn send(&self, stream: &OwnedFd) -> Result<(), String> { - let mut socket_msg = [0u8; 16]; - socket_msg[0..8].copy_from_slice(&match self { - Self::Ping => 0u64.to_ne_bytes(), - Self::Query => 1u64.to_ne_bytes(), - Self::Clear(_) => 2u64.to_ne_bytes(), - Self::Img(_) => 3u64.to_ne_bytes(), - Self::Kill => 4u64.to_ne_bytes(), - }); - - let mmap = match self { - Self::Clear(clear) => Some(clear), - Self::Img(img) => Some(img), - _ => None, - }; - - match send_socket_msg(stream, &mut socket_msg, mmap) { - Ok(true) => (), - Ok(false) => return Err("failed to send full length of message in socket!".to_string()), - Err(e) => return Err(format!("failed to write serialized request: {e}")), + pub fn send(self, stream: &IpcSocket) -> Result<(), String> { + match stream.send(self.into()) { + Ok(true) => Ok(()), + Ok(false) => Err("failed to send full length of message in socket!".to_string()), + Err(e) => Err(format!("failed to write serialized request: {e}")), } - - Ok(()) } } impl RequestRecv { #[must_use] #[inline] - pub fn receive(socket_msg: SocketMsg) -> Self { - let ret = match socket_msg.code { - 0 => Self::Ping, - 1 => Self::Query, - 2 => { - let mmap = socket_msg.shm.unwrap(); - let bytes = mmap.slice(); - let len = bytes[0] as usize; - let mut outputs = Vec::with_capacity(len); - let mut i = 1; - for _ in 0..len { - let output = MmappedStr::new(&mmap, &bytes[i..]); - i += 4 + output.str().len(); - outputs.push(output); - } - let color = [bytes[i], bytes[i + 1], bytes[i + 2]]; - Self::Clear(ClearReq { - color, - outputs: outputs.into(), - }) - } - 3 => { - let mmap = socket_msg.shm.unwrap(); - let bytes = mmap.slice(); - let transition = Transition::deserialize(&bytes[0..]); - let len = bytes[51] as usize; - - let mut imgs = Vec::with_capacity(len); - let mut outputs = Vec::with_capacity(len); - let mut animations = Vec::with_capacity(len); - - let mut i = 52; - for _ in 0..len { - let (img, offset) = ImgReq::deserialize(&mmap, &bytes[i..]); - i += offset; - imgs.push(img); - - let n_outputs = bytes[i] as usize; - i += 1; - let mut out = Vec::with_capacity(n_outputs); - for _ in 0..n_outputs { - let output = MmappedStr::new(&mmap, &bytes[i..]); - i += 4 + output.str().len(); - out.push(output); - } - outputs.push(out.into()); - - if bytes[i] == 1 { - let (animation, offset) = Animation::deserialize(&mmap, &bytes[i + 1..]); - i += offset; - animations.push(animation); - } - i += 1; - } - - Self::Img(ImageReq { - transition, - imgs: imgs.into(), - outputs: outputs.into(), - animations: if animations.is_empty() { - None - } else { - Some(animations.into()) - }, - }) - } - _ => Self::Kill, - }; - ret + pub fn receive(msg: RawMsg) -> Self { + msg.into() } } @@ -239,47 +157,11 @@ pub enum Answer { Ok, Ping(bool), Info(Box<[BgInfo]>), - Err(String), } impl Answer { - pub fn send(&self, stream: &OwnedFd) -> Result<(), String> { - let mut socket_msg = [0u8; 16]; - socket_msg[0..8].copy_from_slice(&match self { - Self::Ok => 0u64.to_ne_bytes(), - Self::Ping(true) => 1u64.to_ne_bytes(), - Self::Ping(false) => 2u64.to_ne_bytes(), - Self::Info(_) => 3u64.to_ne_bytes(), - Self::Err(_) => 4u64.to_ne_bytes(), - }); - - let mmap = match self { - Self::Info(infos) => { - let len = 1 + infos.iter().map(|i| i.serialized_size()).sum::(); - let mut mmap = Mmap::create(len); - let bytes = mmap.slice_mut(); - - bytes[0] = infos.len() as u8; - let mut i = 1; - - for info in infos.iter() { - i += info.serialize(&mut bytes[i..]); - } - - Some(mmap) - } - Self::Err(s) => { - let len = 4 + s.len(); - let mut mmap = Mmap::create(len); - let bytes = mmap.slice_mut(); - bytes[0..4].copy_from_slice(&(s.as_bytes().len() as u32).to_ne_bytes()); - bytes[4..len].copy_from_slice(s.as_bytes()); - Some(mmap) - } - _ => None, - }; - - match send_socket_msg(stream, &mut socket_msg, mmap.as_ref()) { + pub fn send(self, stream: &IpcSocket) -> Result<(), String> { + match stream.send(self.into()) { Ok(true) => Ok(()), Ok(false) => Err("failed to send full length of message in socket!".to_string()), Err(e) => Err(format!("failed to write serialized request: {e}")), @@ -288,36 +170,7 @@ impl Answer { #[must_use] #[inline] - pub fn receive(socket_msg: SocketMsg) -> Self { - match socket_msg.code { - 0 => Self::Ok, - 1 => Self::Ping(true), - 2 => Self::Ping(false), - 3 => { - let mmap = socket_msg.shm.unwrap(); - let bytes = mmap.slice(); - let len = bytes[0] as usize; - let mut bg_infos = Vec::with_capacity(len); - - let mut i = 1; - for _ in 0..len { - let (info, offset) = BgInfo::deserialize(&bytes[i..]); - i += offset; - bg_infos.push(info); - } - - Self::Info(bg_infos.into()) - } - 4 => { - let mmap = socket_msg.shm.unwrap(); - let bytes = mmap.slice(); - let size = u32::from_ne_bytes(bytes[0..4].try_into().unwrap()) as usize; - let s = std::str::from_utf8(&bytes[4..4 + size]) - .expect("received a non utf8 string from socket") - .to_string(); - Self::Err(s) - } - _ => panic!("Received malformed answer from daemon"), - } + pub fn receive(msg: RawMsg) -> Self { + msg.into() } } diff --git a/common/src/ipc/socket.rs b/common/src/ipc/socket.rs index 79352185..0de2959c 100644 --- a/common/src/ipc/socket.rs +++ b/common/src/ipc/socket.rs @@ -1,141 +1,133 @@ -use std::{path::PathBuf, time::Duration}; - -use rustix::{ - fd::OwnedFd, - net::{self, RecvFlags}, -}; - -use super::Mmap; - -pub struct SocketMsg { - pub(super) code: u8, - pub(super) shm: Option, +use std::env; +use std::marker::PhantomData; +use std::sync::OnceLock; +use std::time::Duration; + +use rustix::fd::OwnedFd; +use rustix::io::Errno; +use rustix::net; + +use super::ErrnoExt; +use super::IpcError; +use super::IpcErrorKind; + +/// Represents client in IPC communication, via typestate pattern in [`IpcSocket`] +pub struct Client; +/// Represents server in IPC communication, via typestate pattern in [`IpcSocket`] +pub struct Server; + +/// Typesafe handle for socket facilitating communication between [`Client`] and [`Server`] +pub struct IpcSocket { + fd: OwnedFd, + phantom: PhantomData, } -pub fn read_socket(stream: &OwnedFd) -> Result { - let mut buf = [0u8; 16]; - let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; - - let mut control = net::RecvAncillaryBuffer::new(&mut ancillary_buf); - - let mut tries = 0; - loop { - let iov = rustix::io::IoSliceMut::new(&mut buf); - match net::recvmsg(stream, &mut [iov], &mut control, RecvFlags::WAITALL) { - Ok(_) => break, - Err(e) => { - if e.kind() == std::io::ErrorKind::WouldBlock && tries < 5 { - std::thread::sleep(Duration::from_millis(1)); - } else { - return Err(format!("failed to read serialized length: {e}")); - } - } +impl IpcSocket { + /// Creates new [`IpcSocket`] from provided [`OwnedFd`] + /// + /// TODO: remove external ability to construct [`Self`] from random file descriptors + pub fn new(fd: OwnedFd) -> Self { + Self { + fd, + phantom: PhantomData, } - tries += 1; } - let code = u64::from_ne_bytes(buf[0..8].try_into().unwrap()) as u8; - let len = u64::from_ne_bytes(buf[8..16].try_into().unwrap()) as usize; - - let shm = if len == 0 { - None - } else { - let shm_file = match control.drain().next().unwrap() { - net::RecvAncillaryMessage::ScmRights(mut iter) => iter.next().unwrap(), - _ => panic!("malformed ancillary message"), - }; - Some(Mmap::from_fd(shm_file, len)) - }; - Ok(SocketMsg { code, shm }) -} - -pub(super) fn send_socket_msg( - stream: &OwnedFd, - socket_msg: &mut [u8; 16], - mmap: Option<&Mmap>, -) -> rustix::io::Result { - let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; - let mut ancillary = net::SendAncillaryBuffer::new(&mut ancillary_buf); - - let msg_buf; - if let Some(mmap) = mmap.as_ref() { - socket_msg[8..].copy_from_slice(&(mmap.len() as u64).to_ne_bytes()); - msg_buf = [mmap.fd()]; - let msg = net::SendAncillaryMessage::ScmRights(&msg_buf); - ancillary.push(msg); + pub fn to_fd(self) -> OwnedFd { + self.fd } - let iov = rustix::io::IoSlice::new(&socket_msg[..]); - net::sendmsg(stream, &[iov], &mut ancillary, net::SendFlags::empty()) - .map(|written| written == socket_msg.len()) -} + fn socket_file() -> String { + let runtime = env::var("XDG_RUNTIME_DIR"); + let display = env::var("WAYLAND_DISPLAY"); -#[must_use] -pub fn get_socket_path() -> PathBuf { - let runtime_dir = if let Ok(dir) = std::env::var("XDG_RUNTIME_DIR") { - dir - } else { - "/tmp/swww".to_string() - }; - - let mut socket_path = PathBuf::new(); - socket_path.push(runtime_dir); - - let mut socket_name = String::new(); - socket_name.push_str("swww-"); - if let Ok(socket) = std::env::var("WAYLAND_DISPLAY") { - socket_name.push_str(socket.as_str()); - } else { - socket_name.push_str("wayland-0") + let runtime = runtime.as_deref().unwrap_or("/tmp/swww"); + let display = display.as_deref().unwrap_or("wayland-0"); + + format!("{runtime}/swww-{display}.socket") } - socket_name.push_str(".socket"); - socket_path.push(socket_name); + /// Retreives path to socket file + /// + /// To treat this as filesystem path, wrap it in [`Path`]. + /// If you get errors with missing generics, you can shove any type as `T`, but + /// [`Client`] or [`Server`] are recommended. + /// + /// [`Path`]: std::path::Path + #[must_use] + pub fn path() -> &'static str { + static PATH: OnceLock = OnceLock::new(); + PATH.get_or_init(Self::socket_file) + } - socket_path + #[must_use] + pub fn as_fd(&self) -> &OwnedFd { + &self.fd + } } -/// We make sure the Stream is always set to blocking mode -/// -/// * `tries` - how many times to attempt the connection -/// * `interval` - how long to wait between attempts, in milliseconds -pub fn connect_to_socket(addr: &PathBuf, tries: u8, interval: u64) -> Result { - let socket = rustix::net::socket_with( - rustix::net::AddressFamily::UNIX, - rustix::net::SocketType::STREAM, - rustix::net::SocketFlags::CLOEXEC, - None, - ) - .expect("failed to create socket file descriptor"); - let addr = net::SocketAddrUnix::new(addr).unwrap(); - //Make sure we try at least once - let tries = if tries == 0 { 1 } else { tries }; - let mut error = None; - for _ in 0..tries { - match net::connect_unix(&socket, &addr) { - Ok(()) => { - #[cfg(debug_assertions)] - let timeout = Duration::from_secs(30); //Some operations take a while to respond in debug mode - #[cfg(not(debug_assertions))] - let timeout = Duration::from_secs(5); - if let Err(e) = net::sockopt::set_socket_timeout( - &socket, - net::sockopt::Timeout::Recv, - Some(timeout), - ) { - return Err(format!("failed to set read timeout for socket: {e}")); +impl IpcSocket { + /// Connects to already running `Daemon`, if there is one. + pub fn connect() -> Result { + // these were hardcoded everywhere, no point in passing them around + let tries = 5; + let interval = 100; + + let socket = net::socket_with( + net::AddressFamily::UNIX, + net::SocketType::STREAM, + net::SocketFlags::CLOEXEC, + None, + ) + .context(IpcErrorKind::Socket)?; + + let addr = net::SocketAddrUnix::new(Self::path()).expect("addr is correct"); + + // this will be overwriten, Rust just doesn't know it + let mut error = Errno::INVAL; + for _ in 0..tries { + match net::connect_unix(&socket, &addr) { + Ok(()) => { + #[cfg(debug_assertions)] + let timeout = Duration::from_secs(30); //Some operations take a while to respond in debug mode + #[cfg(not(debug_assertions))] + let timeout = Duration::from_secs(5); + return net::sockopt::set_socket_timeout( + &socket, + net::sockopt::Timeout::Recv, + Some(timeout), + ) + .context(IpcErrorKind::SetTimeout) + .map(|()| Self::new(socket)); } - - return Ok(socket); + Err(e) => error = e, } - Err(e) => error = Some(e), + std::thread::sleep(Duration::from_millis(interval)); } - std::thread::sleep(Duration::from_millis(interval)); - } - let error = error.unwrap(); - if error.kind() == std::io::ErrorKind::NotFound { - return Err("Socket file not found. Are you sure swww-daemon is running?".to_string()); + + let kind = if error.kind() == std::io::ErrorKind::NotFound { + IpcErrorKind::NoSocketFile + } else { + IpcErrorKind::Connect + }; + + Err(error.context(kind)) } +} - Err(format!("Failed to connect to socket: {error}")) +impl IpcSocket { + /// Creates [`IpcSocket`] for use in server (i.e `Daemon`) + pub fn server() -> Result { + let addr = net::SocketAddrUnix::new(Self::path()).expect("addr is correct"); + let socket = net::socket_with( + net::AddressFamily::UNIX, + net::SocketType::STREAM, + net::SocketFlags::CLOEXEC.union(rustix::net::SocketFlags::NONBLOCK), + None, + ) + .context(IpcErrorKind::Socket)?; + net::bind_unix(&socket, &addr).context(IpcErrorKind::Bind)?; + net::listen(&socket, 0).context(IpcErrorKind::Listen)?; + Ok(Self::new(socket)) + } } diff --git a/common/src/ipc/transmit.rs b/common/src/ipc/transmit.rs new file mode 100644 index 00000000..47ecb8ab --- /dev/null +++ b/common/src/ipc/transmit.rs @@ -0,0 +1,293 @@ +use std::thread; +use std::time::Duration; + +use rustix::io; +use rustix::io::Errno; +use rustix::net; +use rustix::net::RecvFlags; + +use super::Animation; +use super::Answer; +use super::BgInfo; +use super::ClearReq; +use super::ErrnoExt; +use super::ImageReq; +use super::ImgReq; +use super::IpcError; +use super::IpcErrorKind; +use super::IpcSocket; +use super::RequestRecv; +use super::RequestSend; +use super::Transition; +use crate::mmap::Mmap; +use crate::mmap::MmappedStr; + +// could be enum +pub struct RawMsg { + code: Code, + shm: Option, +} + +impl From for RawMsg { + fn from(value: RequestSend) -> Self { + let code = match value { + RequestSend::Ping => Code::ReqPing, + RequestSend::Query => Code::ReqQuery, + RequestSend::Clear(_) => Code::ReqClear, + RequestSend::Img(_) => Code::ReqImg, + RequestSend::Kill => Code::ReqKill, + }; + + let shm = match value { + RequestSend::Clear(mem) | RequestSend::Img(mem) => Some(mem), + _ => None, + }; + + Self { code, shm } + } +} + +impl From for RawMsg { + fn from(value: Answer) -> Self { + let code = match value { + Answer::Ok => Code::ResOk, + Answer::Ping(true) => Code::ResConfigured, + Answer::Ping(false) => Code::ResAwait, + Answer::Info(_) => Code::ResInfo, + }; + + let shm = if let Answer::Info(infos) = value { + let len = 1 + infos + .iter() + .map(|info| info.serialized_size()) + .sum::(); + let mut mmap = Mmap::create(len); + let bytes = mmap.slice_mut(); + + bytes[0] = infos.len() as u8; + let mut i = 1; + + for info in infos.iter() { + i += info.serialize(&mut bytes[i..]); + } + + Some(mmap) + } else { + None + }; + + Self { code, shm } + } +} + +// TODO: remove this ugly mess +impl From for RequestRecv { + fn from(value: RawMsg) -> Self { + match value.code { + Code::ReqPing => Self::Ping, + Code::ReqQuery => Self::Query, + Code::ReqClear => { + let mmap = value.shm.unwrap(); + let bytes = mmap.slice(); + let len = bytes[0] as usize; + let mut outputs = Vec::with_capacity(len); + let mut i = 1; + for _ in 0..len { + let output = MmappedStr::new(&mmap, &bytes[i..]); + i += 4 + output.str().len(); + outputs.push(output); + } + let color = [bytes[i], bytes[i + 1], bytes[i + 2]]; + Self::Clear(ClearReq { + color, + outputs: outputs.into(), + }) + } + Code::ReqImg => { + let mmap = value.shm.unwrap(); + let bytes = mmap.slice(); + let transition = Transition::deserialize(&bytes[0..]); + let len = bytes[51] as usize; + + let mut imgs = Vec::with_capacity(len); + let mut outputs = Vec::with_capacity(len); + let mut animations = Vec::with_capacity(len); + + let mut i = 52; + for _ in 0..len { + let (img, offset) = ImgReq::deserialize(&mmap, &bytes[i..]); + i += offset; + imgs.push(img); + + let n_outputs = bytes[i] as usize; + i += 1; + let mut out = Vec::with_capacity(n_outputs); + for _ in 0..n_outputs { + let output = MmappedStr::new(&mmap, &bytes[i..]); + i += 4 + output.str().len(); + out.push(output); + } + outputs.push(out.into()); + + if bytes[i] == 1 { + let (animation, offset) = Animation::deserialize(&mmap, &bytes[i + 1..]); + i += offset; + animations.push(animation); + } + i += 1; + } + + Self::Img(ImageReq { + transition, + imgs: imgs.into(), + outputs: outputs.into(), + animations: if animations.is_empty() { + None + } else { + Some(animations.into()) + }, + }) + } + Code::ReqKill => Self::Kill, + _ => Self::Kill, + } + } +} + +impl From for Answer { + fn from(value: RawMsg) -> Self { + match value.code { + Code::ResOk => Self::Ok, + Code::ResConfigured => Self::Ping(true), + Code::ResAwait => Self::Ping(false), + Code::ResInfo => { + let mmap = value.shm.unwrap(); + let bytes = mmap.slice(); + let len = bytes[0] as usize; + let mut bg_infos = Vec::with_capacity(len); + + let mut i = 1; + for _ in 0..len { + let (info, offset) = BgInfo::deserialize(&bytes[i..]); + i += offset; + bg_infos.push(info); + } + + Self::Info(bg_infos.into()) + } + _ => panic!("Received malformed answer from daemon"), + } + } +} +// TODO: end remove ugly mess block + +macro_rules! code { + ($($name:ident $num:literal),* $(,)?) => { + pub enum Code { + $($name,)* + } + + impl Code { + const fn into(self) -> u64 { + match self { + $(Self::$name => $num,)* + } + } + + const fn from(num: u64) -> Option { + match num { + $($num => Some(Self::$name),)* + _ => None + } + } + } + + }; +} + +code! { + ReqPing 0, + ReqQuery 1, + ReqClear 2, + ReqImg 3, + ReqKill 4, + + ResOk 5, + ResConfigured 6, + ResAwait 7, + ResInfo 8, +} + +impl TryFrom for Code { + type Error = IpcError; + fn try_from(value: u64) -> Result { + Self::from(value).ok_or(IpcError::new(IpcErrorKind::BadCode, Errno::DOM)) + } +} + +// TODO: this along with `RawMsg` should be implementation detail +impl IpcSocket { + pub fn send(&self, msg: RawMsg) -> io::Result { + let mut payload = [0u8; 16]; + payload[0..8].copy_from_slice(&msg.code.into().to_ne_bytes()); + + let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; + let mut ancillary = net::SendAncillaryBuffer::new(&mut ancillary_buf); + + let fd; + if let Some(ref mmap) = msg.shm { + payload[8..].copy_from_slice(&(mmap.len() as u64).to_ne_bytes()); + fd = [mmap.fd()]; + let msg = net::SendAncillaryMessage::ScmRights(&fd); + ancillary.push(msg); + } + + let iov = io::IoSlice::new(&payload[..]); + net::sendmsg( + self.as_fd(), + &[iov], + &mut ancillary, + net::SendFlags::empty(), + ) + .map(|written| written == payload.len()) + } + + pub fn recv(&self) -> Result { + let mut buf = [0u8; 16]; + let mut ancillary_buf = [0u8; rustix::cmsg_space!(ScmRights(1))]; + + let mut control = net::RecvAncillaryBuffer::new(&mut ancillary_buf); + + for _ in 0..5 { + let iov = io::IoSliceMut::new(&mut buf); + match net::recvmsg(self.as_fd(), &mut [iov], &mut control, RecvFlags::WAITALL) { + Ok(_) => break, + Err(Errno::WOULDBLOCK | Errno::INTR) => thread::sleep(Duration::from_millis(1)), + Err(err) => return Err(err).context(IpcErrorKind::Read), + } + } + + let code = u64::from_ne_bytes(buf[0..8].try_into().unwrap()).try_into()?; + let len = u64::from_ne_bytes(buf[8..16].try_into().unwrap()) as usize; + + let shm = if len == 0 { + debug_assert!(matches!( + code, + Code::ReqClear | Code::ReqImg | Code::ResInfo + )); + None + } else { + let file = control + .drain() + .next() + .and_then(|msg| match msg { + net::RecvAncillaryMessage::ScmRights(mut iter) => iter.next(), + _ => None, + }) + .ok_or(Errno::BADMSG) + .context(IpcErrorKind::MalformedMsg)?; + Some(Mmap::from_fd(file, len)) + }; + Ok(RawMsg { code, shm }) + } +} diff --git a/common/src/ipc/types.rs b/common/src/ipc/types.rs index 9de514af..5bc60c5a 100644 --- a/common/src/ipc/types.rs +++ b/common/src/ipc/types.rs @@ -5,8 +5,11 @@ use std::{ }; use crate::compression::BitPack; +use crate::mmap::Mmap; +use crate::mmap::MmappedBytes; +use crate::mmap::MmappedStr; -use super::{ImageRequestBuilder, Mmap, MmappedBytes, MmappedStr}; +use super::ImageRequestBuilder; #[derive(Clone, PartialEq)] pub enum Coord { diff --git a/common/src/lib.rs b/common/src/lib.rs index ddc7a86d..a7c8271e 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,3 +1,4 @@ pub mod cache; pub mod compression; pub mod ipc; +pub mod mmap; diff --git a/common/src/ipc/mmap.rs b/common/src/mmap.rs similarity index 54% rename from common/src/ipc/mmap.rs rename to common/src/mmap.rs index ad4083cb..87696fa1 100644 --- a/common/src/ipc/mmap.rs +++ b/common/src/mmap.rs @@ -1,11 +1,21 @@ +use std::iter::repeat_with; use std::ptr::NonNull; - -use rustix::{ - fd::{AsFd, BorrowedFd, OwnedFd}, - io::Errno, - mm::{mmap, munmap, MapFlags, ProtFlags}, - shm::{Mode, ShmOFlags}, -}; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use rustix::fd::AsFd; +use rustix::fd::BorrowedFd; +use rustix::fd::OwnedFd; +use rustix::fs; +use rustix::io; +use rustix::io::Errno; +use rustix::mm::mmap; +use rustix::mm::munmap; +use rustix::mm::MapFlags; +use rustix::mm::ProtFlags; +use rustix::shm; +use rustix::shm::Mode; +use rustix::shm::ShmOFlags; #[derive(Debug)] pub struct Mmap { @@ -22,7 +32,7 @@ impl Mmap { #[inline] #[must_use] pub fn create(len: usize) -> Self { - let fd = create_shm_fd().unwrap(); + let fd = Self::mmap_fd().unwrap(); rustix::io::retry_on_intr(|| rustix::fs::ftruncate(&fd, len as u64)).unwrap(); let ptr = unsafe { @@ -39,6 +49,61 @@ impl Mmap { } } + #[cfg(target_os = "linux")] + fn mmap_fd() -> io::Result { + match Self::memfd() { + Ok(fd) => Ok(fd), + // Not supported, use fallback. + Err(Errno::NOSYS) => Self::shm(), + Err(err) => Err(err), + } + } + + #[cfg(not(target_os = "linux"))] + fn mmap_fd() -> io::Result { + Self::shm() + } + + fn shm() -> io::Result { + let mut filenames = repeat_with(SystemTime::now) + .map(|time| time.duration_since(UNIX_EPOCH).unwrap().subsec_nanos()) + .map(|stamp| format!("/swww-ipc-{stamp}",)); + + let flags = ShmOFlags::CREATE | ShmOFlags::EXCL | ShmOFlags::RDWR; + let mode = Mode::RUSR | Mode::WUSR; + + loop { + let filename = filenames.next().expect("infinite generator"); + match shm::shm_open(filename.as_str(), flags, mode) { + Ok(fd) => return shm::shm_unlink(filename.as_str()).map(|()| fd), + Err(Errno::EXIST | Errno::INTR) => continue, + Err(err) => return Err(err), + } + } + } + + #[cfg(target_os = "linux")] + fn memfd() -> io::Result { + use rustix::fs::MemfdFlags; + use rustix::fs::SealFlags; + use std::ffi::CStr; + + let name = CStr::from_bytes_with_nul(b"swww-ipc\0").unwrap(); + let flags = MemfdFlags::ALLOW_SEALING | MemfdFlags::CLOEXEC; + + loop { + match fs::memfd_create(name, flags) { + Ok(fd) => { + // We only need to seal for the purposes of optimization, ignore the errors. + let _ = fs::fcntl_add_seals(&fd, SealFlags::SHRINK | SealFlags::SEAL); + return Ok(fd); + } + Err(Errno::INTR) => continue, + Err(err) => return Err(err), + } + } + } + #[inline] /// Unmaps without destroying the file descriptor /// @@ -76,32 +141,28 @@ impl Mmap { } #[inline] - pub fn remap(&mut self, new_len: usize) { - rustix::io::retry_on_intr(|| rustix::fs::ftruncate(&self.fd, new_len as u64)).unwrap(); + pub fn remap(&mut self, new: usize) { + io::retry_on_intr(|| fs::ftruncate(&self.fd, new as u64)).unwrap(); #[cfg(target_os = "linux")] { - let result = unsafe { - rustix::mm::mremap( - self.ptr.as_ptr(), - self.len, - new_len, - rustix::mm::MremapFlags::MAYMOVE, - ) - }; + use rustix::mm; + + let result = + unsafe { mm::mremap(self.ptr.as_ptr(), self.len, new, mm::MremapFlags::MAYMOVE) }; if let Ok(ptr) = result { // SAFETY: the mremap above will never return a null pointer if it succeeds let ptr = unsafe { NonNull::new_unchecked(ptr) }; self.ptr = ptr; - self.len = new_len; + self.len = new; return; } } self.unmap(); - self.len = new_len; + self.len = new; self.ptr = unsafe { let ptr = mmap( std::ptr::null_mut(), @@ -177,110 +238,23 @@ impl Drop for Mmap { } } -fn create_shm_fd() -> std::io::Result { - #[cfg(target_os = "linux")] - { - match create_memfd() { - Ok(fd) => return Ok(fd), - // Not supported, use fallback. - Err(Errno::NOSYS) => (), - Err(err) => return Err(err.into()), - }; - } - - let time = std::time::SystemTime::now(); - let mut mem_file_handle = format!( - "/swww-ipc-{}", - time.duration_since(std::time::UNIX_EPOCH) - .unwrap() - .subsec_nanos() - ); - - let flags = ShmOFlags::CREATE | ShmOFlags::EXCL | ShmOFlags::RDWR; - let mode = Mode::RUSR | Mode::WUSR; - loop { - match rustix::shm::shm_open(mem_file_handle.as_str(), flags, mode) { - Ok(fd) => match rustix::shm::shm_unlink(mem_file_handle.as_str()) { - Ok(_) => return Ok(fd), - - Err(errno) => { - return Err(errno.into()); - } - }, - Err(Errno::EXIST) => { - // Change the handle if we happen to be duplicate. - let time = std::time::SystemTime::now(); - - mem_file_handle = format!( - "/swww-ipc-{}", - time.duration_since(std::time::UNIX_EPOCH) - .unwrap() - .subsec_nanos() - ); - - continue; - } - Err(Errno::INTR) => continue, - Err(err) => return Err(err.into()), - } - } -} - -#[cfg(target_os = "linux")] -fn create_memfd() -> rustix::io::Result { - use rustix::fs::{MemfdFlags, SealFlags}; - use std::ffi::CStr; - - let name = CStr::from_bytes_with_nul(b"swww-ipc\0").unwrap(); - let flags = MemfdFlags::ALLOW_SEALING | MemfdFlags::CLOEXEC; - - loop { - match rustix::fs::memfd_create(name, flags) { - Ok(fd) => { - // We only need to seal for the purposes of optimization, ignore the errors. - let _ = rustix::fs::fcntl_add_seals(&fd, SealFlags::SHRINK | SealFlags::SEAL); - return Ok(fd); - } - Err(Errno::INTR) => continue, - Err(err) => return Err(err), - } - } -} - -pub struct MmappedBytes { +pub struct Mmapped { base_ptr: NonNull, ptr: NonNull, len: usize, } -impl MmappedBytes { +pub type MmappedBytes = Mmapped; +pub type MmappedStr = Mmapped; + +impl Mmapped { const PROT: ProtFlags = ProtFlags::READ; const FLAGS: MapFlags = MapFlags::SHARED; pub(crate) fn new(map: &Mmap, bytes: &[u8]) -> Self { let len = u32::from_ne_bytes(bytes[0..4].try_into().unwrap()) as usize; - let offset = 4 + bytes.as_ptr() as usize - map.ptr.as_ptr() as usize; - let page_size = rustix::param::page_size(); - let page_offset = offset - offset % page_size; - - let base_ptr = unsafe { - let ptr = mmap( - std::ptr::null_mut(), - len + (offset - page_offset), - Self::PROT, - Self::FLAGS, - &map.fd, - page_offset as u64, - ) - .unwrap(); - // SAFETY: the function above will never return a null pointer if it succeeds - // POSIX says that the implementation will never select an address at 0 - NonNull::new_unchecked(ptr) - }; - let ptr = - unsafe { NonNull::new_unchecked(base_ptr.as_ptr().byte_add(offset - page_offset)) }; - - Self { base_ptr, ptr, len } + let bytes = &bytes[4..]; + Self::new_with_len(map, bytes, len) } pub(crate) fn new_with_len(map: &Mmap, bytes: &[u8], len: usize) -> Self { @@ -305,6 +279,12 @@ impl MmappedBytes { let ptr = unsafe { NonNull::new_unchecked(base_ptr.as_ptr().byte_add(offset - page_offset)) }; + if UTF8 { + // try to parse, panicking if we fail + let s = unsafe { std::slice::from_raw_parts(ptr.as_ptr().cast(), len) }; + let _s = std::str::from_utf8(s).expect("received a non utf8 string from socket"); + } + Self { base_ptr, ptr, len } } @@ -315,59 +295,7 @@ impl MmappedBytes { } } -impl Drop for MmappedBytes { - #[inline] - fn drop(&mut self) { - let len = self.len + self.ptr.as_ptr() as usize - self.base_ptr.as_ptr() as usize; - if let Err(e) = unsafe { munmap(self.base_ptr.as_ptr(), len) } { - eprintln!("ERROR WHEN UNMAPPING MEMORY: {e}"); - } - } -} - -unsafe impl Send for MmappedBytes {} -unsafe impl Sync for MmappedBytes {} - -pub struct MmappedStr { - base_ptr: NonNull, - ptr: NonNull, - len: usize, -} - impl MmappedStr { - const PROT: ProtFlags = ProtFlags::READ; - const FLAGS: MapFlags = MapFlags::SHARED; - - pub(crate) fn new(map: &Mmap, bytes: &[u8]) -> Self { - let len = u32::from_ne_bytes(bytes[0..4].try_into().unwrap()) as usize; - let offset = 4 + bytes.as_ptr() as usize - map.ptr.as_ptr() as usize; - let page_size = rustix::param::page_size(); - let page_offset = offset - offset % page_size; - - let base_ptr = unsafe { - let ptr = mmap( - std::ptr::null_mut(), - len + (offset - page_offset), - Self::PROT, - Self::FLAGS, - &map.fd, - page_offset as u64, - ) - .unwrap(); - // SAFETY: the function above will never return a null pointer if it succeeds - // POSIX says that the implementation will never select an address at 0 - NonNull::new_unchecked(ptr) - }; - let ptr = - unsafe { NonNull::new_unchecked(base_ptr.as_ptr().byte_add(offset - page_offset)) }; - - // try to parse, panicking if we fail - let s = unsafe { std::slice::from_raw_parts(ptr.as_ptr().cast(), len) }; - let _s = std::str::from_utf8(s).expect("received a non utf8 string from socket"); - - Self { base_ptr, ptr, len } - } - #[inline] #[must_use] pub fn str(&self) -> &str { @@ -376,10 +304,7 @@ impl MmappedStr { } } -unsafe impl Send for MmappedStr {} -unsafe impl Sync for MmappedStr {} - -impl Drop for MmappedStr { +impl Drop for Mmapped { #[inline] fn drop(&mut self) { let len = self.len + self.ptr.as_ptr() as usize - self.base_ptr.as_ptr() as usize; @@ -388,3 +313,6 @@ impl Drop for MmappedStr { } } } + +unsafe impl Send for Mmapped {} +unsafe impl Sync for Mmapped {} diff --git a/daemon/src/main.rs b/daemon/src/main.rs index 47f1a253..91575b75 100644 --- a/daemon/src/main.rs +++ b/daemon/src/main.rs @@ -23,17 +23,15 @@ use std::{ fs, io::{IsTerminal, Write}, num::{NonZeroI32, NonZeroU32}, - path::PathBuf, + path::Path, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, }; -use common::ipc::{ - connect_to_socket, get_socket_path, read_socket, Answer, BgInfo, ImageReq, MmappedStr, - RequestRecv, RequestSend, Scale, -}; +use common::ipc::{Answer, BgInfo, ImageReq, IpcSocket, RequestRecv, RequestSend, Scale, Server}; +use common::mmap::MmappedStr; use animations::Animator; @@ -124,8 +122,8 @@ impl Daemon { ))); } - fn recv_socket_msg(&mut self, stream: OwnedFd) { - let bytes = match common::ipc::read_socket(&stream) { + fn recv_socket_msg(&mut self, stream: IpcSocket) { + let bytes = match stream.recv() { Ok(bytes) => bytes, Err(e) => { error!("FATAL: cannot read socket: {e}. Exiting..."); @@ -470,7 +468,8 @@ fn main() -> Result<(), String> { if !fds[1].revents().is_empty() { match rustix::net::accept(&listener.0) { - Ok(stream) => daemon.recv_socket_msg(stream), + // TODO: abstract away explicit socket creation + Ok(stream) => daemon.recv_socket_msg(IpcSocket::new(stream)), Err(rustix::io::Errno::INTR | rustix::io::Errno::WOULDBLOCK) => continue, Err(e) => return Err(format!("failed to accept incoming connection: {e}")), } @@ -525,25 +524,26 @@ fn setup_signals() { struct SocketWrapper(OwnedFd); impl SocketWrapper { fn new() -> Result { - let socket_addr = get_socket_path(); + let addr = IpcSocket::::path(); + let addr = Path::new(addr); - if socket_addr.exists() { - if is_daemon_running(&socket_addr)? { + if addr.exists() { + if is_daemon_running()? { return Err( "There is an swww-daemon instance already running on this socket!".to_string(), ); } else { warn!( "socket file {} was not deleted when the previous daemon exited", - socket_addr.to_string_lossy() + addr.to_string_lossy() ); - if let Err(e) = std::fs::remove_file(&socket_addr) { + if let Err(e) = std::fs::remove_file(addr) { return Err(format!("failed to delete previous socket: {e}")); } } } - let runtime_dir = match socket_addr.parent() { + let runtime_dir = match addr.parent() { Some(path) => path, None => return Err("couldn't find a valid runtime directory".to_owned()), }; @@ -555,34 +555,20 @@ impl SocketWrapper { } } - let socket = rustix::net::socket_with( - rustix::net::AddressFamily::UNIX, - rustix::net::SocketType::STREAM, - rustix::net::SocketFlags::CLOEXEC.union(rustix::net::SocketFlags::NONBLOCK), - None, - ) - .expect("failed to create socket file descriptor"); - - rustix::net::bind_unix( - &socket, - &rustix::net::SocketAddrUnix::new(&socket_addr).unwrap(), - ) - .unwrap(); - - rustix::net::listen(&socket, 0).unwrap(); + let socket = IpcSocket::server().map_err(|err| err.to_string())?; - debug!("Created socket in {:?}", socket_addr); - Ok(Self(socket)) + debug!("Created socket in {:?}", addr); + Ok(Self(socket.to_fd())) } } impl Drop for SocketWrapper { fn drop(&mut self) { - let socket_addr = get_socket_path(); - if let Err(e) = fs::remove_file(&socket_addr) { - error!("Failed to remove socket at {socket_addr:?}: {e}"); + let addr = IpcSocket::::path(); + if let Err(e) = fs::remove_file(Path::new(addr)) { + error!("Failed to remove socket at {addr}: {e}"); } - info!("Removed socket at {:?}", socket_addr); + info!("Removed socket at {addr}"); } } @@ -648,8 +634,8 @@ fn make_logger(quiet: bool) { .unwrap(); } -pub fn is_daemon_running(addr: &PathBuf) -> Result { - let sock = match connect_to_socket(addr, 5, 100) { +pub fn is_daemon_running() -> Result { + let sock = match IpcSocket::connect() { Ok(s) => s, // likely a connection refused; either way, this is a reliable signal there's no surviving // daemon. @@ -657,7 +643,7 @@ pub fn is_daemon_running(addr: &PathBuf) -> Result { }; RequestSend::Ping.send(&sock)?; - let answer = Answer::receive(read_socket(&sock)?); + let answer = Answer::receive(sock.recv().map_err(|err| err.to_string())?); match answer { Answer::Ping(_) => Ok(true), _ => Err("Daemon did not return Answer::Ping, as expected".to_string()), diff --git a/daemon/src/wayland/bump_pool.rs b/daemon/src/wayland/bump_pool.rs index 14a29563..a38f99af 100644 --- a/daemon/src/wayland/bump_pool.rs +++ b/daemon/src/wayland/bump_pool.rs @@ -1,6 +1,6 @@ use std::sync::atomic::{AtomicBool, Ordering}; -use common::ipc::Mmap; +use common::mmap::Mmap; use super::{globals, ObjectId};