diff --git a/Cargo.lock b/Cargo.lock index d244c7c..23db3b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,6 +98,12 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "bytes" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" + [[package]] name = "cc" version = "1.0.90" @@ -327,6 +333,26 @@ dependencies = [ "memchr", ] +[[package]] +name = "pin-project" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -375,6 +401,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "socket2" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "strsim" version = "0.11.1" @@ -399,9 +435,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" dependencies = [ "backtrace", + "bytes", + "libc", + "mio", "num_cpus", "pin-project-lite", + "socket2", "tokio-macros", + "windows-sys 0.48.0", ] [[package]] @@ -434,6 +475,7 @@ dependencies = [ "anyhow", "clap", "notify", + "pin-project", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 9111933..06a51b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,5 +22,6 @@ path = "src/bin/main.rs" [dependencies] anyhow = "1.0.81" clap = { version = "4.5.4", features = ["std", "derive", "env"] } -tokio = { version = "1.37.0", features = ["rt-multi-thread", "macros"] } +pin-project = "1.1.5" +tokio = { version = "1.37.0", features = ["io-util", "rt-multi-thread", "macros", "net"] } notify = "6.1.1" diff --git a/README.md b/README.md index 61b5f96..45f421e 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,12 @@ cargo install wait-on wait-on file /path/to/file ``` +### Wait for a Socket to be available using TCP Protocol + +```bash +wait-on socket -i 127.0.0.1 -p 8080 +``` + ## License This project is licensed under the MIT license and the Apache License 2.0. diff --git a/src/bin/command/mod.rs b/src/bin/command/mod.rs index 2e172cd..0850384 100644 --- a/src/bin/command/mod.rs +++ b/src/bin/command/mod.rs @@ -1 +1,2 @@ pub mod file; +pub mod socket; diff --git a/src/bin/command/socket.rs b/src/bin/command/socket.rs new file mode 100644 index 0000000..e931525 --- /dev/null +++ b/src/bin/command/socket.rs @@ -0,0 +1,22 @@ +use std::net::IpAddr; + +use anyhow::Result; +use clap::Args; + +use wait_on::resource::socket::SocketWaiter; +use wait_on::{WaitOptions, Waitable}; + +#[derive(Args, Debug)] +pub struct SocketOpt { + #[clap(short = 'p', long = "port")] + pub port: u16, + #[clap(short = 'i', long = "ip", default_value = "127.0.0.1")] + pub addr: IpAddr, +} + +impl SocketOpt { + pub async fn exec(&self) -> Result<()> { + let waiter = SocketWaiter::new(self.addr, self.port); + waiter.wait(WaitOptions::default()).await + } +} diff --git a/src/bin/main.rs b/src/bin/main.rs index 6e9dde9..c8de2a6 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -4,6 +4,7 @@ use anyhow::Result; use clap::Parser; use self::command::file::FileOpt; +use self::command::socket::SocketOpt; #[derive(Debug, Parser)] #[command( @@ -13,7 +14,10 @@ use self::command::file::FileOpt; next_line_help = true )] pub enum Command { + /// Wait on a file to be available File(FileOpt), + /// Wait on a socket to be available using the TCP Protocol + Socket(SocketOpt), } #[derive(Debug, Parser)] @@ -28,5 +32,6 @@ async fn main() -> Result<()> { match args.command { Command::File(opt) => opt.exec().await, + Command::Socket(opt) => opt.exec().await, } } diff --git a/src/resource/mod.rs b/src/resource/mod.rs index c3f5ff6..435a2cb 100644 --- a/src/resource/mod.rs +++ b/src/resource/mod.rs @@ -2,21 +2,25 @@ //! own configuration based on the protocols used. pub mod file; +pub mod socket; use anyhow::Result; use crate::{WaitOptions, Waitable}; use self::file::FileWaiter; +use self::socket::SocketWaiter; pub enum Resource { File(FileWaiter), + Socket(SocketWaiter), } impl Waitable for Resource { async fn wait(self, options: WaitOptions) -> Result<()> { match self { Resource::File(file) => file.wait(options).await, + Resource::Socket(socket) => socket.wait(options).await, } } } diff --git a/src/resource/socket.rs b/src/resource/socket.rs new file mode 100644 index 0000000..daee01f --- /dev/null +++ b/src/resource/socket.rs @@ -0,0 +1,127 @@ +use std::net::{IpAddr, SocketAddr}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use anyhow::{Error, Result}; +use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; +use tokio::net::{TcpListener, TcpStream}; + +use crate::{WaitOptions, Waitable}; + +/// Listens on a specific IP Address and Port using TCP protocol +pub struct SocketWaiter { + pub addr: IpAddr, + pub port: u16, +} + +impl SocketWaiter { + pub fn new(addr: IpAddr, port: u16) -> Self { + Self { addr, port } + } + + pub fn socket(&self) -> SocketAddr { + SocketAddr::new(self.addr, self.port) + } +} + +#[pin_project] +pub struct PacketExtractor { + pub header: [u8; B], + pub forwarded: usize, + #[pin] + pub socket: TcpStream, +} + +impl PacketExtractor { + pub async fn read(socket: TcpStream) -> Result { + let mut extractor = Self { + header: [0; B], + forwarded: 0, + socket, + }; + + extractor.socket.read_exact(&mut extractor.header).await?; + + Ok(extractor) + } + + pub fn get_header(&mut self) -> &[u8; B] { + &self.header + } +} + +impl AsyncRead for PacketExtractor { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buff: &mut ReadBuf<'_>, + ) -> Poll> { + let extractor = self.project(); + + if *extractor.forwarded < extractor.header.len() { + let leftover = &extractor.header[*extractor.forwarded..]; + let num_forward_now = leftover.len().min(buff.remaining()); + let forward = &leftover[..num_forward_now]; + + buff.put_slice(forward); + *extractor.forwarded += num_forward_now; + + return Poll::Ready(Ok(())); + } + + extractor.socket.poll_read(cx, buff) + } +} + +impl AsyncWrite for PacketExtractor { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buff: &[u8], + ) -> Poll> { + let extractor = self.project(); + extractor.socket.poll_write(cx, buff) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let extractor = self.project(); + extractor.socket.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let extractor = self.project(); + extractor.socket.poll_shutdown(cx) + } +} + +impl Waitable for SocketWaiter { + async fn wait(self, _: WaitOptions) -> Result<()> { + let tcp_listener = TcpListener::bind(self.socket()).await?; + let (socket, _) = tcp_listener.accept().await?; + let mut socket = PacketExtractor::<8>::read(socket).await?; + + tokio::spawn(async move { + let mut buf = vec![0; 1024]; + + loop { + let n = socket + .read(&mut buf) + .await + .expect("failed to read data from socket"); + + if n == 0 { + // socket closed + return; + } + } + }) + .await + .map_err(|err| Error::msg(err.to_string()))?; + + Ok(()) + } +}