Skip to content

Commit

Permalink
Make PeerId unique per connection
Browse files Browse the repository at this point in the history
- `PeerId` is now unique in that it always increments with each new connection,
  avoiding situations where it could be reused as connection slots get
  released. This is in order to avoid race conditions on the consumer side
  where a client leaves and a new one quickly gets his peer id, before the
  client can react and realize these are different peers with different states.
- Integration test for the above scenario.
  • Loading branch information
alfred-hodler committed Mar 6, 2023
1 parent 9f986da commit 07e9d0f
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 90 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Expand Up @@ -15,6 +15,7 @@ socks = ["dep:socks"]
[dependencies]
bitcoin = "0.29.2"
crossbeam-channel = "0.5.7"
intmap = "2.0.0"
log = "0.4.17"
mio = { version = "0.8.6", features = ["net", "os-poll"] }
slab = "0.4.8"
Expand Down
21 changes: 5 additions & 16 deletions src/peer.rs
@@ -1,26 +1,15 @@
/// Unique peer identifier. The user should assume these can be reused by different peers as
/// peers come and go, i.e. they are not assigned just once for the lifetime of the process.
/// Unique peer identifier. These are unique for the lifetime of the process and strictly
/// incrementing for each new connection. Even if the same peer (in terms of socket address)
/// connects multiple times, a new `PeerId` instance will be issued for each connection.
#[derive(Debug, Clone, Hash, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct PeerId(pub usize);
pub struct PeerId(pub u64);

impl PeerId {
pub fn value(&self) -> usize {
pub fn value(&self) -> u64 {
self.0
}
}

impl From<mio::Token> for PeerId {
fn from(token: mio::Token) -> Self {
Self(token.0)
}
}

impl From<PeerId> for mio::Token {
fn from(id: PeerId) -> Self {
Self(id.0)
}
}

impl std::fmt::Display for PeerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", self.0))
Expand Down
193 changes: 122 additions & 71 deletions src/reactor.rs
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant};

use bitcoin::network::message::RawNetworkMessage;
use intmap::IntMap;
use mio::net::{TcpListener, TcpStream};
use mio::{Events, Interest, Poll, Registry, Token, Waker};
use slab::Slab;
Expand Down Expand Up @@ -35,7 +36,9 @@ pub enum Command {
pub enum Event {
/// The reactor attempted to connect to a remote peer.
ConnectedTo {
/// The socket address that was connected to.
addr: SocketAddr,
/// The result of the connection attempt.
result: io::Result<PeerId>,
},
/// The reactor received a connection from a remote peer.
Expand Down Expand Up @@ -220,6 +223,12 @@ impl Handle {
}
}

/// Contains a stream along with its peer id.
struct Entry {
stream: MessageStream<TcpStream>,
peer_id: PeerId,
}

/// Runs the reactor in a loop until an error is produced or a shutdown command is received.
fn run<C: Connector + Sync + Send + 'static>(
Reactor {
Expand Down Expand Up @@ -248,10 +257,13 @@ fn run<C: Connector + Sync + Send + 'static>(
})
.collect::<std::io::Result<Vec<_>>>()?;

let mut streams: Slab<MessageStream<TcpStream>> = Slab::with_capacity(16);
let mut streams: Slab<Entry> = Slab::with_capacity(16);
let mut events = Events::with_capacity(1024);
let mut read_buf = [0; 1024 * 1024];
let mut last_maintenance = Instant::now();
let mut token_map: IntMap<Token> = IntMap::new();
let mut next_peer_id: u64 = 0;
let mut remove_stale: Vec<PeerId> = Vec::with_capacity(16);

loop {
poll.poll(&mut events, Some(Duration::from_secs(5)))?;
Expand All @@ -275,6 +287,8 @@ fn run<C: Connector + Sync + Send + 'static>(
let peer = add_stream(
poll.registry(),
&mut streams,
&mut token_map,
&mut next_peer_id,
stream,
config.stream_config.clone(),
)?;
Expand All @@ -295,53 +309,59 @@ fn run<C: Connector + Sync + Send + 'static>(
let _ = sender.send(Event::ConnectedTo { addr, result });
}

Command::Disconnect(peer) => match streams.try_remove(peer.value()) {
Some(mut stream) => {
poll.registry().deregister(stream.inner_mut())?;
Command::Disconnect(peer) => {
if token_map.contains_key(peer.value()) {
let mut entry = remove_stream(
poll.registry(),
&mut streams,
&mut token_map,
peer,
)?;

let _ = write(&mut stream, now);
let _ = stream.inner_mut().shutdown(std::net::Shutdown::Both);
let _ = write(&mut entry.stream, now);
let _ =
entry.stream.inner_mut().shutdown(std::net::Shutdown::Both);

log::info!("peer {peer}: disconnected");

let _ = sender.send(Event::Disconnected {
peer,
reason: DisconnectReason::Requested,
});
}
None => {
} else {
let _ = sender.send(Event::NoPeer(peer));
log::warn!("disconnect: peer {} not found", peer.value());
}
},
}

Command::Message(peer, message) => {
match streams.get_mut(peer.value()) {
Some(stream) => {
if stream.queue_message(&message) {
poll.registry().reregister(
stream.inner_mut(),
peer.into(),
Interest::READABLE | Interest::WRITABLE,
)?;
} else {
let _ = sender
.send(Event::SendBufferFull { peer, message });
log::warn!("send buffer for peer {peer} is full");
}
Command::Message(peer, message) => match token_map.get(peer.value()) {
Some(token) => {
let entry = streams.get_mut(token.0).expect("must exist here");

if entry.stream.queue_message(&message) {
poll.registry().reregister(
entry.stream.inner_mut(),
*token,
Interest::READABLE | Interest::WRITABLE,
)?;
} else {
let _ =
sender.send(Event::SendBufferFull { peer, message });
log::warn!("send buffer for peer {peer} is full");
}
}

None => {
let _ = sender.send(Event::NoPeer(peer));
log::warn!("message: peer {} not found", peer.value());
}
None => {
let _ = sender.send(Event::NoPeer(peer));
log::warn!("message: peer {} not found", peer.value());
}
}
},

Command::Shutdown => {
for (id, mut stream) in streams {
let _ = write(&mut stream, now);
let r = stream.inner_mut().shutdown(std::net::Shutdown::Both);
for (id, mut entry) in streams {
let _ = write(&mut entry.stream, now);
let r =
entry.stream.inner_mut().shutdown(std::net::Shutdown::Both);
log::debug!("shut down stream {}: {:?}", id, r);
}

Expand All @@ -365,6 +385,8 @@ fn run<C: Connector + Sync + Send + 'static>(
let peer = add_stream(
poll.registry(),
&mut streams,
&mut token_map,
&mut next_peer_id,
stream,
config.stream_config.clone(),
)?;
Expand All @@ -382,10 +404,10 @@ fn run<C: Connector + Sync + Send + 'static>(
}
}

(token, Some(stream)) => {
let peer = token.into();
(token, Some(entry)) => {
let peer = entry.peer_id;

if !stream.is_ready() {
if !entry.stream.is_ready() {
log::trace!("peer: {peer}: stream not ready");
continue;
}
Expand All @@ -394,10 +416,10 @@ fn run<C: Connector + Sync + Send + 'static>(
log::trace!("peer {peer}: readable");

'read: loop {
let read_result = stream.read(&mut read_buf);
let read_result = entry.stream.read(&mut read_buf);

'decode: loop {
match stream.receive_message() {
match entry.stream.receive_message() {
Ok(message) => {
log::debug!("peer {peer}: rx message: {}", message.cmd());

Expand All @@ -410,7 +432,12 @@ fn run<C: Connector + Sync + Send + 'static>(
) => {
log::info!("peer {peer}: codec violation");

remove_stream(poll.registry(), &mut streams, peer)?;
remove_stream(
poll.registry(),
&mut streams,
&mut token_map,
peer,
)?;

let _ = sender.send(Event::Disconnected {
peer,
Expand All @@ -428,7 +455,12 @@ fn run<C: Connector + Sync + Send + 'static>(
Ok(0) => {
log::debug!("peer {peer}: peer left");

remove_stream(poll.registry(), &mut streams, peer)?;
remove_stream(
poll.registry(),
&mut streams,
&mut token_map,
peer,
)?;

let _ = sender.send(Event::Disconnected {
peer,
Expand All @@ -445,7 +477,12 @@ fn run<C: Connector + Sync + Send + 'static>(
Err(err) => {
log::warn!("peer {peer}: IO error: {err}");

remove_stream(poll.registry(), &mut streams, peer)?;
remove_stream(
poll.registry(),
&mut streams,
&mut token_map,
peer,
)?;

let _ = sender.send(Event::Disconnected {
peer,
Expand All @@ -461,19 +498,22 @@ fn run<C: Connector + Sync + Send + 'static>(
if event.is_writable() {
log::trace!("peer {peer}: writable");

match write(stream, now) {
match write(&mut entry.stream, now) {
Ok(()) => {
let interests = choose_interest(stream);
poll.registry()
.reregister(stream.inner_mut(), token, interests)?;
let interests = choose_interest(&entry.stream);
poll.registry().reregister(
entry.stream.inner_mut(),
token,
interests,
)?;
}

Err(err) if would_block(&err) => {}

Err(err) => {
log::warn!("peer {peer}: IO error: {err}");

remove_stream(poll.registry(), &mut streams, peer)?;
remove_stream(poll.registry(), &mut streams, &mut token_map, peer)?;

let _ = sender.send(Event::Disconnected {
peer,
Expand All @@ -495,28 +535,27 @@ fn run<C: Connector + Sync + Send + 'static>(
}

// stale stream removal
streams.retain(|token, stream| {
if stream.is_write_stale(now) {
let peer = PeerId(token);
log::info!("removing stale peer {peer}");
remove_stale.extend(
streams
.iter()
.filter_map(|(_, entry)| entry.stream.is_write_stale(now).then_some(entry.peer_id)),
);

poll.registry().deregister(stream.inner_mut()).unwrap();
for peer in remove_stale.drain(..) {
log::info!("removing stale peer {peer}");

let _ = sender.send(Event::Disconnected {
peer,
reason: DisconnectReason::WriteStale,
});
remove_stream(poll.registry(), &mut streams, &mut token_map, peer)?;

false
} else {
true
}
});
let _ = sender.send(Event::Disconnected {
peer,
reason: DisconnectReason::WriteStale,
});
}

// periodic buffer resize
if (now - last_maintenance).as_secs() > 30 {
for (_, stream) in &mut streams {
stream.resize_buffers();
for (_, entry) in &mut streams {
entry.stream.resize_buffers();
}

last_maintenance = now;
Expand Down Expand Up @@ -617,29 +656,44 @@ fn write(stream: &mut MessageStream<TcpStream>, now: Instant) -> io::Result<()>
/// Registers a peer with the poll and adds him to the stream list.
fn add_stream(
registry: &Registry,
streams: &mut Slab<MessageStream<TcpStream>>,
streams: &mut Slab<Entry>,
token_map: &mut IntMap<Token>,
next_peer_id: &mut u64,
mut stream: TcpStream,
stream_cfg: message_stream::StreamConfig,
) -> std::io::Result<PeerId> {
let token = Token(streams.vacant_key());
let peer_id = *next_peer_id;

registry.register(&mut stream, token, Interest::READABLE)?;
streams.insert(MessageStream::new(stream, stream_cfg));

Ok(token.into())
let prev_mapping = token_map.insert(peer_id, token);
assert!(prev_mapping.is_none());

streams.insert(Entry {
stream: MessageStream::new(stream, stream_cfg),
peer_id: PeerId(peer_id),
});

*next_peer_id += 1;

Ok(PeerId(peer_id))
}

/// Deregisters a peer from the poll and removes him from the stream list.
fn remove_stream(
registry: &Registry,
streams: &mut Slab<MessageStream<TcpStream>>,
streams: &mut Slab<Entry>,
token_map: &mut IntMap<Token>,
peer: PeerId,
) -> std::io::Result<()> {
let mut stream = streams.remove(peer.value());
) -> std::io::Result<Entry> {
let token = token_map.remove(peer.0).expect("must exist here");

registry.deregister(stream.inner_mut())?;
let mut entry = streams.remove(token.0);

Ok(())
registry.deregister(entry.stream.inner_mut())?;

Ok(entry)
}

/// Checks if the token is associated with the server (connection listener).
Expand Down Expand Up @@ -696,7 +750,4 @@ mod test {
assert!(is_listener(3, Token(WAKE_TOKEN.0 - 3)));
assert!(!is_listener(3, Token(WAKE_TOKEN.0 - 4)));
}

#[test]
fn connection_slot_open() {}
}

0 comments on commit 07e9d0f

Please sign in to comment.