-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.rs
260 lines (217 loc) · 7.19 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
extern crate ansi_term;
extern crate clap;
extern crate fern;
#[macro_use] extern crate log;
#[macro_use] extern crate mioco;
extern crate time;
use std::collections::HashMap;
use std::io::{self, Read, Write};
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use clap::{App, Arg};
use mioco::tcp::{TcpListener, TcpStream};
mod detect;
mod logger;
const DEFAULT_LISTEN_ADDR : &'static str = "127.0.0.1:5555";
struct Config {
listen_addr: SocketAddr,
upstreams: HashMap<&'static str, SocketAddr>,
timeout: i64,
}
impl Config {
fn upstream_for(&self, proto: &'static str) -> Option<SocketAddr> {
self.upstreams.get(proto).and_then(|s| Some(s.clone()))
}
}
fn handle_proxy(
mut client_conn: TcpStream,
proto: &'static str,
initial_buf: &[u8],
config: Arc<Config>
) -> io::Result<()> {
// If we don't have an upstream, we just skip it.
// TODO: fallback?
let addr = match config.upstream_for(proto) {
Some(us) => us,
None => {
warn!("No upstream for protocol '{}', dropping connection...", proto);
return Ok(());
},
};
let mut server_conn = try!(TcpStream::connect(&addr));
// Send the initial buffer (the bits we used for protocol detection).
try!(server_conn.write_all(initial_buf));
let mut buf = [0u8; 16 * 1024];
loop {
select!(
client_conn:r => {
let n = try!(client_conn.read(&mut buf));
if n == 0 {
break;
}
trace!("copying {} bytes from client --> server", n);
try!(server_conn.write_all(&buf[..n]));
},
server_conn:r => {
let n = try!(server_conn.read(&mut buf));
if n == 0 {
break;
}
trace!("copying {} bytes from server --> client", n);
try!(client_conn.write_all(&buf[..n]));
},
);
}
Ok(())
}
fn handle_connection(mut conn: TcpStream, config: Arc<Config>) -> io::Result<()> {
let mut buf = [0u8; 1024];
let mut nread = 0usize;
let peer = match conn.peer_addr() {
Ok(a) => format!("{}", a),
Err(_) => "<unknown>".to_string(),
};
// We have a global timeout - if we don't get either a detection or a full buffer by this
// point, we drop the connection.
let mut timer = mioco::timer::Timer::new();
timer.set_timeout(config.timeout);
loop {
// If our 'nread' value is full (i.e. we can't read more data), we just finish our loop.
if nread == buf.len() {
break;
}
select!(
conn:r => {
let n = try!(conn.read(&mut buf[nread..]));
if n == 0 {
// EOF
break;
}
nread += n;
},
timer:r => {
// Timeout :-(
trace!("Timing out connection from: {}", peer);
let _ = conn.shutdown(mioco::tcp::Shutdown::Both);
return Ok(());
},
);
// Run detection on the portion of the buffer we have read into.
let protocol = match detect::detect(&buf[..nread]) {
Some(p) => p,
None => continue,
};
debug!("Got protocol: {}", protocol);
return handle_proxy(conn, protocol, &buf[..nread], config);
}
// Run one final detect...
if let Some(protocol) = detect::detect(&buf[..nread]) {
debug!("Got protocol: {}", protocol);
return handle_proxy(conn, protocol, &buf[..nread], config);
}
// TODO: default / fallback?
warn!("Don't know how to handle connection from: {}", peer);
Ok(())
}
fn main() {
// Convert the protocols into a tuple of:
// (proto, argument name, help string)
let arg_names = detect::protocol_names().into_iter()
.map(|p| {
let arg_name = format!("{}-upstream", p);
let help = format!("Sets the upstream address for the protocol '{}'", p);
(p, arg_name, help)
})
.collect::<Vec<_>>();
let mut config = App::new("demuxrs")
.version("0.0.1")
.author("Andrew Dunham <andrew@du.nham.ca>")
.about("Simple protocol demultiplexer implemented in Rust")
.arg(Arg::with_name("debug")
.short("d")
.multiple(true)
.help("Sets the level of debugging information"))
.arg(Arg::with_name("timeout")
.short("t")
.long("timeout")
.help("Timeout (in milliseconds) for protocol detection"))
.arg(Arg::with_name("listen")
.short("l")
.long("listen")
.takes_value(true)
.help("The listen address in host:port form (default: localhost:5555)"));
// Manually build up the arguments list for each protocol.
for &(_, ref arg_name, ref help) in arg_names.iter() {
config = config.arg(
Arg::with_name(&*arg_name)
.long(&*arg_name)
.takes_value(true)
.help(&*help)
);
}
// Actually parse
let matches = config.get_matches();
logger::init_logger_config(&matches);
// Parse listen address.
let listen_addr = {
let s = matches.value_of("listen").unwrap_or(DEFAULT_LISTEN_ADDR);
match FromStr::from_str(s) {
Ok(a) => a,
Err(e) => {
error!("Invalid listen address '{}': {}", s, e);
return;
},
}
};
// Parse timeout
let timeout = {
let s = matches.value_of("timeout").unwrap_or("1000");
match FromStr::from_str(s) {
Ok(v) => v,
Err(e) => {
error!("Invalid timeout '{}': {}", s, e);
return;
},
}
};
// Parse the upstreams into SocketAddrs.
let mut config = Config {
listen_addr: listen_addr,
upstreams: HashMap::new(),
timeout: timeout,
};
for &(proto, ref arg_name, _) in arg_names.iter() {
let saddr = match matches.value_of(&*arg_name) {
Some(v) => v,
None => continue,
};
let addr: SocketAddr = match FromStr::from_str(saddr) {
Ok(a) => a,
Err(e) => {
error!("Invalid upstream address for protocol '{}': {}", proto, e);
continue;
},
};
debug!("Upstream address for protocol '{}': {}", proto, addr);
config.upstreams.insert(proto, addr);
}
mioco::start(move || {
let config = Arc::new(config);
let listener = match TcpListener::bind(&config.listen_addr) {
Ok(l) => l,
Err(e) => {
error!("Could not bind TCP listener to '{}': {}", config.listen_addr, e);
return Err(e);
},
};
info!("Starting demux server on {:?}", listener.local_addr().unwrap());
loop {
let conn = try!(listener.accept());
let c = config.clone();
mioco::spawn(move || {
handle_connection(conn, c)
});
}
});
}