Skip to content

Commit

Permalink
Get addresses from discovery api
Browse files Browse the repository at this point in the history
  • Loading branch information
Seulgi Kim committed May 7, 2018
1 parent 402d0f1 commit b85f13a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 14 deletions.
41 changes: 31 additions & 10 deletions network/src/p2p/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ use std::io;
use std::sync::Arc;

use cfinally::finally;
use cio::{IoContext, IoHandler, IoHandlerResult, IoManager, StreamToken, TimerToken};
use cio::{IoChannel, IoContext, IoHandler, IoHandlerResult, IoManager, StreamToken, TimerToken};
use mio::deprecated::EventLoop;
use mio::{PollOpt, Ready, Token};
use parking_lot::{Mutex, RwLock};

use super::super::client::Client;
use super::super::extension::NodeToken;
use super::super::session::Session;
use super::super::session_initiator::Message as SessionMessage;
use super::super::token_generator::TokenGenerator;
use super::super::{DiscoveryApi, SocketAddr};
use super::connection::{Connection, ExtensionCallback as ExtensionChannel};
Expand Down Expand Up @@ -152,8 +153,9 @@ impl Manager {
Ok((token, timer_token))
}

fn register_connection(&mut self, connection: Connection, token: &StreamToken) {
fn register_connection(&mut self, connection: Connection, token: &StreamToken, client: &Client) {
let con = self.connections.insert(*token, connection);
client.on_node_added(token);
debug_assert!(con.is_none());
}

Expand Down Expand Up @@ -185,7 +187,12 @@ impl Manager {
}
}

fn create_connection(&mut self, stream: Stream, session: &Session) -> IoHandlerResult<StreamToken> {
fn create_connection(
&mut self,
stream: Stream,
session: &Session,
client: &Client,
) -> IoHandlerResult<StreamToken> {
let mut connection = Connection::new(stream, session.secret().clone(), session.id().clone());
let nonce = session.id();
connection.enqueue_sync(nonce.clone());
Expand All @@ -195,7 +202,7 @@ impl Manager {
Ok(self.tokens
.gen()
.map(|token| {
self.register_connection(connection, &token);
self.register_connection(connection, &token, client);
token
})
.expect("The number of peers must be checked before"))
Expand All @@ -211,9 +218,14 @@ impl Manager {
}
}

pub fn connect(&mut self, socket_address: &SocketAddr, session: &Session) -> IoHandlerResult<Option<StreamToken>> {
pub fn connect(
&mut self,
socket_address: &SocketAddr,
session: &Session,
client: &Client,
) -> IoHandlerResult<Option<StreamToken>> {
Ok(match Stream::connect(socket_address)? {
Some(stream) => Some(self.create_connection(stream, session)?),
Some(stream) => Some(self.create_connection(stream, session, client)?),
None => None,
})
}
Expand Down Expand Up @@ -300,8 +312,7 @@ impl Manager {
let removed = self.registered_sessions.remove(&nonce);
debug_assert!(removed);

self.register_connection(connection, stream);
client.on_node_added(&stream);
self.register_connection(connection, stream, client);
Ok(false)
}

Expand Down Expand Up @@ -352,6 +363,7 @@ pub struct Handler {
client: Arc<Client>,

discovery: RwLock<Option<Arc<DiscoveryApi>>>,
session_initiator: IoChannel<SessionMessage>,

min_peers: usize,
max_peers: usize,
Expand All @@ -361,6 +373,7 @@ impl Handler {
pub fn try_new(
socket_address: SocketAddr,
client: Arc<Client>,
session_initiator: IoChannel<SessionMessage>,
min_peers: usize,
max_peers: usize,
) -> ::std::result::Result<Self, String> {
Expand All @@ -375,6 +388,7 @@ impl Handler {
client,

discovery: RwLock::new(None),
session_initiator,

min_peers,
max_peers,
Expand Down Expand Up @@ -405,10 +419,16 @@ impl IoHandler<Message> for Handler {

let num_of_requests = self.min_peers - manager.connections.len();
// FIXME: Pick random session
let mut count: usize = 0;
for (_, &(ref session, ref socket_address)) in manager.registered_sessions.iter().take(num_of_requests)
{
count += 1;
io.channel().send(Message::RequestConnection(socket_address.clone(), session.clone()))?;
}
if count + manager.connections.len() < self.min_peers {
let requests = self.min_peers - count - manager.connections.len();
self.session_initiator.send(SessionMessage::RequestSession(requests))?;
}

Ok(())
}
Expand Down Expand Up @@ -436,8 +456,9 @@ impl IoHandler<Message> for Handler {
}

trace!(target: "net", "Connecting to {:?}", socket_address);
let token =
manager.connect(&socket_address, session)?.ok_or(Error::General("Cannot create connection"))?;
let token = manager
.connect(&socket_address, session, &self.client)?
.ok_or(Error::General("Cannot create connection"))?;
io.register_stream(token)?;

if let Some(ref discovery) = *self.discovery.read() {
Expand Down
15 changes: 12 additions & 3 deletions network/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub struct Service {
session_initiator: IoService<session_initiator::Message>,
session_initiator_handler: Arc<session_initiator::Handler>,
_p2p: IoService<p2p::Message>,
p2p_handler: Arc<p2p::Handler>,
timer: IoService<timer::Message>,
client: Arc<Client>,
}
Expand All @@ -41,8 +42,14 @@ impl Service {

let client = Client::new(p2p.channel(), timer.channel());

let p2p_handler = Arc::new(p2p::Handler::try_new(address.clone(), Arc::clone(&client), min_peers, max_peers)?);
p2p.register_handler(p2p_handler)?;
let p2p_handler = Arc::new(p2p::Handler::try_new(
address.clone(),
Arc::clone(&client),
session_initiator.channel(),
min_peers,
max_peers,
)?);
p2p.register_handler(p2p_handler.clone())?;

timer.register_handler(Arc::new(timer::Handler::new(Arc::clone(&client))))?;

Expand All @@ -53,6 +60,7 @@ impl Service {
session_initiator,
session_initiator_handler,
_p2p: p2p,
p2p_handler,
timer,
client,
})
Expand All @@ -71,7 +79,8 @@ impl Service {
}

pub fn set_discovery_api(&self, api: Arc<DiscoveryApi>) {
self.session_initiator_handler.set_discovery_api(api);
self.session_initiator_handler.set_discovery_api(Arc::clone(&api));
self.p2p_handler.set_discovery_api(api);
}

pub fn connect_to(&self, address: SocketAddr) -> Result<(), String> {
Expand Down
28 changes: 27 additions & 1 deletion network/src/session_initiator/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::error;
use std::fmt;
use std::io;
Expand Down Expand Up @@ -50,6 +50,8 @@ struct SessionInitiator {
tmp_nonce_tokens: TokenGenerator,
tmp_nonce_token_to_addr: HashMap<TimerToken, SocketAddr>,
addr_to_tmp_nonce_token: HashMap<SocketAddr, TimerToken>,

session_registered_addresses: HashSet<SocketAddr>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -143,6 +145,7 @@ type Result<T> = ::std::result::Result<T, Error>;
#[derive(Clone, Debug, PartialOrd, PartialEq)]
pub enum Message {
ConnectTo(SocketAddr),
RequestSession(usize),
}

const START_OF_TMP_NONCE_TOKEN: TimerToken = 0;
Expand All @@ -164,6 +167,8 @@ impl SessionInitiator {
tmp_nonce_tokens: TokenGenerator::new(START_OF_TMP_NONCE_TOKEN, NUM_OF_TMP_NONCES),
tmp_nonce_token_to_addr: HashMap::new(),
addr_to_tmp_nonce_token: HashMap::new(),

session_registered_addresses: HashSet::new(),
})
}

Expand Down Expand Up @@ -221,6 +226,7 @@ impl SessionInitiator {

let session = Session::new(*secret, nonce);
channel_to_p2p.send(p2p::Message::RegisterSession(from.clone(), session))?;
self.session_registered_addresses.insert(from.clone());
encrypted_nonce
};

Expand All @@ -236,6 +242,7 @@ impl SessionInitiator {

let session = Session::new(*secret, nonce);
channel_to_p2p.send(p2p::Message::RegisterSession(from.clone(), session))?;
self.session_registered_addresses.insert(from.clone());
Ok(())
}
&message::Body::ConnectionDenied(ref reason) => {
Expand Down Expand Up @@ -379,6 +386,25 @@ impl IoHandler<Message> for Handler {
session_initiator.create_new_connection(&socket_address)?;
io.update_registration(RECEIVE_TOKEN)?;
}
&Message::RequestSession(n) => {
let mut session_initiator = self.session_initiator.lock();
let discovery = self.discovery.read();
if let Some(ref discovery) = *discovery {
let addresses = discovery.get(n);
if !addresses.is_empty() {
let _f = finally(|| {
if let Err(err) = io.update_registration(RECEIVE_TOKEN) {
warn!(target: "net", "Cannot update registration for session_initiator : {:?}", err);
}
});
for address in addresses {
if !session_initiator.session_registered_addresses.contains(&address) {
session_initiator.create_new_connection(&address)?;
}
}
}
}
}
};
Ok(())
}
Expand Down

0 comments on commit b85f13a

Please sign in to comment.