diff --git a/Cargo.toml b/Cargo.toml index 2c785ae..4df053b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,9 @@ default-members = ["lighthouse-client"] resolver = "2" [workspace.package] -version = "3.4.0" +version = "3.4.1" edition = "2021" license = "MIT" [workspace.dependencies] -lighthouse-protocol = { version = "^3.4.0", path = "lighthouse-protocol" } +lighthouse-protocol = { version = "^3.4.1", path = "lighthouse-protocol" } diff --git a/lighthouse-client/examples/admin_crud.rs b/lighthouse-client/examples/admin_crud.rs index 6831dfc..4b5928d 100644 --- a/lighthouse-client/examples/admin_crud.rs +++ b/lighthouse-client/examples/admin_crud.rs @@ -2,7 +2,7 @@ use clap::Parser; use lighthouse_client::{protocol::Authentication, Error, Lighthouse, Result, TokioWebSocket, LIGHTHOUSE_URL}; use tracing::{info, info_span, Instrument}; -async fn run(mut lh: Lighthouse) -> Result<()> { +async fn run(lh: Lighthouse) -> Result<()> { info!("Connected to the Lighthouse server"); async { diff --git a/lighthouse-client/examples/admin_get_metrics.rs b/lighthouse-client/examples/admin_get_metrics.rs index bba50ff..e3acda6 100644 --- a/lighthouse-client/examples/admin_get_metrics.rs +++ b/lighthouse-client/examples/admin_get_metrics.rs @@ -2,7 +2,7 @@ use clap::Parser; use lighthouse_client::{protocol::Authentication, Lighthouse, Result, TokioWebSocket, LIGHTHOUSE_URL}; use tracing::info; -async fn run(mut lh: Lighthouse) -> Result<()> { +async fn run(lh: Lighthouse) -> Result<()> { info!("Connected to the Lighthouse server"); let metrics = lh.get_laser_metrics().await?.payload; diff --git a/lighthouse-client/examples/admin_list_root.rs b/lighthouse-client/examples/admin_list_root.rs index 6031cc8..42c4ee7 100644 --- a/lighthouse-client/examples/admin_list_root.rs +++ b/lighthouse-client/examples/admin_list_root.rs @@ -1,11 +1,11 @@ use clap::Parser; -use lighthouse_client::{protocol::Authentication, Lighthouse, Result, TokioWebSocket, LIGHTHOUSE_URL}; +use lighthouse_client::{protocol::Authentication, root, Lighthouse, Result, TokioWebSocket, LIGHTHOUSE_URL}; use tracing::info; -async fn run(mut lh: Lighthouse) -> Result<()> { +async fn run(lh: Lighthouse) -> Result<()> { info!("Connected to the Lighthouse server"); - let tree = lh.list(&[]).await?.payload; + let tree = lh.list(root![]).await?.payload; info!("Got {}", tree); Ok(()) diff --git a/lighthouse-client/examples/black.rs b/lighthouse-client/examples/black.rs index 8c9cf0c..fd16b21 100644 --- a/lighthouse-client/examples/black.rs +++ b/lighthouse-client/examples/black.rs @@ -2,7 +2,7 @@ use clap::Parser; use lighthouse_client::{protocol::{Authentication, Color, Frame}, Lighthouse, Result, TokioWebSocket, LIGHTHOUSE_URL}; use tracing::info; -async fn run(mut lh: Lighthouse) -> Result<()> { +async fn run(lh: Lighthouse) -> Result<()> { info!("Connected to the Lighthouse server"); lh.put_model(Frame::fill(Color::BLACK)).await?; diff --git a/lighthouse-client/examples/disco.rs b/lighthouse-client/examples/disco.rs index 5872b62..eb01e76 100644 --- a/lighthouse-client/examples/disco.rs +++ b/lighthouse-client/examples/disco.rs @@ -4,7 +4,7 @@ use tracing::info; use tokio::time; use std::time::Duration; -async fn run(mut lh: Lighthouse) -> Result<()> { +async fn run(lh: Lighthouse) -> Result<()> { info!("Connected to the Lighthouse server"); loop { diff --git a/lighthouse-client/examples/input_events.rs b/lighthouse-client/examples/input_events.rs index 3425709..52720a8 100644 --- a/lighthouse-client/examples/input_events.rs +++ b/lighthouse-client/examples/input_events.rs @@ -4,7 +4,7 @@ use lighthouse_client::{protocol::Authentication, Lighthouse, Result, TokioWebSo use lighthouse_protocol::Model; use tracing::info; -async fn run(mut lh: Lighthouse) -> Result<()> { +async fn run(lh: Lighthouse) -> Result<()> { info!("Connected to the Lighthouse server"); // Stream input events diff --git a/lighthouse-client/examples/snake.rs b/lighthouse-client/examples/snake.rs index 08078a2..81b3d7f 100644 --- a/lighthouse-client/examples/snake.rs +++ b/lighthouse-client/examples/snake.rs @@ -124,7 +124,7 @@ impl State { } } -async fn run_updater(mut lh: Lighthouse, shared_state: Arc>) -> Result<()> { +async fn run_updater(lh: Lighthouse, shared_state: Arc>) -> Result<()> { loop { // Update the snake and render it let frame = { @@ -190,7 +190,7 @@ async fn main() -> Result<()> { let auth = Authentication::new(&args.username, &args.token); let state = Arc::new(Mutex::new(State::new())); - let mut lh = Lighthouse::connect_with_tokio_to(&args.url, auth).await?; + let lh = Lighthouse::connect_with_tokio_to(&args.url, auth).await?; info!("Connected to the Lighthouse server"); let stream = lh.stream_model().await?; diff --git a/lighthouse-client/examples/stress_test.rs b/lighthouse-client/examples/stress_test.rs index df545d8..a5f95d5 100644 --- a/lighthouse-client/examples/stress_test.rs +++ b/lighthouse-client/examples/stress_test.rs @@ -5,7 +5,7 @@ use lighthouse_client::{protocol::{Authentication, Frame}, Lighthouse, Result, T use tokio::time::{self, Instant}; use tracing::info; -async fn run(mut lh: Lighthouse, delay_ms: Option) -> Result<()> { +async fn run(lh: Lighthouse, delay_ms: Option) -> Result<()> { info!("Connected to the Lighthouse server"); let mut last_second = Instant::now(); diff --git a/lighthouse-client/src/lib.rs b/lighthouse-client/src/lib.rs index 3fc6775..f2c225c 100644 --- a/lighthouse-client/src/lib.rs +++ b/lighthouse-client/src/lib.rs @@ -13,3 +13,11 @@ pub use lighthouse::*; pub use spawn::*; pub use lighthouse_protocol as protocol; + +/// Small convenience macro that expresses the root path. +#[macro_export] +macro_rules! root { + () => { + &[] as &[&str] + }; +} diff --git a/lighthouse-client/src/lighthouse.rs b/lighthouse-client/src/lighthouse.rs index d39241c..95bf3f1 100644 --- a/lighthouse-client/src/lighthouse.rs +++ b/lighthouse-client/src/lighthouse.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc, fmt::Debug}; +use std::{collections::HashMap, fmt::Debug, sync::{atomic::{AtomicI32, Ordering}, Arc}}; use async_tungstenite::tungstenite::{Message, self}; use futures::{prelude::*, channel::mpsc::{Sender, self}, stream::{SplitSink, SplitStream}, lock::Mutex}; @@ -11,13 +11,13 @@ use crate::{Check, Error, Result, Spawner}; /// A connection to the lighthouse server for sending requests and receiving events. pub struct Lighthouse { /// The sink-part of the WebSocket connection. - ws_sink: SplitSink, + ws_sink: Arc>>, /// The response/event slots, keyed by request id. slots: Arc>>>>, /// The credentials used to authenticate with the lighthouse. authentication: Authentication, /// The next request id. Incremented on every request. - request_id: i32, + request_id: Arc, } /// A facility for coordinating asynchronous responses to a request between a @@ -47,10 +47,10 @@ impl Lighthouse let (ws_sink, ws_stream) = web_socket.split(); let slots = Arc::new(Mutex::new(HashMap::new())); let lh = Self { - ws_sink, + ws_sink: Arc::new(Mutex::new(ws_sink)), slots: slots.clone(), authentication, - request_id: 0, + request_id: Arc::new(AtomicI32::new(0)), }; W::spawn(Self::run_receive_loop(ws_stream, slots)); Ok(lh) @@ -119,110 +119,131 @@ impl Lighthouse } /// Replaces the user's lighthouse model with the given frame. - pub async fn put_model(&mut self, frame: Frame) -> Result> { + pub async fn put_model(&self, frame: Frame) -> Result> { let username = self.authentication.username.clone(); - self.put(&["user", username.as_str(), "model"], Model::Frame(frame)).await + self.put(&["user".into(), username, "model".into()], Model::Frame(frame)).await } /// Requests a stream of events (including key/controller events) for the user's lighthouse model. - pub async fn stream_model(&mut self) -> Result>>> { + pub async fn stream_model(&self) -> Result>>> { let username = self.authentication.username.clone(); - self.stream(&["user", username.as_str(), "model"], ()).await + self.stream(&["user".into(), username, "model".into()], ()).await } /// Fetches lamp server metrics. - pub async fn get_laser_metrics(&mut self) -> Result> { + pub async fn get_laser_metrics(&self) -> Result> { self.get(&["metrics", "laser"]).await } /// Combines PUT and CREATE. Requires CREATE and WRITE permission. - pub async fn post

(&mut self, path: &[&str], payload: P) -> Result> + pub async fn post

(&self, path: &[impl AsRef + Debug], payload: P) -> Result> where P: Serialize { self.perform(&Verb::Post, path, payload).await } /// Updates the resource at the given path with the given payload. Requires WRITE permission. - pub async fn put

(&mut self, path: &[&str], payload: P) -> Result> + pub async fn put

(&self, path: &[impl AsRef + Debug], payload: P) -> Result> where P: Serialize { self.perform(&Verb::Put, path, payload).await } /// Creates a resource at the given path. Requires CREATE permission. - pub async fn create(&mut self, path: &[&str]) -> Result> { + pub async fn create(&self, path: &[impl AsRef + Debug]) -> Result> { self.perform(&Verb::Create, path, ()).await } /// Deletes a resource at the given path. Requires DELETE permission. - pub async fn delete(&mut self, path: &[&str]) -> Result> { + pub async fn delete(&self, path: &[impl AsRef + Debug]) -> Result> { self.perform(&Verb::Delete, path, ()).await } /// Creates a directory at the given path. Requires CREATE permission. - pub async fn mkdir(&mut self, path: &[&str]) -> Result> { + pub async fn mkdir(&self, path: &[impl AsRef + Debug]) -> Result> { self.perform(&Verb::Mkdir, path, ()).await } /// Lists the directory tree at the given path. Requires READ permission. - pub async fn list(&mut self, path: &[&str]) -> Result> { + pub async fn list(&self, path: &[impl AsRef + Debug]) -> Result> { self.perform(&Verb::List, path, ()).await } /// Gets the resource at the given path. Requires READ permission. - pub async fn get(&mut self, path: &[&str]) -> Result> + pub async fn get(&self, path: &[impl AsRef + Debug]) -> Result> where R: for<'de> Deserialize<'de> { self.perform(&Verb::Get, path, ()).await } /// Links the given source to the given destination path. - pub async fn link(&mut self, src_path: &[&str], dest_path: &[&str]) -> Result> { - self.perform(&Verb::Link, dest_path, src_path).await + pub async fn link(&self, src_path: &[impl AsRef + Debug], dest_path: &[impl AsRef + Debug]) -> Result> { + self.perform(&Verb::Link, dest_path, src_path.iter().map(|s| s.as_ref().to_owned()).collect::>()).await } /// Unlinks the given source from the given destination path. - pub async fn unlink(&mut self, src_path: &[&str], dest_path: &[&str]) -> Result> { - self.perform(&Verb::Unlink, dest_path, src_path).await + pub async fn unlink(&self, src_path: &[impl AsRef + Debug], dest_path: &[impl AsRef + Debug]) -> Result> { + self.perform(&Verb::Unlink, dest_path, src_path.iter().map(|s| s.as_ref().to_owned()).collect::>()).await } - /// Stops the given stream. - pub async fn stop(&mut self, path: &[&str]) -> Result> { - self.perform(&Verb::Stop, path, ()).await + /// Stops the given stream. **Should generally not be called manually**, + /// since streams will automatically be stopped once dropped. + pub async fn stop(&self, request_id: i32, path: &[impl AsRef + Debug]) -> Result> { + self.perform_with_id(request_id, &Verb::Stop, path, ()).await } /// Performs a single request to the given path with the given payload. #[tracing::instrument(skip(self, payload))] - pub async fn perform(&mut self, verb: &Verb, path: &[&str], payload: P) -> Result> + pub async fn perform(&self, verb: &Verb, path: &[impl AsRef + Debug], payload: P) -> Result> + where + P: Serialize, + R: for<'de> Deserialize<'de> { + let request_id = self.next_request_id(); + self.perform_with_id(request_id, verb, path, payload).await + } + + /// Performs a single request to the given path with the given request id. + #[tracing::instrument(skip(self, payload))] + async fn perform_with_id(&self, request_id: i32, verb: &Verb, path: &[impl AsRef + Debug], payload: P) -> Result> where P: Serialize, R: for<'de> Deserialize<'de> { assert_ne!(verb, &Verb::Stream, "Lighthouse::perform may only be used for one-off requests, use Lighthouse::stream for streaming."); - let request_id = self.send_request(verb, path, payload).await?; + self.send_request(request_id, verb, path, payload).await?; let response = self.receive_single(request_id).await?.check()?.decode_payload()?; Ok(response) } /// Performs a STREAM request to the given path with the given payload. + /// Automatically sends a STOP once dropped. #[tracing::instrument(skip(self, payload))] - pub async fn stream(&mut self, path: &[&str], payload: P) -> Result>>> + pub async fn stream(&self, path: &[impl AsRef + Debug], payload: P) -> Result>>> where P: Serialize, R: for<'de> Deserialize<'de> { - let request_id = self.send_request(&Verb::Stream, path, payload).await?; + let request_id = self.next_request_id(); + let path: Vec = path.into_iter().map(|s| s.as_ref().to_string()).collect(); + self.send_request(request_id, &Verb::Stream, &path, payload).await?; let stream = self.receive_streaming(request_id).await?; - // TODO: Send STOP once dropped - Ok(stream) + Ok(stream.guard({ + // Stop the stream on drop + let this = (*self).clone(); + move || { + tokio::spawn(async move { + if let Err(error) = this.stop(request_id, &path).await { + error! { ?path, %error, "Could not STOP stream" }; + } + }); + } + })) } /// Sends a request to the given path with the given payload. - async fn send_request

(&mut self, verb: &Verb, path: &[&str], payload: P) -> Result + async fn send_request

(&self, request_id: i32, verb: &Verb, path: &[impl AsRef + Debug], payload: P) -> Result where P: Serialize { - let path = path.into_iter().map(|s| s.to_string()).collect(); - let request_id = self.request_id; + let path = path.into_iter().map(|s| s.as_ref().to_string()).collect(); debug! { %request_id, "Sending request" }; - self.request_id += 1; self.send_message(&ClientMessage { request_id, authentication: self.authentication.clone(), @@ -235,7 +256,7 @@ impl Lighthouse } /// Sends a generic message to the lighthouse. - async fn send_message

(&mut self, message: &ClientMessage

) -> Result<()> + async fn send_message

(&self, message: &ClientMessage

) -> Result<()> where P: Serialize { self.send_raw(rmp_serde::to_vec_named(message)?).await @@ -291,8 +312,13 @@ impl Lighthouse } /// Sends raw bytes to the lighthouse via the WebSocket connection. - async fn send_raw(&mut self, bytes: impl Into> + Debug) -> Result<()> { - Ok(self.ws_sink.send(Message::Binary(bytes.into())).await?) + async fn send_raw(&self, bytes: impl Into> + Debug) -> Result<()> { + Ok(self.ws_sink.lock().await.send(Message::Binary(bytes.into())).await?) + } + + /// Fetches the next request id. + fn next_request_id(&self) -> i32 { + self.request_id.fetch_add(1, Ordering::Relaxed) } /// Fetches the credentials used to authenticate with the lighthouse. @@ -303,7 +329,22 @@ impl Lighthouse /// Closes the WebSocket connection gracefully with a close message. While /// the server will usually also handle abruptly closed connections /// properly, it is recommended to always close the [``Lighthouse``]. - pub async fn close(&mut self) -> Result<()> { - Ok(self.ws_sink.close().await?) + pub async fn close(&self) -> Result<()> { + Ok(self.ws_sink.lock().await.close().await?) + } +} + +// For some reason `#[derive(Clone)]` adds the trait bound `S: Clone`, despite +// not actually being needed since the WebSocket sink is already wrapped in an +// `Arc`, therefore we implement `Clone` manually. + +impl Clone for Lighthouse { + fn clone(&self) -> Self { + Self { + ws_sink: self.ws_sink.clone(), + slots: self.slots.clone(), + authentication: self.authentication.clone(), + request_id: self.request_id.clone(), + } } }