Skip to content

Commit

Permalink
feat: per-IP rate limiting (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
geekbrother committed May 16, 2024
1 parent 1abce43 commit 70e3002
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 3 deletions.
33 changes: 33 additions & 0 deletions integration/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,37 @@ describe('Echo Server', () => {
expect(status).toBe(200)
})
})
describe('Middlewares', () => {
const httpClient = axios.create({
validateStatus: (_status) => true,
})

// Simulate flood of requests and check for rate-limited responses
it('Rate limiting', async () => {
const url = `${BASE_URL}/health`
const requests_to_send = 100;
const promises = [];
for (let i = 0; i < requests_to_send; i++) {
promises.push(
httpClient.get(url)
);
}
const results = await Promise.allSettled(promises);

let ok_statuses_counter = 0;
let rate_limited_statuses_counter = 0;
results.forEach((result) => {
if (result.status === 'fulfilled' && result.value.status === 429) {
rate_limited_statuses_counter++;
}else if (result.status === 'fulfilled' && result.value.status === 200) {
ok_statuses_counter++;
}
});

console.log(`➜ Rate limited statuses: ${rate_limited_statuses_counter} out of ${requests_to_send} total requests.`);
// Check if there are any successful and rate limited statuses
expect(ok_statuses_counter).toBeGreaterThan(0);
expect(rate_limited_statuses_counter).toBeGreaterThan(0);
})
})
})
11 changes: 9 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use {
axum_client_ip::SecureClientIpSource,
config::Config,
hyper::http::Method,
middleware::rate_limit::rate_limit_middleware,
opentelemetry::{sdk::Resource, KeyValue},
sqlx::{
postgres::{PgConnectOptions, PgPoolOptions},
Expand Down Expand Up @@ -256,7 +257,10 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Config) ->
} else {
app
};

let app = app.route_layer(axum::middleware::from_fn_with_state(
state_arc.clone(),
rate_limit_middleware,
));
app.with_state(state_arc.clone())
};

Expand All @@ -281,7 +285,10 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Config) ->
} else {
app
};

let app = app.route_layer(axum::middleware::from_fn_with_state(
state_arc.clone(),
rate_limit_middleware,
));
let app = app.with_state(state_arc.clone());

let private_app = Router::new()
Expand Down
1 change: 1 addition & 0 deletions src/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod rate_limit;
pub mod validate_signature;
63 changes: 63 additions & 0 deletions src/middleware/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::{networking, state::AppState};
use axum::{
extract::Request,
extract::State,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use moka::future::Cache;
use std::{net::IpAddr, sync::Arc};
use tokio::time::Duration;
use tracing::error;

pub const MAX_REQUESTS_PER_SEC: u32 = 10;

#[derive(Clone)]
pub struct RateLimiter {
cache: Arc<Cache<IpAddr, u32>>,
max_requests: u32,
}

impl RateLimiter {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
cache: Arc::new(Cache::builder().time_to_live(window).build()),
max_requests,
}
}
}

/// Rate limit middleware that limits the number of requests per second from a single IP address and
/// uses in-memory caching to store the number of requests.
pub async fn rate_limit_middleware(
State(state): State<Arc<AppState>>,
req: Request,
next: Next,
) -> Response {
let headers = req.headers().clone();
let client_ip = match networking::get_forwarded_ip(headers.clone()) {
Some(ip) => ip,
None => {
error!(
"Failed to get forwarded IP from request in rate limiting middleware. Skipping the \
rate-limiting."
);
// We are skipping the drop to the connect info IP address here, because we are
// using the Load Balancer and if any issues with the X-Forwarded-IP header, we
// will rate-limit the LB IP address.
return next.run(req).await;
}
};

let rate_limiter = &state.rate_limit;
let mut rate_limit = rate_limiter.cache.get_with(client_ip, async { 0 }).await;

if rate_limit < rate_limiter.max_requests {
rate_limit += 1;
rate_limiter.cache.insert(client_ip, rate_limit).await;
next.run(req).await
} else {
(StatusCode::TOO_MANY_REQUESTS, "Too many requests").into_response()
}
}
10 changes: 9 additions & 1 deletion src/networking.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use {ipnet::IpNet, std::net::IpAddr};
use {axum::http::HeaderMap, ipnet::IpNet, std::net::IpAddr};

#[derive(thiserror::Error, Debug)]
pub enum NetworkInterfaceError {
Expand Down Expand Up @@ -65,3 +65,11 @@ fn is_public_ip_addr(addr: IpAddr) -> bool {

RESERVED_NETWORKS.iter().all(|range| !range.contains(&addr))
}

pub fn get_forwarded_ip(headers: HeaderMap) -> Option<IpAddr> {
headers
.get("X-Forwarded-For")
.and_then(|header| header.to_str().ok())
.and_then(|header| header.split(',').next())
.and_then(|client_ip| client_ip.trim().parse::<IpAddr>().ok())
}
7 changes: 7 additions & 0 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use {
crate::{
config::Config,
metrics::Metrics,
middleware::rate_limit,
networking,
providers::Provider,
relay::RelayClient,
Expand All @@ -10,6 +11,7 @@ use {
build_info::BuildInfo,
moka::future::Cache,
std::{net::IpAddr, sync::Arc},
tokio::time::Duration,
wc::geoip::{block::middleware::GeoBlockLayer, MaxMindResolver},
};

Expand Down Expand Up @@ -55,6 +57,7 @@ pub struct AppState {
pub uptime: std::time::Instant,
pub http_client: reqwest::Client,
pub provider_cache: Cache<String, Provider>,
pub rate_limit: rate_limit::RateLimiter,
}

build_info::build_info!(fn build_info);
Expand Down Expand Up @@ -101,6 +104,10 @@ pub fn new_state(
uptime: std::time::Instant::now(),
http_client: reqwest::Client::new(),
provider_cache: Cache::new(100),
rate_limit: rate_limit::RateLimiter::new(
rate_limit::MAX_REQUESTS_PER_SEC,
Duration::from_secs(1),
),
})
}

Expand Down

0 comments on commit 70e3002

Please sign in to comment.