From 6749486f99d506c5636783d652e711661ec49f86 Mon Sep 17 00:00:00 2001 From: Ondrej Perutka Date: Tue, 29 Dec 2015 17:28:04 +0100 Subject: [PATCH] initial commit --- .gitignore | 2 + Cargo.toml | 33 + LICENSE | 202 ++++ README.md | 145 +++ build.rs | 21 + ca.pem | 30 + rtsp-paths | 206 ++++ src/main.rs | 503 ++++++++++ src/net/arrow/error.rs | 91 ++ src/net/arrow/mod.rs | 1400 +++++++++++++++++++++++++++ src/net/arrow/protocol/control.rs | 537 ++++++++++ src/net/arrow/protocol/mod.rs | 346 +++++++ src/net/arrow/protocol/svc_table.rs | 500 ++++++++++ src/net/discovery.rs | 277 ++++++ src/net/mod.rs | 23 + src/net/raw/arp.rs | 330 +++++++ src/net/raw/devices.c | 189 ++++ src/net/raw/devices.rs | 108 +++ src/net/raw/ether.rs | 363 +++++++ src/net/raw/ip.rs | 366 +++++++ src/net/raw/mod.rs | 23 + src/net/raw/pcap.rs | 407 ++++++++ src/net/raw/tcp.rs | 517 ++++++++++ src/net/raw/utils.rs | 139 +++ src/net/rtsp.rs | 790 +++++++++++++++ src/net/utils.rs | 183 ++++ src/utils/config.rs | 283 ++++++ src/utils/logger/mod.rs | 127 +++ src/utils/logger/syslog.rs | 91 ++ src/utils/mod.rs | 300 ++++++ 30 files changed, 8532 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 LICENSE create mode 100644 README.md create mode 100644 build.rs create mode 100644 ca.pem create mode 100644 rtsp-paths create mode 100644 src/main.rs create mode 100644 src/net/arrow/error.rs create mode 100644 src/net/arrow/mod.rs create mode 100644 src/net/arrow/protocol/control.rs create mode 100644 src/net/arrow/protocol/mod.rs create mode 100644 src/net/arrow/protocol/svc_table.rs create mode 100644 src/net/discovery.rs create mode 100644 src/net/mod.rs create mode 100644 src/net/raw/arp.rs create mode 100644 src/net/raw/devices.c create mode 100644 src/net/raw/devices.rs create mode 100644 src/net/raw/ether.rs create mode 100644 src/net/raw/ip.rs create mode 100644 src/net/raw/mod.rs create mode 100644 src/net/raw/pcap.rs create mode 100644 src/net/raw/tcp.rs create mode 100644 src/net/raw/utils.rs create mode 100644 src/net/rtsp.rs create mode 100644 src/net/utils.rs create mode 100644 src/utils/config.rs create mode 100644 src/utils/logger/mod.rs create mode 100644 src/utils/logger/syslog.rs create mode 100644 src/utils/mod.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a9d37c5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..d8c7b4e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "arrow-client" +version = "0.3.0" +authors = ["Ondrej Perutka "] +license = "Apache-2.0" +readme = "README.md" +build = "build.rs" + +[features] +discovery = [] + +[dependencies] +libc = "0.2" +regex = "0.1" +mio = "0.5" +uuid = "0.1" +time = "0.1" +rustc-serialize = "0.3" + +[dependencies.openssl] +version = "0.7.3" +features = ["tlsv1_2"] + +[build-dependencies] +gcc = "0.3" + +[profile.dev] +opt-level = 0 +debug = true + +[profile.release] +opt-level = 3 +debug = false diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..fbbc376 --- /dev/null +++ b/README.md @@ -0,0 +1,145 @@ +# Arrow Client + +Arrow Client is an application used to simplify the process of connecting IP +cameras to the Angelcam cloud services. The Arrow Client is meant to be shipped +together with IP camera firmwares or as a standalone application for various +devices such as Raspberry Pi. It can connect to RTSP services passed as +command line arguments and optionally, it can be also compiled with a network +scanning feature. This feature allows the client to scan all its network +interfaces and find local RTSP services automatically. All such RTSP services +can be registered within Angelcam cloud services under your Angelcam account +without the need of a public IP address and setting up your firewall. + +## Features + +- Automatic RTSP service discovery +- Zero-conf IP camera connection +- Connection to Angelcam cloud services secured using TLS v1.2 +- Secure pairing with your Angelcam account + +## Usage + +The application requires `/etc/arrow` directory for storing its configuration +file generated on the first start. The directory must also contain a file named +`rtsp-paths` in case the network scanning feature is enabled. The file should +contain RTSP paths that will be checked on an RTSP service discovery. You can +use the `rtsp-paths` file from this repository. In order to start the +application, you have to pass address of the Angelcam Arrow Service and +certificate file for the service verification. Address of the Angelcam Arrow +Service is: + +``` +arr-rs.angelcam.com:8900 +``` + +Currently, a self-signed certificate is used. You can find the certificate in +this repository (file `ca.pem`). The certificate will be later replaced by +a proper CA certificate. + +Here is an example of starting the Arrow Client with one fixed RTSP service and +with network scanning enabled: + +```bash +arrow-client arr-rs.angelcam.com:8900 ca.pem -d -r "rtsp://localhost:8554/stream.sdp?prof=baseline&res=low" +``` + +Note that the application requires root privileges for direct access to local +network interfaces. + +## Dependencies + +This application requires the following native libraries: + +- OpenSSL +- libpcap (this dependency is optional, it is required only when the network + scanning feature is enabled) + +and the following Rust libraries (downloaded automatically on build): + +- libc +- regex +- mio +- uuid +- time +- rustc-serialize +- openssl + +## Compilation + +Arrow Client compilation is currently supported for x86, x86\_64 and ARM +platforms. Compilation for x86 and x86\_64 can be done directly on +a particular machine. ARM binaries can be compiled using a cross-compiler or +directly in the target or in a virtualized environment (e.g. QEMU). + +### Direct compilation on x86, x86\_64 or ARM + +- Download and install Rust build environment (in case you do not already + have one) from https://www.rust-lang.org/install.html. +- Use the following commands to build the Arrow Client binary: + +```bash +# to build Arrow Client: +cargo build --release +# to build Arrow Client with network scanning feature: +cargo build --release --features "discovery" +``` + +- You will find the binary in the `target/release/` subdir. +- Run the application without any arguments to see its usage. + +### Cross-compilation + +First of all, you will need gcc cross-compiler. In case of ARM, you can +download it from https://github.com/raspberrypi/tools. You will also have to +add some essential libraries for the target architecture (e.g. OpenSSL). + +Now, you need to get your copy of Rust: + +```bash +git clone https://github.com/rust-lang/rust.git +``` + +Add the gcc cross-compiler into the PATH env. variable, e.g.: + +```bash +export PATH=~/pi-tools/arm-bcm2708/gcc-linaro-arm-linux-gnueabihf-raspbian-x64/bin:$PATH +``` + +Then build the Rust cross-compiler, e.g.: + +```bash +./configure --target=arm-unknown-linux-gnueabihf && make && make install +``` + +When the Rust cross-compiler is ready, you will have to modify your +cargo configuration in order to tell the Rust compiler the name of your gcc +linker for the target architecture. For `arm-unknown-linux-gnueabihf` target, +insert the following configuration into your `~/.cargo/config`: + +```toml +[target.arm-unknown-linux-gnueabihf] +ar = "arm-linux-gnueabihf-ar" +linker = "arm-linux-gnueabihf-gcc" +``` + +Finally, set CC env. variable to the path of your gcc cross-compiler, e.g.: + +```bash +export CC=~/pi-tools/arm-bcm2708/gcc-linaro-arm-linux-gnueabihf-raspbian-x64/bin/arm-linux-gnueabihf-gcc +``` + +and add libraries for the target architecture into the LD_LIBRARY_PATH env. +variable, e.g.: + +```bash +export LD_LIBRARY_PATH=~/pi-tools/arm-bcm2708/gcc-linaro-arm-linux-gnueabihf-raspbian-x64/lib:$LD_LIBRARY_PATH +``` + +Now, you are ready to build the Arrow Client for the target architecture, e.g.: + +```bash +# to build Arrow Client: +cargo build --target=arm-unknown-linux-gnueabihf --release +# to build Arrow Client with network scanning feature: +cargo build --target=arm-unknown-linux-gnueabihf --release --features "discovery" +``` diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..5e0ac85 --- /dev/null +++ b/build.rs @@ -0,0 +1,21 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate gcc; + +fn main() { + gcc::compile_library("libnet_devices.a", + &["src/net/raw/devices.c"]); +} + diff --git a/ca.pem b/ca.pem new file mode 100644 index 0000000..c2a041f --- /dev/null +++ b/ca.pem @@ -0,0 +1,30 @@ +-----BEGIN CERTIFICATE----- +MIIFFDCCAvwCCQD+OG2n4YvuhzANBgkqhkiG9w0BAQUFADBMMQswCQYDVQQGEwJD +WjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UEBwwGUHJhZ3VlMRcwFQYDVQQK +DA5hbmdlbGNhbSwgaW5jLjAeFw0xNTAzMDUxMDIxMTBaFw0yNTAzMDIxMDIxMTBa +MEwxCzAJBgNVBAYTAkNaMRMwEQYDVQQIDApTb21lLVN0YXRlMQ8wDQYDVQQHDAZQ +cmFndWUxFzAVBgNVBAoMDmFuZ2VsY2FtLCBpbmMuMIICIjANBgkqhkiG9w0BAQEF +AAOCAg8AMIICCgKCAgEA7nhHZxTAevpmPMVJYtmtec+mjLom4IoXgLnNj2vX59th +RBPUEHBTMEozU4A3u0lZXp5nGLi/S4yVQ+I2JXCLPivh76sqUaW3aaDBVoN9B0G6 +lZWCCR3T9bVZSimnFn/ybnqCqpCCFK3uE5a0hBSbqwm+s0XekRkfoapOoxpwcKAZ +yN8dVOXX4gcr7ZJOMJwMDBjCooUBG43zVCwGPMcbDg6tWSLTVY0cPUmlffDG+OvL +GXdwsI1wT0JsHEYe3XCxl/mn8FxFInL24aiHA/XsKOuAj6P9b9Iz6kcaJYPBc0xy +ecsGwFO3Cv+KbdyIFsnw+xm3J7IBHlOTeJnv6rdD3R9MOZy8bFZcFDY13B4Ikzf9 +7odcWVMnM5w4uWexkOPbjfMqOqhfmFHNvL4iuoc6kYbWjNOcqj4e0O0SmQcX0oDM +dZkhOAD087lNif+TNHI33KA0eCUPO65rknWiI0jT+DNuHYbKcdrIdIs+ndnPA9KB +U3FIa2COwmh1Zhy/gduluthf5t8ndja6jvOEeP4HujSQumt6iCmQsKZ16ftX5nob +bbHEY7BKEtMbBASTJnZ3rvBzqOmuyg+T1x2tRmTVFkyvdyaGaV++DbkHvGszw0rG +BhAR0M3/HI0Ew+yg4SROGtKihQ3DXwXndADkyPespJkaMmcUERukt2KuJXRpDCkC +AwEAATANBgkqhkiG9w0BAQUFAAOCAgEACupy0wOMJT2vzQgh30Cd36sH25PimImX +V54nU2HbVXSXQ4w88cKkS88NeDuqvcf+n1uKxC2fnUMor0JAy0irXyddUT5SZkn3 +mTheasPzDG96SRVx25yIqmzHlfrDE0prangh3oBD3vrPyIWPAIVWVaZ56b3p8MTY +ZdcRVQkp1fVyfLU89EBpVAQDV5roic4UD1IBHjvYRpycL31oi8FXeCTMgwhPE3GA +M/EPD3NBX5wDtFWkJKrhbU+IFeDYldKzeBd1TrbfmifkoFkddSOxf89EQpnfPgxs +9XRuwZiPLuakuGvnw7FZYyzv5FoettDePuxlNsKuThe/eYTo+dL+1+OvpFm4epeI +7HPazSmSaLDYqnOakiSRyUVQEC7cSODBuwVG9hb/5NCwO1bsp/sAe2TZechV2E1p +v+2XAPCWLtsNFK48OEvQ7sP/7YyOUuHP0hEgbWAe5R795MT5b1GJCikN/JNe8omk +3evPhqJg3p3439Lnr6Fcfgbj2ZSSasWcxd3XleZGweoUB6VNFBbjSjf5W573C3/o +SH19jbTHRM+Q9ELqjvR/JNJp3BYUr8g+6HZoKKjSvRye2orvlyH0rg0Q0puXdrv1 +keEsNQUlw5QbZhNEWn7PHiRNcQkwXbRCJD8NvBKL5tdQ0fOWTk5RBnYlAbMDUd4l +u8c1VUFo3So= +-----END CERTIFICATE----- diff --git a/rtsp-paths b/rtsp-paths new file mode 100644 index 0000000..1cfe218 --- /dev/null +++ b/rtsp-paths @@ -0,0 +1,206 @@ +/videoMain +/videoSub +/play1.sdp +/mpeg4 +/axis-media/media.amp +/video.h264 +/Streaming/Channels/2 +/Streaming/Channels/1 +/Streaming/Channels/103 +/Streaming/Channels/102 +/play2sdp +/live1.sdp +/live2.sdp +/cam/realmonitor +/live/ch00_0 +/live.sdp +/11 +/12 +/live/h264 +/live/h264/ch1 +/cam1/h264 +/0/av0 +/profile2/media.smp +/profile1/media.smp +/h264unicast +/H264/media.smp +/MediaInput/h264 +/media/video1 +/media/video2 +/live_h264.sdp +/0 +/1 +/ipcam_h264.sdp +/ipcam.sdp +/video.mp4 +/media/media.amp?videocodec=h264&streamprofile=Profile3 +/video.pro4 +/video.pro3 +/video.pro2 +/ +/stream1 +/stream2 +/h264 +/0/video1 +/nph-h264.cgi +/img/video.sav +/ch001.sdp +/ch002.sdp +/media/media.amp?videocodec=h264 +/live/0/h264.sdp +/h264.sdp?res=half&x0=800&y0=400&x1=1600&y1=1200&qp=10&ssn=5 +/onvif-media/media.amp +/live/ch00_1 +/ch0.h264 +/live +/video +/rtsp_tunnel +/ch1-s1 +/img/media.sdp +/medias1 +/video0.sdp +/video.pro1 +/stander/livestream/0/0 +/ip_adx +/play3.sdp +/img/media.sav +/channel1/stream1 +/v02 +/PSIA/Streaming/channels/1 +/ch01_sub.264 +/channel1 +/channel2 +/live0.264 +/mpeg4/1/media.amp +/ch0_0.h264 +/cam1/mpeg4 +/stream0/Channel=0 +/Live/Channel=1 +/video1+audio1 +/ch0 +/VideoInput/1/h264/1 +/rtsph2641080p +/rtsph264720p +/H264 +/cam1/onvif-h264 +/media.amp +/0/1:1/main +/video1 +/main +/sub +/av0_0 +/av0_1 +/webcam +/cam0_0 +/PSIA/Streaming/channels/0 +/cam0_1 +/snl/live/1/1 +/media +/ioImage/1 +/avn=2 +/1.AMP +/defaultPrimary?streamtype=u +/v1 +/v2 +/video1enc1 +/tcp/av0_0 +/live1.264 +/live/stream3 +/1/h264major +/live/mpeg4 +/rtsph264 +/multicaststream +/h264/media.amp +/profile1 +/stream/profile0=r +/profile1=u +/h264_2 +/h264_1 +/ch1/stream0 +/streaming/video1 +/rtsp/profile1 +/rtsp_live0 +/MainStream +/live/av0 +/h264Preview_01_main +/AVStream1_1 +/v01 +/defaultPrimary?streamType=u +/profile2 +/live/myStream +/ch1.h264 +/streaming/channels/0 +/cgi-bin/rtspStreamOvf/1 +/cgi-bin/rtspStreamOvf/0 +/profile1=r +/h264_4 +/h264_3 +/vis +/videoinput_1:0/h264_1/onvif.stm +/video1.sdp +/v03 +/udp/av0_0 +/track1 +/streaming/video0 +/Streaming/Channels/3 +/stream0 +/stream/profile1=r +/stream.sdp1 +/stander/livestream/0/1 +/snl/live/1/2 +/s1 +/rtsph264vga +/rtsph2641024p +/rtsp/stream1 +/rtsp/record.sdp +/rtpvideo1.sdp +/PSIA/Streaming/channels/h264 +/ProfileToken_1_4 +/profile/profile01 +/play2.sdp +/OVProfile02 +/OVProfile00 +/medias2 +/media_ch1 +/media_ch0 +/Master-0 +/LowResolutionVideo +/lowQ.sdp +/live4.sdp +/live/stream1 +/live/second +/live/main +/live/camera1 +/live/0/onvif.sdp +/live/0/MAIN +/live_st2 +/live_h264_1.sdp +/id=0 +/hiQ.sdp +/HighResolutionVideo +/h264main +/H264/sub +/h264/ch1/sub/ +/H264/ch1/sub +/H264/ch1 +/h264_vga.sdp +/h264_stream +/gnz_media/second +/gnz_media/main +/Channel0/Sub2 +/Channel0/Sub1 +/ch2 +/ch1.sdp +/ch0_unicast_secondstream +/ch0_unicast_firststream +/ch0_1.H264 +/cgi-bin/rtspStream/1 +/camera.stm +/cam1/h264/multicast +/cam +/av2 +/2 +/1/stream1 +/1/1:1/main +/0/888888:888888/sub +/live/h264/VGA diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..56d0fff --- /dev/null +++ b/src/main.rs @@ -0,0 +1,503 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Arrow Client definitions. + +extern crate mio; +extern crate libc; +extern crate regex; +extern crate openssl; +extern crate time; +extern crate uuid; +extern crate rustc_serialize; + +#[macro_use] +mod utils; + +pub mod net; + +use std::env; +use std::process; +use std::thread; + +use std::sync::Arc; +use std::fmt::Debug; +use std::error::Error; +use std::str::FromStr; +use std::thread::JoinHandle; +use std::net::{SocketAddr, ToSocketAddrs}; + +use utils::logger::syslog; + +use utils::{Shared, RuntimeError}; +use utils::logger::{Logger, Severity}; +use utils::config::{ArrowConfig, AppContext}; + +#[cfg(feature = "discovery")] +use net::discovery; + +use net::raw::ether::MacAddr; +use net::raw::devices::EthernetDevice; +use net::arrow::error::ArrowError; +use net::arrow::{ArrowClient, Sender, Command}; +use net::arrow::protocol::{Service, ServiceTable}; + +use openssl::x509::X509FileType; +use openssl::ssl::error::SslError; +use openssl::ssl::{IntoSsl, SslContext, SslMethod}; + +use mio::{EventLoop, Handler, NotifyError}; + +use regex::Regex; + +/// Network scan period. +const NETWORK_SCAN_PERIOD: u64 = 300000; + +/// Connectionn retry timeout. +const RETRY_TIMEOUT: f64 = 60.0; + +/// Arrow Client configuration file. +static CONFIG_FILE: &'static str = "/etc/arrow/config.json"; + +/// Get socket address from a given argument. +fn get_socket_address(s: T) -> Result + where T: ToSocketAddrs { + let mut addrs = try!(s.to_socket_addrs() + .or(Err(RuntimeError::from("unable get socket address")))); + + match addrs.next() { + Some(addr) => Ok(addr), + _ => Err(RuntimeError::from("unable get socket address")) + } +} + +/// Get MAC address of the first configured ethernet device. +fn get_first_mac() -> Result { + let mut devices = EthernetDevice::list() + .into_iter(); + + match devices.next() { + Some(dev) => Ok(dev.mac_addr), + None => Err(RuntimeError::from("there is no configured ethernet device")) + } +} + +/// Unwrap a given result (if possible) or print the error message and exit +/// the process printing application usage. +fn result_or_usage(res: Result) -> T + where E: Error + Debug { + match res { + Ok(res) => res, + Err(err) => { + println!("ERROR: {}\n", err.description()); + usage(1); + } + } +} + +/// Parse a given RTSP URL and return Service::RTSP, Service::LockedRTSP or +/// an error. +fn parse_rtsp_url(url: &str) -> Result { + let res = r"^rtsp://([^/]+@)?([^/@:]+|\[[0-9a-fA-F:.]+\])(:(\d+))?(/.*)?$"; + let re = Regex::new(res).unwrap(); + + if let Some(caps) = re.captures(url) { + // we don't care about the actual MAC address + let mac = MacAddr::new(0, 0, 0, 0, 0, 0); + let host = caps.at(2).unwrap(); + let path = caps.at(5).unwrap(); + let port = match caps.at(4) { + Some(port_str) => u16::from_str(port_str).unwrap(), + _ => 554 + }; + + let socket_addr = try!(get_socket_address((host, port)) + .or(Err(RuntimeError::from( + "unable to resolve RTSP service address")))); + + // note: we do not want to probe the service here as it might not be + // available on app startup + match caps.at(1) { + Some(_) => Ok(Service::LockedRTSP(mac, socket_addr)), + None => Ok(Service::RTSP(mac, socket_addr, path.to_string())) + } + } else { + Err(RuntimeError::from("invalid RTSP URL given")) + } +} + +/// Print usage and exit the process with a given exit code. +fn usage(exit_code: i32) -> ! { + println!("USAGE: arrow-client arr-host[:arr-port] ca-cert [OPTIONS]\n"); + println!(" arr-host Angelcam Arrow Service host"); + println!(" arr-port Angelcam Arrow Service port\n"); + println!(" ca-cert CA certificate in PEM format for Arrow Service identity"); + println!(" verification\n"); + println!("OPTIONS:\n"); + if cfg!(feature = "discovery") { + println!(" -d automatic service discovery"); + } + println!(" -r URL local RTSP service URL"); + println!(" -v enable debug logs\n"); + process::exit(exit_code); +} + +/// Initialize SSL context. +fn init_ssl(ca_file: &str) -> Result { + let mut ssl_context = try!(SslContext::new(SslMethod::Tlsv1_2)); + try!(ssl_context.set_certificate_file(ca_file, X509FileType::PEM)); + try!(ssl_context.set_cipher_list("HIGH:!aNULL:!kRSA:!PSK:!MD5:!RC4")); + Ok(ssl_context) +} + +/// Spawn a new Arrow Client thread. +fn spawn_arrow_thread( + logger: L, + ssl_context: Arc, + cmd_sender: CommandSender, + addr: &str, + arrow_mac: &MacAddr, + app_context: &Shared) { + let addr = addr.to_string(); + let arrow_mac = arrow_mac.clone(); + let app_context = app_context.clone(); + + thread::spawn(move || arrow_thread(logger, ssl_context, cmd_sender, + &addr, &arrow_mac, app_context)); +} + +/// Arrow Client main thread. +/// +/// This function ensures maintaining connection with a remote Arrow Service. +fn arrow_thread + Clone>( + mut logger: L, + ssl_context: Arc, + cmd_sender: Q, + addr: &str, + arrow_mac: &MacAddr, + app_context: Shared) { + let mut last_error = time::precise_time_s(); + let mut cur_addr = addr.to_string(); + + loop { + log_info!(logger, &format!("connecting to remote Arrow Service {}", cur_addr)); + + let lgr = logger.clone(); + let ctx = app_context.clone(); + + let res = match utils::result_or_log(&mut logger, Severity::WARN, + connect(lgr, &*ssl_context, cmd_sender.clone(), + &cur_addr, arrow_mac, ctx)) { + Some(addr) => Ok(addr), + None => Err(time::precise_time_s()) + }; + + match res { + Ok(addr) => cur_addr = addr, + Err(t) => { + if (last_error + RETRY_TIMEOUT - 0.5) > t { + let retry = RETRY_TIMEOUT + last_error - t; + log_info!(logger, &format!("retrying in {:.3} seconds", retry)); + thread::sleep_ms((retry * 1000.0) as u32); + } + + cur_addr = addr.to_string(); + last_error = t; + } + } + } +} + +/// Connect to a given Arrow Service. +fn connect>( + logger: L, + s: S, + cmd_sender: Q, + addr: &str, + arrow_mac: &MacAddr, + app_context: Shared) -> Result { + let addr = try!(get_socket_address(addr) + .or(Err(ArrowError::from(format!("failed to lookup Arrow Service {} address information", addr))))); + + match ArrowClient::new(logger, s, cmd_sender, + &addr, arrow_mac, app_context) { + Err(err) => Err(ArrowError::from(format!("unable to connect to remote Arrow Service {} ({})", addr, err.description()))), + Ok(mut client) => client.event_loop() + } +} + +#[cfg(feature = "discovery")] +/// Run device discovery and update a given service table. +fn network_scanner_thread( + mut logger: L, + app_context: Shared) { + log_info!(logger, "looking for local services..."); + let services = utils::result_or_log(&mut logger, Severity::WARN, + discovery::find_rtsp_streams()); + + if let Some(services) = services { + let mut app_context = app_context.lock() + .unwrap(); + let config = &mut app_context.config; + let count = services.len(); + + let bump = services.into_iter() + .fold(false, |b, svc| { + config.add(svc) + .is_some() | b + }); + + if bump { + config.bump_version(); + } + + log_info!(logger, &format!("{} services found, current service table: {}", count, config)); + utils::result_or_log(&mut logger, Severity::WARN, + config.save(CONFIG_FILE)); + } +} + +#[cfg(not(feature = "discovery"))] +/// Dummy scanner. +fn network_scanner_thread(_: L, _: Shared) { +} + +/// Periodical event types. +#[derive(Debug, Copy, Clone)] +enum TimerEvent { + ScanNetwork +} + +/// Arrow Command wrapper/extender. +#[derive(Debug, Copy, Clone)] +enum CommandWrapper { + Wrapped(Command), + ScanCompleted +} + +/// Command channel. +#[derive(Debug, Clone)] +struct CommandSender { + sender: mio::Sender, +} + +impl CommandSender { + /// Crate a new channel for sending Arrow Commands. + fn new(sender: mio::Sender) -> CommandSender { + CommandSender { + sender: sender + } + } +} + +impl Sender for CommandSender { + fn send(&self, cmd: Command) -> Result<(), Command> { + match self.sender.send(CommandWrapper::Wrapped(cmd)) { + Ok(_) => Ok(()), + Err(err) => match err { + NotifyError::Closed(None) => Ok(()), + _ => Err(cmd) + } + } + } +} + +/// Arrow command handler. +struct CommandHandler { + logger: L, + default_svc_table: ServiceTable, + app_context: Shared, + scanner: Option>, + discovery: bool, +} + +impl CommandHandler { + /// Create a new Arrow Command handler. + fn new( + logger: L, + default_svc_table: ServiceTable, + app_context: Shared, + discovery: bool) -> CommandHandler { + CommandHandler { + logger: logger, + default_svc_table: default_svc_table, + app_context: app_context, + scanner: None, + discovery: discovery + } + } + + /// Scan the local network for new services and schedule the next network + /// scanning event. + fn periodical_network_scan(&mut self, event_loop: &mut EventLoop) { + self.scan_network(event_loop); + + event_loop.timeout_ms(TimerEvent::ScanNetwork, NETWORK_SCAN_PERIOD) + .unwrap(); + } + + /// Spawn a new network scanner thread (if it is not already running) and + /// save its join handle. + fn scan_network(&mut self, event_loop: &mut EventLoop) { + // check if the discovery is enabled and if there is another scanner + // running + if self.discovery && self.scanner.is_none() { + let mut app_context = self.app_context.lock() + .unwrap(); + + app_context.scanning = true; + + let logger = self.logger.clone(); + let app_context = self.app_context.clone(); + let sender = event_loop.channel(); + let handle = thread::spawn(move || { + network_scanner_thread(logger, app_context); + sender.send(CommandWrapper::ScanCompleted) + .unwrap(); + }); + + self.scanner = Some(handle); + } + } + + /// Called upon network scanner thread completion. + fn scan_completed(&mut self) { + let res = match self.scanner.take() { + Some(handle) => handle.join(), + _ => Ok(()), + }; + + let mut app_context = self.app_context.lock() + .unwrap(); + + app_context.scanning = false; + + if res.is_err() { + log_warn!(self.logger, "network scanner thread panicked"); + } + } + + /// Reinitialize the shared config with the default service table. + fn reset_svc_table(&mut self) { + let mut app_context = self.app_context.lock() + .unwrap(); + let config = &mut app_context.config; + let table = &self.default_svc_table; + + config.reinit(table.clone()); + config.bump_version(); + + utils::result_or_log(&mut self.logger, Severity::WARN, + config.save(CONFIG_FILE)); + } +} + +impl Handler for CommandHandler { + type Timeout = TimerEvent; + type Message = CommandWrapper; + + fn timeout( + &mut self, + event_loop: &mut EventLoop, + event: TimerEvent) { + match event { + TimerEvent::ScanNetwork => self.periodical_network_scan(event_loop) + } + } + + fn notify( + &mut self, + event_loop: &mut EventLoop, + cmd: CommandWrapper) { + match cmd { + CommandWrapper::ScanCompleted => self.scan_completed(), + CommandWrapper::Wrapped(cmd) => match cmd { + Command::ResetServiceTable => self.reset_svc_table(), + Command::ScanNetwork => self.scan_network(event_loop) + } + } + } +} + +/// Arrow Client main function. +fn main() { + let mut logger = syslog::new(); + let args = env::args() + .collect::>(); + + if args.len() < 3 { + usage(1); + } else { + let arrow_mac = utils::result_or_error(get_first_mac(), 2); + let arrow_addr = &args[1]; + let ca_file = &args[2]; + + let mut discovery = false; + + let mut i = 3; + + let mut config = ArrowConfig::load(CONFIG_FILE) + .unwrap_or(ArrowConfig::new()); + + while i < args.len() { + match &args[i] as &str { + "-d" if cfg!(feature = "discovery") => { discovery = true; }, + "-r" => { + let service = parse_rtsp_url(&args[i + 1]); + let service = result_or_usage(service); + config.add(service); + i += 1; + }, + "-v" => { logger.set_level(Severity::DEBUG); }, + _ => { + println!("unknown argument: {}\n", &args[i]); + usage(1); + } + } + + i += 1; + } + + utils::result_or_error(config.save(CONFIG_FILE), 3); + + let ssl_context = Arc::new( + utils::result_or_error(init_ssl(ca_file), 4)); + + log_info!(logger, &format!("application started (uuid: {}, mac: {})", + config.uuid_string(), arrow_mac)); + + let default_svc_table = config.service_table(); + let app_context = Shared::new(AppContext::new(config)); + + let mut event_loop = EventLoop::new() + .unwrap(); + + let mut cmd_handler = CommandHandler::new( + logger.clone(), + default_svc_table, + app_context.clone(), + discovery); + + let cmd_sender = CommandSender::new(event_loop.channel()); + + spawn_arrow_thread(logger, ssl_context, cmd_sender, + arrow_addr, &arrow_mac, &app_context); + + event_loop.timeout_ms(TimerEvent::ScanNetwork, 0) + .unwrap(); + + event_loop.run(&mut cmd_handler) + .unwrap(); + } +} diff --git a/src/net/arrow/error.rs b/src/net/arrow/error.rs new file mode 100644 index 0000000..b9ea838 --- /dev/null +++ b/src/net/arrow/error.rs @@ -0,0 +1,91 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Definitions of ArrowError which may be returned by Arrow client. + +use std::io; +use std::fmt; +use std::result; + +use std::error::Error; +use std::fmt::{Display, Formatter}; + +use mio::TimerError; + +use openssl::ssl::error::{SslError, NonblockingSslError}; + +/// Type alias for Result with ArrowError. +pub type Result = result::Result; + +/// Arrow error (it may be returned by Arrow client). +#[derive(Debug, Clone)] +pub struct ArrowError { + msg: String, +} + +impl Error for ArrowError { + /// Get error description. + fn description(&self) -> &str { + &self.msg + } +} + +impl Display for ArrowError { + /// Format error message. + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + f.write_str(&self.msg) + } +} + +impl From for ArrowError { + /// Create a new ArrowError from a given error string. + fn from(msg: String) -> ArrowError { + ArrowError { msg: msg } + } +} + +impl<'a> From<&'a str> for ArrowError { + /// Create a new ArrowError from a given error string. + fn from(msg: &'a str) -> ArrowError { + ArrowError::from(msg.to_string()) + } +} + +impl From for ArrowError { + /// Create a new ArrowError from a given IO error. + fn from(err: io::Error) -> ArrowError { + ArrowError::from(format!("IO error: {}", err.description())) + } +} + +impl From for ArrowError { + /// Create a new ArrowError for a given timer error. + fn from(_: TimerError) -> ArrowError { + ArrowError::from("timer error") + } +} + +impl From for ArrowError { + /// Create a new ArrowError from a given SSL error. + fn from(err: SslError) -> ArrowError { + ArrowError::from(format!("OpenSSL error: {}", err.description())) + } +} + +impl From for ArrowError { + /// Create a new ArrowError from a given SSL error. + fn from(err: NonblockingSslError) -> ArrowError { + ArrowError::from(format!("OpenSSL error: {}", err.description())) + } +} diff --git a/src/net/arrow/mod.rs b/src/net/arrow/mod.rs new file mode 100644 index 0000000..4290443 --- /dev/null +++ b/src/net/arrow/mod.rs @@ -0,0 +1,1400 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Arrow Protocol implementation. + +pub mod error; +pub mod protocol; + +use std::io; +use std::cmp; +use std::mem; +use std::result; + +use std::ffi::CStr; +use std::error::Error; +use std::collections::VecDeque; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::io::{Read, Write, ErrorKind}; + +use utils; + +use net::raw::ether::MacAddr; +use net::utils::{Timeout, WriteBuffer}; + +use utils::logger::Logger; +use utils::config::AppContext; +use utils::{Shared, Serialize}; + +use self::protocol::*; +use self::error::{Result, ArrowError}; + +use mio::tcp::TcpStream; +use mio::{EventLoop, EventSet, Token, PollOpt, Handler}; + +use openssl::ssl::{NonblockingSslStream, IntoSsl}; +use openssl::ssl::error::NonblockingSslError; + +/// Register a given TCP stream in a given event loop. +fn register_socket( + token_id: usize, + stream: &TcpStream, + readable: bool, + writable: bool, + event_loop: &mut EventLoop) { + let poll = PollOpt::level(); + let mut events = EventSet::all(); + + if !readable { + events.remove(EventSet::readable()); + } + + if !writable { + events.remove(EventSet::writable()); + } + + event_loop.register(stream, Token(token_id), events, poll) + .unwrap(); +} + +/// Re-register a given TCP stream in a given event loop. +fn reregister_socket( + token_id: usize, + stream: &TcpStream, + readable: bool, + writable: bool, + event_loop: &mut EventLoop) { + let poll = PollOpt::level(); + let mut events = EventSet::all(); + + if !readable { + events.remove(EventSet::readable()); + } + + if !writable { + events.remove(EventSet::writable()); + } + + event_loop.reregister(stream, Token(token_id), events, poll) + .unwrap(); +} + +/// Deregister a given socket. +fn deregister_socket( + stream: &TcpStream, + event_loop: &mut EventLoop) { + event_loop.deregister(stream) + .unwrap(); +} + +/// Commands that might be sent by the Arrow Client into a given mpsc queue. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum Command { + ResetServiceTable, + ScanNetwork, +} + +/// Common trait for various implementations of command senders. +pub trait Sender { + /// Send a given command or return the command back if the send operation + /// failed. + fn send(&self, cmd: C) -> result::Result<(), C>; +} + +/// ArrowStream states. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +enum ArrowStreamState { + Ok, + ReaderWantRead, + ReaderWantWrite, + WriterWantRead, + WriterWantWrite, +} + +/// Abstraction over the Arrow SSL stream. +struct ArrowStream { + stream: NonblockingSslStream, + state: ArrowStreamState, + token_id: usize, +} + +impl ArrowStream { + /// Create a new ArrowStream instance and register the underlaying socket + /// within a given event loop. + fn connect( + s: S, + arrow_addr: &SocketAddr, + token_id: usize, + event_loop: &mut EventLoop) -> Result { + let tcp_stream = try!(TcpStream::connect(arrow_addr)); + let ssl_stream = try!(NonblockingSslStream::connect(s, tcp_stream)); + + register_socket(token_id, ssl_stream.get_ref(), + true, true, event_loop); + + let res = ArrowStream { + stream: ssl_stream, + state: ArrowStreamState::Ok, + token_id: token_id + }; + + Ok(res) + } + + /// Enable receiving writable events for the underlaying TCP socket. + fn enable_socket_events( + &mut self, + readable: bool, + writable: bool, + event_loop: &mut EventLoop) { + reregister_socket(self.token_id, self.stream.get_ref(), + readable, writable, event_loop); + } + + /// Read available data from the underlaying SSL stream into a given + /// buffer. + fn read( + &mut self, + buf: &mut [u8], + event_loop: &mut EventLoop) -> Result { + match self.stream.read(buf) { + Err(NonblockingSslError::WantRead) => { + self.state = ArrowStreamState::ReaderWantRead; + self.enable_socket_events(true, false, event_loop); + Ok(0) + }, + Err(NonblockingSslError::WantWrite) => { + self.state = ArrowStreamState::ReaderWantWrite; + self.enable_socket_events(false, true, event_loop); + Ok(0) + }, + other => { + self.state = ArrowStreamState::Ok; + self.enable_socket_events(true, true, event_loop); + Ok(try!(other)) + } + } + } + + /// Write given data using the underlaying SSL stream. + fn write( + &mut self, + data: &[u8], + event_loop: &mut EventLoop) -> Result { + match self.stream.write(data) { + Err(NonblockingSslError::WantRead) => { + self.state = ArrowStreamState::WriterWantRead; + self.enable_socket_events(true, false, event_loop); + Ok(0) + }, + Err(NonblockingSslError::WantWrite) => { + self.state = ArrowStreamState::WriterWantWrite; + self.enable_socket_events(false, true, event_loop); + Ok(0) + }, + other => { + self.state = ArrowStreamState::Ok; + self.enable_socket_events(true, true, event_loop); + Ok(try!(other)) + } + } + } + + /// Check if the underlaying socket is ready to read. + fn can_read(&self, event_set: EventSet) -> bool { + match self.state { + ArrowStreamState::Ok => event_set.is_readable(), + ArrowStreamState::ReaderWantRead => event_set.is_readable(), + ArrowStreamState::ReaderWantWrite => event_set.is_writable(), + _ => false + } + } + + /// Check if the underlaying socket is ready to write. + fn can_write(&self, event_set: EventSet) -> bool { + match self.state { + ArrowStreamState::Ok => event_set.is_writable(), + ArrowStreamState::WriterWantRead => event_set.is_readable(), + ArrowStreamState::WriterWantWrite => event_set.is_writable(), + _ => false + } + } + + fn take_socket_error(&self) -> io::Result<()> { + self.stream.get_ref() + .take_socket_error() + } +} + +/// TCP stream abstraction for ignoring EWOULDBLOCKs. +struct ServiceStream { + /// TCP stream. + stream: TcpStream, +} + +impl ServiceStream { + /// Connect to a given TCP socket address. + fn connect(addr: &SocketAddr) -> io::Result { + let stream = try!(TcpStream::connect(addr)); + let res = ServiceStream { + stream: stream + }; + + Ok(res) + } + + /// Get reference to the underlaying TCP stream. + fn get_ref(&self) -> &TcpStream { + &self.stream + } + + /// Take error from the underlaying TCP stream. + fn take_socket_error(&self) -> io::Result<()> { + self.stream.take_socket_error() + } +} + +impl Read for ServiceStream { + /// Read data from the underlaying socket (EWOULDBLOCK is silently + /// ignored). + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.stream.read(buf) { + Err(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(0), + other => other + } + } +} + +impl Write for ServiceStream { + /// Write data into the underlaying socket (EWOULDBLOCK is silently + /// ignored). + fn write(&mut self, buf: &[u8]) -> io::Result { + match self.stream.write(buf) { + Err(ref err) if err.kind() == ErrorKind::WouldBlock => Ok(0), + other => other + } + } + + /// Flush buffered data into the underlaying socket (EWOULDBLOCK is not + /// ignored in this case). + fn flush(&mut self) -> io::Result<()> { + self.stream.flush() + } +} + +/// External service session context. +/// +/// This struct holds connection to an external service (e.g. RTSP) and +/// its I/O buffers. +struct SessionContext { + /// Logger. + #[allow(dead_code)] + logger: L, + /// Service ID. + service_id: u16, + /// Session ID. + session_id: u32, + /// TCP stream. + stream: ServiceStream, + /// Input buffer. + input_buffer: WriteBuffer, + /// Output buffer. + output_buffer: WriteBuffer, + /// Read buffer. + read_buffer: Box<[u8]>, + /// Write timeout. + write_tout: Timeout, +} + +impl SessionContext { + /// Create a new session context for a given session ID and service + /// address. + fn new( + logger: L, + service_id: u16, + session_id: u32, + addr: &SocketAddr, + event_loop: &mut EventLoop) -> Result> { + let stream = try!(ServiceStream::connect(addr)); + + register_socket(session2token(session_id), stream.get_ref(), + true, true, event_loop); + + let res = SessionContext { + logger: logger, + service_id: service_id, + session_id: session_id, + stream: stream, + input_buffer: WriteBuffer::new(256 * 1024), + output_buffer: WriteBuffer::new(0), + read_buffer: Box::new([0u8; 32768]), + write_tout: Timeout::new() + }; + + Ok(res) + } + + /// Dispose resources held by this object. + fn dispose(&self, event_loop: &mut EventLoop) { + deregister_socket(self.stream.get_ref(), event_loop); + } + + /// Enable/disable notifications for the underlaying socket. + fn update_socket_events( + &mut self, + event_loop: &mut EventLoop) { + let readable = !self.input_buffer.is_full(); + let writable = !self.output_buffer.is_empty(); + reregister_socket( + session2token(self.session_id), + self.stream.get_ref(), + readable, writable, event_loop); + } + + /// Process a given set of socket events and return size of the input + /// buffer or None in case the connection has been closed. + fn socket_ready( + &mut self, + event_loop: &mut EventLoop, + event_set: EventSet) -> Result> { + try!(self.check_read_event(event_loop, event_set)); + try!(self.check_write_event(event_loop, event_set)); + + if event_set.is_error() { + let err = self.get_socket_error() + .ok_or(ArrowError::from("socket error expected")); + Err(try!(err)) + } else if event_set.is_hup() { + Ok(None) + } else { + Ok(Some(self.input_buffer.buffered())) + } + } + + /// Read a message if the underlaying socket is readable and the input + /// buffer is not already full. + fn check_read_event( + &mut self, + event_loop: &mut EventLoop, + event_set: EventSet) -> Result<()> { + if event_set.is_readable() { + if self.input_buffer.is_full() { + self.update_socket_events(event_loop); + } else { + let buffer = &mut *self.read_buffer; + let len = try!(self.stream.read(buffer)); + self.input_buffer.write_all(&buffer[..len]) + .unwrap(); + + //log_debug!(self.logger, &format!("{} bytes read from session socket {:08x} (buffer size: {})", len, self.session_id, self.input_buffer.buffered())); + } + } + + Ok(()) + } + + /// Write data from the output buffer into the underlaying socket if the + /// socket is writable. + fn check_write_event( + &mut self, + event_loop: &mut EventLoop, + event_set: EventSet) -> Result<()> { + if event_set.is_writable() { + if self.output_buffer.is_empty() { + self.update_socket_events(event_loop); + self.write_tout.clear(); + } else { + let len = try!(self.stream.write( + self.output_buffer.as_bytes())); + + if len > 0 { + //log_debug!(self.logger, &format!("{} bytes written into session socket {:08x} (buffer size: {})", len, self.session_id, self.output_buffer.buffered())); + self.output_buffer.drop(len); + self.write_tout.set(CONNECTION_TIMEOUT); + } + } + } + + Ok(()) + } + + /// Get socket error. + fn get_socket_error(&self) -> Option { + let err = self.stream.take_socket_error(); + match err.err() { + Some(err) => Some(ArrowError::from(err)), + None => None + } + } + + /// Check if there are some data in the input buffer. + fn input_ready(&self) -> bool { + !self.input_buffer.is_empty() + } + + /// Get buffered input data. + fn input_buffer(&self) -> &[u8] { + self.input_buffer.as_bytes() + } + + /// Drop a given number of bytes from the input buffer. + fn drop_input_bytes( + &mut self, + count: usize, + event_loop: &mut EventLoop) { + let was_full = self.input_buffer.is_full(); + + self.input_buffer.drop(count); + + if was_full && !self.input_buffer.is_full() { + self.update_socket_events(event_loop); + } + } + + /// Send a given message. + fn send_message( + &mut self, + data: &[u8], + event_loop: &mut EventLoop) { + let was_empty = self.output_buffer.is_empty(); + + self.output_buffer.write_all(data) + .unwrap(); + + if was_empty { + self.write_tout.set(CONNECTION_TIMEOUT); + self.update_socket_events(event_loop); + } + } +} + +/// Convert a given session ID into a token (socket) ID. +fn session2token(session_id: u32) -> usize { + assert!(mem::size_of::() >= 4); + (session_id as usize) | (1 << 24) +} + +/// Convert a given token (socket) ID into a session ID. +fn token2session(token_id: usize) -> u32 { + assert!(mem::size_of::() >= 4); + let mask = ((1 as usize) << 24) - 1; + assert!((token_id & !mask) == (1 << 24)); + (token_id & mask) as u32 +} + +/// Arrow Protocol states. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +enum ProtocolState { + Handshake, + Established +} + +type SocketEventResult = Result>; + +const UPDATE_CHECK_PERIOD: u64 = 5000; +const TIMEOUT_CHECK_PERIOD: u64 = 1000; +const PING_PERIOD: u64 = 60000; + +const CONNECTION_TIMEOUT: u64 = 20000; + +/// Arrow client connection handler. +struct ConnectionHandler> { + /// Application logger. + logger: L, + /// Shared application context. + app_context: Shared, + /// Channel for sending Arrow Commands. + cmd_sender: Q, + /// SSL/TLS connection to a remote Arrow Service. + stream: ArrowStream, + /// Session contexts. + sessions: HashMap>, + /// Session read queue. + session_queue: VecDeque, + /// Buffer for reading Arrow Protocol requests. + read_buffer: Box<[u8]>, + /// Buffer for writing Arrow Protocol responses. + write_buffer: Box<[u8]>, + /// Parser for requests received from Arrow Service. + req_parser: ArrowMessageParser, + /// Output buffer for messages to be passed to Arrow Service. + output_buffer: WriteBuffer, + /// Arrow Client result returned after the connection shut down. + result: Option>, + /// Protocol state. + state: ProtocolState, + /// Version of the last sent service table. + last_update: Option, + /// Write timeout. + write_tout: Timeout, + /// ACK timeout. + ack_tout: Timeout, + /// Current Control Message ID. + msg_id: u16, + /// Expected ACKs. + expected_acks: VecDeque, +} + +impl> ConnectionHandler { + /// Create a new connection handler. + fn new( + logger: L, + s: S, + cmd_sender: Q, + addr: &SocketAddr, + arrow_mac: &MacAddr, + app_context: Shared, + event_loop: &mut EventLoop) -> Result { + let stream = try!(ArrowStream::connect(s, addr, 0, event_loop)); + + let mut res = ConnectionHandler { + logger: logger, + app_context: app_context, + cmd_sender: cmd_sender, + stream: stream, + sessions: HashMap::new(), + session_queue: VecDeque::new(), + read_buffer: Box::new([0u8; 32768]), + write_buffer: Box::new([0u8; 16384]), + req_parser: ArrowMessageParser::new(), + output_buffer: WriteBuffer::new(256 * 1024), + result: None, + state: ProtocolState::Handshake, + last_update: None, + write_tout: Timeout::new(), + ack_tout: Timeout::new(), + msg_id: 0, + expected_acks: VecDeque::new() + }; + + res.create_register_request(arrow_mac, event_loop); + + // start timeout checker: + event_loop.timeout_ms(TimerEvent::TimeoutCheck(0), TIMEOUT_CHECK_PERIOD) + .unwrap(); + + Ok(res) + } + + /// Get session context for a given session ID. + fn get_session_context( + &self, + session_id: u32) -> Option<&SessionContext> { + self.sessions.get(&session_id) + } + + /// Get session context for a given session ID. + fn get_session_context_mut( + &mut self, + session_id: u32) -> Option<&mut SessionContext> { + self.sessions.get_mut(&session_id) + } + + /// Create a new session context for a given service and session IDs. + fn create_session_context( + &mut self, + service_id: u16, + session_id: u32, + event_loop: &mut EventLoop) -> Option<&mut SessionContext> { + if !self.sessions.contains_key(&session_id) { + let app_context = self.app_context.lock() + .unwrap(); + let config = &app_context.config; + if let Some(svc) = config.get(service_id) { + if let Some(addr) = svc.address() { + log_info!(self.logger, &format!("connecting to remote service: {}, session ID: {:08x}", addr, session_id)); + match SessionContext::new(self.logger.clone(), + service_id, session_id, addr, event_loop) { + Err(err) => log_warn!(self.logger, &format!("unable to open connection to a remote service: {}", err.description())), + Ok(ctx) => { + let token_id = session2token(session_id); + let tevent = TimerEvent::TimeoutCheck(token_id); + self.sessions.insert(session_id, ctx); + self.session_queue.push_back(session_id); + event_loop.timeout_ms(tevent, TIMEOUT_CHECK_PERIOD) + .unwrap(); + } + } + } else { + log_warn!(self.logger, "requested service ID belongs to a Control Protocol service"); + } + } else { + log_warn!(self.logger, &format!("non-existing service requested (service ID: {})", service_id)); + } + } + + self.sessions.get_mut(&session_id) + } + + /// Remove session context with a given session ID. + fn remove_session_context( + &mut self, + session_id: u32, + event_loop: &mut EventLoop) { + if let Some(ctx) = self.sessions.remove(&session_id) { + ctx.dispose(event_loop); + } + } + + /// Create a new REGISTER request. + fn create_register_request( + &mut self, + arrow_mac: &MacAddr, + event_loop: &mut EventLoop) { + let control_msg = { + let app_context = self.app_context.lock() + .unwrap(); + let config = &app_context.config; + let msg = RegisterMessage::new( + config.uuid(), + arrow_mac.octets(), + config.password(), + config.service_table()); + let control_msg = control::create_register_message(self.msg_id, + msg); + self.last_update = Some(config.version()); + self.msg_id += 1; + control_msg + }; + + log_debug!(self.logger, "sending REGISTER request..."); + + self.send_unconfirmed_control_message(control_msg, event_loop); + } + + /// Send an update message (if needed) and schedule the next update event. + fn send_update_message( + &mut self, + svc_table: ServiceTable, + event_loop: &mut EventLoop) { + let control_msg = control::create_update_message(self.msg_id, + svc_table); + + self.msg_id += 1; + + log_debug!(self.logger, "sending an UPDATE message..."); + + self.send_control_message(control_msg, event_loop); + } + + /// Send the PING message and schedule the next PING event. + fn send_ping_message(&mut self, event_loop: &mut EventLoop) { + let control_msg = control::create_ping_message(self.msg_id); + + self.msg_id += 1; + + log_debug!(self.logger, "sending a PING message..."); + + self.send_unconfirmed_control_message(control_msg, event_loop); + } + + /// Send HUP message for a given session ID. + fn send_hup_message( + &mut self, + session_id: u32, + error_code: u32, + event_loop: &mut EventLoop) { + let control_msg = control::create_hup_message(self.msg_id, + session_id, error_code); + + self.msg_id += 1; + + log_debug!(self.logger, "sending a HUP message..."); + + self.send_control_message(control_msg, event_loop); + } + + /// Send status message for a given request ID. + fn send_status( + &mut self, + request_id: u16, + event_loop: &mut EventLoop) { + let active_sessions = self.sessions.len() as u32; + let mut status_flags = 0; + + { + let app_context = self.app_context.lock() + .unwrap(); + + if app_context.scanning { + status_flags |= control::STATUS_FLAG_SCAN; + } + } + + let status_msg = StatusMessage::new(request_id, + status_flags, active_sessions); + let control_msg = control::create_status_message(self.msg_id, + status_msg); + + self.msg_id += 1; + + log_debug!(self.logger, "sending a STATUS message..."); + + self.send_control_message(control_msg, event_loop); + } + + /// Send ACK message with a given message id and error code. + fn send_ack_message( + &mut self, + msg_id: u16, + error_code: u32, + event_loop: &mut EventLoop) { + let control_msg = control::create_ack_message(msg_id, error_code); + + log_debug!(self.logger, "sending and ACK message..."); + + self.send_control_message(control_msg, event_loop); + } + + /// Send a given Control protocol message. + fn send_control_message( + &mut self, + control_msg: ControlMessage, + event_loop: &mut EventLoop) { + let arrow_msg = ArrowMessage::new(0, 0, control_msg); + self.send_message(&arrow_msg, event_loop); + } + + /// Send a given Control Protocol message which needs to be confirmed by + // ACK. + fn send_unconfirmed_control_message( + &mut self, + control_msg: ControlMessage, + event_loop: &mut EventLoop) { + if self.expected_acks.is_empty() { + self.ack_tout.set(CONNECTION_TIMEOUT); + } + + let msg_id = control_msg.header() + .msg_id; + + self.expected_acks.push_back(msg_id); + + self.send_control_message(control_msg, event_loop); + } + + /// Send a given Arrow Message. + fn send_message( + &mut self, + arrow_msg: &ArrowMessage, + event_loop: &mut EventLoop) { + if self.output_buffer.is_empty() { + self.write_tout.set(CONNECTION_TIMEOUT); + } + + arrow_msg.serialize(&mut self.output_buffer) + .unwrap(); + + self.stream.enable_socket_events(true, true, event_loop); + } + + /// Check if the service table has been updated and send an UPDATE message + /// if needed. + fn check_update(&mut self, event_loop: &mut EventLoop) { + let cur_version; + let svc_table; + + { + let app_context = self.app_context.lock() + .unwrap(); + let config = &app_context.config; + cur_version = config.version(); + svc_table = config.service_table(); + } + + let send_update = match self.last_update { + Some(sent_version) => cur_version > sent_version, + None => true + }; + + if send_update { + self.send_update_message(svc_table, event_loop); + self.last_update = Some(cur_version); + } + } + + /// Check if the service table has been updated and send an UPDATE message + /// if needed. + fn te_check_update( + &mut self, + event_loop: &mut EventLoop) -> Result<()> { + self.check_update(event_loop); + + event_loop.timeout_ms(TimerEvent::Update, UPDATE_CHECK_PERIOD) + .unwrap(); + + Ok(()) + } + + /// Periodical connection check. + fn te_check_connection( + &mut self, + event_loop: &mut EventLoop) -> Result<()> { + self.send_ping_message(event_loop); + + event_loop.timeout_ms(TimerEvent::Ping, PING_PERIOD) + .unwrap(); + + Ok(()) + } + + /// Check connection timeout. + fn te_check_timeout( + &mut self, + token: usize, + event_loop: &mut EventLoop) -> Result<()> { + match token { + 0 => self.check_arrow_timeout(event_loop), + t => self.check_session_timeout(token2session(t), event_loop) + } + } + + /// Check connection timeout of the underlaying Arrow socket. + fn check_arrow_timeout( + &mut self, + event_loop: &mut EventLoop) -> Result<()> { + if !self.write_tout.check() || !self.ack_tout.check() { + Err(ArrowError::from("Arrow Service connection timeout")) + } else { + event_loop.timeout_ms(TimerEvent::TimeoutCheck(0), + TIMEOUT_CHECK_PERIOD).unwrap(); + + Ok(()) + } + } + + /// Check session communication timeout. + fn check_session_timeout( + &mut self, + session_id: u32, + event_loop: &mut EventLoop) -> Result<()> { + let mut timeout = false; + + if let Some(ctx) = self.get_session_context(session_id) { + timeout = !ctx.write_tout.check(); + } + + if timeout { + log_warn!(self.logger, &format!("session {} connection timeout", session_id)); + self.send_hup_message(session_id, 0, event_loop); + self.remove_session_context(session_id, event_loop); + } else { + event_loop.timeout_ms( + TimerEvent::TimeoutCheck(session2token(session_id)), + TIMEOUT_CHECK_PERIOD).unwrap(); + } + + Ok(()) + } + + /// Process all notifications for the underlaying TLS socket. + fn arrow_socket_ready( + &mut self, + event_loop: &mut EventLoop, + event_set: EventSet) -> SocketEventResult { + let res = try!(self.check_arrow_read_event(event_loop, event_set)); + if res.is_some() { + return Ok(res); + } + + let res = try!(self.check_arrow_write_event(event_loop, event_set)); + if res.is_some() { + return Ok(res); + } + + if event_set.is_error() { + let socket_err = self.stream.take_socket_error(); + Err(ArrowError::from(socket_err.unwrap_err())) + } else if event_set.is_hup() { + Err(ArrowError::from("connection to Arrow Service lost")) + } else { + Ok(None) + } + } + + /// Read a request/response chunk if the underlaying TLS socket is + /// readable. + fn check_arrow_read_event( + &mut self, + event_loop: &mut EventLoop, + event_set: EventSet) -> SocketEventResult { + if self.stream.can_read(event_set) { + self.read_request(event_loop) + } else { + Ok(None) + } + } + + /// Write a request/response chunk if the underlaying TLS socket is + /// writable. + fn check_arrow_write_event( + &mut self, + event_loop: &mut EventLoop, + event_set: EventSet) -> SocketEventResult { + if self.stream.can_write(event_set) { + self.send_response(event_loop) + } else { + Ok(None) + } + } + + /// Read request data from the underlaying TLS socket. + fn read_request( + &mut self, + event_loop: &mut EventLoop) -> SocketEventResult { + let mut consumed = 0; + + let len = try!(self.stream.read(&mut *self.read_buffer, event_loop)); + + //log_debug!(self.logger, &format!("{} bytes read from the Arrow socket", len)); + + while consumed < len { + consumed += try!(self.req_parser.add( + &self.read_buffer[consumed..len])); + if self.req_parser.is_complete() { + let redirect = try!(self.process_request(event_loop)); + if redirect.is_some() { + return Ok(redirect); + } + } + } + + Ok(None) + } + + /// Parse the last complete request. + /// + /// # Panics + /// If the last request has not been completed yet. + fn process_request( + &mut self, + event_loop: &mut EventLoop) -> SocketEventResult { + let service_id; + let session_id; + + if let Some(header) = self.req_parser.header() { + service_id = header.service; + session_id = header.session; + } else { + panic!("incomplete message") + } + + match service_id { + 0 => self.process_control_message(event_loop), + _ => self.process_service_request(service_id, session_id, + event_loop) + } + } + + /// Process a Control Protocol message. + fn process_control_message( + &mut self, + event_loop: &mut EventLoop) -> SocketEventResult { + let (header, body) = try!(self.parse_control_message()); + + log_debug!(self.logger, &format!("received control message: {:?}", header.message_type())); + + let res = match header.message_type() { + ControlMessageType::ACK => + self.process_ack_message(header.msg_id, &body, event_loop), + ControlMessageType::PING => + self.process_ping_message(header.msg_id, event_loop), + ControlMessageType::REDIRECT => + self.process_redirect_message(&body), + ControlMessageType::HUP => + self.process_hup_message(&body, event_loop), + ControlMessageType::RESET_SVC_TABLE => + self.process_command(Command::ResetServiceTable), + ControlMessageType::SCAN_NETWORK => + self.process_command(Command::ScanNetwork), + ControlMessageType::GET_STATUS => + self.process_status_request(header.msg_id, event_loop), + mt => Err(ArrowError::from(format!("cannot handle Control Protocol message type: {:?}", mt))) + }; + + self.req_parser.clear(); + + res + } + + /// Parse a Control Protocol message from the underlaying Arrow Message + /// parser. + fn parse_control_message(&self) -> Result<(ControlMessageHeader, Vec)> { + if let Some(body) = self.req_parser.body() { + let mut parser = ControlMessageParser::new(); + try!(parser.process(body)); + let header = parser.header(); + let body = parser.body(); + if header.message_type() == ControlMessageType::UNKNOWN { + Err(ArrowError::from("unknown Control Protocol message type")) + } else { + Ok((header.clone(), body.to_vec())) + } + } else { + panic!("incomplete message"); + } + } + + /// Process a Control Protocol ACK message. + fn process_ack_message( + &mut self, + msg_id: u16, + msg: &[u8], + event_loop: &mut EventLoop) -> SocketEventResult { + let expected_ack = self.expected_acks.pop_front(); + + if self.expected_acks.is_empty() { + self.ack_tout.clear(); + } else { + self.ack_tout.set(CONNECTION_TIMEOUT); + } + + if let Some(expected_ack) = expected_ack { + if msg_id == expected_ack { + if self.state == ProtocolState::Handshake { + self.process_handshake_ack(msg, event_loop) + } else { + Ok(None) + } + } else { + Err(ArrowError::from("unexpected ACK message ID")) + } + } else { + Err(ArrowError::from("no ACK message expected")) + } + } + + /// Process ACK response for the REGISTER command. + fn process_handshake_ack( + &mut self, + msg: &[u8], + event_loop: &mut EventLoop) -> SocketEventResult { + if self.state == ProtocolState::Handshake { + let ack = try!(control::parse_ack_message(msg)); + if ack == 0 { + // switch the protocol state into normal operation + self.state = ProtocolState::Established; + // start sending update messages + event_loop.timeout_ms(TimerEvent::Update, + UPDATE_CHECK_PERIOD).unwrap(); + // start sending PING messages + event_loop.timeout_ms(TimerEvent::Ping, + PING_PERIOD).unwrap(); + + Ok(None) + } else { + Err(ArrowError::from("Arrow REGISTER failed")) + } + } else { + panic!("unexpected protocol state"); + } + } + + /// Process a Control Protocol PING message. + fn process_ping_message( + &mut self, + msg_id: u16, + event_loop: &mut EventLoop) -> SocketEventResult { + if self.state == ProtocolState::Established { + self.send_ack_message(msg_id, 0, event_loop); + Ok(None) + } else { + Err(ArrowError::from("cannot handle PING message in the Handshake state")) + } + } + + /// Process a Control Protocol REDIRECT message. + fn process_redirect_message(&mut self, msg: &[u8]) -> SocketEventResult { + if self.state == ProtocolState::Established { + let ptr = msg.as_ptr(); + let cstr = unsafe { + CStr::from_ptr(ptr as *const _) + }; + + let addr = String::from_utf8_lossy(cstr.to_bytes()); + + Ok(Some(addr.to_string())) + } else { + Err(ArrowError::from("cannot handle REDIRECT message in the Handshake state")) + } + } + + /// Process a Control Protocol HUP message. + fn process_hup_message( + &mut self, + msg: &[u8], + event_loop: &mut EventLoop) -> SocketEventResult { + if self.state == ProtocolState::Established { + let msg = try!(HupMessage::from_bytes(msg)); + let session_id = msg.session_id; + // XXX: the HUP error code should be processed here + log_info!(self.logger, &format!("session {:08x} closed", session_id)); + self.remove_session_context(session_id, event_loop); + Ok(None) + } else { + Err(ArrowError::from("cannot handle HUP message in the Handshake state")) + } + } + + /// Send command using the underlaying command channel. + fn process_command(&mut self, cmd: Command) -> SocketEventResult { + match self.cmd_sender.send(cmd) { + Err(cmd) => log_warn!(self.logger, &format!("unable to process command {:?}", cmd)), + _ => () + } + + Ok(None) + } + + /// Process status request (GET_STATUS message) with a given ID. + fn process_status_request( + &mut self, + msg_id: u16, + event_loop: &mut EventLoop) -> SocketEventResult { + self.send_status(msg_id, event_loop); + Ok(None) + } + + /// Process request for a remote service. + fn process_service_request( + &mut self, + service_id: u16, + session_id: u32, + event_loop: &mut EventLoop) -> SocketEventResult { + if self.state == ProtocolState::Established { + let request = match self.req_parser.body() { + Some(body) => body.to_vec(), + None => panic!("incomplete message") + }; + + self.req_parser.clear(); + + let send_hup = match self.create_session_context( + service_id, session_id, event_loop) { + None => true, + Some(ctx) => { + ctx.send_message(&request, event_loop); + false + } + }; + + if send_hup { + self.send_hup_message(session_id, 1, event_loop); + } + + Ok(None) + } else { + Err(ArrowError::from("cannot handle service requests in the Handshake state")) + } + } + + /// Fill the Arrow Protocol output buffer with data from session input + /// buffers. + fn fill_output_buffer(&mut self, event_loop: &mut EventLoop) { + // using round robin alg. here in order to avoid session read + // starvation + let mut queue_size = self.session_queue.len(); + while queue_size > 0 && !self.output_buffer.is_full() { + if let Some(session_id) = self.session_queue.pop_front() { + if let Some(ctx) = self.sessions.get_mut(&session_id) { + // avoid sending empty packets + let len = if ctx.input_ready() { + let data = ctx.input_buffer(); + let len = cmp::min(32768, data.len()); + let arrow_msg = ArrowMessage::new( + ctx.service_id, ctx.session_id, + &data[..len]); + + if self.output_buffer.is_empty() { + self.write_tout.set(CONNECTION_TIMEOUT); + } + + arrow_msg.serialize(&mut self.output_buffer) + .unwrap(); + + len + } else { + 0 + }; + + ctx.drop_input_bytes(len, event_loop); + + self.session_queue.push_back(session_id); + + //log_debug!(self.logger, &format!("{} bytes moved from session {:08x} input buffer into the Arrow output buffer", len, session_id)); + } + } + + queue_size -= 1; + } + } + + /// Send response data using the underlaying TLS socket. + fn send_response( + &mut self, + event_loop: &mut EventLoop) -> SocketEventResult { + self.fill_output_buffer(event_loop); + + if self.output_buffer.is_empty() { + self.stream.enable_socket_events(true, false, event_loop); + self.write_tout.clear(); + } else { + let len = { + let data = self.output_buffer.as_bytes(); + let len = cmp::min(data.len(), self.write_buffer.len()); + let buffer = &mut self.write_buffer[..len]; + utils::memcpy(buffer, &data[..len]); + try!(self.stream.write(buffer, event_loop)) + }; + + if len > 0 { + //log_debug!(self.logger, &format!("{} bytes written into the Arrow socket", len)); + self.write_tout.set(CONNECTION_TIMEOUT); + self.output_buffer.drop(len); + } + } + + Ok(None) + } + + /// Process all notifications for a given remote session socket. + fn session_socket_ready( + &mut self, + session_id: u32, + event_loop: &mut EventLoop, + event_set: EventSet) -> SocketEventResult { + let res = match self.get_session_context_mut(session_id) { + Some(ctx) => ctx.socket_ready(event_loop, event_set), + None => Ok(Some(0)) + }; + + match res { + Err(err) => { + log_warn!(self.logger, &format!("service connection error: {}", err.description())); + self.send_hup_message(session_id, 2, event_loop); + self.remove_session_context(session_id, event_loop); + }, + Ok(None) => { + log_info!(self.logger, "service connection closed"); + self.send_hup_message(session_id, 0, event_loop); + self.remove_session_context(session_id, event_loop); + }, + Ok(Some(size)) if size > 0 => { + self.stream.enable_socket_events(true, true, event_loop); + }, + _ => () + } + + Ok(None) + } +} + +/// Types of epoll() timer events. +#[derive(Debug, Copy, Clone)] +enum TimerEvent { + Update, + Ping, + TimeoutCheck(usize), +} + +impl Handler for ConnectionHandler + where L: Logger + Clone, + Q: Sender { + type Timeout = TimerEvent; + type Message = (); + + /// Event loop handler method. + fn ready( + &mut self, + event_loop: &mut EventLoop, + token: Token, + event_set: EventSet) { + let res = match token { + Token(0) => self.arrow_socket_ready(event_loop, event_set), + Token(id) => self.session_socket_ready(token2session(id), + event_loop, event_set) + }; + + match res { + Ok(None) => (), + Ok(Some(redirect)) => self.result = Some(Ok(redirect)), + Err(err) => self.result = Some(Err(err)) + } + + if self.result.is_some() { + event_loop.shutdown(); + } + } + + /// Timer handler method. + fn timeout(&mut self, event_loop: &mut EventLoop, token: TimerEvent) { + let res = match token { + TimerEvent::Update => self.te_check_update(event_loop), + TimerEvent::Ping => self.te_check_connection(event_loop), + TimerEvent::TimeoutCheck(token) => + self.te_check_timeout(token, event_loop) + }; + + match res { + Err(err) => self.result = Some(Err(err)), + _ => () + } + + if self.result.is_some() { + event_loop.shutdown(); + } + } +} + +/// Arrow client. +pub struct ArrowClient> { + connection: ConnectionHandler, + event_loop: EventLoop>, +} + +impl> ArrowClient { + /// Create a new Arrow client. + pub fn new( + logger: L, + s: S, + cmd_sender: Q, + addr: &SocketAddr, + arrow_mac: &MacAddr, + app_context: Shared) -> Result { + let mut event_loop = try!(EventLoop::new()); + let connection = try!(ConnectionHandler::new( + logger, s, cmd_sender, + addr, arrow_mac, app_context, + &mut event_loop)); + + let res = ArrowClient { + connection: connection, + event_loop: event_loop + }; + + Ok(res) + } + + /// Connect to the remote Arrow Service and start listening for incoming + /// requests. Return error or redirect address in case the connection has + /// been shut down. + pub fn event_loop(&mut self) -> Result { + try!(self.event_loop.run(&mut self.connection)); + match self.connection.result { + Some(ref res) => res.clone(), + _ => panic!("result expected") + } + } +} diff --git a/src/net/arrow/protocol/control.rs b/src/net/arrow/protocol/control.rs new file mode 100644 index 0000000..72c625a --- /dev/null +++ b/src/net/arrow/protocol/control.rs @@ -0,0 +1,537 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Common Arrow Control Protocol definitions. + +use std::io; +use std::mem; + +use std::io::Write; + +use utils; + +use utils::Serialize; +use net::arrow::protocol::ArrowMessageBody; +use net::arrow::protocol::svc_table::ServiceTable; +use net::arrow::error::{ArrowError, Result}; + +/// Arrow Control Protocol message types. +#[allow(non_camel_case_types)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ControlMessageType { + ACK, + PING, + REGISTER, + REDIRECT, + UPDATE, + HUP, + RESET_SVC_TABLE, + SCAN_NETWORK, + GET_STATUS, + STATUS, + UNKNOWN, +} + +// message type constants +const CMSG_ACK: u16 = 0x0000; +const CMSG_PING: u16 = 0x0001; +const CMSG_REGISTER: u16 = 0x0002; +const CMSG_REDIRECT: u16 = 0x0003; +const CMSG_UPDATE: u16 = 0x0004; +const CMSG_HUP: u16 = 0x0005; +const CMSG_RESET_SVC_TABLE: u16 = 0x0006; +const CMSG_SCAN_NETWORK: u16 = 0x0007; +const CMSG_GET_STATUS: u16 = 0x0008; +const CMSG_STATUS: u16 = 0x0009; + +/// Common trait for Control Protocol payload types. +pub trait ControlMessageBody : Serialize { + /// Get body size in bytes. + fn len(&self) -> usize; +} + +/// Arrow Control Protocol message header. +#[derive(Debug, Copy, Clone)] +#[repr(packed)] +pub struct ControlMessageHeader { + /// Message ID. + pub msg_id: u16, + /// Message type. + msg_type: u16, +} + +impl ControlMessageHeader { + /// Create a new Control Protocol message with a given message ID and + /// message type. + fn new(msg_id: u16, msg_type: u16) -> ControlMessageHeader { + ControlMessageHeader { + msg_id: msg_id, + msg_type: msg_type + } + } + + /// Deserialize a Control Message header. + fn from_bytes(data: &[u8]) -> ControlMessageHeader { + assert_eq!(data.len(), mem::size_of::()); + let ptr = data.as_ptr() as *const ControlMessageHeader; + let header = unsafe { &*ptr }; + + ControlMessageHeader { + msg_id: u16::from_be(header.msg_id), + msg_type: u16::from_be(header.msg_type) + } + } + + /// Get message type. + pub fn message_type(&self) -> ControlMessageType { + match self.msg_type { + CMSG_ACK => ControlMessageType::ACK, + CMSG_PING => ControlMessageType::PING, + CMSG_REGISTER => ControlMessageType::REGISTER, + CMSG_REDIRECT => ControlMessageType::REDIRECT, + CMSG_UPDATE => ControlMessageType::UPDATE, + CMSG_HUP => ControlMessageType::HUP, + CMSG_RESET_SVC_TABLE => ControlMessageType::RESET_SVC_TABLE, + CMSG_SCAN_NETWORK => ControlMessageType::SCAN_NETWORK, + CMSG_GET_STATUS => ControlMessageType::GET_STATUS, + CMSG_STATUS => ControlMessageType::STATUS, + _ => ControlMessageType::UNKNOWN + } + } +} + +impl Serialize for ControlMessageHeader { + fn serialize(&self, w: &mut W) -> io::Result<()> { + let be_header = ControlMessageHeader { + msg_id: self.msg_id.to_be(), + msg_type: self.msg_type.to_be() + }; + + w.write_all(utils::as_bytes(&be_header)) + } +} + +/// Arrow Control protocol message. +#[derive(Debug, Clone)] +pub struct ControlMessage { + /// Message header. + header: ControlMessageHeader, + /// Message payload. + body: B, +} + +impl ControlMessage { + /// Create a new Control Protocol message with a given message ID, message + /// type and payload. + pub fn new(msg_id: u16, msg_type: u16, body: B) -> ControlMessage { + ControlMessage { + header: ControlMessageHeader::new(msg_id, msg_type), + body: body + } + } + + /// Get message header. + pub fn header(&self) -> &ControlMessageHeader { + &self.header + } +} + +impl Serialize for ControlMessage { + fn serialize(&self, w: &mut W) -> io::Result<()> { + try!(self.header.serialize(w)); + self.body.serialize(w) + } +} + +impl ArrowMessageBody for ControlMessage { + fn len(&self) -> usize { + mem::size_of::() + self.body.len() + } +} + +/// Create a new ACK message with a given message ID and error code. +pub fn create_ack_message(msg_id: u16, err: u32) -> ControlMessage { + ControlMessage::new(msg_id, CMSG_ACK, err) +} + +/// Create a new PING message with a given message ID. +pub fn create_ping_message(msg_id: u16) -> ControlMessage { + ControlMessage::new(msg_id, CMSG_PING, EmptyBody) +} + +/// Create a new REGISTER message for a given message ID and message body. +pub fn create_register_message( + msg_id: u16, + body: RegisterMessage) -> ControlMessage { + ControlMessage::new(msg_id, CMSG_REGISTER, body) +} + +/// Create a new UPDATE message for a given message ID and service table. +pub fn create_update_message( + msg_id: u16, + svc_table: ServiceTable) -> ControlMessage { + ControlMessage::new(msg_id, CMSG_UPDATE, svc_table.clone()) +} + +/// Create a new HUP message for a given message ID, session ID and error code. +pub fn create_hup_message( + msg_id: u16, + session_id: u32, + error_code: u32) -> ControlMessage { + ControlMessage::new(msg_id, CMSG_HUP, + HupMessage::new(session_id, error_code)) +} + +/// Create a new STATUS control message for a given message ID and message +/// body. +pub fn create_status_message( + msg_id: u16, + status_msg: StatusMessage) -> ControlMessage { + ControlMessage::new(msg_id, CMSG_STATUS, status_msg) +} + +/// Arrow Control Protocol message parser. +pub struct ControlMessageParser<'a> { + header: Option, + body: Option<&'a [u8]>, +} + +impl<'a> ControlMessageParser<'a> { + /// Create a new Control Protocol message parser. + pub fn new() -> ControlMessageParser<'a> { + ControlMessageParser { + header: None, + body: None + } + } + + /// Process given message data. + pub fn process(&mut self, data: &'a [u8]) -> Result<()> { + let header_size = mem::size_of::(); + if data.len() < header_size { + return Err(ArrowError::from("not enough data to parse an Arrow Control Protocol message")); + } + + let header_data = &data[..header_size]; + let body_data = &data[header_size..]; + let header = ControlMessageHeader::from_bytes(header_data); + + self.header = Some(header); + self.body = Some(body_data); + + Ok(()) + } + + /// Get message header of the last successfully parsed message. + pub fn header(&self) -> &ControlMessageHeader { + match self.header { + Some(ref header) => header, + None => panic!("no Control Protocol message has been processed yet") + } + } + + /// Get message body of the last successfully parsed message. + pub fn body(&self) -> &[u8] { + match self.body { + Some(ref body) => body, + None => panic!("no Control Protocol message has been processed yet") + } + } +} + +impl ControlMessageBody for u32 { + fn len(&self) -> usize { + mem::size_of::() + } +} + +/// Dummy type representing empty payload. +#[derive(Debug, Copy, Clone)] +pub struct EmptyBody; + +impl Serialize for EmptyBody { + fn serialize(&self, _: &mut W) -> io::Result<()> { + Ok(()) + } +} + +impl ControlMessageBody for EmptyBody { + fn len(&self) -> usize { + 0 + } +} + +/// REGISTER message header. +#[derive(Debug, Copy, Clone)] +#[repr(packed)] +pub struct RegisterMessageHeader { + /// Client identifier. + pub uuid: [u8; 16], + /// Client MAC address. + pub mac_addr: [u8; 6], + /// Client passphrase. + pub passwd: [u8; 16], +} + +impl RegisterMessageHeader { + /// Create a new REGISTER message header. + fn new( + uuid: [u8; 16], + mac_addr: [u8; 6], + passwd: [u8; 16]) -> RegisterMessageHeader { + RegisterMessageHeader { + uuid: uuid, + mac_addr: mac_addr, + passwd: passwd + } + } +} + +impl Serialize for RegisterMessageHeader { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(utils::as_bytes(self)) + } +} + +/// REGISTER message. +#[derive(Debug, Clone)] +pub struct RegisterMessage { + /// Message header. + header: RegisterMessageHeader, + /// Service table. + table: ServiceTable, +} + +impl RegisterMessage { + /// Create a new REGISTER message. + pub fn new( + uuid: [u8; 16], + mac_addr: [u8; 6], + passwd: [u8; 16], + svc_table: ServiceTable) -> RegisterMessage { + RegisterMessage { + header: RegisterMessageHeader::new(uuid, mac_addr, passwd), + table: svc_table + } + } + + /// Get message header. + pub fn header(&self) -> &RegisterMessageHeader { + &self.header + } + + /// Get service table. + pub fn service_table(&self) -> &ServiceTable { + &self.table + } +} + +impl Serialize for RegisterMessage { + fn serialize(&self, w: &mut W) -> io::Result<()> { + try!(self.header.serialize(w)); + self.table.serialize(w) + } +} + +impl ControlMessageBody for RegisterMessage { + fn len(&self) -> usize { + mem::size_of::() + self.table.len() + } +} + +/// HUP message. +#[derive(Debug, Copy, Clone)] +#[repr(packed)] +pub struct HupMessage { + /// Session ID (note: the upper 8 bits are reserved). + pub session_id: u32, + /// Error code. + pub error_code: u32, +} + +impl HupMessage { + /// Create a new HUP message for a given session ID and error code. + fn new(session_id: u32, error_code: u32) -> HupMessage { + HupMessage { + session_id: session_id & ((1 << 24) - 1), + error_code: error_code + } + } + + /// Parse a HUP message. + pub fn from_bytes(data: &[u8]) -> Result { + let msg_size = mem::size_of::(); + if data.len() != msg_size { + return Err(ArrowError::from("invalid size of an Arrow Control Protocol HUP message")); + } + + let ptr = data.as_ptr() as *const HupMessage; + let msg = unsafe { &*ptr }; + let res = HupMessage { + session_id: u32::from_be(msg.session_id), + error_code: u32::from_be(msg.error_code) + }; + + Ok(res) + } +} + +impl Serialize for HupMessage { + fn serialize(&self, w: &mut W) -> io::Result<()> { + let be_msg = HupMessage { + session_id: self.session_id.to_be(), + error_code: self.error_code.to_be() + }; + + w.write_all(utils::as_bytes(&be_msg)) + } +} + +impl ControlMessageBody for HupMessage { + fn len(&self) -> usize { + mem::size_of::() + } +} + +/// Status flag indicating that there is a network scan currently in progress. +pub const STATUS_FLAG_SCAN: u32 = 0x00000001; + +/// Status message. +#[derive(Debug, Copy, Clone)] +#[repr(packed)] +pub struct StatusMessage { + request_id: u16, + status_flags: u32, + active_sessions: u32, +} + +impl StatusMessage { + pub fn new( + request_id: u16, + status_flags: u32, + active_sessions: u32) -> StatusMessage { + StatusMessage { + request_id: request_id, + status_flags: status_flags, + active_sessions: active_sessions + } + } +} + +impl Serialize for StatusMessage { + fn serialize(&self, w: &mut W) -> io::Result<()> { + let be_msg = StatusMessage { + request_id: self.request_id.to_be(), + status_flags: self.status_flags.to_be(), + active_sessions: self.active_sessions.to_be() + }; + + w.write_all(utils::as_bytes(&be_msg)) + } +} + +impl ControlMessageBody for StatusMessage { + fn len(&self) -> usize { + mem::size_of::() + } +} + +/// Parse a given ACK message body and return the error code. +pub fn parse_ack_message(msg: &[u8]) -> Result { + if msg.len() == mem::size_of::() { + let ptr = msg.as_ptr() as *const u32; + let ack = unsafe { + u32::from_be(*ptr) + }; + + Ok(ack) + } else { + Err(ArrowError::from("incorrect Control Protocol ACK message length")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use utils::Serialize; + use net::utils::WriteBuffer; + use net::arrow::protocol::svc_table::ServiceTable; + + #[test] + fn test_control_msg_serialization() { + let ack_data = [0x56, 0x78, 0x00, 0x00, 0xab, 0xcd, 0xef, 0x00]; + let ping_data = [0x12, 0x34, 0x00, 0x01]; + let ack = create_ack_message(0x5678, 0xabcdef00); + let ping = create_ping_message(0x1234); + + let mut buf = WriteBuffer::new(0); + + ack.serialize(&mut buf).unwrap(); + + assert_eq!(&ack_data, buf.as_bytes()); + + buf.clear(); + + ping.serialize(&mut buf).unwrap(); + + assert_eq!(&ping_data, buf.as_bytes()); + } + + #[test] + fn test_control_msg_deserialization() { + let data = [0x56, 0x78, 0x00, 0x00, 0xab, 0xcd, 0xef, 0x00]; + let mut parser = ControlMessageParser::new(); + + parser.process(&data).unwrap(); + + let header = parser.header(); + + assert_eq!(header.msg_id, 0x5678); + assert_eq!(header.message_type(), ControlMessageType::ACK); + + let body = parser.body(); + + assert_eq!(body, &data[4..]); + } + + #[test] + fn test_register_msg_serialization() { + let data = [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 0, 0, + 0, 0, + 0, 0, 0, 0, 0, 0, + 4, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + 0]; + + let svc_table = ServiceTable::new(); + let register = RegisterMessage::new( + [1u8; 16], + [2u8; 6], + [3u8; 16], + svc_table); + + let mut buf = WriteBuffer::new(0); + + register.serialize(&mut buf).unwrap(); + + let data_bytes: &[u8] = &data; + + assert_eq!(data_bytes, buf.as_bytes()); + } +} diff --git a/src/net/arrow/protocol/mod.rs b/src/net/arrow/protocol/mod.rs new file mode 100644 index 0000000..04acae6 --- /dev/null +++ b/src/net/arrow/protocol/mod.rs @@ -0,0 +1,346 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Common Arrow Protocol definitions. + +pub mod control; +pub mod svc_table; + +pub use self::control::ControlMessage; +pub use self::control::ControlMessageHeader; +pub use self::control::ControlMessageBody; +pub use self::control::ControlMessageParser; +pub use self::control::ControlMessageType; + +pub use self::control::EmptyBody; + +pub use self::control::RegisterMessage; +pub use self::control::RegisterMessageHeader; + +pub use self::control::HupMessage; + +pub use self::control::StatusMessage; + +pub use self::svc_table::Service; +pub use self::svc_table::ServiceTable; + +use std::io; +use std::mem; + +use std::io::Write; + +use utils; + +use utils::Serialize; +use net::arrow::error::{Result, ArrowError}; + +/// Common trait for Arrow Message payload types. +pub trait ArrowMessageBody : Serialize { + /// Get body size in bytes. + fn len(&self) -> usize; +} + +/// Arrow Message header. +#[derive(Debug, Copy, Clone)] +#[repr(packed)] +pub struct ArrowMessageHeader { + /// Arrow Protocol major version. + pub version: u8, + /// Service ID. + pub service: u16, + /// Session ID (note: the upper 8 bits are reserved). + pub session: u32, + /// Payload size. + size: u32, +} + +impl ArrowMessageHeader { + /// Create a new Arrow Message header with a given service ID, session ID + /// and payload size. + fn new(service: u16, session: u32, size: u32) -> ArrowMessageHeader { + ArrowMessageHeader { + version: 0, + service: service, + session: session & ((1 << 24) - 1), + size: size + } + } + + /// Deserialize an Arrow Message header. + fn from_bytes(slice: &[u8]) -> Result { + assert_eq!(slice.len(), mem::size_of::()); + let ptr = slice.as_ptr() as *const ArrowMessageHeader; + let header = unsafe { &*ptr }; + + let res = ArrowMessageHeader { + version: header.version, + service: u16::from_be(header.service), + session: u32::from_be(header.session) & ((1 << 24) - 1), + size: u32::from_be(header.size) + }; + + if res.version == 0 { + Ok(res) + } else { + Err(ArrowError::from("unsupported Arrow Protocol version")) + } + } +} + +impl Serialize for ArrowMessageHeader { + fn serialize(&self, w: &mut W) -> io::Result<()> { + let be_header = ArrowMessageHeader { + version: self.version, + service: self.service.to_be(), + session: self.session.to_be(), + size: self.size.to_be() + }; + + w.write_all(utils::as_bytes(&be_header)) + } +} + +/// Arrow Message envelope. +#[derive(Debug, Clone)] +pub struct ArrowMessage { + /// Message header. + header: ArrowMessageHeader, + /// Payload. + body: B, +} + +impl ArrowMessage { + /// Create a new Arrow Message with a given service ID, session ID and + /// payload. + pub fn new(service: u16, session: u32, body: B) -> ArrowMessage { + ArrowMessage { + header: ArrowMessageHeader::new(service, session, 0), + body: body + } + } + + /// Get message header. + pub fn header(&self) -> &ArrowMessageHeader { + &self.header + } +} + +impl Serialize for ArrowMessage { + fn serialize(&self, w: &mut W) -> io::Result<()> { + let header = ArrowMessageHeader::new( + self.header.service, + self.header.session, + self.body.len() as u32); + + try!(header.serialize(w)); + + self.body.serialize(w) + } +} + +/// Arrow Message parser. +/// +/// This structure allows to read Arrow Messages from continuous streams. +pub struct ArrowMessageParser { + header: Option, + buffer: Vec, + expected: usize, +} + +impl ArrowMessageParser { + /// Create a new Arrow Message parser. + pub fn new() -> ArrowMessageParser { + ArrowMessageParser { + header: None, + buffer: Vec::new(), + expected: 0 + } + } + + /// Check if the last message is complete. + pub fn is_complete(&self) -> bool { + self.header.is_some() && self.expected == 0 + } + + /// Process a new chunk of data and return the number of bytes used. + pub fn add(&mut self, data: &[u8]) -> Result { + let mut consumed = 0; + + if self.header.is_none() { + consumed += try!(self.read_header(data)); + if let Some(header) = self.header { + self.expected = header.size as usize; + } + } + + if self.header.is_some() { + consumed += self.read_body(&data[consumed..]); + } + + Ok(consumed) + } + + /// Clear the last message and prepare the parser for a new one. + pub fn clear(&mut self) { + self.buffer.clear(); + + self.expected = 0; + self.header = None; + } + + /// Get last message header. + pub fn header(&self) -> Option<&ArrowMessageHeader> { + match self.header { + Some(ref header) => Some(header), + None => None + } + } + + /// Get last message body. + pub fn body(&self) -> Option<&[u8]> { + let header_size = mem::size_of::(); + if self.is_complete() { + Some(&self.buffer[header_size..]) + } else { + None + } + } + + /// Read header chunk. + fn read_header(&mut self, data: &[u8]) -> Result { + let size = mem::size_of::(); + let mut consumed = size - self.buffer.len(); + + if consumed > data.len() { + consumed = data.len(); + } + + let data = &data[..consumed]; + + self.buffer.extend(data.iter()); + + if size == self.buffer.len() { + self.header = Some(try!( + ArrowMessageHeader::from_bytes(&self.buffer))); + } + + Ok(consumed) + } + + /// Read body chunk. + fn read_body(&mut self, data: &[u8]) -> usize { + let mut consumed = self.expected; + + if consumed > data.len() { + consumed = data.len(); + } + + let data = &data[..consumed]; + + self.buffer.extend(data.iter()); + self.expected -= consumed; + + consumed + } +} + +impl ArrowMessageBody for Vec { + fn len(&self) -> usize { + Vec::::len(self) + } +} + +impl<'a> ArrowMessageBody for &'a [u8] { + fn len(&self) -> usize { + <[u8]>::len(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use utils::Serialize; + use net::utils::WriteBuffer; + + #[test] + fn test_message_serialization() { + let msg_data = [0x00, // version + 0x10, 0x22, // svc_id + 0x00, 0x34, 0x56, 0x78, // session_id + 0x00, 0x00, 0x00, 0x02, // body_size + 0xab, 0xcd]; // body + + let message = ArrowMessage::new(0x1022, 0x12345678, vec![0xab, 0xcd]); + + let mut buf = WriteBuffer::new(0); + + message.serialize(&mut buf).unwrap(); + + assert_eq!(&msg_data, buf.as_bytes()); + } + + #[test] + fn test_message_deserialization() { + let mut parser = ArrowMessageParser::new(); + let msg = [0x00, // version + 0x10, 0x22, // svc_id + 0x12, 0x34, 0x56, 0x78, // session_id + 0x00, 0x00, 0x00, 0x02, // body_size + 0xab, 0xcd]; // body + + assert_eq!(parser.is_complete(), false); + assert!(parser.header().is_none()); + assert!(parser.body().is_none()); + + assert_eq!(parser.add(&msg).unwrap(), msg.len()); + + assert_eq!(parser.is_complete(), true); + assert!(parser.header().is_some()); + assert!(parser.body().is_some()); + + { + let header = parser.header().unwrap(); + + assert_eq!(header.version, 0); + assert_eq!(header.service, 0x1022); + assert_eq!(header.session, 0x00345678); + } + + { + let body = parser.body().unwrap(); + + assert_eq!(body, &[0xab, 0xcd]); + } + + assert_eq!(parser.add(&msg).unwrap(), 0); + + parser.clear(); + + assert_eq!(parser.is_complete(), false); + assert!(parser.header().is_none()); + assert!(parser.body().is_none()); + + assert_eq!(parser.add(&msg[..11]).unwrap(), 11); + + assert_eq!(parser.is_complete(), false); + assert!(parser.header().is_some()); + assert!(parser.body().is_none()); + + assert_eq!(parser.add(&msg[11..]).unwrap(), 2); + + assert_eq!(parser.is_complete(), true); + assert!(parser.header().is_some()); + assert!(parser.body().is_some()); + } +} diff --git a/src/net/arrow/protocol/svc_table.rs b/src/net/arrow/protocol/svc_table.rs new file mode 100644 index 0000000..466aa1b --- /dev/null +++ b/src/net/arrow/protocol/svc_table.rs @@ -0,0 +1,500 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Service table definitions. + +use std::io; +use std::mem; + +use std::io::Write; +use std::str::FromStr; +use std::error::Error; +use std::collections::HashSet; +use std::net::{ToSocketAddrs, SocketAddr, SocketAddrV4, Ipv4Addr, Ipv6Addr}; + +use utils; + +use utils::Serialize; +use utils::config::ConfigError; +use net::raw::ether::MacAddr; +use net::arrow::protocol::control::ControlMessageBody; + +use rustc_serialize::{Decodable, Decoder, Encodable, Encoder}; + +const SVC_TYPE_CONTROL_PROTOCOL: u16 = 0x0000; +const SVC_TYPE_RTSP: u16 = 0x0001; +const SVC_TYPE_LOCKED_RTSP: u16 = 0x0002; + +/// Service Table item header. +#[derive(Debug, Copy, Clone)] +#[repr(packed)] +struct ServiceHeader { + svc_id: u16, + svc_type: u16, + mac_addr: [u8; 6], + ip_version: u8, + ip_addr: [u8; 16], + port: u16, +} + +impl ServiceHeader { + /// Create a new item header. + fn new( + svc_id: u16, + svc_type: u16, + haddr: &MacAddr, + saddr: &SocketAddr) -> ServiceHeader { + let ip_version = match saddr { + &SocketAddr::V4(_) => 4, + &SocketAddr::V6(_) => 6 + }; + + let ip_bytes = match saddr { + &SocketAddr::V4(ref addr) => Self::ipv4_bytes(addr.ip()), + &SocketAddr::V6(ref addr) => Self::ipv6_bytes(addr.ip()) + }; + + ServiceHeader { + svc_id: svc_id, + svc_type: svc_type, + mac_addr: haddr.octets(), + ip_version: ip_version, + ip_addr: ip_bytes, + port: saddr.port(), + } + } + + /// Get IPv6 bytes. + fn ipv6_bytes(addr: &Ipv6Addr) -> [u8; 16] { + let segments = addr.segments(); + let mut res = [0u8; 16]; + + for i in 0..segments.len() { + let segment = segments[i]; + let j = i << 1; + res[j] = (segment >> 8) as u8; + res[j + 1] = (segment & 0xff) as u8; + } + + res + } + + /// Get IPv4 bytes left-aligned and padded to 16 bytes. + fn ipv4_bytes(addr: &Ipv4Addr) -> [u8; 16] { + let octets = addr.octets(); + let mut res = [0u8; 16]; + + for i in 0..octets.len() { + res[i] = octets[i]; + } + + res + } +} + +impl Serialize for ServiceHeader { + fn serialize(&self, w: &mut W) -> io::Result<()> { + let be_header = ServiceHeader { + svc_id: self.svc_id.to_be(), + svc_type: self.svc_type.to_be(), + mac_addr: self.mac_addr, + ip_version: self.ip_version, + ip_addr: self.ip_addr, + port: self.port.to_be(), + }; + + w.write_all(utils::as_bytes(&be_header)) + } +} + +/// Service Table item. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub enum Service { + /// Control Protocol service. + ControlProtocol, + /// Remote RTSP service (mac, addr, path). + RTSP(MacAddr, SocketAddr, String), + /// Remote RTSP service requiring authorization (mac, addr). + LockedRTSP(MacAddr, SocketAddr), +} + +impl Service { + /// Get service address (in case it is a remote service). + pub fn address(&self) -> Option<&SocketAddr> { + match self { + &Service::ControlProtocol => None, + &Service::RTSP(_, ref addr, _) => Some(addr), + &Service::LockedRTSP(_, ref addr) => Some(addr) + } + } + + /// Serialize this Service Table item in-place. + fn serialize(&self, w: &mut W, id: u16) -> io::Result<()> { + match self { + &Service::ControlProtocol => Self::serialize_cp(w, id), + &Service::RTSP(ref mac, ref addr, ref path) => + Self::serialize_svc(w, id, SVC_TYPE_RTSP, mac, addr, path), + &Service::LockedRTSP(ref mac, ref addr) => + Self::serialize_svc(w, id, SVC_TYPE_LOCKED_RTSP, mac, addr, "") + } + } + + /// Serialize a control protocol service item in-place. + fn serialize_cp(w: &mut W, id: u16) -> io::Result<()> { + let saddr = SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(0, 0, 0, 0), 0)); + let haddr = MacAddr::new(0, 0, 0, 0, 0, 0); + let header = ServiceHeader::new( + id, SVC_TYPE_CONTROL_PROTOCOL, &haddr, &saddr); + + try!(header.serialize(w)); + + w.write_all(&[0u8]) + } + + /// Serialize a remote service item in-place. + fn serialize_svc( + w: &mut W, + svc_id: u16, + svc_type: u16, + haddr: &MacAddr, + saddr: &SocketAddr, + path: &str) -> io::Result<()> { + let header = ServiceHeader::new(svc_id, svc_type, haddr, saddr); + + try!(header.serialize(w)); + + try!(w.write_all(path.as_bytes())); + + w.write_all(&[0u8]) + } + + /// Get size of this Service Table item in bytes. + fn len(&self) -> usize { + match self { + &Service::ControlProtocol => Self::cp_len(), + &Service::RTSP(_, _, ref path) => Self::svc_len(path), + &Service::LockedRTSP(_, _) => Self::svc_len("") + } + } + + /// Get size of a control protocol service item in bytes. + fn cp_len() -> usize { + mem::size_of::() + 1 + } + + /// Get size of a remote service item in bytes. + fn svc_len(path: &str) -> usize { + let path_bytes = path.as_bytes(); + mem::size_of::() + path_bytes.len() + 1 + } +} + +/// JSON mapping for a service. +#[derive(Debug, Clone, RustcDecodable, RustcEncodable)] +struct JsonService { + svc_type: u16, + mac: String, + address: String, + path: String, +} + +impl JsonService { + /// Create a new JsonService instance. + fn new( + svc_type: u16, + mac: String, + address: String, + path: String) -> JsonService { + JsonService { + svc_type: svc_type, + mac: mac, + address: address, + path: path + } + } + + /// Transform this service description into a service object. + fn into_service(self) -> Result { + match self.svc_type { + SVC_TYPE_CONTROL_PROTOCOL => Ok(Service::ControlProtocol), + SVC_TYPE_RTSP => Ok(Service::RTSP( + try!(MacAddr::from_str(&self.mac)), + try!(parse_socket_addr(&self.address)), self.path)), + SVC_TYPE_LOCKED_RTSP => Ok(Service::LockedRTSP( + try!(MacAddr::from_str(&self.mac)), + try!(parse_socket_addr(&self.address)))), + _ => Err(ConfigError::from("unknown service type")) + } + } +} + +impl<'a> From<&'a Service> for JsonService { + fn from(svc: &Service) -> JsonService { + match svc { + &Service::ControlProtocol => JsonService::new( + SVC_TYPE_CONTROL_PROTOCOL, + String::new(), String::new(), String::new()), + &Service::RTSP(ref mac, ref addr, ref path) => JsonService::new( + SVC_TYPE_RTSP, + format!("{}", mac), format!("{}", addr), path.clone()), + &Service::LockedRTSP(ref mac, ref addr) => JsonService::new( + SVC_TYPE_LOCKED_RTSP, + format!("{}", mac), format!("{}", addr), String::new()) + } + } +} + +/// Service Table. +#[derive(Debug, Clone)] +pub struct ServiceTable { + services: Vec, + set: HashSet, +} + +impl ServiceTable { + /// Create a new Service Table containing only a single Control Protocol + /// service. + pub fn new() -> ServiceTable { + ServiceTable { + services: Vec::new(), + set: HashSet::new() + } + } + + /// Check if there is a given service in the table. + pub fn contains(&self, svc: &Service) -> bool { + match svc { + &Service::ControlProtocol => true, + svc => self.set.contains(svc) + } + } + + /// Get service according to its ID. + pub fn get(&self, id: u16) -> Option { + if id == 0 { + Some(Service::ControlProtocol) + } else { + match self.services.get((id - 1) as usize) { + Some(svc) => Some(svc.clone()), + None => None + } + } + } + + /// Add a given service into the table in case it is not already there and + /// return the service ID, otherwise return None. + pub fn add(&mut self, svc: Service) -> Option { + if self.contains(&svc) { + None + } else { + self.services.push(svc.clone()); + self.set.insert(svc); + Some(self.services.len() as u16) + } + } + + /// Get vector of remote services in this configuration (i.e. without the + /// implicit Control Protocol service). + /// + /// The result is a vector of pairs. The first element is service ID, + /// the second element is the service itself. + pub fn services(&self) -> Vec<(u16, Service)> { + let mut res = Vec::new(); + for i in 0..self.services.len() { + let svc = &self.services[i]; + let id = i + 1; + res.push((id as u16, svc.clone())); + } + + res + } +} + +impl Serialize for ServiceTable { + fn serialize(&self, w: &mut W) -> io::Result<()> { + for i in 0..self.services.len() { + let svc = &self.services[i]; + let id = i + 1; + try!(svc.serialize(w, id as u16)); + } + + let cp_svc = Service::ControlProtocol; + + cp_svc.serialize(w, 0) + } +} + +impl ControlMessageBody for ServiceTable { + fn len(&self) -> usize { + let cp_svc = Service::ControlProtocol; + cp_svc.len() + self.services.iter() + .fold(0, |sum, svc| sum + svc.len()) + } +} + +impl Decodable for ServiceTable { + fn decode(d: &mut D) -> Result { + let table = try!(JsonServiceTable::decode(d)); + match table.into_service_table() { + Err(err) => Err(d.error(err.description())), + Ok(table) => Ok(table) + } + } +} + +impl Encodable for ServiceTable { + fn encode(&self, s: &mut S) -> Result<(), S::Error> { + let mut table = JsonServiceTable::new(); + for svc in &self.services { + table.add(JsonService::from(svc)); + } + + table.encode(s) + } +} + +/// JSON mapping for the ServiceTable. +#[derive(Debug, Clone, RustcDecodable, RustcEncodable)] +struct JsonServiceTable { + services: Vec, +} + +impl JsonServiceTable { + /// Create a new JsonServiceTable instance. + fn new() -> JsonServiceTable { + JsonServiceTable { + services: Vec::new() + } + } + + /// Add a new configuration entry. + fn add(&mut self, svc: JsonService) -> &mut Self { + self.services.push(svc); + self + } + + /// Transform this service table representation into a real service table. + fn into_service_table(self) -> Result { + let mut res = ServiceTable::new(); + for svc in self.services { + res.add(try!(svc.into_service())); + } + + Ok(res) + } +} + +/// Parse a socket address. +fn parse_socket_addr(addr: &str) -> Result { + try!(addr.to_socket_addrs()) + .next() + .ok_or(ConfigError::from("no socket address given")) +} + +#[cfg(test)] +mod tests { + use super::*; + use utils::Serialize; + use rustc_serialize::json; + use net::utils::WriteBuffer; + use net::raw::ether::MacAddr; + use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + use net::arrow::protocol::control::ControlMessageBody; + + #[test] + fn test_service_table() { + let mac = MacAddr::new(0, 0, 0, 0, 0, 0); + let addr = SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(1, 2, 3, 4), 5)); + let rtsp = Service::RTSP( + mac.clone(), addr.clone(), "/foo".to_string()); + let lrtsp = Service::LockedRTSP( + mac.clone(), addr.clone()); + let mut table = ServiceTable::new(); + + assert!(table.contains(&Service::ControlProtocol)); + assert!(!table.contains(&rtsp)); + assert!(!table.contains(&lrtsp)); + + assert_eq!(table.add(rtsp.clone()), Some(1)); + assert_eq!(table.add(lrtsp.clone()), Some(2)); + + assert!(table.contains(&rtsp)); + assert!(table.contains(&lrtsp)); + } + + #[test] + fn test_service_table_serialization() { + let data = [ + 0, 1, 0, 1, + 0, 0, 0, 0, 0, 0, + 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, + 47, 102, 111, 111, 0, + 0, 2, 0, 2, + 0, 0, 0, 0, 0, 0, + 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, + 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0]; + + let mac = MacAddr::new(0, 0, 0, 0, 0, 0); + let addr = SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(1, 2, 3, 4), 5)); + let rtsp = Service::RTSP( + mac.clone(), addr.clone(), "/foo".to_string()); + let lrtsp = Service::LockedRTSP( + mac.clone(), addr.clone()); + let mut table = ServiceTable::new(); + + table.add(rtsp); + table.add(lrtsp); + + let mut buf = WriteBuffer::new(0); + + table.serialize(&mut buf).unwrap(); + + let data_bytes: &[u8] = &data; + + assert_eq!(data_bytes, buf.as_bytes()); + } + + #[test] + fn test_service_table_json_serialization() { + let mac = MacAddr::new(0, 0, 0, 0, 0, 0); + let addr = SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(1, 2, 3, 4), 5)); + let rtsp = Service::RTSP( + mac.clone(), addr.clone(), "/foo".to_string()); + let lrtsp = Service::LockedRTSP( + mac.clone(), addr.clone()); + let mut table = ServiceTable::new(); + + table.add(rtsp.clone()); + table.add(lrtsp.clone()); + + let json = json::encode(&table).unwrap(); + let table = json::decode::(&json).unwrap(); + + assert!(table.contains(&rtsp)); + assert!(table.contains(&lrtsp)); + assert!(table.contains(&Service::ControlProtocol)); + + let services = table.services(); + + assert_eq!(services.len(), 2); + } +} diff --git a/src/net/discovery.rs b/src/net/discovery.rs new file mode 100644 index 0000000..d74c32e --- /dev/null +++ b/src/net/discovery.rs @@ -0,0 +1,277 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///! Network scanner for RTSP streams. + +use std::io; +use std::fmt; +use std::thread; +use std::result; + +use std::fs::File; +use std::sync::Arc; +use std::error::Error; +use std::io::{BufReader, BufRead}; +use std::fmt::{Display, Formatter}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + +use net::rtsp; +use net::raw::pcap; + +use net::rtsp::Client as RtspClient; +use net::raw::devices::EthernetDevice; +use net::raw::ether::MacAddr; +use net::arrow::protocol::Service; +use net::raw::arp::scanner::Ipv4ArpScanner; +use net::raw::tcp::scanner::{TcpPortScanner, PortCollection}; + +static RTSP_PATH_FILE: &'static str = "/etc/arrow/rtsp-paths"; + +/// Discovery error. +#[derive(Debug, Clone)] +pub struct DiscoveryError { + msg: String, +} + +impl Error for DiscoveryError { + fn description(&self) -> &str { + &self.msg + } +} + +impl Display for DiscoveryError { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + f.write_str(&self.msg) + } +} + +impl From for DiscoveryError { + fn from(msg: String) -> DiscoveryError { + DiscoveryError { msg: msg } + } +} + +impl<'a> From<&'a str> for DiscoveryError { + fn from(msg: &'a str) -> DiscoveryError { + DiscoveryError { msg: msg.to_string() } + } +} + +impl From for DiscoveryError { + fn from(err: rtsp::RtspError) -> DiscoveryError { + DiscoveryError::from(format!("RTSP client error: {}", + err.description())) + } +} + +impl From for DiscoveryError { + fn from(err: pcap::PcapError) -> DiscoveryError { + DiscoveryError::from(format!("pcap error: {}", err.description())) + } +} + +impl From for DiscoveryError { + fn from(err: io::Error) -> DiscoveryError { + DiscoveryError::from(format!("IO error: {}", err.description())) + } +} + +/// Discovery result type alias. +pub type Result = result::Result; + +/// Discovery host type alias. +type Host = (MacAddr, Ipv4Addr); +/// Discovery service type alias. +type Socket = (MacAddr, SocketAddrV4); + +/// Find all RTSP streams in all local networks. +pub fn find_rtsp_streams() -> Result> { + let tc = pcap::new_threading_context(); + let devices = EthernetDevice::list(); + + let port_candidates = PortCollection::new() + .add(1..1024) + .add(8000..9000); + + let mut threads = Vec::new(); + + for dev in devices { + let pc = port_candidates.clone(); + let tc = tc.clone(); + let handle = thread::spawn(move || { + find_services(tc, &dev, &pc) + }); + + threads.push(handle); + } + + let mut services = Vec::new(); + + for handle in threads { + match handle.join() { + Err(_) => return Err(DiscoveryError::from("port scanner thread panicked")), + Ok(res) => services.extend(try!(res)) + } + } + + let rtsp_services = try!(find_rtsp_services(&services)); + + let mut threads = Vec::new(); + let paths = Arc::new(try!(load_rtsp_paths(RTSP_PATH_FILE))); + + for (mac, addr) in rtsp_services { + let addr = SocketAddr::V4(addr); + let paths = paths.clone(); + let handle = thread::spawn(move || { + find_rtsp_paths(mac, addr, &paths) + }); + threads.push(handle); + } + + let mut res = Vec::new(); + + for handle in threads { + match handle.join() { + Err(_) => return Err(DiscoveryError::from( + "path testing thread panicked")), + Ok(svc) => res.extend(try!(svc)) + } + } + + Ok(res) +} + +/// Load all known RTSP path variants from a given file. +fn load_rtsp_paths(file: &str) -> Result> { + let file = try!(File::open(file)); + let breader = BufReader::new(file); + let mut paths = Vec::new(); + + for line in breader.lines() { + let path = try!(line); + if !path.starts_with('#') { + paths.push(path); + } + } + + Ok(paths) +} + +/// Check if a given service is an RTSP service. +fn is_rtsp_service(addr: SocketAddr) -> Result { + let mut client = try!(RtspClient::new(addr)); + client.set_timeout(Some(1000)); + Ok(client.options().is_ok()) +} + +/// Get describe status code for a given RTSP service and path. +fn get_describe_status(addr: SocketAddr, path: &str) -> Result> { + let mut client = try!(RtspClient::new(addr)); + client.set_timeout(Some(1000)); + if let Ok(response) = client.describe(path) { + let header = response.header; + let hipcam = match header.get_str("Server") { + Some("HiIpcam/V100R003 VodServer/1.0.0") => true, + Some("Hipcam RealServer/V1.0") => true, + _ => false + }; + + if hipcam && path != "/11" && path != "/12" { + Ok(None) + } else { + Ok(Some(header.code)) + } + } else { + Ok(None) + } +} + +/// Find open ports on all available hosts within a given network and port +/// range. +fn find_services( + pc: pcap::ThreadingContext, + device: &EthernetDevice, + ports: &PortCollection) -> Result> { + let hosts = try!(Ipv4ArpScanner::scan_device(pc.clone(), device)); + let res = try!(find_open_ports(pc, device, + hosts.into_iter(), ports)); + + Ok(res) +} + +/// Check if any of given TCP ports is open on on any host from a given set. +fn find_open_ports>( + pc: pcap::ThreadingContext, + device: &EthernetDevice, + hosts: HI, + ports: &PortCollection) -> Result> { + let res = try!(TcpPortScanner::scan_ipv4_hosts(pc, device, hosts, ports)) + .into_iter() + .map(|(mac, ip, p)| (mac, SocketAddrV4::new(ip, p))) + .collect::>(); + + Ok(res) +} + +/// Find all RTSP services among a given set of sockets. +fn find_rtsp_services(sockets: &[Socket]) -> Result> { + let mut threads = Vec::new(); + let mut res = Vec::new(); + + for &(mac, addr) in sockets { + let handle = thread::spawn(move || { + (mac, addr, is_rtsp_service(SocketAddr::V4(addr))) + }); + threads.push(handle); + } + + for handle in threads { + match handle.join() { + Err(_) => return Err(DiscoveryError::from("RTSP service testing thread panicked")), + Ok((mac, addr, rtsp)) => { + if try!(rtsp) { + res.push((mac, addr)) + } + } + } + } + + Ok(res) +} + +/// Find all available RTSP paths for a given RTSP service. +fn find_rtsp_paths( + mac: MacAddr, + addr: SocketAddr, + paths: &[String]) -> Result> { + let mut res = Vec::new(); + let mut locked = false; + for path in paths { + match try!(get_describe_status(addr, path)) { + Some(200) => res.push(path.to_string()), + Some(401) => locked = true, + _ => () + } + } + + let res = if locked { + vec![Service::LockedRTSP(mac, addr)] + } else { + res.into_iter() + .map(|path| Service::RTSP(mac, addr, path)) + .collect::>() + }; + + Ok(res) +} diff --git a/src/net/mod.rs b/src/net/mod.rs new file mode 100644 index 0000000..60dc6e5 --- /dev/null +++ b/src/net/mod.rs @@ -0,0 +1,23 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(feature = "discovery")] +pub mod rtsp; + +#[cfg(feature = "discovery")] +pub mod discovery; + +pub mod raw; +pub mod arrow; +pub mod utils; diff --git a/src/net/raw/arp.rs b/src/net/raw/arp.rs new file mode 100644 index 0000000..0ca94c5 --- /dev/null +++ b/src/net/raw/arp.rs @@ -0,0 +1,330 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! ARP packet definitions. + +use std::io; +use std::mem; + +use utils; + +use std::io::Write; +use std::net::Ipv4Addr; + +use net::raw::ether::MacAddr; +use net::raw::ether::{Result, PacketParseError}; +use net::raw::ether::{EtherPacketHeader, EtherPacketBody, EtherPacketType}; + +/// ARP packet. +#[derive(Debug, Clone)] +pub struct ArpPacket { + pub htype: u16, + pub ptype: u16, + pub hlen: u8, + pub plen: u8, + pub oper: ArpOperation, + pub sha: Vec, + pub spa: Vec, + pub tha: Vec, + pub tpa: Vec, +} + +/// ARP operation. +#[derive(Debug, Copy, Clone)] +pub enum ArpOperation { + REQUEST = 1, + REPLY = 2, +} + +impl From for ArpOperation { + fn from(v: u16) -> ArpOperation { + match v { + 1 => ArpOperation::REQUEST, + 2 => ArpOperation::REPLY, + _ => panic!("illegal value passed as ARP operation") + } + } +} + +const ARP_HTYPE_EHER: u16 = 0x0001; +const ARP_PTYPE_IPV4: u16 = 0x0800; + +impl ArpPacket { + /// Create a new ARP packet for IPv4 over Ethernet. + pub fn ipv4_over_ethernet( + oper: ArpOperation, + sha: &MacAddr, + spa: &Ipv4Addr, + tha: &MacAddr, + tpa: &Ipv4Addr) -> ArpPacket { + ArpPacket { + htype: ARP_HTYPE_EHER, + ptype: ARP_PTYPE_IPV4, + hlen: 6, + plen: 4, + oper: oper, + sha: sha.octets().to_vec(), + spa: spa.octets().to_vec(), + tha: tha.octets().to_vec(), + tpa: tpa.octets().to_vec() + } + } +} + +impl EtherPacketBody for ArpPacket { + fn parse(data: &[u8]) -> Result { + let size = mem::size_of::(); + if data.len() < size { + Err(PacketParseError::from("unable to parse ARP packet, not enough data")) + } else { + let ptr = data.as_ptr(); + let ptr = ptr as *const RawArpPacketHeader; + let rh = unsafe { + &*ptr + }; + + let hlen = rh.hlen as usize; + let plen = rh.plen as usize; + let required = size + + (hlen << 1) + + (plen << 1); + + if data.len() < required { + Err(PacketParseError::from("unable to parse ARP packet, not enough data")) + } else { + let offset_1 = size; + let offset_2 = offset_1 + hlen; + let offset_3 = offset_2 + plen; + let offset_4 = offset_3 + hlen; + + let res = ArpPacket { + htype: u16::from_be(rh.htype), + ptype: u16::from_be(rh.ptype), + hlen: rh.hlen, + plen: rh.plen, + oper: ArpOperation::from(u16::from_be(rh.oper)), + sha: data[offset_1..offset_1+hlen].to_vec(), + spa: data[offset_2..offset_2+plen].to_vec(), + tha: data[offset_3..offset_3+hlen].to_vec(), + tpa: data[offset_4..offset_4+plen].to_vec() + }; + + Ok(res) + } + } + } + + fn serialize( + &self, + _: &EtherPacketHeader, + w: &mut W) -> io::Result<()> { + let rh = RawArpPacketHeader::new(self); + try!(w.write_all(utils::as_bytes(&rh))); + try!(w.write_all(&self.sha)); + try!(w.write_all(&self.spa)); + try!(w.write_all(&self.tha)); + w.write_all(&self.tpa) + } + + fn packet_type(&self) -> EtherPacketType { + EtherPacketType::ARP + } +} + +/// Packed representation of ARP packet header. +#[repr(packed)] +#[derive(Debug, Copy, Clone)] +struct RawArpPacketHeader { + htype: u16, + ptype: u16, + hlen: u8, + plen: u8, + oper: u16, +} + +impl RawArpPacketHeader { + /// Create a new raw ARP packet header. + fn new(arp: &ArpPacket) -> RawArpPacketHeader { + RawArpPacketHeader { + htype: arp.htype.to_be(), + ptype: arp.ptype.to_be(), + hlen: arp.hlen, + plen: arp.plen, + oper: (arp.oper as u16).to_be() + } + } +} + +#[cfg(feature = "discovery")] +pub mod scanner { + use super::*; + + use net::raw; + use net::raw::pcap; + + use std::net::Ipv4Addr; + + use utils::Serialize; + use net::utils::WriteBuffer; + use net::raw::pcap::ThreadingContext; + use net::raw::devices::EthernetDevice; + use net::raw::ether::{MacAddr, EtherPacket}; + use net::raw::pcap::{Scanner, PacketGenerator}; + + /// IPv4 ARP scanner. + pub struct Ipv4ArpScanner { + device: EthernetDevice, + scanner: Scanner, + } + + impl Ipv4ArpScanner { + /// Scan a given device and return list of all active hosts. + pub fn scan_device( + tc: ThreadingContext, + device: &EthernetDevice) -> pcap::Result> { + Ipv4ArpScanner::new(tc, device).scan() + } + + /// Create a new scanner instance. + fn new( + tc: ThreadingContext, + device: &EthernetDevice) -> Ipv4ArpScanner { + Ipv4ArpScanner { + device: device.clone(), + scanner: Scanner::new(tc, &device.name) + } + } + + /// Scan a given device and return list of all active hosts. + fn scan(&mut self) -> pcap::Result> { + let mut gen = Ipv4ArpScannerPacketGenerator::new(&self.device); + let filter = format!("arp and ether dst {}", + self.device.mac_addr); + let packets = try!(self.scanner.sr(&filter, + &mut gen, 1000000000)); + let mut hosts = Vec::new(); + + for p in packets { + if let Ok(ep) = EtherPacket::::parse(&p) { + let sha = MacAddr::from_slice(&ep.body.sha); + let spa = raw::utils::slice_to_ipv4addr(&ep.body.spa); + hosts.push((sha, spa)); + } + } + + Ok(hosts) + } + } + + /// Packet generator for the IPv4 ARP scanner. + struct Ipv4ArpScannerPacketGenerator { + device: EthernetDevice, + hdst: MacAddr, + bcast: MacAddr, + current: u32, + last: u32, + buffer: WriteBuffer, + } + + impl Ipv4ArpScannerPacketGenerator { + /// Create a new packet generator. + fn new(device: &EthernetDevice) -> Ipv4ArpScannerPacketGenerator { + let bcast = MacAddr::new(0xff, 0xff, 0xff, 0xff, 0xff, 0xff); + let hdst = MacAddr::new(0x00, 0x00, 0x00, 0x00, 0x00, 0x00); + let mask: u32 = raw::utils::ipv4addr_to_u32(&device.netmask); + let addr: u32 = raw::utils::ipv4addr_to_u32(&device.ip_addr); + let mut current = addr & mask; + let last = current | !mask; + + current += 1; + + Ipv4ArpScannerPacketGenerator { + device: device.clone(), + hdst: hdst, + bcast: bcast, + current: current, + last: last, + buffer: WriteBuffer::new(0) + } + } + } + + impl PacketGenerator for Ipv4ArpScannerPacketGenerator { + fn next<'a>(&'a mut self) -> Option<&'a [u8]> { + if self.current < self.last { + let pdst = Ipv4Addr::from(self.current); + let arpp = ArpPacket::ipv4_over_ethernet(ArpOperation::REQUEST, + &self.device.mac_addr, &self.device.ip_addr, + &self.hdst, &pdst); + let pkt = EtherPacket::create( + self.device.mac_addr, self.bcast, arpp); + + self.buffer.clear(); + + pkt.serialize(&mut self.buffer) + .unwrap(); + + self.current += 1; + + Some(self.buffer.as_bytes()) + } else { + None + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::net::Ipv4Addr; + + use utils::Serialize; + use net::utils::WriteBuffer; + use net::raw::ether::{MacAddr, EtherPacket}; + + #[test] + fn test_arp_packet() { + let sip = Ipv4Addr::new(192, 168, 3, 7); + let smac = MacAddr::new(1, 2, 3, 4, 5, 6); + let dip = Ipv4Addr::new(192, 168, 8, 1); + let dmac = MacAddr::new(6, 5, 4, 3, 2, 1); + + let arp = ArpPacket::ipv4_over_ethernet(ArpOperation::REQUEST, + &smac, &sip, &dmac, &dip); + let pkt = EtherPacket::create(smac, dmac, arp); + + let mut buf = WriteBuffer::new(0); + + pkt.serialize(&mut buf) + .unwrap(); + + let ep2 = EtherPacket::::parse(buf.as_bytes()) + .unwrap(); + + let arp = &pkt.body; + let arp2 = &ep2.body; + + assert_eq!(arp.htype, arp2.htype); + assert_eq!(arp.ptype, arp2.ptype); + assert_eq!(arp.hlen, arp2.hlen); + assert_eq!(arp.plen, arp2.plen); + assert_eq!(arp.oper as i32, arp2.oper as i32); + assert_eq!(arp.sha, arp2.sha); + assert_eq!(arp.spa, arp2.spa); + assert_eq!(arp.tha, arp2.tha); + assert_eq!(arp.tpa, arp2.tpa); + } +} diff --git a/src/net/raw/devices.c b/src/net/raw/devices.c new file mode 100644 index 0000000..84f7094 --- /dev/null +++ b/src/net/raw/devices.c @@ -0,0 +1,189 @@ +/* + * Copyright 2015 click2stream, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#define MAC_ADDR_SIZE 6 +#define IPV4_ADDR_SIZE 4 + +typedef struct net_device { + char* name; + unsigned char ipv4_address[IPV4_ADDR_SIZE]; + unsigned char ipv4_netmask[IPV4_ADDR_SIZE]; + unsigned char mac_address[MAC_ADDR_SIZE]; + struct net_device* next; +} net_device; + +static char * string_dup(const char* str) { + if (!str) + return NULL; + + size_t len = strlen(str); + char* result = malloc(len + 1); + if (!result) + return NULL; + + memcpy(result, str, len); + result[len] = 0; + + return result; +} + +static int get_mac_address(int fd, const char* dname, unsigned char* buffer) { + struct ifreq dconf; + + memset(&dconf, 0, sizeof(dconf)); + strncpy(dconf.ifr_name, dname, IFNAMSIZ); + + if (ioctl(fd, SIOCGIFHWADDR, &dconf) != 0) + return -1; + if (dconf.ifr_hwaddr.sa_family != ARPHRD_ETHER) + return -2; + + memcpy(buffer, dconf.ifr_hwaddr.sa_data, MAC_ADDR_SIZE); + + return 0; +} + +static int get_ipv4_record(int fd, unsigned long addr_type, + const char* dname, unsigned char* buffer) { + struct sockaddr_in* inet_addr; + struct ifreq dconf; + + memset(&dconf, 0, sizeof(dconf)); + strncpy(dconf.ifr_name, dname, IFNAMSIZ); + + if (ioctl(fd, addr_type, &dconf) != 0) + return -1; + if (dconf.ifr_hwaddr.sa_family != AF_INET) + return -2; + + inet_addr = (struct sockaddr_in*)&dconf.ifr_addr; + memcpy(buffer, &inet_addr->sin_addr, IPV4_ADDR_SIZE); + + return 0; +} + +static int get_ipv4_address(int fd, const char* dname, unsigned char* buffer) { + return get_ipv4_record(fd, SIOCGIFADDR, dname, buffer); +} + +static int get_ipv4_netmask(int fd, const char* dname, unsigned char* buffer) { + return get_ipv4_record(fd, SIOCGIFNETMASK, dname, buffer); +} + +void net_free_device_list(struct net_device* dev) { + struct net_device* tmp; + while (dev) { + tmp = dev; + dev = dev->next; + free(tmp->name); + free(tmp); + } +} + +static struct net_device * get_device_info(int fd, const char* name) { + struct net_device* result; + + result = malloc(sizeof(net_device)); + if (!result) + return NULL; + + memset(result, 0, sizeof(net_device)); + + if (!(result->name = string_dup(name))) { + free(result); + return NULL; + } + + if (get_mac_address(fd, name, result->mac_address) != 0) + goto err; + if (get_ipv4_address(fd, name, result->ipv4_address) != 0) + goto err; + if (get_ipv4_netmask(fd, name, result->ipv4_netmask) != 0) + goto err; + + return result; + +err: + net_free_device_list(result); + + return NULL; +} + +struct net_device * net_find_devices() { + struct net_device* result = NULL; + struct net_device* tmp; + struct ifreq dconf; + int fd, ret, i = 0; + + fd = socket(AF_INET, SOCK_DGRAM, 0); + if (fd < 0) + return NULL; + + memset(&dconf, 0, sizeof(dconf)); + dconf.ifr_ifindex = ++i; + + while ((ret = ioctl(fd, SIOCGIFNAME, &dconf)) == 0) { + tmp = get_device_info(fd, dconf.ifr_name); + if (tmp) { + tmp->next = result; + result = tmp; + } + + memset(&dconf, 0, sizeof(dconf)); + dconf.ifr_ifindex = ++i; + } + + close(fd); + + return result; +} + +const char * net_get_name(const struct net_device* dev) { + return dev->name; +} + +const unsigned char * net_get_ipv4_address(const struct net_device* dev) { + return dev->ipv4_address; +} + +const unsigned char * net_get_ipv4_netmask(const struct net_device* dev) { + return dev->ipv4_netmask; +} + +const unsigned char * net_get_mac_address(const struct net_device* dev) { + return dev->mac_address; +} + +const struct net_device * net_get_next_device(const struct net_device* dev) { + return dev->next; +} + +const size_t net_get_mac_addr_size() { + return MAC_ADDR_SIZE; +} + +const size_t net_get_ipv4_addr_size() { + return IPV4_ADDR_SIZE; +} + diff --git a/src/net/raw/devices.rs b/src/net/raw/devices.rs new file mode 100644 index 0000000..7aaad7b --- /dev/null +++ b/src/net/raw/devices.rs @@ -0,0 +1,108 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Ethernet device definitions. + +use std::slice; + +use std::net::Ipv4Addr; + +use utils; + +use net::raw::ether::MacAddr; + +use libc::{c_char, c_void, size_t}; + +#[allow(non_camel_case_types)] +type net_device = *mut c_void; + +#[link(name = "net_devices")] +extern "C" { + fn net_find_devices() -> net_device; + fn net_free_device_list(dev: net_device) -> c_void; + fn net_get_name(dev: net_device) -> *const c_char; + fn net_get_ipv4_address(dev: net_device) -> *const c_char; + fn net_get_ipv4_netmask(dev: net_device) -> *const c_char; + fn net_get_mac_address(dev: net_device) -> *const c_char; + fn net_get_next_device(dev: net_device) -> net_device; + fn net_get_mac_addr_size() -> size_t; + fn net_get_ipv4_addr_size() -> size_t; +} + +/// Ethernet device. +#[derive(Clone, Debug)] +pub struct EthernetDevice { + pub name: String, + pub mac_addr: MacAddr, + pub ip_addr: Ipv4Addr, + pub netmask: Ipv4Addr, +} + +impl EthernetDevice { + /// List all configured IPv4 network devices. + pub fn list() -> Vec { + let mut result = Vec::new(); + unsafe { + let devices = net_find_devices(); + let mut device = devices.clone(); + while !device.is_null() { + result.push(EthernetDevice::new(device)); + device = net_get_next_device(device); + } + net_free_device_list(devices); + } + + result + } + + /// Create a new ethernet device instance from its raw counterpart. + unsafe fn new(dev: net_device) -> EthernetDevice { + EthernetDevice { + name: get_name(dev), + mac_addr: get_mac_addr(dev), + ip_addr: get_ipv4_addr(dev), + netmask: get_ipv4_mask(dev) + } + } +} + +/// Get device name. +unsafe fn get_name(dev: net_device) -> String { + utils::cstr_to_string(net_get_name(dev) as *const i8) +} + +/// Get device MAC address. +unsafe fn get_mac_addr(dev: net_device) -> MacAddr { + let addr = net_get_mac_address(dev) as *const c_void; + let bytes = ptr_to_bytes(addr, net_get_mac_addr_size() as usize); + MacAddr::new(bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]) +} + +/// Get device IPv4 address. +unsafe fn get_ipv4_addr(dev: net_device) -> Ipv4Addr { + let addr = net_get_ipv4_address(dev) as *const c_void; + let bytes = ptr_to_bytes(addr, net_get_ipv4_addr_size() as usize); + Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]) +} + +/// Get device IPv4 mask. +unsafe fn get_ipv4_mask(dev: net_device) -> Ipv4Addr { + let addr = net_get_ipv4_netmask(dev) as *const c_void; + let bytes = ptr_to_bytes(addr, net_get_ipv4_addr_size() as usize); + Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]) +} + +unsafe fn ptr_to_bytes<'a>(ptr: *const c_void, len: usize) -> &'a [u8] { + slice::from_raw_parts(ptr as *const u8, len) +} diff --git a/src/net/raw/ether.rs b/src/net/raw/ether.rs new file mode 100644 index 0000000..2fc48f4 --- /dev/null +++ b/src/net/raw/ether.rs @@ -0,0 +1,363 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Ethernet packet definitions. + +use std::io; +use std::mem; +use std::fmt; +use std::result; + +use utils; + +use std::io::Write; +use std::str::FromStr; +use std::error::Error; +use std::fmt::{Display, Formatter}; + +use utils::Serialize; +use net::raw::ip::Ipv4PacketBody; + +/// MacAddr parse error. +#[derive(Debug, Clone)] +pub struct AddrParseError { + msg: String, +} + +impl Error for AddrParseError { + fn description(&self) -> &str { + &self.msg + } +} + +impl Display for AddrParseError { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + f.write_str(self.description()) + } +} + +impl<'a> From<&'a str> for AddrParseError { + fn from(msg: &'a str) -> AddrParseError { + AddrParseError { msg: msg.to_string() } + } +} + +/// MAC address type. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct MacAddr { + bytes: [u8; 6], +} + +impl MacAddr { + /// Create a new MAC address. + pub fn new(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8) -> MacAddr { + MacAddr { bytes: [a, b, c, d, e, f] } + } + + /// Get address octets. + pub fn octets(&self) -> [u8; 6] { + self.bytes + } + + /// Crete address from slice. + pub fn from_slice(bytes: &[u8]) -> MacAddr { + assert_eq!(bytes.len(), 6); + MacAddr::new(bytes[0], bytes[1], bytes[2], + bytes[3], bytes[4], bytes[5]) + } +} + +impl Display for MacAddr { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + f.write_str(&format!("{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", + self.bytes[0], self.bytes[1], self.bytes[2], + self.bytes[3], self.bytes[4], self.bytes[5])) + } +} + +impl FromStr for MacAddr { + type Err = AddrParseError; + + fn from_str(s: &str) -> result::Result { + let octets = s.split(':') + .map(|x| u8::from_str_radix(x, 16) + .or(Err(AddrParseError::from("unable to parse a MAC address, invalid octet")))) + .collect::>(); + if octets.len() == 6 { + Ok(MacAddr::new( + try!(octets[0].clone()), + try!(octets[1].clone()), + try!(octets[2].clone()), + try!(octets[3].clone()), + try!(octets[4].clone()), + try!(octets[5].clone()))) + } else { + Err(AddrParseError::from("unable to parse a MAC address, invalid number of octets")) + } + } +} + +/// Packet parser error. +#[derive(Debug, Clone)] +pub struct PacketParseError { + msg: String, +} + +impl Error for PacketParseError { + fn description(&self) -> &str { + &self.msg + } +} + +impl Display for PacketParseError { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + f.write_str(&self.msg) + } +} + +impl<'a> From<&'a str> for PacketParseError { + fn from(msg: &'a str) -> PacketParseError { + PacketParseError { msg: msg.to_string() } + } +} + +/// Type alias for parser results. +pub type Result = result::Result; + +pub const ETYPE_ARP: u16 = 0x0806; +pub const ETYPE_IPV4: u16 = 0x0800; + +/// Ethernet packet header. +#[derive(Debug, Copy, Clone)] +pub struct EtherPacketHeader { + pub src: MacAddr, + pub dst: MacAddr, + pub etype: u16, +} + +impl EtherPacketHeader { + /// Create a new ethernet packet header. + pub fn new(src: MacAddr, dst: MacAddr, etype: u16) -> EtherPacketHeader { + EtherPacketHeader { + src: src, + dst: dst, + etype: etype + } + } + + /// Get packet type. + pub fn packet_type(&self) -> EtherPacketType { + EtherPacketType::from(self.etype) + } + + /// Get raw header. + fn raw_header(&self) -> RawEtherPacketHeader { + RawEtherPacketHeader { + src: self.src.octets(), + dst: self.dst.octets(), + etype: self.etype.to_be() + } + } + + /// Read header from a given raw representation. + fn parse(data: &[u8]) -> EtherPacketHeader { + assert_eq!(data.len(), mem::size_of::()); + let ptr = data.as_ptr(); + let ptr = ptr as *const RawEtherPacketHeader; + let rh = unsafe { + &*ptr + }; + + EtherPacketHeader { + src: MacAddr::from_slice(&rh.src), + dst: MacAddr::from_slice(&rh.dst), + etype: u16::from_be(rh.etype) + } + } +} + +impl Serialize for EtherPacketHeader { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(utils::as_bytes(&self.raw_header())) + } +} + +/// Packed representation of the Ethernet packet header. +#[repr(packed)] +#[derive(Debug, Copy, Clone)] +struct RawEtherPacketHeader { + dst: [u8; 6], + src: [u8; 6], + etype: u16, +} + +/// Ethernet packet types. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum EtherPacketType { + ARP, + IPv4, + UNKNOWN +} + +impl EtherPacketType { + /// Get system code of this packet type. + pub fn code(self) -> u16 { + match self { + EtherPacketType::ARP => ETYPE_ARP, + EtherPacketType::IPv4 => ETYPE_IPV4, + _ => panic!("no etype code for unknown packet type") + } + } +} + +impl From for EtherPacketType { + /// Get ethernet packet type from a given code. + fn from(code: u16) -> EtherPacketType { + match code { + ETYPE_ARP => EtherPacketType::ARP, + ETYPE_IPV4 => EtherPacketType::IPv4, + _ => EtherPacketType::UNKNOWN + } + } +} + +/// Common trait for ethernet packet body implementations. +pub trait EtherPacketBody : Sized { + /// Parse body from its raw representation. + fn parse(data: &[u8]) -> Result; + + /// Serialize the packet body in-place using a given writer. + fn serialize( + &self, + eh: &EtherPacketHeader, + w: &mut W) -> io::Result<()>; + + /// Get type of this body. + fn packet_type(&self) -> EtherPacketType; +} + +impl EtherPacketBody for Vec { + fn parse(data: &[u8]) -> Result> { + Ok(data.to_vec()) + } + + fn serialize( + &self, + _: &EtherPacketHeader, + w: &mut W) -> io::Result<()> { + w.write_all(self) + } + + fn packet_type(&self) -> EtherPacketType { + EtherPacketType::UNKNOWN + } +} + +/// Ethernet packet. +#[derive(Debug, Clone)] +pub struct EtherPacket { + pub header: EtherPacketHeader, + pub body: B, +} + +impl EtherPacket { + /// Create a new ethernet packet. + pub fn new(header: EtherPacketHeader, body: B) -> EtherPacket { + EtherPacket { + header: header, + body: body + } + } + + /// Create a new ethernet packet. + pub fn create( + src: MacAddr, + dst: MacAddr, + body: B) -> EtherPacket { + let pt = body.packet_type(); + let header = EtherPacketHeader::new(src, dst, pt.code()); + EtherPacket::new(header, body) + } + + /// Parse a given ethernet packet. + pub fn parse(data: &[u8]) -> Result> { + let hsize = mem::size_of::(); + if data.len() < hsize { + Err(PacketParseError::from("unable to parse ethernet packet, not enough data")) + } else { + let header = EtherPacketHeader::parse(&data[..hsize]); + let body = try!(B::parse(&data[hsize..])); + let btype = body.packet_type(); + if btype == EtherPacketType::UNKNOWN || + btype == EtherPacketType::from(header.etype) { + Ok(EtherPacket::new(header, body)) + } else { + Err(PacketParseError::from("expect and actual ethernet packet types do not match")) + } + } + } +} + +impl Serialize for EtherPacket { + fn serialize(&self, w: &mut W) -> io::Result<()> { + try!(self.header.serialize(w)); + self.body.serialize(&self.header, w) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use net::raw::arp::*; + use utils::Serialize; + use net::utils::WriteBuffer; + + use std::net::Ipv4Addr; + + #[test] + fn test_mac_addr() { + let addr = MacAddr::new(1, 2, 3, 4, 5, 6); + let octets = addr.octets(); + + assert_eq!([1, 2, 3, 4, 5, 6], octets); + + let addr2 = MacAddr::from_slice(&octets); + + assert_eq!(octets, addr2.octets()); + } + + #[test] + fn test_ether_packet() { + let src = MacAddr::new(1, 2, 3, 4, 5, 6); + let dst = MacAddr::new(6, 5, 4, 3, 2, 1); + let sip = Ipv4Addr::new(192, 168, 3, 7); + let dip = Ipv4Addr::new(192, 168, 8, 1); + let arp = ArpPacket::ipv4_over_ethernet(ArpOperation::REQUEST, + &src, &sip, &dst, &dip); + let pkt = EtherPacket::create(src, dst, arp); + + let mut buf = WriteBuffer::new(0); + + pkt.serialize(&mut buf) + .unwrap(); + + let ep2 = EtherPacket::::parse(buf.as_bytes()) + .unwrap(); + + assert_eq!(pkt.header.src.octets(), ep2.header.src.octets()); + assert_eq!(pkt.header.dst.octets(), ep2.header.dst.octets()); + assert_eq!(pkt.header.etype, ep2.header.etype); + } +} diff --git a/src/net/raw/ip.rs b/src/net/raw/ip.rs new file mode 100644 index 0000000..183ed34 --- /dev/null +++ b/src/net/raw/ip.rs @@ -0,0 +1,366 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! IP packet definitions. + +use std::io; +use std::mem; + +use utils; +use net::raw; + +use std::io::Write; +use std::net::Ipv4Addr; + +use utils::Serialize; +use net::raw::ether::{Result, PacketParseError}; +use net::raw::ether::{EtherPacketHeader, EtherPacketBody, EtherPacketType}; + +pub const IP_PROTO_ICMP: u8 = 0x01; +pub const IP_PROTO_TCP: u8 = 0x06; +pub const IP_PROTO_UDP: u8 = 0x11; + +/// IPv4 packet header. +#[derive(Clone, Debug)] +pub struct Ipv4PacketHeader { + pub version: u8, + pub dscp: u8, + pub ecn: u8, + pub ident: u16, + pub flags: u8, + pub foffset: u16, + pub ttl: u8, + pub protocol: u8, + pub src: Ipv4Addr, + pub dst: Ipv4Addr, + pub options: Vec, + length: usize, +} + +impl Ipv4PacketHeader { + /// Create a new IPv4 header. + pub fn new( + src: Ipv4Addr, + dst: Ipv4Addr, + protocol: u8, + ttl: u8) -> Ipv4PacketHeader { + Ipv4PacketHeader { + version: 4, + dscp: 0, + ecn: 0, + ident: 0, + flags: 0, + foffset: 0, + ttl: ttl, + protocol: protocol, + src: src, + dst: dst, + options: Vec::new(), + length: 0 + } + } + + /// Serialize header in-place using a given writer. + fn serialize(&self, dlen: usize, w: &mut W) -> io::Result<()> { + let rh = RawIpv4PacketHeader::new(self, dlen); + try!(w.write_all(utils::as_bytes(&rh))); + w.write_all(utils::slice_as_bytes(&self.options)) + } + + /// Read header from given raw representation. + fn parse(data: &[u8]) -> Result { + let size = mem::size_of::(); + if data.len() < size { + Err(PacketParseError::from("unable to parse IPv4 packet, not enough data")) + } else { + let ptr = data.as_ptr(); + let ptr = ptr as *const RawIpv4PacketHeader; + let rh = unsafe { + &*ptr + }; + + let flags_foffset = u16::from_be(rh.flags_foffset); + let ihl = rh.vihl & 0x0f; + let options_len = ihl as usize - (size >> 2); + let offset_1 = size as isize; + + if data.len() < (size + (options_len << 2)) { + Err(PacketParseError::from("unable to parse IPv4 packet, not enough data")) + } else { + let options = unsafe { + utils::vec_from_raw_parts( + ptr.offset(offset_1) as *const u32, + options_len) + }; + + let res = Ipv4PacketHeader { + version: rh.vihl >> 4, + dscp: rh.dscp_ecn >> 2, + ecn: rh.dscp_ecn & 0x03, + ident: u16::from_be(rh.ident), + flags: (flags_foffset >> 13) as u8, + foffset: flags_foffset & 0x1fff, + ttl: rh.ttl, + protocol: rh.protocol, + src: raw::utils::slice_to_ipv4addr(&rh.src), + dst: raw::utils::slice_to_ipv4addr(&rh.dst), + options: options, + length: u16::from_be(rh.length) as usize + }; + + Ok(res) + } + } + } +} + +/// Packed representation of the IPv4 packet header. +#[repr(packed)] +#[allow(dead_code)] +#[derive(Debug, Copy, Clone)] +struct RawIpv4PacketHeader { + vihl: u8, + dscp_ecn: u8, + length: u16, + ident: u16, + flags_foffset: u16, + ttl: u8, + protocol: u8, + checksum: u16, + src: [u8; 4], + dst: [u8; 4], +} + +impl RawIpv4PacketHeader { + /// Create a new raw IPv4 packet header. + fn new(ip: &Ipv4PacketHeader, dlen: usize) -> RawIpv4PacketHeader { + let size = mem::size_of::(); + let length = size + (ip.options.len() << 2) + dlen; + let ihl = 5 + ip.options.len() as u8; + let flags_foffset = ((ip.flags as u16) << 13) | (ip.foffset & 0x1fff); + let mut rh = RawIpv4PacketHeader { + vihl: (ip.version << 4) | (ihl & 0x0f), + dscp_ecn: (ip.dscp << 2) | (ip.ecn & 0x03), + length: (length as u16).to_be(), + ident: ip.ident.to_be(), + flags_foffset: flags_foffset.to_be(), + ttl: ip.ttl, + protocol: ip.protocol, + checksum: 0, + src: ip.src.octets(), + dst: ip.dst.octets() + }; + + let mut sum = raw::utils::sum_type(&rh); + sum += raw::utils::sum_slice(&ip.options); + + rh.checksum = raw::utils::sum_to_checksum(sum).to_be(); + + rh + } +} + +/// IPv4 packet types. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum Ipv4PacketType { + ICMP, + TCP, + UDP, + UNKNOWN +} + +impl Ipv4PacketType { + /// Get protocol code of this packet type. + pub fn code(self) -> u8 { + match self { + Ipv4PacketType::ICMP => IP_PROTO_ICMP, + Ipv4PacketType::TCP => IP_PROTO_TCP, + Ipv4PacketType::UDP => IP_PROTO_UDP, + _ => panic!("no protocol code for unknown IPv4 packet type") + } + } +} + +impl From for Ipv4PacketType { + /// Get IPv4 packet type from a given code. + fn from(code: u8) -> Ipv4PacketType { + match code { + IP_PROTO_ICMP => Ipv4PacketType::ICMP, + IP_PROTO_TCP => Ipv4PacketType::TCP, + IP_PROTO_UDP => Ipv4PacketType::UDP, + _ => Ipv4PacketType::UNKNOWN + } + } +} + +/// Common trait for IPv4 body implementations. +pub trait Ipv4PacketBody : Sized { + /// Parse body from its raw representation. + fn parse(data: &[u8]) -> Result; + + /// Serialize the packet body in-place using a given writer. + fn serialize( + &self, + iph: &Ipv4PacketHeader, + w: &mut W) -> io::Result<()>; + + /// Get IPv4 packet type of this body. + fn packet_type(&self) -> Ipv4PacketType; + + /// Get body length. + fn len(&self) -> usize; +} + +impl Ipv4PacketBody for Vec { + fn parse(data: &[u8]) -> Result> { + Ok(data.to_vec()) + } + + fn serialize( + &self, + _: &Ipv4PacketHeader, + w: &mut W) -> io::Result<()> { + w.write_all(self) + } + + fn packet_type(&self) -> Ipv4PacketType { + Ipv4PacketType::UNKNOWN + } + + fn len(&self) -> usize { + Vec::::len(self) + } +} + +/// IPv4 packet. +#[derive(Debug, Clone)] +pub struct Ipv4Packet { + pub header: Ipv4PacketHeader, + pub body: B, +} + +impl Ipv4Packet { + /// Create a new IPv4 packet. + pub fn new(header: Ipv4PacketHeader, body: B) -> Ipv4Packet { + Ipv4Packet { + header: header, + body: body + } + } + + /// Create a new IPv4 packet. + pub fn create( + saddr: Ipv4Addr, + daddr: Ipv4Addr, + ttl: u8, + body: B) -> Ipv4Packet { + let pt = body.packet_type(); + let header = Ipv4PacketHeader::new(saddr, daddr, pt.code(), ttl); + Ipv4Packet::new(header, body) + } +} + +impl EtherPacketBody for Ipv4Packet { + fn parse(data: &[u8]) -> Result> { + let hsize = mem::size_of::(); + if data.len() < hsize { + Err(PacketParseError::from("unable to parse IPv4 packet, not enough data")) + } else { + let header = try!(Ipv4PacketHeader::parse(data)); + let offset = hsize + (header.options.len() << 2); + let body = try!(B::parse(&data[offset..])); + let btype = body.packet_type(); + if btype == Ipv4PacketType::UNKNOWN || + btype == Ipv4PacketType::from(header.protocol) { + Ok(Ipv4Packet::new(header, body)) + } else { + Err(PacketParseError::from("expected and actual IPv4 packet types do not match")) + } + } + } + + fn serialize( + &self, + _: &EtherPacketHeader, + w: &mut W) -> io::Result<()> { + let dlen = self.body.len(); + try!(self.header.serialize(dlen, w)); + self.body.serialize(&self.header, w) + } + + fn packet_type(&self) -> EtherPacketType { + EtherPacketType::IPv4 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use net::raw::tcp::*; + use utils::Serialize; + use net::utils::WriteBuffer; + use net::raw::ether::{MacAddr, EtherPacket}; + + use std::net::Ipv4Addr; + + #[test] + fn test_ip_packet() { + let sip = Ipv4Addr::new(192, 168, 3, 7); + let dip = Ipv4Addr::new(192, 168, 8, 1); + let mac = MacAddr::new(0, 0, 0, 0, 0, 0); + + let data = [1, 2, 3]; + + let tcp = TcpPacket::new(10, 20, TCP_FLAG_FIN | TCP_FLAG_SYN, &data); + let ip = Ipv4Packet::create(sip, dip, 64, tcp); + let pkt = EtherPacket::create(mac, mac, ip); + + let mut buf = WriteBuffer::new(0); + + pkt.serialize(&mut buf) + .unwrap(); + + let ep2 = EtherPacket::>::parse(buf.as_bytes()) + .unwrap(); + + let iph = &pkt.body.header; + let iph2 = &ep2.body.header; + + assert_eq!(iph.version, iph2.version); + assert_eq!(iph.dscp, iph2.dscp); + assert_eq!(iph.ecn, iph2.ecn); + assert_eq!(iph.ident, iph2.ident); + assert_eq!(iph.flags, iph2.flags); + assert_eq!(iph.foffset, iph2.foffset); + assert_eq!(iph.ttl, iph2.ttl); + assert_eq!(iph.protocol, iph2.protocol); + assert_eq!(iph.src, iph2.src); + assert_eq!(iph.dst, iph2.dst); + assert_eq!(iph.options, iph2.options); + + let tcp = &pkt.body.body; + let tcp2 = &ep2.body.body; + + assert_eq!(tcp.sport, tcp2.sport); + assert_eq!(tcp.dport, tcp2.dport); + assert_eq!(tcp.seq, tcp2.seq); + assert_eq!(tcp.ack, tcp2.ack); + assert_eq!(tcp.flags, tcp2.flags); + assert_eq!(tcp.wsize, tcp2.wsize); + assert_eq!(tcp.uptr, tcp2.uptr); + assert_eq!(tcp.options, tcp2.options); + assert_eq!(tcp.data, tcp2.data); + } +} diff --git a/src/net/raw/mod.rs b/src/net/raw/mod.rs new file mode 100644 index 0000000..d4a7eed --- /dev/null +++ b/src/net/raw/mod.rs @@ -0,0 +1,23 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(feature = "discovery")] +pub mod pcap; + +pub mod devices; +pub mod ether; +pub mod ip; +pub mod arp; +pub mod tcp; +pub mod utils; diff --git a/src/net/raw/pcap.rs b/src/net/raw/pcap.rs new file mode 100644 index 0000000..0612a8a --- /dev/null +++ b/src/net/raw/pcap.rs @@ -0,0 +1,407 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! PCAP network scanner definitions. + +use std::ptr; +use std::fmt; +use std::thread; +use std::result; + +use std::error::Error; +use std::thread::JoinHandle; +use std::sync::{Arc, Mutex}; +use std::ffi::CString; +use std::fmt::{Display, Formatter}; + +use utils; + +use time; + +use libc::{c_int, c_uint, c_long, c_char, c_uchar, c_void, size_t}; + +/// PCAP module error. +#[derive(Debug)] +pub struct PcapError { + msg: String, +} + +impl PcapError { + unsafe fn from_cstr(msg: *const c_char) -> PcapError { + PcapError { msg: utils::cstr_to_string(msg as *const _) } + } + + fn from_pcap(p: pcap_t) -> PcapError { + unsafe { + let cstr = pcap_geterr(p); + Self::from_cstr(cstr) + } + } +} + +impl Error for PcapError { + fn description(&self) -> &str { + &self.msg + } +} + +impl Display for PcapError { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + f.write_str(&self.msg) + } +} + +impl<'a> From<&'a str> for PcapError { + fn from(msg: &'a str) -> PcapError { + PcapError { msg: msg.to_string() } + } +} + +pub type Result = result::Result; + +#[allow(non_camel_case_types)] +type pcap_t = *mut c_void; +#[allow(non_camel_case_types)] +type bpf_u_int32 = c_uint; +#[allow(non_camel_case_types)] +type time_t = c_long; +#[allow(non_camel_case_types)] +type suseconds_t = c_long; + +#[repr(C)] +#[allow(non_camel_case_types)] +struct timeval { + tv_sec: time_t, + tv_usec: suseconds_t, +} + +#[repr(C)] +#[allow(non_camel_case_types)] +struct pcap_pkthdr { + ts: timeval, + caplen: bpf_u_int32, + len: bpf_u_int32, +} + +#[repr(C)] +#[allow(non_camel_case_types)] +struct bpf_program { + bf_len: c_uint, + bf_insns: *mut c_void, +} + +impl bpf_program { + fn new() -> bpf_program { + bpf_program { + bf_len: 0, + bf_insns: ptr::null_mut() + } + } +} + +#[link(name = "pcap")] +extern "C" { + fn pcap_create(source: *const c_char, errbuf: *mut c_char) -> pcap_t; + fn pcap_activate(p: pcap_t) -> c_int; + fn pcap_close(p: pcap_t) -> c_void; + fn pcap_geterr(p: pcap_t) -> *const c_char; + fn pcap_set_promisc(p: pcap_t, promisc: c_int) -> c_int; + fn pcap_set_timeout(p: pcap_t, ms: c_int) -> c_int; + fn pcap_next_ex( + p: pcap_t, + header: *mut *mut pcap_pkthdr, + data: *mut *const c_uchar) -> c_int; + fn pcap_compile( + p: pcap_t, + prog: *mut bpf_program, + pstr: *const c_char, + optimize: c_int, + netmask: bpf_u_int32) -> c_int; + fn pcap_freecode(prog: *mut bpf_program) -> c_void; + fn pcap_setfilter(p: pcap_t, prog: *mut bpf_program) -> c_int; + fn pcap_inject(p: pcap_t, buf: *const c_void, size: size_t) -> c_int; +} + +/// PCAP context for synchronizing thread unsafe calls. +pub struct Context; + +/// PCAP context for synchronizing thread unsafe calls. +pub type ThreadingContext = Arc>; + +/// Create a new PCAP context for synchronizing thread unsafe calls. +pub fn new_threading_context() -> ThreadingContext { + Arc::new(Mutex::new(Context)) +} + +/// PCAP Capture builder. +pub struct CaptureBuilder { + capture: Capture, +} + +impl CaptureBuilder { + /// Create a new CaptureBuilder for a given device. + pub fn new(pc: ThreadingContext, dname: &str) -> Result { + let mut result = CaptureBuilder { + capture: Capture { + pc: pc, + errbuf: Box::new([0; 4096]), + h: ptr::null_mut() + } + }; + + let dname_cstr = CString::new(dname) + .unwrap() + .as_ptr() as *const c_char; + let errbuf_ptr = result.capture.errbuf.as_mut_ptr(); + result.capture.h = unsafe { + pcap_create(dname_cstr, errbuf_ptr as *mut c_char) + }; + + if result.capture.h.is_null() { + Err(unsafe { PcapError::from_cstr(errbuf_ptr as *const c_char) }) + } else { + Ok(result) + } + } + + /// Set promiscuous mode. + pub fn promisc(self, p: bool) -> CaptureBuilder { + unsafe { pcap_set_promisc(self.capture.h, p as c_int); } + self + } + + /// Set timeout. + pub fn timeout(self, ms: i32) -> CaptureBuilder { + unsafe { pcap_set_timeout(self.capture.h, ms as c_int); } + self + } + + /// Activate. + pub fn activate(self) -> Result { + let ret = unsafe { pcap_activate(self.capture.h) }; + match ret { + 0 => Ok(self.capture), + _ => Err(PcapError::from_pcap(self.capture.h)) + } + } +} + +/// Capture result. +pub type CaptureResult = Result>>; + +/// PCAP capture. +pub struct Capture { + pc: ThreadingContext, + errbuf: Box<[i8; 4096]>, + h: pcap_t, +} + +impl Capture { + /// Capture next packet. + pub fn next(&mut self) -> CaptureResult { + let mut header: *mut pcap_pkthdr = ptr::null_mut(); + let mut data: *const c_uchar = ptr::null(); + + unsafe { + match pcap_next_ex(self.h, &mut header, &mut data) { + 1 => { + let href = &*header; + let vec = utils::vec_from_raw_parts( + data, href.caplen as usize); + Ok(Some(vec)) + }, + 0 => Ok(None), + _ => Err(PcapError::from_pcap(self.h)) + } + } + } + + /// Set packet filter. + pub fn filter(&mut self, f: &str) -> Result<()> { + unsafe { + let mut prog = try!(self.compile_filter(f)); + let ret = pcap_setfilter(self.h, &mut prog); + + pcap_freecode(&mut prog); + + match ret { + 0 => Ok(()), + _ => Err(PcapError::from_pcap(self.h)) + } + } + } + + /// Inject a given raw packet. + pub fn inject(&mut self, data: &[u8]) -> Result { + let ptr = data.as_ptr() as *const c_void; + let ret = unsafe { + pcap_inject(self.h, ptr, data.len() as size_t) + }; + + if ret < 0 { + Err(PcapError::from_pcap(self.h)) + } else { + Ok(ret as usize) + } + } + + /// Compile a given filter string. + unsafe fn compile_filter(&mut self, f: &str) -> Result { + let _lock = self.pc.lock() + .unwrap(); + + let f_cstr = CString::new(f) + .unwrap() + .as_ptr() as *const c_char; + + let mut prog = bpf_program::new(); + + match pcap_compile(self.h, &mut prog, f_cstr, 0, 0) { + 0 => Ok(prog), + _ => Err(PcapError::from_pcap(self.h)) + } + } +} + +impl Drop for Capture { + fn drop(&mut self) { + unsafe { pcap_close(self.h); } + } +} + +unsafe impl Send for Capture { +} + +/// Common trait for packet generators which may be used in combination with +/// the PCAP packet scanner. +pub trait PacketGenerator { + /// Get next packet. + fn next<'a>(&'a mut self) -> Option<&'a [u8]>; +} + +/// PCAP packet scanner (implementation of a send-receive service). +pub struct Scanner { + pc: ThreadingContext, + device: String, + end_indicator: Arc> +} + +impl Scanner { + /// Create a new PCAP scanner for a given device. + pub fn new(pc: ThreadingContext, device: &str) -> Scanner { + Scanner { + pc: pc, + device: device.to_string(), + end_indicator: Arc::new(Mutex::new(false)) + } + } + + /// Send all packets from a given iterator and receive all packets + /// according to a given filter. + pub fn sr( + &mut self, + filter: &str, + gen: &mut G, + timeout: u64) -> Result>> { + self.set_end_indicator(false); + + let thread = try!(self.start_listener(filter, timeout)); + + try!(self.send_requests(gen)); + + self.set_end_indicator(true); + + match thread.join() { + Err(_) => Err(PcapError::from("listener thread panicked")), + Ok(res) => Ok(res) + } + } + + /// Start packet listener thread. + fn start_listener( + &mut self, + filter: &str, + timeout: u64) -> Result>>> { + let ei = self.end_indicator.clone(); + + let cap = try!(CaptureBuilder::new(self.pc.clone(), &self.device)) + .timeout((timeout / 1000000) as i32) + .promisc(true); + + let mut cap = try!(cap.activate()); + + try!(cap.filter(filter)); + + let handle = thread::spawn(move || { + Self::packet_listener(cap, ei, timeout) + }); + + Ok(handle) + } + + /// Packet listener thread. + fn packet_listener( + mut cap: Capture, + shared_end_indicator: Arc>, + timeout: u64) -> Vec> { + let mut vec = Vec::new(); + let mut t = time::precise_time_ns(); + let mut end = false; + + while !end || (time::precise_time_ns() - t) < timeout { + match cap.next() { + Ok(Some(data)) => vec.push(data), + Err(error) => panic!(error), + _ => (), + } + + if !end && Self::get_end_indicator_value(&shared_end_indicator) { + t = time::precise_time_ns(); + end = true; + } + } + + vec + } + + /// Send all pending packets. + fn send_requests( + &mut self, + gen: &mut G) -> Result<()> { + let cap = try!(CaptureBuilder::new(self.pc.clone(), &self.device)); + let mut cap = try!(cap.activate()); + + while let Some(pkt) = gen.next() { + try!(cap.inject(pkt)); + } + + Ok(()) + } + + /// Set listener end indicator. + fn set_end_indicator(&mut self, val: bool) { + let mut end_indicator = self.end_indicator.lock() + .unwrap(); + + *end_indicator = val; + } + + /// Get end indicator value. + fn get_end_indicator_value(end_indicator: &Arc>) -> bool { + let ei = end_indicator.lock() + .unwrap(); + + *ei + } +} diff --git a/src/net/raw/tcp.rs b/src/net/raw/tcp.rs new file mode 100644 index 0000000..1a3fc63 --- /dev/null +++ b/src/net/raw/tcp.rs @@ -0,0 +1,517 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! TCP packet definitions. + +use std::io; +use std::mem; + +use utils; +use net::raw; + +use std::io::Write; + +use net::raw::ether::{Result, PacketParseError}; +use net::raw::ip::{Ipv4PacketHeader, Ipv4PacketBody, Ipv4PacketType}; + +pub const TCP_FLAG_NS: u16 = 1 << 8; +pub const TCP_FLAG_CWR: u16 = 1 << 7; +pub const TCP_FLAG_ECE: u16 = 1 << 6; +pub const TCP_FLAG_URG: u16 = 1 << 5; +pub const TCP_FLAG_ACK: u16 = 1 << 4; +pub const TCP_FLAG_PSH: u16 = 1 << 3; +pub const TCP_FLAG_RST: u16 = 1 << 2; +pub const TCP_FLAG_SYN: u16 = 1 << 1; +pub const TCP_FLAG_FIN: u16 = 1; + +/// TCP packet. +#[derive(Clone, Debug)] +pub struct TcpPacket { + pub sport: u16, + pub dport: u16, + pub seq: u32, + pub ack: u32, + pub flags: u16, + pub wsize: u16, + pub uptr: u16, + pub options: Vec, + pub data: Vec, +} + +impl TcpPacket { + /// Create a new TCP packet. + pub fn new(sport: u16, dport: u16, flags: u16, data: &[u8]) -> TcpPacket { + TcpPacket { + sport: sport, + dport: dport, + seq: 0, + ack: 0, + flags: flags, + wsize: 8192, + uptr: 0, + options: Vec::new(), + data: data.to_vec() + } + } +} + +impl Ipv4PacketBody for TcpPacket { + fn parse(data: &[u8]) -> Result { + let size = mem::size_of::(); + if data.len() < size { + Err(PacketParseError::from("unable to parse TCP packet, not enough data")) + } else { + let ptr = data.as_ptr(); + let ptr = ptr as *const RawTcpPacketHeader; + let rh = unsafe { + &*ptr + }; + + let doffset_flags = u16::from_be(rh.doffset_flags); + let doffset = doffset_flags >> 12; + let options_len = doffset as usize - (size >> 2); + + let offset_1 = size; + let offset_2 = offset_1 + (options_len << 2); + + if offset_2 > data.len() { + Err(PacketParseError::from("unable to parse TCP packet, not enough data")) + } else { + let options = unsafe { + utils::vec_from_raw_parts( + ptr.offset(offset_1 as isize) as *const u32, + options_len) + }; + + let res = TcpPacket { + sport: u16::from_be(rh.sport), + dport: u16::from_be(rh.dport), + seq: u32::from_be(rh.seq), + ack: u32::from_be(rh.ack), + flags: doffset_flags & 0x01ff, + wsize: u16::from_be(rh.wsize), + uptr: u16::from_be(rh.uptr), + options: options, + data: data[offset_2..].to_vec() + }; + + Ok(res) + } + } + } + + fn serialize( + &self, + iph: &Ipv4PacketHeader, + w: &mut W) -> io::Result<()> { + let rh = RawTcpPacketHeader::new(iph, self); + try!(w.write_all(utils::as_bytes(&rh))); + try!(w.write_all(utils::slice_as_bytes(&self.options))); + w.write_all(&self.data) + } + + fn packet_type(&self) -> Ipv4PacketType { + Ipv4PacketType::TCP + } + + fn len(&self) -> usize { + let header_size = mem::size_of::(); + let option_size = mem::size_of::(); + header_size + + option_size * self.options.len() + + self.data.len() + } +} + +/// Packed representation of the TCP packet header. +#[repr(packed)] +#[derive(Debug, Copy, Clone)] +struct RawTcpPacketHeader { + sport: u16, + dport: u16, + seq: u32, + ack: u32, + doffset_flags: u16, + wsize: u16, + checksum: u16, + uptr: u16, +} + +impl RawTcpPacketHeader { + /// Create a new raw TCP packet header. + fn new(iph: &Ipv4PacketHeader, tcp: &TcpPacket) -> RawTcpPacketHeader { + let mut ph = PseudoIpv4PacketHeader::new(iph); + let doffset = 5 + tcp.options.len() as u16; + let doffset_flags = (doffset << 12) | (tcp.flags & 0x01ff); + let tcp_len = (doffset << 2) + tcp.data.len() as u16; + let mut rh = RawTcpPacketHeader { + sport: tcp.sport.to_be(), + dport: tcp.dport.to_be(), + seq: tcp.seq.to_be(), + ack: tcp.ack.to_be(), + doffset_flags: doffset_flags.to_be(), + wsize: tcp.wsize.to_be(), + checksum: 0, + uptr: 0 + }; + + ph.tcp_len = tcp_len.to_be(); + + let mut sum = raw::utils::sum_type(&ph); + sum += raw::utils::sum_type(&rh); + sum += raw::utils::sum_slice(&tcp.options); + sum += raw::utils::sum_slice(&tcp.data); + + rh.checksum = raw::utils::sum_to_checksum(sum).to_be(); + + rh + } +} + +/// Pseudo IPv4 packet header for TCP checksum computation. +#[repr(packed)] +#[allow(dead_code)] +#[derive(Debug, Copy, Clone)] +struct PseudoIpv4PacketHeader { + src: [u8; 4], + dst: [u8; 4], + res: u8, + protocol: u8, + tcp_len: u16, +} + +impl PseudoIpv4PacketHeader { + /// Create a new pseudo IPv4 packet header. + fn new(iph: &Ipv4PacketHeader) -> PseudoIpv4PacketHeader { + PseudoIpv4PacketHeader { + src: iph.src.octets(), + dst: iph.dst.octets(), + res: 0, + protocol: iph.protocol, + tcp_len: 0 + } + } +} + +#[cfg(feature = "discovery")] +pub mod scanner { + use super::*; + + use std::slice; + + use net::raw::pcap; + + use std::ops::Range; + use std::net::Ipv4Addr; + + use utils::Serialize; + use net::utils::WriteBuffer; + use net::raw::ip::Ipv4Packet; + use net::raw::pcap::ThreadingContext; + use net::raw::devices::EthernetDevice; + use net::raw::ether::{MacAddr, EtherPacket}; + use net::raw::pcap::{Scanner, PacketGenerator}; + + /// TCP port range. + #[derive(Debug, Clone, Eq, PartialEq)] + pub enum PortRange { + Single(u16), + Range(Range), + } + + impl PortRange { + /// Convert TCP port range into a Range instance. + fn to_range(&self) -> Range { + match self { + &PortRange::Range(ref r) => r.clone(), + &PortRange::Single(p) => p..(p + 1), + } + } + } + + impl From for PortRange { + fn from(p: u16) -> PortRange { + PortRange::Single(p) + } + } + + impl From> for PortRange { + fn from(r: Range) -> PortRange { + PortRange::Range(r) + } + } + + /// Collection of ports for PortScanner. (This collection does not handle + /// port overlaps.) + #[derive(Debug, Clone)] + pub struct PortCollection { + ranges: Vec, + } + + impl PortCollection { + /// Create a new empty collection of ports. + pub fn new() -> PortCollection { + PortCollection { + ranges: Vec::new() + } + } + + /// Add a single port or a range. + pub fn add(mut self, v: T) -> Self + where PortRange: From { + self.ranges.push(PortRange::from(v)); + self + } + + /// Get port collection iterator. + pub fn iter<'a>(&'a self) -> PortCollectionIterator<'a> { + PortCollectionIterator::new(self.ranges.iter()) + } + } + + /// Port collection iterator. + #[derive(Clone)] + pub struct PortCollectionIterator<'a> { + iter: slice::Iter<'a, PortRange>, + last: u16, + port: u16, + } + + impl<'a> PortCollectionIterator<'a> { + fn new( + iter: slice::Iter<'a, PortRange>) -> PortCollectionIterator<'a> { + PortCollectionIterator { + iter: iter, + last: 0, + port: 0 + } + } + } + + impl<'a> Iterator for PortCollectionIterator<'a> { + type Item = u16; + + fn next(&mut self) -> Option { + if self.port >= self.last { + if let Some(r) = self.iter.next() { + let r = r.to_range(); + self.port = r.start; + self.last = r.end; + } + } + + if self.port < self.last { + let res = self.port; + self.port += 1; + Some(res) + } else { + None + } + } + } + + type Host = (MacAddr, Ipv4Addr); + type Service = (MacAddr, Ipv4Addr, u16); + + /// TCP port scanner. + pub struct TcpPortScanner { + device: EthernetDevice, + scanner: Scanner, + } + + impl TcpPortScanner { + /// Scan given IPv4 hosts for open ports from a given collection of + /// ports. (It's expected the hosts are accessible through a local + /// Ethernet network, the EthernetDevice and the MAC address must + /// be also specified.) + pub fn scan_ipv4_hosts>( + tc: ThreadingContext, + device: &EthernetDevice, + hosts: HI, + endpoints: &PortCollection) -> pcap::Result> { + TcpPortScanner::new(tc, device) + .scan(hosts, endpoints) + } + + /// Create a new port scanner. + fn new( + tc: ThreadingContext, + device: &EthernetDevice) -> TcpPortScanner { + TcpPortScanner { + device: device.clone(), + scanner: Scanner::new(tc, &device.name) + } + } + + /// Scan a given IPv4 hosts for open ports from a given collection of + /// ports. + fn scan>( + &mut self, + hosts: HI, + endpoints: &PortCollection) -> pcap::Result> { + let sport = 61234; + let mut gen = TcpPortScannerPacketGenerator::new( + &self.device, hosts, sport, endpoints); + let filter = format!("tcp and dst host {} and dst port {} and \ + tcp[tcpflags] & tcp-syn != 0 and \ + tcp[tcpflags] & tcp-ack != 0", + self.device.ip_addr, sport); + let packets = try!(self.scanner.sr(&filter, + &mut gen, 1000000000)); + + let mut services = Vec::new(); + + for p in packets { + if let Ok(ep) = + EtherPacket::>::parse(&p) { + let ipp = &ep.body; + let tcpp = &ipp.body; + let hsrc = ep.header.src; + let psrc = ipp.header.src; + services.push((hsrc, psrc, tcpp.sport)); + } + } + + Ok(services) + } + } + + /// Packet generator for the TCP port scanner. + struct TcpPortScannerPacketGenerator<'a, HI: Iterator> { + device: EthernetDevice, + hosts: HI, + sport: u16, + endpoints: &'a PortCollection, + host: Option, + ports: PortCollectionIterator<'a>, + buffer: WriteBuffer, + } + + impl<'a, HI: Iterator> TcpPortScannerPacketGenerator<'a, HI> { + /// Create a new packet generator. + fn new( + device: &EthernetDevice, + mut hosts: HI, + sport: u16, + endpoints: &'a PortCollection) -> TcpPortScannerPacketGenerator<'a, HI> { + let host = hosts.next(); + let ports = endpoints.iter(); + TcpPortScannerPacketGenerator { + device: device.clone(), + hosts: hosts, + sport: sport, + endpoints: endpoints, + host: host, + ports: ports, + buffer: WriteBuffer::new(0) + } + } + } + + impl<'a, HI> PacketGenerator for TcpPortScannerPacketGenerator<'a, HI> + where HI: Iterator { + fn next<'b>(&'b mut self) -> Option<&'b [u8]> { + if let Some((hdst, pdst)) = self.host { + if let Some(port) = self.ports.next() { + let tcpp = TcpPacket::new( + self.sport, port, TCP_FLAG_SYN, &[]); + let ipp = Ipv4Packet::create( + self.device.ip_addr, pdst, 64, tcpp); + let pkt = EtherPacket::create( + self.device.mac_addr, hdst, ipp); + + self.buffer.clear(); + + pkt.serialize(&mut self.buffer) + .unwrap(); + + Some(self.buffer.as_bytes()) + } else { + self.host = self.hosts.next(); + self.ports = self.endpoints.iter(); + self.next() + } + } else { + None + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "discovery")] + use super::scanner::PortCollection; + + use net::raw::ip::*; + use utils::Serialize; + use net::utils::WriteBuffer; + use net::raw::ether::{MacAddr, EtherPacket}; + + use std::net::Ipv4Addr; + + #[test] + #[cfg(feature = "discovery")] + fn test_port_collection() { + let col = PortCollection::new() + .add(3) + .add(5) + .add(10..15) + .add(100); + + let mut iter = col.iter(); + + let ports = vec![3, 5, 10, 11, 12, 13, 14, 100]; + + for p in ports { + assert_eq!(p, iter.next().unwrap()); + } + } + + #[test] + fn test_tcp_packet() { + let sip = Ipv4Addr::new(192, 168, 3, 7); + let dip = Ipv4Addr::new(192, 168, 8, 1); + let mac = MacAddr::new(0, 0, 0, 0, 0, 0); + + let data = [1, 2, 3]; + + let tcp = TcpPacket::new(10, 20, TCP_FLAG_FIN | TCP_FLAG_SYN, &data); + let ip = Ipv4Packet::create(sip, dip, 64, tcp); + let pkt = EtherPacket::create(mac, mac, ip); + + let mut buf = WriteBuffer::new(0); + + pkt.serialize(&mut buf) + .unwrap(); + + let ep2 = EtherPacket::>::parse(buf.as_bytes()) + .unwrap(); + + let p1 = &pkt.body.body; + let p2 = &ep2.body.body; + + assert_eq!(p1.sport, p2.sport); + assert_eq!(p1.dport, p2.dport); + assert_eq!(p1.seq, p2.seq); + assert_eq!(p1.ack, p2.ack); + assert_eq!(p1.flags, p2.flags); + assert_eq!(p1.wsize, p2.wsize); + assert_eq!(p1.uptr, p2.uptr); + assert_eq!(p1.options, p2.options); + assert_eq!(p1.data, p2.data); + } +} diff --git a/src/net/raw/utils.rs b/src/net/raw/utils.rs new file mode 100644 index 0000000..01cce00 --- /dev/null +++ b/src/net/raw/utils.rs @@ -0,0 +1,139 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Common functions used throughout the `net::raw::*` modules. + +use std::mem; +use std::slice; + +use std::net::Ipv4Addr; + +/// Sum a given Sized type instance as 16-bit unsigned big endian numbers. +pub fn sum_type(data: &T) -> u32 { + let size = mem::size_of::(); + let ptr = data as *const T; + unsafe { + sum_raw_be(ptr as *const u8, size) + } +} + +/// Sum a given slice of Sized type instances as 16-bit unsigned big endian +/// numbers. +pub fn sum_slice(data: &[T]) -> u32 { + let size = mem::size_of::(); + let ptr = data.as_ptr(); + unsafe { + sum_raw_be(ptr as *const u8, size * data.len()) + } +} + +/// Sum given raw data as 16-bit unsigned big endian numbers. +pub unsafe fn sum_raw_be(data: *const u8, size: usize) -> u32 { + let sdata = slice::from_raw_parts(data as *const u16, size >> 1); + let slice = slice::from_raw_parts(data, size); + let mut sum: u32 = 0; + for w in sdata { + sum += u16::from_be(*w) as u32; + } + + if (size & 0x01) != 0 { + sum + ((slice[size - 1] as u32) << 8) + } else { + sum + } +} + +/// Convert given 32-bit unsigned sum into 16-bit unsigned checksum. +pub fn sum_to_checksum(sum: u32) -> u16 { + let mut checksum = sum; + while (checksum & 0xffff0000) != 0 { + let hw = checksum >> 16; + let lw = checksum & 0xffff; + checksum = lw + hw; + } + + !checksum as u16 +} + +/// Convert a given slice of bytes into IPv4 address. +pub fn slice_to_ipv4addr(slice: &[u8]) -> Ipv4Addr { + if slice.len() < 4 { + panic!("slice is too short"); + } else { + let ptr = slice.as_ptr() as *const u32; + let addr = unsafe { u32::from_be(*ptr) }; + Ipv4Addr::from(addr) + } +} + +/// Convert a given IPv4 address into big endian 32-bit unsigned number. +pub fn ipv4addr_to_u32(addr: &Ipv4Addr) -> u32 { + let octets = addr.octets(); + let nr: u32 = unsafe { mem::transmute(octets) }; + nr.to_be() +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::net::Ipv4Addr; + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + #[repr(packed)] + struct TestType { + b1: u8, + b2: u8, + } + + #[test] + fn test_sum_type() { + let val = TestType { b1: 1, b2: 2 }; + assert_eq!(0x0102, sum_type(&val)); + } + + #[test] + fn test_sum_slice() { + let val = TestType { b1: 1, b2: 2 }; + let vec = vec![val, val]; + assert_eq!(0x0204, sum_slice(&vec)); + } + + #[test] + fn test_sum_to_checksum() { + assert_eq!(!0x00003333, sum_to_checksum(0x11112222)); + } + + #[test] + #[should_panic(expected = "slice is too short")] + fn test_slice_to_ipv4addr_1() { + let buffer = [0u8; 3]; + slice_to_ipv4addr(&buffer); + } + + #[test] + fn test_slice_to_ipv4addr_2() { + let buffer = [192, 168, 2, 3]; + let addr = slice_to_ipv4addr(&buffer); + + assert_eq!(buffer, addr.octets()); + } + + #[test] + fn test_ipv4addr_to_u32() { + let addr = Ipv4Addr::new(192, 168, 2, 5); + assert_eq!(0xc0a80205, ipv4addr_to_u32(&addr)); + } +} + diff --git a/src/net/rtsp.rs b/src/net/rtsp.rs new file mode 100644 index 0000000..115c13f --- /dev/null +++ b/src/net/rtsp.rs @@ -0,0 +1,790 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! RTSP client definitions (only OPTIONS and DESCRIBE methods are currently +//! implemented. + +use std::io; +use std::fmt; +use std::num; +use std::result; +use std::str; + +use std::error::Error; +use std::str::FromStr; +use std::net::SocketAddr; +use std::collections::HashMap; +use std::io::{Read, Write}; +use std::fmt::{Display, Formatter, Debug}; + +use mio; + +use regex::Regex; + +use mio::{EventLoop, Handler, Token, EventSet, PollOpt}; +use mio::tcp::TcpStream; + +/// Error returned by RTSP client. +#[derive(Debug, Clone)] +pub struct RtspError { + msg: String, +} + +impl Error for RtspError { + fn description(&self) -> &str { + &self.msg + } +} + +impl Display for RtspError { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + f.write_str(&self.msg) + } +} + +impl From for RtspError { + fn from(msg: String) -> RtspError { + RtspError { msg: msg } + } +} + +impl<'a> From<&'a str> for RtspError { + fn from(msg: &'a str) -> RtspError { + RtspError { msg: msg.to_string() } + } +} + +impl From for RtspError { + fn from(err: io::Error) -> RtspError { + RtspError::from(format!("IO error: {}", err.description())) + } +} + +impl From for RtspError { + fn from(_: mio::TimerError) -> RtspError { + RtspError::from(format!("timer error")) + } +} + +impl From for RtspError { + fn from(_: num::ParseIntError) -> RtspError { + RtspError::from("integer parsing error") + } +} + +impl From for RtspError { + fn from(_: str::Utf8Error) -> RtspError { + RtspError::from("UTF-8 parsing error") + } +} + +/// RTSP client result type. +pub type Result = result::Result; + +/// Header field type alias. +type Header = (String, String); + +/// RTSP method. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +enum Method { + OPTIONS, + DESCRIBE, +} + +impl Method { + /// Get method name. + fn name(self) -> &'static str { + match self { + Method::OPTIONS => "OPTIONS", + Method::DESCRIBE => "DESCRIBE", + } + } +} + +/// RTSP request. +struct Request { + method: Method, + host: SocketAddr, + path: String, + headers: Vec
, +} + +impl Request { + /// Create a new request. + fn new(method: Method, host: &SocketAddr, path: &str) -> Request { + Request { + method: method, + host: host.clone(), + path: path.to_string(), + headers: Vec::new() + } + } + + /// Add a new header field into the request. + fn add_header(mut self, header: (N, V)) -> Request + where N: ToString, V: ToString { + let (name, value) = header; + self.headers.push((name.to_string(), value.to_string())); + self + } +} + +impl Display for Request { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + try!(f.write_str(&format!("{} rtsp://{}{} RTSP/1.0\r\n", + self.method.name(), &self.host, &self.path))); + for &(ref name, ref val) in &self.headers { + try!(f.write_str(&format!("{}: {}\r\n", name, val))); + } + f.write_str("\r\n") + } +} + +/// RTSP response. +#[derive(Debug, Clone)] +pub struct Response { + pub header: ResponseHeader, + pub body: Vec, +} + +impl Response { + /// Create a new RTSP response. + fn new(header: ResponseHeader, body: Vec) -> Response { + Response { + header: header, + body: body, + } + } +} + +impl Display for Response { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + let body = String::from_utf8_lossy(&self.body); + try!(Display::fmt(&self.header, f)); + f.write_str(&body) + } +} + +/// RTSP client response header. +#[derive(Debug, Clone)] +pub struct ResponseHeader { + pub code: i32, + pub line: String, + headers: Vec
, + header_map: HashMap, +} + +impl ResponseHeader { + /// Create a new RTSP response header. + fn new( + code: i32, + line: String, + headers: Vec
) -> ResponseHeader { + let mut res = ResponseHeader { + code: code, + line: line, + headers: headers, + header_map: HashMap::new() + }; + + for i in 0..res.headers.len() { + let &(ref name, _) = res.headers.get(i).unwrap(); + res.header_map.insert(name.to_lowercase(), i); + } + + res + } + + /// Get response header value. + pub fn get(&self, name: &str) -> Option + where T::Err: Debug { + let key = name.to_lowercase(); + if let Some(i) = self.header_map.get(&key) { + let &(_, ref val) = self.headers.get(*i).unwrap(); + let res = T::from_str(val); + Some(res.unwrap()) + } else { + None + } + } + + /// Get response header value string without copying it. + pub fn get_str<'a>(&'a self, name: &str) -> Option<&'a str> { + let key = name.to_lowercase(); + if let Some(i) = self.header_map.get(&key) { + let &(_, ref val) = self.headers.get(*i).unwrap(); + Some(val) + } else { + None + } + } +} + +impl Display for ResponseHeader { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + try!(f.write_str(&format!("RTSP/1.0 {} {}\r\n", + self.code, &self.line))); + for &(ref name, ref value) in &self.headers { + try!(f.write_str(&format!("{}: {}\r\n", name, value))); + } + + f.write_str("\r\n") + } +} + +/// Header or continuation (convenience enum for header field parsing). +enum HeaderCont { + Header(Header), + Cont(String), + Empty, +} + +/// RTSP response header parser. +struct ResponseHeaderParser { + status_re: Regex, + header_re: Regex, + cont_re: Regex, +} + +impl ResponseHeaderParser { + /// Create a new response header parser. + fn new() -> ResponseHeaderParser { + ResponseHeaderParser { + status_re: Regex::new(r"^RTSP/1.0 (\d+) (.*)$").unwrap(), + header_re: Regex::new(r"^([^ :]+):\s*(.*)$").unwrap(), + cont_re: Regex::new(r"^\s+(.*)$").unwrap(), + } + } + + /// Parse a given response header string. + fn parse(&self, s: &str) -> Result { + let mut lines = s.split("\r\n"); + + let mut headers = Vec::new(); + let status_code; + let status; + + if let Some(line) = lines.next() { + let (sc, s) = try!(self.parse_status_line(line)); + status_code = sc; + status = s; + } else { + return Err(RtspError::from("RTSP status line is missing")); + } + + for line in lines { + match try!(self.parse_header_line(line)) { + HeaderCont::Empty => break, + HeaderCont::Header(h) => headers.push(h), + HeaderCont::Cont(c) => { + if let Some((name, val)) = headers.pop() { + headers.push((name, val + &c)); + } else { + return Err(RtspError::from( + "first RTSP header cannot be continuation")); + } + }, + } + } + + Ok(ResponseHeader::new(status_code, status, headers)) + } + + /// Parse RTSP status line. + fn parse_status_line(&self, line: &str) -> Result<(i32, String)> { + if let Some(caps) = self.status_re.captures(line) { + let status_code = caps.at(1).unwrap(); + let status = caps.at(2).unwrap(); + let sc_int = try!(i32::from_str(status_code)); + Ok((sc_int, status.to_string())) + } else { + Err(RtspError::from("invalid RTSP status line")) + } + } + + /// Parse RTSP header line. + fn parse_header_line(&self, line: &str) -> Result { + if line.is_empty() { + Ok(HeaderCont::Empty) + } else if let Some(caps) = self.header_re.captures(line) { + let name = caps.at(1).unwrap(); + let value = caps.at(2).unwrap(); + Ok(HeaderCont::Header((name.to_string(), value.to_string()))) + } else if let Some(caps) = self.cont_re.captures(line) { + let value = caps.at(1).unwrap(); + Ok(HeaderCont::Cont(value.to_string())) + } else { + Err(RtspError::from("invalid RTSP header line")) + } + } +} + +/// RTSP response parser. +struct ResponseParser { + header_parser: ResponseHeaderParser, + buffer: Vec, + buffer_limit: usize, + last_line: usize, + header: Option, + header_len: usize, + expected: usize, +} + +impl ResponseParser { + /// Create a new response parser. + fn new(buffer_limit: usize) -> ResponseParser { + ResponseParser { + header_parser: ResponseHeaderParser::new(), + buffer: Vec::new(), + buffer_limit: buffer_limit, + last_line: 0, + header: None, + header_len: 0, + expected: 0 + } + } + + /// Check if the last message is complete. + fn is_complete(&self) -> bool { + self.header.is_some() && self.expected == 0 + } + + /// Get last response. + fn response(&self) -> Option { + let header = if let Some(ref header) = self.header { + header.clone() + } else { + return None; + }; + + let body = if self.is_complete() { + self.buffer[self.header_len..].to_vec() + } else { + return None; + }; + + Some(Response::new(header, body)) + } + + /// Clear the current message. + fn clear(&mut self) { + self.buffer.clear(); + self.last_line = 0; + self.header = None; + self.header_len = 0; + self.expected = 0; + } + + /// Process a given chunk of data and return the number of bytes used. + fn add(&mut self, chunk: &[u8]) -> Result { + let mut pos = 0; + + while pos < chunk.len() && + (self.header.is_none() || self.expected > 0) { + if self.header.is_none() { + pos += try!(self.read_header(&chunk[pos..])); + if let Some(ref header) = self.header { + if let Some(len) = header.get::("Content-Length") { + self.expected = len; + } else { + self.expected = 0; + } + } + } else if self.expected > 0 { + let end = if (pos + self.expected) > chunk.len() { + chunk.len() + } else { + pos + self.expected + }; + // TODO: use resize (as soon as it is available) and memcpy + // here as it is more effective + self.buffer.extend(chunk[pos..end].iter()); + self.expected -= end - pos; + pos = end; + } + } + + Ok(pos) + } + + /// Read RTSP header. + fn read_header(&mut self, chunk: &[u8]) -> Result { + let mut pos = 0; + + while self.header.is_none() && pos < chunk.len() { + let (complete, used) = try!(self.read_line(&chunk[pos..])); + + pos += used; + + if complete { + let line_len = self.buffer.len() - self.last_line; + self.last_line = self.buffer.len(); + + if line_len == 2 { + let header_str = try!(str::from_utf8(&self.buffer)); + let header = try!(self.header_parser.parse(header_str)); + self.header = Some(header); + } + } + } + + Ok(pos) + } + + /// Read next line. + fn read_line(&mut self, chunk: &[u8]) -> Result<(bool, usize)> { + let mut complete = false; + let mut pos = 0; + + let mut last = match self.buffer[self.last_line..].last() { + Some(c) => Some(*c), + None => None + }; + + while !complete && pos < chunk.len() { + if self.buffer.len() >= self.buffer_limit { + return Err(RtspError::from( + "unable to parse RTSP response, buffer limit exceeded")); + } + + let c = chunk[pos]; + self.buffer.push(c); + pos += 1; + + if let Some(last) = last { + if last == 0x0d && c == 0x0a { + complete = true; + } + } + + last = Some(c); + } + + Ok((complete, pos)) + } +} + +/// RTSP client. +pub struct Client { + connection: ClientHandler, + event_loop: EventLoop, + endpoint: SocketAddr, +} + +impl Client { + /// Create a new RTSP client for a given remote service. + pub fn new(addr: SocketAddr) -> Result { + let stream = try!(TcpStream::connect(&addr)); + let mut event_loop = try!(EventLoop::new()); + let connection = try!(ClientHandler::new(stream, &mut event_loop)); + let client = Client { + connection: connection, + event_loop: event_loop, + endpoint: addr + }; + + Ok(client) + } + + /// Set timeout for read and write operations. + pub fn set_timeout(&mut self, ms: Option) { + self.connection.set_timeout(ms) + } + + /// Send OPTIONS command. + pub fn options(&mut self) -> Result { + let request = Request::new(Method::OPTIONS, &self.endpoint, "/") + .add_header(("CSeq", 1)); + + try!(self.connection.send(&request, &mut self.event_loop)); + + self.connection.read(&mut self.event_loop) + } + + /// Send DESCRIBE command. + pub fn describe(&mut self, path: &str) -> Result { + let request = Request::new(Method::DESCRIBE, &self.endpoint, path) + .add_header(("CSeq", 1)); + + try!(self.connection.send(&request, &mut self.event_loop)); + + self.connection.read(&mut self.event_loop) + } +} + +/// RTSP client connection handler. +struct ClientHandler { + stream: TcpStream, + timeout: Option, + buffer: Box<[u8]>, + buffered: usize, + read: usize, + parser: ResponseParser, + request: Option>, + sent: usize, + err: Option, +} + +impl ClientHandler { + /// Create a new connection handler. + fn new( + stream: TcpStream, + event_loop: &mut EventLoop) -> Result { + let mut events = EventSet::all(); + events.remove(EventSet::readable()); + events.remove(EventSet::writable()); + try!(event_loop.register(&stream, Token(0), + events, PollOpt::level())); + + let res = ClientHandler { + stream: stream, + timeout: None, + buffer: Box::new([0u8; 4096]), + buffered: 0, + read: 0, + parser: ResponseParser::new(4096), + request: None, + sent: 0, + err: None + }; + + Ok(res) + } + + /// Set send/receive timeout. + fn set_timeout(&mut self, ms: Option) { + self.timeout = ms; + } + + /// Send a given request. + fn send( + &mut self, + request: &Request, + event_loop: &mut EventLoop) -> Result<()> { + self.init(Some(request)); + + let mut events = EventSet::all(); + events.remove(EventSet::readable()); + try!(event_loop.reregister(&self.stream, Token(0), + events, PollOpt::level())); + + let timeout = match self.timeout { + Some(ms) => Some(try!(event_loop.timeout_ms(0, ms))), + None => None + }; + + try!(event_loop.run(self)); + + if let Some(timeout) = timeout { + event_loop.clear_timeout(timeout); + } + + if let Some(ref err) = self.err { + Err(err.clone()) + } else { + Ok(()) + } + } + + /// Read RTSP response. + fn read(&mut self, event_loop: &mut EventLoop) -> Result { + self.init(None); + + let mut events = EventSet::all(); + events.remove(EventSet::writable()); + try!(event_loop.reregister(&self.stream, Token(0), + events, PollOpt::level())); + + let timeout = match self.timeout { + Some(ms) => Some(try!(event_loop.timeout_ms(0, ms))), + None => None + }; + + try!(event_loop.run(self)); + + if let Some(timeout) = timeout { + event_loop.clear_timeout(timeout); + } + + if let Some(ref err) = self.err { + Err(err.clone()) + } else if let Some(response) = self.parser.response() { + Ok(response) + } else { + Err(RtspError::from("unable to get server response")) + } + } + + /// Initialize handler. + fn init(&mut self, request: Option<&Request>) { + self.parser.clear(); + + self.sent = 0; + self.err = None; + + self.request = match request { + None => None, + Some(request) => { + let request_data = format!("{}", request); + Some(request_data.into_bytes()) + } + }; + } + + /// Check socket events. + fn socket_ready(&mut self, event_set: EventSet) -> Result { + let read_res = if self.request.is_none() && event_set.is_readable() { + try!(self.read_ready()) + } else { + false + }; + + let write_res = if self.request.is_some() && event_set.is_writable() { + try!(self.write_ready()) + } else { + false + }; + + if event_set.is_error() { + let socket_err = self.stream.take_socket_error(); + Err(RtspError::from(socket_err.unwrap_err())) + } else if event_set.is_hup() { + Ok(false) + } else { + Ok(read_res || write_res) + } + } + + /// Check read event. + fn read_ready(&mut self) -> Result { + // process any leftovers + let read = try!(self.process_buffer()); + + // check if we still need to read anything + if read { + self.buffered = try!(self.stream.read(&mut *self.buffer)); + self.read = 0; + + Ok(try!(self.process_buffer())) + } else { + Ok(false) + } + } + + /// Process buffered data. + fn process_buffer(&mut self) -> Result { + while self.read < self.buffered && !self.parser.is_complete() { + let buffer = &self.buffer[self.read..self.buffered]; + self.read += try!(self.parser.add(buffer)); + } + + Ok(!self.parser.is_complete()) + } + + /// Check write event. + fn write_ready(&mut self) -> Result { + let mut discard = false; + + if let Some(ref request) = self.request { + self.sent += try!(self.stream.write(&request[self.sent..])); + if self.sent >= request.len() { + discard = true; + } + } + + if discard { + self.request = None; + } + + Ok(self.request.is_some()) + } +} + +impl Handler for ClientHandler { + type Timeout = u32; + type Message = (); + + fn ready( + &mut self, + event_loop: &mut EventLoop, + _: Token, + event_set: EventSet) { + match self.socket_ready(event_set) { + Ok(true) => (), + Ok(false) => event_loop.shutdown(), + Err(err) => { + self.err = Some(err); + event_loop.shutdown(); + } + } + } + + fn timeout(&mut self, event_loop: &mut EventLoop, _: u32) { + self.err = Some(RtspError::from("connection timeout")); + event_loop.shutdown(); + } +} + +#[cfg(test)] +use std::net::ToSocketAddrs; + +#[cfg(test)] +#[test] +fn test_rtsp_request() { + let addr = "127.0.0.1:554".to_socket_addrs() + .unwrap() + .next() + .unwrap(); + + let request = Request::new(Method::DESCRIBE, &addr, "/foo") + .add_header(("CSeq", 1)) + .add_header(("Connection", "close")); + + let expected = "DESCRIBE rtsp://127.0.0.1:554/foo RTSP/1.0\r\n".to_string() + + "CSeq: 1\r\n" + + "Connection: close\r\n" + + "\r\n"; + + let msg = format!("{}", request); + + assert_eq!(expected, msg); +} + +#[cfg(test)] +#[test] +fn test_rtsp_response() { + let mut header_fields = Vec::new(); + header_fields.push(("CSeq".to_string(), "1".to_string())); + + let header = ResponseHeader::new(200, "OK".to_string(), header_fields); + let body = "hello".as_bytes().to_vec(); + + let response = Response::new(header, body); + + let expected = "RTSP/1.0 200 OK\r\n".to_string() + + "CSeq: 1\r\n" + + "\r\n" + + "hello"; + + let msg = format!("{}", response); + + assert_eq!(expected, msg); + + let parser = ResponseHeaderParser::new(); + let response = parser.parse(&expected).unwrap(); + + assert_eq!(response.code, 200); + assert_eq!(response.line, "OK"); + assert_eq!(response.get("cseq"), Some(1)); +} diff --git a/src/net/utils.rs b/src/net/utils.rs new file mode 100644 index 0000000..d382815 --- /dev/null +++ b/src/net/utils.rs @@ -0,0 +1,183 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Common networking utils. + +use std::io; +use std::ptr; + +use std::io::Write; + +use time; + +/// Timeout provider for various network protocols. +#[derive(Debug)] +pub struct Timeout { + timeout: Option, +} + +impl Timeout { + /// Create a new instance of Timeout. The initial state is reset. + pub fn new() -> Timeout { + Timeout { + timeout: None + } + } + + /// Clear the timeout (i.e. the check() method will always return true + /// until the timeout is set). + pub fn clear(&mut self) -> &mut Self { + self.timeout = None; + self + } + + /// Set the timeout. + /// + /// The timeout will expire after a specified delay in miliseconds. + pub fn set(&mut self, delay_ms: u64) -> &mut Self { + self.timeout = Some(time::precise_time_ns() + delay_ms * 1000000); + self + } + + /// Check if the timeout has already expired. + /// + /// The method returns false if the timeout has already expired, otherwise + /// true is returned. + pub fn check(&self) -> bool { + match self.timeout { + Some(t) => time::precise_time_ns() <= t, + None => true + } + } +} + +/// Writer that can be used for buffering data. +pub struct WriteBuffer { + buffer: Vec, + capacity: usize, + offset: usize, + used: usize, +} + +impl WriteBuffer { + /// Create a new buffer with a given capacity. Note that the capacity is + /// only a soft limit. The buffer will always allow you to write more than + /// its capacity. + pub fn new(capacity: usize) -> WriteBuffer { + let mut res = WriteBuffer { + buffer: Vec::with_capacity(capacity), + capacity: capacity, + offset: 0, + used: 0 + }; + + // TODO: replace this with resize (after it's stabilized) + let buf_capacity = res.buffer.capacity(); + unsafe { + res.buffer.set_len(buf_capacity); + } + + res + } + + /// Check if the buffer is full. + pub fn is_full(&self) -> bool { + self.used >= self.capacity + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.used == 0 + } + + /// Get number of bytes availalbe until the soft limit is reached. + pub fn available(&self) -> usize { + if self.is_full() { + 0 + } else { + self.capacity - self.used + } + } + + /// Get number of buffered bytes. + pub fn buffered(&self) -> usize { + self.used + } + + /// Get slice of bytes of the currently buffered data. + pub fn as_bytes(&self) -> &[u8] { + let start = self.offset; + let end = start + self.used; + &self.buffer[start..end] + } + + /// Drop a given number of bytes from the buffer. + pub fn drop(&mut self, count: usize) { + if count > self.used { + self.offset += self.used; + self.used = 0; + } else { + self.offset += count; + self.used -= count; + } + } + + /// Drop all buffered data. + pub fn clear(&mut self) { + self.offset += self.used; + self.used = 0; + } +} + +impl Write for WriteBuffer { + /// Write given data into the buffer. + fn write(&mut self, data: &[u8]) -> io::Result { + // expand buffer if needed + let buf_capacity = self.buffer.capacity(); + if (self.used + data.len()) > buf_capacity { + // TODO: replace this with resize (after it's stabilized) + self.buffer.reserve(self.used + data.len() - buf_capacity); + let buf_capacity = self.buffer.capacity(); + unsafe { + self.buffer.set_len(buf_capacity); + } + } + + // shift the buffered data to the left if needed + let buf_capacity = self.buffer.capacity(); + if (self.offset + self.used + data.len()) > buf_capacity { + let dst = self.buffer.as_mut_ptr(); + unsafe { + let src = dst.offset(self.offset as isize); + ptr::copy(src, dst, self.used); + } + self.offset = 0; + } + + // write given data + let offset = self.offset + self.used; + let mut buffer = &mut self.buffer[offset..]; + buffer.write_all(data) + .unwrap(); + + self.used += data.len(); + + Ok(data.len()) + } + + /// Do nothing. + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/src/utils/config.rs b/src/utils/config.rs new file mode 100644 index 0000000..e668ac0 --- /dev/null +++ b/src/utils/config.rs @@ -0,0 +1,283 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Arrow Box config definitions. + +use std::io; +use std::fmt; +use std::result; + +use std::fs::File; +use std::borrow::Cow; +use std::error::Error; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::fmt::{Display, Formatter}; + +use utils; +use net::raw::ether; + +use net::arrow::protocol::{Service, ServiceTable}; + +use uuid; + +use uuid::Uuid; + +use rustc_serialize::json; + +/// Arrow configuration loading/parsing/saving error. +#[derive(Debug, Clone)] +pub struct ConfigError { + msg: String, +} + +impl Error for ConfigError { + fn description(&self) -> &str { + &self.msg + } +} + +impl Display for ConfigError { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + f.write_str(self.description()) + } +} + +impl<'a> From<&'a str> for ConfigError { + fn from(msg: &'a str) -> ConfigError { + ConfigError { msg: msg.to_string() } + } +} + +impl From for ConfigError { + fn from(err: io::Error) -> ConfigError { + ConfigError::from(err.description()) + } +} + +impl From for ConfigError { + fn from(err: json::DecoderError) -> ConfigError { + ConfigError::from(err.description()) + } +} + +impl From for ConfigError { + fn from(err: json::EncoderError) -> ConfigError { + ConfigError::from(err.description()) + } +} + +impl From for ConfigError { + fn from(err: uuid::ParseError) -> ConfigError { + ConfigError { msg: format!("{}", err) } + } +} + +impl From for ConfigError { + fn from(err: ether::AddrParseError) -> ConfigError { + ConfigError::from(err.description()) + } +} + +/// Type alias for Arrow configuration results. +pub type Result = result::Result; + +/// JSON mapping for the Arrow client configuration. +#[derive(Debug, Clone, RustcDecodable, RustcEncodable)] +struct JsonConfig<'a> { + uuid: String, + passwd: String, + version: usize, + svc_table: Cow<'a, ServiceTable>, +} + +impl<'a> JsonConfig<'a> { + /// Create a new JsonConfig instance. + fn new( + uuid: String, + passwd: String, + version: usize, + svc_table: &'a ServiceTable) -> JsonConfig<'a> { + JsonConfig { + uuid: uuid, + passwd: passwd, + version: version, + svc_table: Cow::Borrowed(svc_table) + } + } + + /// Load configuration from a given file. + fn load(file: &str) -> Result> { + let mut content = String::new(); + let file = try!(File::open(file)); + let mut breader = BufReader::new(file); + + try!(breader.read_to_string(&mut content)); + + Ok(try!(json::decode(&content))) + } + + /// Save configuration into a given file. + fn save(&self, file: &str) -> Result<()> { + let content = try!(json::encode(self)); + let file = try!(File::create(file)); + let mut bwriter = BufWriter::new(file); + + try!(bwriter.write(content.as_bytes())); + + Ok(()) + } +} + +impl<'a> Display for JsonConfig<'a> { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + let content = try!(json::encode(self) + .or(Err(fmt::Error))); + f.write_str(&content) + } +} + +/// Arrow configuration. +#[derive(Debug, Clone)] +pub struct ArrowConfig { + uuid: Uuid, + passwd: Uuid, + version: usize, + svc_table: ServiceTable, +} + +impl ArrowConfig { + /// Create a new empty Arrow configuration. + pub fn new() -> ArrowConfig { + ArrowConfig { + uuid: Uuid::new_v4(), + passwd: Uuid::new_v4(), + version: 0, + svc_table: ServiceTable::new() + } + } + + /// Get Arrow Client UUID. + pub fn uuid(&self) -> [u8; 16] { + uuid_to_bytes(&self.uuid) + } + + /// Get formatted Arrow Client UUID. + pub fn uuid_string(&self) -> String { + self.uuid.to_hyphenated_string() + } + + /// Get Arrow Client password. + pub fn password(&self) -> [u8; 16] { + uuid_to_bytes(&self.passwd) + } + + /// Get current configuration version. + pub fn version(&self) -> usize { + self.version + } + + /// Get service according to its ID from the underlaying service table. + pub fn get(&self, id: u16) -> Option { + self.svc_table.get(id) + } + + /// Add a new service into the underlaying service table. + pub fn add(&mut self, svc: Service) -> Option { + self.svc_table.add(svc) + } + + /// Increment version of this config. + pub fn bump_version(&mut self) { + self.version += 1; + } + + /// Get a copy of the underlaying service table. + pub fn service_table(&self) -> ServiceTable { + self.svc_table.clone() + } + + /// Set contents of the service table to a given value. + pub fn reinit(&mut self, svc_table: ServiceTable) { + self.svc_table = svc_table + } + + /// Load configuration from a given file. + pub fn load(file: &str) -> Result { + let json = try!(JsonConfig::load(file)); + let uuid = try!(Uuid::parse_str(&json.uuid)); + let passwd = try!(Uuid::parse_str(&json.passwd)); + let svc_table = json.svc_table.into_owned(); + + let res = ArrowConfig { + uuid: uuid, + passwd: passwd, + version: json.version, + svc_table: svc_table + }; + + Ok(res) + } + + /// Save configuration into a given file. + pub fn save(&self, file: &str) -> Result<()> { + let json = JsonConfig::new( + self.uuid.to_hyphenated_string(), + self.passwd.to_hyphenated_string(), + self.version, + &self.svc_table); + + json.save(file) + } +} + +impl Display for ArrowConfig { + fn fmt(&self, f: &mut Formatter) -> result::Result<(), fmt::Error> { + let json = JsonConfig::new( + self.uuid.to_hyphenated_string(), + self.passwd.to_hyphenated_string(), + self.version, + &self.svc_table); + + json.fmt(f) + } +} + +/// Application context. +#[derive(Debug, Clone)] +pub struct AppContext { + pub config: ArrowConfig, + pub scanning: bool, +} + +impl AppContext { + /// Create a new application context. + pub fn new(config: ArrowConfig) -> AppContext { + AppContext { + config: config, + scanning: false + } + } +} + +/// Transform a given UUID into an array of 16 bytes. +fn uuid_to_bytes(uuid: &Uuid) -> [u8; 16] { + let bytes = uuid.as_bytes(); + let mut res = [0u8; 16]; + + assert_eq!(bytes.len(), res.len()); + + utils::memcpy(&mut res, bytes); + + res +} diff --git a/src/utils/logger/mod.rs b/src/utils/logger/mod.rs new file mode 100644 index 0000000..e44fe67 --- /dev/null +++ b/src/utils/logger/mod.rs @@ -0,0 +1,127 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Logger definitions. + +macro_rules! log { + ($logger:expr, $severity:expr, $msg:expr) => { + $logger.log(file!(), line!(), $severity, $msg) + }; +} + +macro_rules! log_debug { + ($logger:expr, $msg:expr) => { + $logger.debug(file!(), line!(), $msg) + }; +} + +macro_rules! log_info { + ($logger:expr, $msg:expr) => { + $logger.info(file!(), line!(), $msg) + }; +} + +macro_rules! log_warn { + ($logger:expr, $msg:expr) => { + $logger.warn(file!(), line!(), $msg) + }; +} + +macro_rules! log_error { + ($logger:expr, $msg:expr) => { + $logger.error(file!(), line!(), $msg) + }; +} + +pub mod syslog; + +/// Log message severity. +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +pub enum Severity { + DEBUG = 0, + INFO = 1, + WARN = 2, + ERROR = 3 +} + +const DEBUG: Severity = Severity::DEBUG; +const INFO: Severity = Severity::INFO; +const WARN: Severity = Severity::WARN; +const ERROR: Severity = Severity::ERROR; + +/// Common trait for application loggers. +pub trait Logger { + /// Log a given message with a given severity. + fn log(&mut self, file: &str, line: u32, s: Severity, msg: &str); + + /// Set minimum log level. + /// + /// Messages with lover level will be discarded. + fn set_level(&mut self, s: Severity) -> &mut Self; + + /// Get minimum log level. + fn get_level(&self) -> Severity; + + /// Log a given debug message. + fn debug(&mut self, file: &str, line: u32, msg: &str) { + self.log(file, line, DEBUG, msg) + } + + /// Log a given info message. + fn info(&mut self, file: &str, line: u32, msg: &str) { + self.log(file, line, INFO, msg) + } + + /// Log a given warning message. + fn warn(&mut self, file: &str, line: u32, msg: &str) { + self.log(file, line, WARN, msg) + } + + /// Log a given error message. + fn error(&mut self, file: &str, line: u32, msg: &str) { + self.log(file, line, ERROR, msg) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestLogger { + last_severity: Severity, + } + + impl Logger for TestLogger { + fn log(&mut self, _: &str, _: u32, s: Severity, _: &str) { + self.last_severity = s; + } + + fn set_level(&mut self, _: Severity) -> &mut Self { self } + fn get_level(&self) -> Severity { Severity::DEBUG } + } + + #[test] + fn test_logger() { + let mut logger = TestLogger { last_severity: Severity::DEBUG }; + + log_error!(logger, "msg"); + assert_eq!(Severity::ERROR, logger.last_severity); + log_warn!(logger, "msg"); + assert_eq!(Severity::WARN, logger.last_severity); + log_info!(logger, "msg"); + assert_eq!(Severity::INFO, logger.last_severity); + log_debug!(logger, "msg"); + assert_eq!(Severity::DEBUG, logger.last_severity); + } +} diff --git a/src/utils/logger/syslog.rs b/src/utils/logger/syslog.rs new file mode 100644 index 0000000..371cc13 --- /dev/null +++ b/src/utils/logger/syslog.rs @@ -0,0 +1,91 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Syslog definitions. + +use std::ptr; + +use std::ffi::CString; +use std::sync::{Once, ONCE_INIT}; + +use utils::logger::{Logger, Severity}; + +use libc::{c_char, c_int, c_void}; + +const LOG_PID: c_int = 0x01; +const LOG_CONS: c_int = 0x02; + +const LOG_USER: c_int = 0x08; + +const LOG_ERR: c_int = 3; +const LOG_WARNING: c_int = 4; +const LOG_INFO: c_int = 6; +const LOG_DEBUG: c_int = 7; + +static SYSLOG_INIT: Once = ONCE_INIT; + +#[link(name = "c")] +extern "C" { + fn openlog(ident: *const c_char, option: c_int, facility: c_int) -> c_void; + fn syslog(priority: c_int, format: *const c_char, ...) -> c_void; +} + +/// Syslog logger structure. +#[derive(Debug, Clone)] +pub struct Syslog { + level: Severity, +} + +/// Create a new syslog logger with log level set to INFO. +pub fn new() -> Syslog { + SYSLOG_INIT.call_once(|| unsafe { + openlog(ptr::null(), LOG_CONS | LOG_PID, LOG_USER); + }); + + Syslog { + level: Severity::INFO + } +} + +impl Logger for Syslog { + fn log(&mut self, file: &str, line: u32, s: Severity, msg: &str) { + let msg = format!("[{}:{}] {}", file, line, msg); + let cstr_fmt = CString::new("%s").unwrap(); + let cstr_msg = CString::new(msg).unwrap(); + let fmt_ptr = cstr_fmt.as_ptr() as *const c_char; + let msg_ptr = cstr_msg.as_ptr() as *const c_char; + + if s >= self.level { + unsafe { + match s { + Severity::DEBUG => syslog(LOG_DEBUG, fmt_ptr, msg_ptr), + Severity::INFO => syslog(LOG_INFO, fmt_ptr, msg_ptr), + Severity::WARN => syslog(LOG_WARNING, fmt_ptr, msg_ptr), + Severity::ERROR => syslog(LOG_ERR, fmt_ptr, msg_ptr) + } + }; + } + } + + fn set_level(&mut self, s: Severity) -> &mut Self { + self.level = s; + self + } + + fn get_level(&self) -> Severity { + self.level + } +} + +unsafe impl Send for Syslog { } diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..bcf9c71 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,300 @@ +// Copyright 2015 click2stream, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Common util functions. + +#[macro_use] +pub mod logger; + +pub mod config; + +use std::io; +use std::ptr; +use std::mem; +use std::fmt; +use std::slice; +use std::process; + +use std::ffi::CStr; +use std::error::Error; +use std::ops::Deref; +use std::io::Write; +use std::sync::{Arc, Mutex}; +use std::fmt::{Debug, Display, Formatter}; + +use utils::logger::{Logger, Severity}; + +/// General purpose runtime error. +#[derive(Debug, Clone)] +pub struct RuntimeError { + msg: String, +} + +impl Error for RuntimeError { + fn description(&self) -> &str { + &self.msg + } +} + +impl Display for RuntimeError { + fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { + f.write_str(&self.msg) + } +} + +impl<'a> From<&'a str> for RuntimeError { + fn from(msg: &'a str) -> RuntimeError { + RuntimeError { msg: msg.to_string() } + } +} + +impl From for RuntimeError { + fn from(msg: String) -> RuntimeError { + RuntimeError { msg: msg } + } +} + +/// Arc> shorthand. +#[derive(Clone)] +pub struct Shared { + object: Arc>, +} + +impl Shared { + /// Create a new shared object. + pub fn new(obj: T) -> Shared { + Shared { + object: Arc::new(Mutex::new(obj)) + } + } +} + +impl Deref for Shared { + type Target = Mutex; + + fn deref(&self) -> &Mutex { + self.object.deref() + } +} + +unsafe impl Send for Shared { } + +/// Common trait for serializable objects. +pub trait Serialize { + /// Serialize this object using a given writer. + fn serialize(&self, w: &mut W) -> io::Result<()>; +} + +impl Serialize for u8 { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for i8 { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for u16 { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for i16 { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for u32 { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for i32 { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for u64 { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for i64 { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for usize { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for isize { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(as_bytes(&self.to_be())) + } +} + +impl Serialize for Vec { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(self) + } +} + +impl<'a> Serialize for &'a [u8] { + fn serialize(&self, w: &mut W) -> io::Result<()> { + w.write_all(self) + } +} + +/// Efficient function for copying data from one slice to another. +/// +/// It copies all data from the src slice into the dst slice. +/// +/// # Panics +/// The function panics when src.len() > dst.len() +pub fn memcpy(dst: &mut [T], src: &[T]) { + assert!(src.len() <= dst.len()); + unsafe { + ptr::copy(src.as_ptr(), dst.as_mut_ptr(), src.len()); + } +} + +/// Get slice of bytes representing a given object. +pub fn as_bytes<'a, T: Sized>(val: &'a T) -> &'a [u8] { + let ptr = val as *const T; + let size = mem::size_of::(); + unsafe { + slice::from_raw_parts(ptr as *const u8, size) + } +} + +/// Convert a given slice of Sized type instances to a slice of bytes. +pub fn slice_as_bytes<'a, T: Sized>(data: &'a [T]) -> &'a [u8] { + let ptr = data.as_ptr(); + let size = mem::size_of::(); + unsafe { + slice::from_raw_parts(ptr as *const u8, size * data.len()) + } +} + +/// Convert a given typed pointer into a new vector (copying the dats). +pub unsafe fn vec_from_raw_parts( + ptr: *const T, + len: usize) -> Vec { + slice::from_raw_parts(ptr, len) + .to_vec() +} + +/// Convert a given C-string pointer to a new instance of String. +pub unsafe fn cstr_to_string(ptr: *const i8) -> String { + let cstr = CStr::from_ptr(ptr); + let slice = String::from_utf8_lossy(cstr.to_bytes()); + slice.to_string() +} + +/// Exit application printing a given error. +pub fn error(err: T, exit_code: i32) -> ! { + println!("ERROR: {}", err.description()); + process::exit(exit_code); +} + +/// Unwrap a given result or exit the process printing the error. +pub fn result_or_error(res: Result, exit_code: i32) -> T + where E: Error + Debug { + match res { + Ok(res) => res, + Err(err) => error(err, exit_code) + } +} + +/// Unwrap a given result or log an error with a given severity and return None. +pub fn result_or_log( + logger: &mut L, + severity: Severity, + res: Result) -> Option + where E: Error + Debug, + L: Logger { + match res { + Err(err) => { + log!(logger, severity, err.description()); + None + }, + Ok(res) => Some(res) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::ffi::CString; + use utils::logger::*; + + struct DummyLogger; + + impl Logger for DummyLogger { + fn log(&mut self, _: &str, _: u32, _: Severity, _: &str) { } + fn set_level(&mut self, _: Severity) -> &mut Self { self } + fn get_level(&self) -> Severity { Severity::DEBUG } + } + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + #[repr(packed)] + struct TestType { + b1: u8, + b2: u8, + } + + #[test] + fn test_vec_from_raw_parts() { + let val = TestType { b1: 1, b2: 2 }; + let vec = vec![val, val]; + let ptr = vec.as_ptr(); + let vec2 = unsafe { vec_from_raw_parts(ptr, vec.len()) }; + + assert_eq!(vec, vec2); + } + + #[test] + fn test_cstr_to_string() { + let cstr = CString::new("hello").unwrap(); + unsafe { + assert!("hello" == &cstr_to_string(cstr.as_ptr() as *const i8)); + assert!("world" != &cstr_to_string(cstr.as_ptr() as *const i8)); + } + } + + #[test] + fn test_result_or_error() { + assert_eq!(1, result_or_error::(Ok(1), 0)); + } + + #[test] + fn test_result_or_log() { + assert_eq!(Some(1), result_or_log::( + &mut DummyLogger, Severity::WARN, Ok(1))); + assert_eq!(None, result_or_log::( + &mut DummyLogger, Severity::WARN, Err(RuntimeError::from("foo")))); + } +}