Skip to content

Commit

Permalink
fix(cli,tcp): replies are sent on the requesting channel
Browse files Browse the repository at this point in the history
Replace the client socket with replies sent on the other side of the
querying stream, for both UDS and TCP clients. This has two results:
unix socket clients such as komorebic no longer race on the socket bind,
fixing the out of order bind error, and the response mixup conditions
that could occur. Queries over TCP now receive replies over TCP, rather
than replies being sent to a future or in-flight command line client.
  • Loading branch information
raggi authored and LGUG2Z committed Feb 17, 2024
1 parent afd93c3 commit c8f6502
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 120 deletions.
112 changes: 33 additions & 79 deletions komorebi/src/process_command.rs
Expand Up @@ -4,7 +4,6 @@ use std::fs::OpenOptions;
use std::io::BufRead;
use std::io::BufReader;
use std::io::Read;
use std::io::Write;
use std::net::TcpListener;
use std::net::TcpStream;
use std::num::NonZeroUsize;
Expand Down Expand Up @@ -60,7 +59,6 @@ use crate::BORDER_OFFSET;
use crate::BORDER_OVERFLOW_IDENTIFIERS;
use crate::BORDER_WIDTH;
use crate::CUSTOM_FFM;
use crate::DATA_DIR;
use crate::DISPLAY_INDEX_PREFERENCES;
use crate::FLOAT_IDENTIFIERS;
use crate::HIDING_BEHAVIOUR;
Expand Down Expand Up @@ -144,8 +142,15 @@ pub fn listen_for_commands_tcp(wm: Arc<Mutex<WindowManager>>, port: usize) {
}

impl WindowManager {
#[tracing::instrument(skip(self))]
pub fn process_command(&mut self, message: SocketMessage) -> Result<()> {
// TODO(raggi): wrap reply in a newtype that can decorate a human friendly
// name for the peer, such as getting the pid of the komorebic process for
// the UDS or the IP:port for TCP.
#[tracing::instrument(skip(self, reply))]
pub fn process_command(
&mut self,
message: SocketMessage,
mut reply: impl std::io::Write,
) -> Result<()> {
if let Some(virtual_desktop_id) = &self.virtual_desktop_id {
if let Some(id) = current_virtual_desktop() {
if id != *virtual_desktop_id {
Expand Down Expand Up @@ -743,15 +748,11 @@ impl WindowManager {
Err(error) => error.to_string(),
};

let socket = DATA_DIR.join("komorebic.sock");
tracing::info!("replying to state");

let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(state.as_bytes())?;
}
}
reply.write_all(state.as_bytes())?;

tracing::info!("replying to state done");
}
SocketMessage::VisibleWindows => {
let mut monitor_visible_windows = HashMap::new();
Expand All @@ -774,15 +775,7 @@ impl WindowManager {
Err(error) => error.to_string(),
};

let socket = DATA_DIR.join("komorebic.sock");

let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(visible_windows_state.as_bytes())?;
}
}
reply.write_all(visible_windows_state.as_bytes())?;
}

SocketMessage::Query(query) => {
Expand All @@ -801,15 +794,7 @@ impl WindowManager {
}
.to_string();

let socket = DATA_DIR.join("komorebic.sock");

let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(response.as_bytes())?;
}
}
reply.write_all(response.as_bytes())?;
}
SocketMessage::ResizeWindowEdge(direction, sizing) => {
self.resize_window(direction, sizing, self.resize_delta, true)?;
Expand Down Expand Up @@ -1275,41 +1260,20 @@ impl WindowManager {
SocketMessage::ApplicationSpecificConfigurationSchema => {
let asc = schema_for!(Vec<ApplicationConfiguration>);
let schema = serde_json::to_string_pretty(&asc)?;
let socket = DATA_DIR.join("komorebic.sock");

let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(schema.as_bytes())?;
}
}
reply.write_all(schema.as_bytes())?;
}
SocketMessage::NotificationSchema => {
let notification = schema_for!(Notification);
let schema = serde_json::to_string_pretty(&notification)?;
let socket = DATA_DIR.join("komorebic.sock");

let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(schema.as_bytes())?;
}
}
reply.write_all(schema.as_bytes())?;
}
SocketMessage::SocketSchema => {
let socket_message = schema_for!(SocketMessage);
let schema = serde_json::to_string_pretty(&socket_message)?;
let socket = DATA_DIR.join("komorebic.sock");

let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(schema.as_bytes())?;
}
}
reply.write_all(schema.as_bytes())?;
}
SocketMessage::StaticConfigSchema => {
let settings = SchemaSettings::default().with(|s| {
Expand All @@ -1321,27 +1285,13 @@ impl WindowManager {
let gen = settings.into_generator();
let socket_message = gen.into_root_schema_for::<StaticConfig>();
let schema = serde_json::to_string_pretty(&socket_message)?;
let socket = DATA_DIR.join("komorebic.sock");

let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(schema.as_bytes())?;
}
}
reply.write_all(schema.as_bytes())?;
}
SocketMessage::GenerateStaticConfig => {
let config = serde_json::to_string_pretty(&StaticConfig::from(&*self))?;
let socket = DATA_DIR.join("komorebic.sock");

let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(config.as_bytes())?;
}
}
reply.write_all(config.as_bytes())?;
}
SocketMessage::RemoveTitleBar(_, ref id) => {
let mut identifiers = NO_TITLEBAR.lock();
Expand Down Expand Up @@ -1526,17 +1476,21 @@ impl WindowManager {
}
}

pub fn read_commands_uds(wm: &Arc<Mutex<WindowManager>>, stream: UnixStream) -> Result<()> {
let stream = BufReader::new(stream);
for line in stream.lines() {
pub fn read_commands_uds(wm: &Arc<Mutex<WindowManager>>, mut stream: UnixStream) -> Result<()> {
let reader = BufReader::new(stream.try_clone()?);
// TODO(raggi): while this processes more than one command, if there are
// replies there is no clearly defined protocol for framing yet - it's
// perhaps whole-json objects for now, but termination is signalled by
// socket shutdown.
for line in reader.lines() {
let message = SocketMessage::from_str(&line?)?;

let mut wm = wm.lock();

if wm.is_paused {
return match message {
SocketMessage::TogglePause | SocketMessage::State | SocketMessage::Stop => {
Ok(wm.process_command(message)?)
Ok(wm.process_command(message, &mut stream)?)
}
_ => {
tracing::trace!("ignoring while paused");
Expand All @@ -1545,7 +1499,7 @@ pub fn read_commands_uds(wm: &Arc<Mutex<WindowManager>>, stream: UnixStream) ->
};
}

wm.process_command(message.clone())?;
wm.process_command(message.clone(), &mut stream)?;
notify_subscribers(&serde_json::to_string(&Notification {
event: NotificationEvent::Socket(message.clone()),
state: wm.as_ref().into(),
Expand All @@ -1560,11 +1514,11 @@ pub fn read_commands_tcp(
stream: &mut TcpStream,
addr: &str,
) -> Result<()> {
let mut stream = BufReader::new(stream);
let mut reader = BufReader::new(stream.try_clone()?);

loop {
let mut buf = vec![0; 1024];
match stream.read(&mut buf) {
match reader.read(&mut buf) {
Err(..) => {
tracing::warn!("removing disconnected tcp client: {addr}");
let mut connections = TCP_CONNECTIONS.lock();
Expand All @@ -1585,7 +1539,7 @@ pub fn read_commands_tcp(
if wm.is_paused {
return match message {
SocketMessage::TogglePause | SocketMessage::State | SocketMessage::Stop => {
Ok(wm.process_command(message)?)
Ok(wm.process_command(message, stream)?)
}
_ => {
tracing::trace!("ignoring while paused");
Expand All @@ -1594,7 +1548,7 @@ pub fn read_commands_tcp(
};
}

wm.process_command(message.clone())?;
wm.process_command(message.clone(), &mut *stream)?;
notify_subscribers(&serde_json::to_string(&Notification {
event: NotificationEvent::Socket(message.clone()),
state: wm.as_ref().into(),
Expand Down
67 changes: 26 additions & 41 deletions komorebic/src/main.rs
Expand Up @@ -5,8 +5,9 @@ use std::fs::File;
use std::fs::OpenOptions;
use std::io::BufRead;
use std::io::BufReader;
use std::io::ErrorKind;
use std::io::Read;
use std::io::Write;
use std::net::Shutdown;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
Expand All @@ -30,7 +31,6 @@ use miette::Report;
use miette::SourceOffset;
use miette::SourceSpan;
use paste::paste;
use uds_windows::UnixListener;
use uds_windows::UnixStream;
use which::which;
use windows::Win32::Foundation::HWND;
Expand Down Expand Up @@ -1172,35 +1172,26 @@ pub fn send_message(bytes: &[u8]) -> Result<()> {
Ok(())
}

fn with_komorebic_socket<F: Fn() -> Result<()>>(f: F) -> Result<()> {
let socket = DATA_DIR.join("komorebic.sock");
pub fn send_query(bytes: &[u8]) -> Result<String> {
let socket = DATA_DIR.join("komorebi.sock");

match std::fs::remove_file(&socket) {
Ok(()) => {}
Err(error) => match error.kind() {
// Doing this because ::exists() doesn't work reliably on Windows via IntelliJ
ErrorKind::NotFound => {}
_ => {
return Err(error.into());
}
},
};
let mut stream = UnixStream::connect(&socket)?;
stream.write_all(bytes)?;
stream.shutdown(Shutdown::Write)?;

f()?;
let mut reader = BufReader::new(stream);
let mut response = String::new();
reader.read_to_string(&mut response)?;

let listener = UnixListener::bind(socket)?;
match listener.accept() {
Ok(incoming) => {
let stream = BufReader::new(incoming.0);
for line in stream.lines() {
println!("{}", line?);
}
Ok(response)
}

Ok(())
}
Err(error) => {
panic!("{}", error);
}
// print_query is a helper that queries komorebi and prints the response.
// panics on error.
pub fn print_query(bytes: &[u8]) {
match send_query(bytes) {
Ok(response) => println!("{}", response),
Err(error) => panic!("{}", error),
}
}

Expand Down Expand Up @@ -2000,15 +1991,13 @@ Stop-Process -Name:whkd -ErrorAction SilentlyContinue
)?;
}
SubCommand::State => {
with_komorebic_socket(|| send_message(&SocketMessage::State.as_bytes()?))?;
print_query(&SocketMessage::State.as_bytes()?);
}
SubCommand::VisibleWindows => {
with_komorebic_socket(|| send_message(&SocketMessage::VisibleWindows.as_bytes()?))?;
print_query(&SocketMessage::VisibleWindows.as_bytes()?);
}
SubCommand::Query(arg) => {
with_komorebic_socket(|| {
send_message(&SocketMessage::Query(arg.state_query).as_bytes()?)
})?;
print_query(&SocketMessage::Query(arg.state_query).as_bytes()?);
}
SubCommand::RestoreWindows => {
let hwnd_json = DATA_DIR.join("komorebi.hwnd.json");
Expand Down Expand Up @@ -2239,23 +2228,19 @@ Stop-Process -Name:whkd -ErrorAction SilentlyContinue
);
}
SubCommand::ApplicationSpecificConfigurationSchema => {
with_komorebic_socket(|| {
send_message(&SocketMessage::ApplicationSpecificConfigurationSchema.as_bytes()?)
})?;
print_query(&SocketMessage::ApplicationSpecificConfigurationSchema.as_bytes()?);
}
SubCommand::NotificationSchema => {
with_komorebic_socket(|| send_message(&SocketMessage::NotificationSchema.as_bytes()?))?;
print_query(&SocketMessage::NotificationSchema.as_bytes()?);
}
SubCommand::SocketSchema => {
with_komorebic_socket(|| send_message(&SocketMessage::SocketSchema.as_bytes()?))?;
print_query(&SocketMessage::SocketSchema.as_bytes()?);
}
SubCommand::StaticConfigSchema => {
with_komorebic_socket(|| send_message(&SocketMessage::StaticConfigSchema.as_bytes()?))?;
print_query(&SocketMessage::StaticConfigSchema.as_bytes()?);
}
SubCommand::GenerateStaticConfig => {
with_komorebic_socket(|| {
send_message(&SocketMessage::GenerateStaticConfig.as_bytes()?)
})?;
print_query(&SocketMessage::GenerateStaticConfig.as_bytes()?);
}
}

Expand Down

0 comments on commit c8f6502

Please sign in to comment.