From 7fbad34c88035bd0e54e3aa65879c734665ea428 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Thu, 16 May 2024 00:39:39 +0200 Subject: [PATCH 1/2] feat: per IP rate limiting --- src/lib.rs | 11 +++++-- src/middleware/mod.rs | 1 + src/middleware/rate_limit.rs | 63 ++++++++++++++++++++++++++++++++++++ src/networking.rs | 10 +++++- src/state.rs | 7 ++++ 5 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 src/middleware/rate_limit.rs diff --git a/src/lib.rs b/src/lib.rs index c338dd84..cf1f7769 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ use { axum_client_ip::SecureClientIpSource, config::Config, hyper::http::Method, + middleware::rate_limit::rate_limit_middleware, opentelemetry::{sdk::Resource, KeyValue}, sqlx::{ postgres::{PgConnectOptions, PgPoolOptions}, @@ -256,7 +257,10 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Config) -> } else { app }; - + let app = app.route_layer(axum::middleware::from_fn_with_state( + state_arc.clone(), + rate_limit_middleware, + )); app.with_state(state_arc.clone()) }; @@ -281,7 +285,10 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Config) -> } else { app }; - + let app = app.route_layer(axum::middleware::from_fn_with_state( + state_arc.clone(), + rate_limit_middleware, + )); let app = app.with_state(state_arc.clone()); let private_app = Router::new() diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index b59628ab..5fd09d66 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1 +1,2 @@ +pub mod rate_limit; pub mod validate_signature; diff --git a/src/middleware/rate_limit.rs b/src/middleware/rate_limit.rs new file mode 100644 index 00000000..233af5b1 --- /dev/null +++ b/src/middleware/rate_limit.rs @@ -0,0 +1,63 @@ +use crate::{networking, state::AppState}; +use axum::{ + extract::Request, + extract::State, + http::StatusCode, + middleware::Next, + response::{IntoResponse, Response}, +}; +use moka::future::Cache; +use std::{net::IpAddr, sync::Arc}; +use tokio::time::Duration; +use tracing::error; + +pub const MAX_REQUESTS_PER_SEC: u32 = 10; + +#[derive(Clone)] +pub struct RateLimiter { + cache: Arc>, + max_requests: u32, +} + +impl RateLimiter { + pub fn new(max_requests: u32, window: Duration) -> Self { + Self { + cache: Arc::new(Cache::builder().time_to_live(window).build()), + max_requests, + } + } +} + +/// Rate limit middleware that limits the number of requests per second from a single IP address and +/// uses in-memory caching to store the number of requests. +pub async fn rate_limit_middleware( + State(state): State>, + req: Request, + next: Next, +) -> Response { + let headers = req.headers().clone(); + let client_ip = match networking::get_forwarded_ip(headers.clone()) { + Some(ip) => ip, + None => { + error!( + "Failed to get forwarded IP from request in rate limiting middleware. Skipping the \ + rate-limiting." + ); + // We are skipping the drop to the connect info IP address here, because we are + // using the Load Balancer and if any issues with the X-Forwarded-IP header, we + // will rate-limit the LB IP address. + return next.run(req).await; + } + }; + + let rate_limiter = &state.rate_limit; + let mut rate_limit = rate_limiter.cache.get_with(client_ip, async { 0 }).await; + + if rate_limit < rate_limiter.max_requests { + rate_limit += 1; + rate_limiter.cache.insert(client_ip, rate_limit).await; + next.run(req).await + } else { + (StatusCode::TOO_MANY_REQUESTS, "Too many requests").into_response() + } +} diff --git a/src/networking.rs b/src/networking.rs index febd1da8..1f588958 100644 --- a/src/networking.rs +++ b/src/networking.rs @@ -1,4 +1,4 @@ -use {ipnet::IpNet, std::net::IpAddr}; +use {axum::http::HeaderMap, ipnet::IpNet, std::net::IpAddr}; #[derive(thiserror::Error, Debug)] pub enum NetworkInterfaceError { @@ -65,3 +65,11 @@ fn is_public_ip_addr(addr: IpAddr) -> bool { RESERVED_NETWORKS.iter().all(|range| !range.contains(&addr)) } + +pub fn get_forwarded_ip(headers: HeaderMap) -> Option { + headers + .get("X-Forwarded-For") + .and_then(|header| header.to_str().ok()) + .and_then(|header| header.split(',').next()) + .and_then(|client_ip| client_ip.trim().parse::().ok()) +} diff --git a/src/state.rs b/src/state.rs index 114f184d..1277bfd5 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2,6 +2,7 @@ use { crate::{ config::Config, metrics::Metrics, + middleware::rate_limit, networking, providers::Provider, relay::RelayClient, @@ -10,6 +11,7 @@ use { build_info::BuildInfo, moka::future::Cache, std::{net::IpAddr, sync::Arc}, + tokio::time::Duration, wc::geoip::{block::middleware::GeoBlockLayer, MaxMindResolver}, }; @@ -55,6 +57,7 @@ pub struct AppState { pub uptime: std::time::Instant, pub http_client: reqwest::Client, pub provider_cache: Cache, + pub rate_limit: rate_limit::RateLimiter, } build_info::build_info!(fn build_info); @@ -101,6 +104,10 @@ pub fn new_state( uptime: std::time::Instant::now(), http_client: reqwest::Client::new(), provider_cache: Cache::new(100), + rate_limit: rate_limit::RateLimiter::new( + rate_limit::MAX_REQUESTS_PER_SEC, + Duration::from_secs(1), + ), }) } From 1ea8da794878554c2931bdcd89d5c8b28c615f31 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Thu, 16 May 2024 10:02:29 +0200 Subject: [PATCH 2/2] feat: basic flood integration test --- integration/integration.test.ts | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/integration/integration.test.ts b/integration/integration.test.ts index 9f872714..77f4549a 100644 --- a/integration/integration.test.ts +++ b/integration/integration.test.ts @@ -69,4 +69,37 @@ describe('Echo Server', () => { expect(status).toBe(200) }) }) + describe('Middlewares', () => { + const httpClient = axios.create({ + validateStatus: (_status) => true, + }) + + // Simulate flood of requests and check for rate-limited responses + it('Rate limiting', async () => { + const url = `${BASE_URL}/health` + const requests_to_send = 100; + const promises = []; + for (let i = 0; i < requests_to_send; i++) { + promises.push( + httpClient.get(url) + ); + } + const results = await Promise.allSettled(promises); + + let ok_statuses_counter = 0; + let rate_limited_statuses_counter = 0; + results.forEach((result) => { + if (result.status === 'fulfilled' && result.value.status === 429) { + rate_limited_statuses_counter++; + }else if (result.status === 'fulfilled' && result.value.status === 200) { + ok_statuses_counter++; + } + }); + + console.log(`➜ Rate limited statuses: ${rate_limited_statuses_counter} out of ${requests_to_send} total requests.`); + // Check if there are any successful and rate limited statuses + expect(ok_statuses_counter).toBeGreaterThan(0); + expect(rate_limited_statuses_counter).toBeGreaterThan(0); + }) + }) })