diff --git a/Cargo.toml b/Cargo.toml index 12d1ec0..780be0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,14 +24,11 @@ memcached = ["r2d2-memcache", "backoff"] [dependencies] log = "0.4.11" -actix-web = {version = "3.3.2"} -actix-http = {version = "2.2.0", features=["actors"]} -actix = "0.10" +actix = "0.13.0" +actix-web = {version = "4.2.1"} futures = "0.3.8" failure = "0.1.8" - dashmap = {version = "4.0.1", optional = true} - redis_rs = {version = "0.15.1", optional = true, package= "redis"} backoff = {version = "0.2.1", optional = true} r2d2-memcache = { version = "0.6", optional = true } diff --git a/src/errors.rs b/src/errors.rs index cd65a3c..9320c07 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,8 +1,10 @@ //! Errors that can occur during middleware processing stage -use actix_web::error::Error as AWError; -use actix_web::web::HttpResponse; +use actix_web::body::BoxBody; +use actix_web::error::ResponseError; +use actix_web::http::StatusCode; +use actix_web::HttpResponse; use failure::{self, Fail}; -use log::*; +use std::time::Duration; /// Custom error type. Useful for logging and debugging different kinds of errors. /// This type can be converted to Actix Error, which defaults to @@ -29,11 +31,33 @@ pub enum ARError { /// Identifier error #[fail(display = "client identification failed")] IdentificationError, + + /// Rate limited error + #[fail(display = "rate limit failed")] + RateLimitError { + max_requests: usize, + c: usize, + reset: Duration, + }, } -impl From for AWError { - fn from(err: ARError) -> AWError { - error!("{}", &err); - HttpResponse::InternalServerError().into() +impl ResponseError for ARError { + fn status_code(&self) -> StatusCode { + StatusCode::INTERNAL_SERVER_ERROR + } + + fn error_response(&self) -> HttpResponse { + match *self { + Self::RateLimitError { + max_requests, + c, + reset, + } => HttpResponse::TooManyRequests() + .insert_header(("x-ratelimit-limit", max_requests.to_string())) + .insert_header(("x-ratelimit-remaining", c.to_string())) + .insert_header(("x-ratelimit-reset", reset.as_secs().to_string())) + .finish(), + _ => HttpResponse::InternalServerError().finish(), + } } } diff --git a/src/lib.rs b/src/lib.rs index efb651e..78021cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -190,12 +190,12 @@ pub mod stores; use errors::ARError; pub use middleware::RateLimiter; +#[cfg(feature = "memcached")] +pub use stores::memcached::{MemcacheStore, MemcacheStoreActor}; #[cfg(feature = "memory")] pub use stores::memory::{MemoryStore, MemoryStoreActor}; #[cfg(feature = "redis-store")] pub use stores::redis::{RedisStore, RedisStoreActor}; -#[cfg(feature = "memcached")] -pub use stores::memcached::{MemcacheStore, MemcacheStoreActor}; use std::future::Future; use std::marker::Send; @@ -248,9 +248,9 @@ where A: Actor, M: actix::Message, { - fn handle>(self, _: &mut A::Context, tx: Option) { + fn handle(self, _: &mut A::Context, tx: Option>) { if let Some(tx) = tx { - tx.send(self); + tx.send(self).ok(); } } } diff --git a/src/middleware.rs b/src/middleware.rs index 59f38e9..d46e317 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -2,9 +2,8 @@ use actix::dev::*; use actix_web::{ dev::{Service, ServiceRequest, ServiceResponse, Transform}, - error::Error as AWError, - http::{HeaderName, HeaderValue}, - HttpResponse, + error::{Error as AWError, ErrorInternalServerError}, + http::header::{HeaderName, HeaderValue}, }; use futures::future::{ok, Ready}; use log::*; @@ -50,7 +49,7 @@ where interval: Duration, max_requests: usize, store: Addr, - identifier: Rc Result>>, + identifier: Rc Result>>, } impl RateLimiter @@ -63,8 +62,9 @@ where let identifier = |req: &ServiceRequest| { let connection_info = req.connection_info(); let ip = connection_info - .remote_addr() - .ok_or(ARError::IdentificationError)?; + .peer_addr() + .ok_or(ARError::IdentificationError) + .map_err(ErrorInternalServerError)?; Ok(String::from(ip)) }; RateLimiter { @@ -88,7 +88,7 @@ where } /// Function to get the identifier for the client request - pub fn with_identifier Result + 'static>( + pub fn with_identifier Result + 'static>( mut self, identifier: F, ) -> Self { @@ -97,15 +97,14 @@ where } } -impl Transform for RateLimiter +impl Transform for RateLimiter where T: Handler + Send + Sync + 'static, T::Context: ToEnvelope, - S: Service, Error = AWError> + 'static, + S: Service, Error = AWError> + 'static, S::Future: 'static, B: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = S::Error; type InitError = (); @@ -134,37 +133,37 @@ where // Exists here for the sole purpose of knowing the max_requests and interval from RateLimiter max_requests: usize, interval: u64, - identifier: Rc Result + 'static>>, + identifier: Rc Result + 'static>>, } -impl Service for RateLimitMiddleware +impl Service for RateLimitMiddleware where T: Handler + 'static, - S: Service, Error = AWError> + 'static, + S: Service, Error = AWError> + 'static, S::Future: 'static, B: 'static, T::Context: ToEnvelope, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = S::Error; type Future = Pin>>>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.borrow_mut().poll_ready(cx) } - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { let store = self.store.clone(); - let mut srv = self.service.clone(); + let srv = self.service.clone(); let max_requests = self.max_requests; let interval = Duration::from_secs(self.interval); let identifier = self.identifier.clone(); Box::pin(async move { - let identifier: String = (identifier)(&req)?; + let identifier: String = (identifier)(&req).map_err(ErrorInternalServerError)?; let remaining: ActorResponse = store .send(ActorMessage::Get(String::from(&identifier))) - .await?; + .await + .map_err(ErrorInternalServerError)?; match remaining { ActorResponse::Get(opt) => { let opt = opt.await?; @@ -172,19 +171,21 @@ where // Existing entry in store let expiry = store .send(ActorMessage::Expire(String::from(&identifier))) - .await?; + .await + .map_err(ErrorInternalServerError)?; let reset: Duration = match expiry { ActorResponse::Expire(dur) => dur.await?, _ => unreachable!(), }; if c == 0 { info!("Limit exceeded for client: {}", &identifier); - let mut response = HttpResponse::TooManyRequests(); - // let mut response = (error_callback)(&mut response); - response.set_header("x-ratelimit-limit", max_requests.to_string()); - response.set_header("x-ratelimit-remaining", c.to_string()); - response.set_header("x-ratelimit-reset", reset.as_secs().to_string()); - Err(response.into()) + + Err(ARError::RateLimitError { + max_requests, + c, + reset, + } + .into()) } else { // Decrement value let res: ActorResponse = store @@ -192,27 +193,31 @@ where key: identifier, value: 1, }) - .await?; + .await + .map_err(ErrorInternalServerError)?; let updated_value: usize = match res { - ActorResponse::Update(c) => c.await?, + ActorResponse::Update(c) => { + c.await.map_err(ErrorInternalServerError)? + } _ => unreachable!(), }; // Execute the request let fut = srv.call(req); - let mut res = fut.await?; + let mut res = fut.await.map_err(ErrorInternalServerError)?; let headers = res.headers_mut(); // Safe unwraps, since usize is always convertible to string headers.insert( HeaderName::from_static("x-ratelimit-limit"), - HeaderValue::from_str(max_requests.to_string().as_str())?, + HeaderValue::from_str(max_requests.to_string().as_str()).unwrap(), ); headers.insert( HeaderName::from_static("x-ratelimit-remaining"), - HeaderValue::from_str(updated_value.to_string().as_str())?, + HeaderValue::from_str(updated_value.to_string().as_str()).unwrap(), ); headers.insert( HeaderName::from_static("x-ratelimit-reset"), - HeaderValue::from_str(reset.as_secs().to_string().as_str())?, + HeaderValue::from_str(reset.as_secs().to_string().as_str()) + .unwrap(), ); Ok(res) } @@ -225,13 +230,16 @@ where value: current_value, expiry: interval, }) - .await?; + .await + .map_err(ErrorInternalServerError)?; match res { - ActorResponse::Set(c) => c.await?, + ActorResponse::Set(c) => { + c.await.map_err(|err| ErrorInternalServerError(err))? + } _ => unreachable!(), } let fut = srv.call(req); - let mut res = fut.await?; + let mut res = fut.await.map_err(|err| ErrorInternalServerError(err))?; let headers = res.headers_mut(); // Safe unwraps, since usize is always convertible to string headers.insert( diff --git a/src/stores/memcached.rs b/src/stores/memcached.rs index 1f8a430..7fbaa25 100644 --- a/src/stores/memcached.rs +++ b/src/stores/memcached.rs @@ -204,7 +204,7 @@ impl Handler for MemcacheStoreActor { Ok(c) => match c { Some(v) => Ok(Some(v as usize)), None => Ok(None), - } + }, Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))), } })), @@ -247,10 +247,6 @@ impl Handler for MemcacheStoreActor { } } - - - - #[cfg(test)] mod tests { use super::*; @@ -315,7 +311,7 @@ mod tests { _ => panic!("Shouldn't happen!"), }; } - + #[actix_rt::test] async fn test_expiry() { init(); @@ -357,6 +353,5 @@ mod tests { }, _ => panic!("Shouldn't happen!"), }; - } } diff --git a/tests/version-numbers.rs b/tests/version-numbers.rs index b2c668d..288592d 100644 --- a/tests/version-numbers.rs +++ b/tests/version-numbers.rs @@ -7,4 +7,3 @@ fn test_readme_deps() { fn test_html_root_url() { version_sync::assert_html_root_url_updated!("src/lib.rs"); } -