diff --git a/dubbo/Cargo.toml b/dubbo/Cargo.toml index 5f79fe8d..8b4b1aad 100644 --- a/dubbo/Cargo.toml +++ b/dubbo/Cargo.toml @@ -16,8 +16,12 @@ tower-service.workspace = true http-body = "0.4.4" tower = { workspace = true, features = ["timeout"] } futures-util = "0.3.23" +futures-core ="0.3.23" +argh = "0.1" +rustls-pemfile = "1.0.0" +tokio-rustls="0.23.4" +tokio = { version = "1.0", features = [ "rt-multi-thread", "time", "fs", "macros", "net", "signal", "full" ] } futures-core = "0.3.23" -tokio = { workspace = true, features = ["rt-multi-thread", "time", "fs", "macros", "net", "signal"] } prost = "0.10.4" async-trait = "0.1.56" tower-layer.workspace = true diff --git a/dubbo/src/triple/server/builder.rs b/dubbo/src/triple/server/builder.rs index 15a7e935..c85f83b3 100644 --- a/dubbo/src/triple/server/builder.rs +++ b/dubbo/src/triple/server/builder.rs @@ -17,6 +17,7 @@ use std::{ net::{SocketAddr, ToSocketAddrs}, + path::Path, str::FromStr, }; @@ -25,13 +26,17 @@ use dubbo_logger::tracing; use http::{Request, Response, Uri}; use hyper::body::Body; use tower_service::Service; +use tokio_rustls::rustls::{Certificate, PrivateKey}; -use crate::{triple::transport::DubboServer, BoxBody}; +use crate::{common::url::Url, triple::transport::DubboServer}; +use crate::{utils, BoxBody}; #[derive(Clone, Default, Debug)] pub struct ServerBuilder { pub listener: String, pub addr: Option, + pub certs: Vec, + pub keys: Vec, pub service_names: Vec, server: DubboServer, } @@ -45,6 +50,26 @@ impl ServerBuilder { Self { listener, ..self } } + pub fn with_tls(self, certs: &str, keys: &str) -> ServerBuilder { + Self { + certs: match utils::tls::load_certs(Path::new(certs)) { + Ok(v) => v, + Err(err) => { + tracing::error!("error loading tls certs {:?}", err); + Vec::new() + } + }, + keys: match utils::tls::load_keys(Path::new(keys)) { + Ok(v) => v, + Err(err) => { + tracing::error!("error loading tls keys {:?}", err); + Vec::new() + } + }, + ..self + } + } + pub fn with_addr(self, addr: &'static str) -> ServerBuilder { Self { addr: addr.to_socket_addrs().unwrap().next(), @@ -61,6 +86,13 @@ impl ServerBuilder { pub fn build(self) -> Self { let mut server = self.server.with_listener(self.listener.clone()); + + { + if self.certs.len() != 0 && self.keys.len() != 0 { + server = server.with_tls(self.certs.clone(), self.keys.clone()); + } + } + { let lock = crate::protocol::triple::TRIPLE_SERVICES.read().unwrap(); for name in self.service_names.iter() { @@ -73,6 +105,8 @@ impl ServerBuilder { server = server.add_service(name.clone(), svc.clone()); } } + + {} Self { server, ..self } } @@ -114,6 +148,8 @@ impl From for ServerBuilder { addr: authority.to_string().to_socket_addrs().unwrap().next(), service_names: vec![u.service_name], server: DubboServer::default(), + certs: Vec::new(), + keys: Vec::new(), } } } diff --git a/dubbo/src/triple/transport/service.rs b/dubbo/src/triple/transport/service.rs index 8e861006..14afabf0 100644 --- a/dubbo/src/triple/transport/service.rs +++ b/dubbo/src/triple/transport/service.rs @@ -15,7 +15,9 @@ * limitations under the License. */ +use std::io; use std::net::SocketAddr; +use std::sync::Arc; use dubbo_logger::tracing; use futures_core::Future; @@ -23,8 +25,12 @@ use http::{Request, Response}; use hyper::body::Body; use tokio::time::Duration; use tower_service::Service; +use tokio_rustls::rustls::{Certificate, PrivateKey}; +use tokio_rustls::{rustls, TlsAcceptor}; -use super::{listener::get_listener, router::DubboRouter}; +use super::listener::get_listener; +use super::router::DubboRouter; +use crate::triple::transport::io::BoxIO; use crate::BoxBody; #[derive(Default, Clone, Debug)] @@ -38,6 +44,8 @@ pub struct DubboServer { http2_keepalive_timeout: Option, router: DubboRouter, listener: Option, + certs: Vec, + keys: Vec, } impl DubboServer { @@ -93,6 +101,14 @@ impl DubboServer { ..self } } + + pub fn with_tls(self, certs: Vec, keys: Vec) -> Self { + Self { + certs: certs, + keys: keys, + ..self + } + } } impl DubboServer { @@ -107,6 +123,8 @@ impl DubboServer { max_frame_size: None, router: DubboRouter::new(), listener: None, + certs: Vec::new(), + keys: Vec::new(), } } } @@ -147,10 +165,25 @@ impl DubboServer { None => { return Err(Box::new(crate::status::DubboError::new( "listener name is empty".to_string(), - ))) + ))); } }; + let acceptor: Option; + if self.certs.len() != 0 && !self.keys.len() != 0 { + let mut keys = self.keys; + + let config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(self.certs, keys.remove(0)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; + + acceptor = Some(TlsAcceptor::from(Arc::new(config))); + } else { + acceptor = None; + } + let listener = match get_listener(name, addr).await { Ok(v) => v, Err(err) => return Err(err), @@ -166,6 +199,14 @@ impl DubboServer { match res { Ok(conn) => { let (io, local_addr) = conn; + let b :BoxIO; + + if !acceptor.is_none() { + b = BoxIO::new(acceptor.as_ref().unwrap().clone().accept(io).await?); + } else { + b = io; + } + tracing::debug!("hyper serve, local address: {:?}", local_addr); let c = hyper::server::conn::Http::new() .http2_only(self.accept_http2) @@ -175,10 +216,9 @@ impl DubboServer { .http2_keep_alive_interval(self.http2_keepalive_interval) .http2_keep_alive_timeout(http2_keepalive_timeout) .http2_max_frame_size(self.max_frame_size) - .serve_connection(io, svc.clone()).with_upgrades(); + .serve_connection(b,svc.clone()).with_upgrades(); tokio::spawn(c); - }, Err(err) => tracing::error!("hyper serve, err: {:?}", err), } diff --git a/dubbo/src/utils/mod.rs b/dubbo/src/utils/mod.rs index f088d725..e885a967 100644 --- a/dubbo/src/utils/mod.rs +++ b/dubbo/src/utils/mod.rs @@ -17,3 +17,4 @@ pub mod boxed; pub mod boxed_clone; +pub mod tls; diff --git a/dubbo/src/utils/tls.rs b/dubbo/src/utils/tls.rs new file mode 100644 index 00000000..0072bf24 --- /dev/null +++ b/dubbo/src/utils/tls.rs @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use rustls_pemfile::{certs, rsa_private_keys}; +use std::{ + fs::File, + io::{self, BufReader}, + path::Path, +}; +use tokio_rustls::rustls::{Certificate, PrivateKey}; + +pub fn load_certs(path: &Path) -> io::Result> { + certs(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) + .map(|mut certs| certs.drain(..).map(Certificate).collect()) +} + +pub fn load_keys(path: &Path) -> io::Result> { + rsa_private_keys(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) + .map(|mut keys| keys.drain(..).map(PrivateKey).collect()) +}