Skip to content

Commit

Permalink
anemo-tower: add a RequireAuthorization middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
bmwill committed Oct 19, 2022
1 parent c2c91bd commit 7da7c9a
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 0 deletions.
64 changes: 64 additions & 0 deletions crates/anemo-tower/src/auth/future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use anemo::Response;
use bytes::Bytes;
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};

pin_project! {
/// Response future for [`RequireAuthorization`].
///
/// [`RequireAuthorization`]: super::RequireAuthorization
pub struct ResponseFuture<F> {
#[pin]
kind: Kind<F>,
}
}

impl<F> ResponseFuture<F> {
pub(super) fn future(future: F) -> Self {
Self {
kind: Kind::Future { future },
}
}

pub(super) fn invalid_auth(response: Response<Bytes>) -> Self {
Self {
kind: Kind::Error {
response: Some(response),
},
}
}
}

pin_project! {
#[project = KindProj]
enum Kind<F> {
Future {
#[pin]
future: F,
},
Error {
response: Option<Response<Bytes>>,
},
}
}

impl<F, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<Bytes>, E>>,
{
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Future { future } => future.poll(cx),
KindProj::Error { response } => {
let response = response.take().unwrap();
Poll::Ready(Ok(response))
}
}
}
}
37 changes: 37 additions & 0 deletions crates/anemo-tower/src/auth/layer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use super::{AuthorizeRequest, RequireAuthorization};
use tower::Layer;

/// [`Layer`] that adds authorization to a [`Service`].
///
/// See the [module docs](crate::auth) for more details.
///
/// [`Layer`]: tower::layer::Layer
/// [`Service`]: tower::Service
#[derive(Debug, Copy, Clone)]
pub struct RequireAuthorizationLayer<A> {
pub(super) auth: A,
}

impl<A> RequireAuthorizationLayer<A> {
/// Create a new [`RequireAuthorizationLayer`] using the given [`AuthorizeRequest`].
pub fn new(auth: A) -> Self
where
A: AuthorizeRequest,
{
Self { auth }
}
}

impl<S, A> Layer<S> for RequireAuthorizationLayer<A>
where
A: Clone,
{
type Service = RequireAuthorization<S, A>;

fn layer(&self, inner: S) -> Self::Service {
RequireAuthorization {
inner,
auth: self.auth.clone(),
}
}
}
168 changes: 168 additions & 0 deletions crates/anemo-tower/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
use anemo::{Request, Response};
use bytes::Bytes;

pub use self::{
future::ResponseFuture, layer::RequireAuthorizationLayer, service::RequireAuthorization,
};

mod future;
mod layer;
mod service;

/// Trait for authorizing requests.
pub trait AuthorizeRequest {
/// Authorize the request.
///
/// If `Ok(())` is returned then the request is allowed through, otherwise not.
fn authorize(&self, request: &mut Request<Bytes>) -> Result<(), Response<Bytes>>;
}

impl<F> AuthorizeRequest for F
where
F: Fn(&mut Request<Bytes>) -> Result<(), Response<Bytes>>,
{
fn authorize(&self, request: &mut Request<Bytes>) -> Result<(), Response<Bytes>> {
self(request)
}
}

#[derive(Clone, Debug)]
pub struct AllowedPeers {
allowed_peers: std::collections::HashSet<anemo::PeerId>,
}

impl AllowedPeers {
pub fn new<P>(peers: P) -> Self
where
P: IntoIterator<Item = anemo::PeerId>,
{
Self {
allowed_peers: peers.into_iter().collect(),
}
}
}

impl AuthorizeRequest for AllowedPeers {
fn authorize(&self, request: &mut Request<Bytes>) -> Result<(), Response<Bytes>> {
use anemo::types::response::IntoResponse;
use anemo::types::response::StatusCode;

let peer_id = request
.peer_id()
.ok_or_else(|| StatusCode::InternalServerError.into_response())?;

if self.allowed_peers.contains(peer_id) {
Ok(())
} else {
Err(StatusCode::NotFound.into_response())
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use anemo::types::response::IntoResponse;
use anemo::types::response::StatusCode;
use anemo::PeerId;
use anemo::Request;
use anemo::Response;
use bytes::Bytes;
use tower::{BoxError, Service, ServiceBuilder, ServiceExt};

#[tokio::test]
async fn authorize_request_fn() {
const AUTH_HEADER: &str = "authorize";

// Authorize requests that have a particular header set
let auth_layer = RequireAuthorizationLayer::new(|request: &mut Request<Bytes>| {
if request.headers().contains_key(AUTH_HEADER) {
Ok(())
} else {
Err(StatusCode::NotFound.into_response())
}
});

let mut svc = ServiceBuilder::new().layer(auth_layer).service_fn(echo);

// Unauthorized Request
let response = svc
.ready()
.await
.unwrap()
.call(Request::new(Bytes::from("foobar")))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NotFound);

// Authorized Request
let response = svc
.ready()
.await
.unwrap()
.call(Request::new(Bytes::from("foobar")).with_header(AUTH_HEADER, "0"))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::Success);
assert_eq!(response.inner(), "foobar");
}

#[tokio::test]
async fn authorize_request_by_peer_id() {
let allowed_peer_1 = PeerId([42; 32]);
let allowed_peer_2 = PeerId([13; 32]);
let disallowed_peer = PeerId([9; 32]);

// Authorize requests that have a particular header set
let auth_layer =
RequireAuthorizationLayer::new(AllowedPeers::new([allowed_peer_1, allowed_peer_2]));

let mut svc = ServiceBuilder::new().layer(auth_layer).service_fn(echo);

// Unable to query requester's PeerId
let response = svc
.ready()
.await
.unwrap()
.call(Request::new(Bytes::from("foobar")))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::InternalServerError);

// Unauthorized Request
let response = svc
.ready()
.await
.unwrap()
.call(Request::new(Bytes::from("foobar")).with_extension(disallowed_peer))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NotFound);

// Authorized Request
let response = svc
.ready()
.await
.unwrap()
.call(Request::new(Bytes::from("foobar")).with_extension(allowed_peer_1))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::Success);
assert_eq!(response.inner(), "foobar");

// Authorized Request
let response = svc
.ready()
.await
.unwrap()
.call(Request::new(Bytes::from("bar")).with_extension(allowed_peer_2))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::Success);
assert_eq!(response.inner(), "bar");
}

async fn echo(req: Request<Bytes>) -> Result<Response<Bytes>, BoxError> {
Ok(Response::new(req.into_body()))
}
}
71 changes: 71 additions & 0 deletions crates/anemo-tower/src/auth/service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use super::{AuthorizeRequest, RequireAuthorizationLayer, ResponseFuture};
use anemo::{Request, Response};
use bytes::Bytes;
use std::task::{Context, Poll};
use tower::Service;

/// Middleware that adds authorization to a [`Service`].
///
/// See the [module docs](crate::auth) for an example.
///
/// [`Service`]: tower::Service
#[derive(Debug, Clone, Copy)]
pub struct RequireAuthorization<S, A> {
pub(crate) inner: S,
pub(crate) auth: A,
}

impl<S, A> RequireAuthorization<S, A> {
/// Create a new [`RequireAuthorization`].
pub fn new(inner: S, auth: A) -> Self {
Self { inner, auth }
}

/// Returns a new [`Layer`] that wraps services with a [`RequireAuthorizationLayer`] middleware.
///
/// [`Layer`]: tower::layer::Layer
pub fn layer(auth: A) -> RequireAuthorizationLayer<A>
where
A: AuthorizeRequest,
{
RequireAuthorizationLayer::new(auth)
}
}

impl<S, A> RequireAuthorization<S, A> {
/// Gets a reference to the underlying service.
pub fn inner(&self) -> &S {
&self.inner
}

/// Gets a mutable reference to the underlying service.
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}

/// Consumes `self`, returning the underlying service.
pub fn into_inner(self) -> S {
self.inner
}
}

impl<S, A> Service<Request<Bytes>> for RequireAuthorization<S, A>
where
S: Service<Request<Bytes>, Response = Response<Bytes>>,
A: AuthorizeRequest,
{
type Response = Response<Bytes>;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut request: Request<Bytes>) -> Self::Future {
match self.auth.authorize(&mut request) {
Ok(()) => ResponseFuture::future(self.inner.call(request)),
Err(response) => ResponseFuture::invalid_auth(response),
}
}
}
1 change: 1 addition & 0 deletions crates/anemo-tower/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod auth;
pub mod callback;
pub mod classify;
pub mod trace;
Expand Down

0 comments on commit 7da7c9a

Please sign in to comment.