Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: per-IP rate limiting #327

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions integration/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,37 @@ describe('Echo Server', () => {
expect(status).toBe(200)
})
})
describe('Middlewares', () => {
const httpClient = axios.create({
validateStatus: (_status) => true,
})

// Simulate flood of requests and check for rate-limited responses
it('Rate limiting', async () => {
const url = `${BASE_URL}/health`
const requests_to_send = 100;
const promises = [];
for (let i = 0; i < requests_to_send; i++) {
promises.push(
httpClient.get(url)
);
}
const results = await Promise.allSettled(promises);

let ok_statuses_counter = 0;
let rate_limited_statuses_counter = 0;
results.forEach((result) => {
if (result.status === 'fulfilled' && result.value.status === 429) {
rate_limited_statuses_counter++;
}else if (result.status === 'fulfilled' && result.value.status === 200) {
ok_statuses_counter++;
}
});

console.log(`➜ Rate limited statuses: ${rate_limited_statuses_counter} out of ${requests_to_send} total requests.`);
// Check if there are any successful and rate limited statuses
expect(ok_statuses_counter).toBeGreaterThan(0);
expect(rate_limited_statuses_counter).toBeGreaterThan(0);
})
})
})
11 changes: 9 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use {
axum_client_ip::SecureClientIpSource,
config::Config,
hyper::http::Method,
middleware::rate_limit::rate_limit_middleware,
opentelemetry::{sdk::Resource, KeyValue},
sqlx::{
postgres::{PgConnectOptions, PgPoolOptions},
Expand Down Expand Up @@ -256,7 +257,10 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Config) ->
} else {
app
};

let app = app.route_layer(axum::middleware::from_fn_with_state(
state_arc.clone(),
rate_limit_middleware,
));
app.with_state(state_arc.clone())
};

Expand All @@ -281,7 +285,10 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Config) ->
} else {
app
};

let app = app.route_layer(axum::middleware::from_fn_with_state(
state_arc.clone(),
rate_limit_middleware,
));
let app = app.with_state(state_arc.clone());

let private_app = Router::new()
Expand Down
1 change: 1 addition & 0 deletions src/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod rate_limit;
pub mod validate_signature;
63 changes: 63 additions & 0 deletions src/middleware/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::{networking, state::AppState};
use axum::{
extract::Request,
extract::State,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use moka::future::Cache;
use std::{net::IpAddr, sync::Arc};
use tokio::time::Duration;
use tracing::error;

pub const MAX_REQUESTS_PER_SEC: u32 = 10;

#[derive(Clone)]
pub struct RateLimiter {
cache: Arc<Cache<IpAddr, u32>>,
max_requests: u32,
}

impl RateLimiter {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
cache: Arc::new(Cache::builder().time_to_live(window).build()),
max_requests,
}
}
}

/// Rate limit middleware that limits the number of requests per second from a single IP address and
/// uses in-memory caching to store the number of requests.
pub async fn rate_limit_middleware(
State(state): State<Arc<AppState>>,
req: Request,
next: Next,
) -> Response {
let headers = req.headers().clone();
let client_ip = match networking::get_forwarded_ip(headers.clone()) {
Some(ip) => ip,
None => {
error!(
"Failed to get forwarded IP from request in rate limiting middleware. Skipping the \
rate-limiting."
);
// We are skipping the drop to the connect info IP address here, because we are
// using the Load Balancer and if any issues with the X-Forwarded-IP header, we
// will rate-limit the LB IP address.
return next.run(req).await;
}
};

let rate_limiter = &state.rate_limit;
let mut rate_limit = rate_limiter.cache.get_with(client_ip, async { 0 }).await;

if rate_limit < rate_limiter.max_requests {
rate_limit += 1;
rate_limiter.cache.insert(client_ip, rate_limit).await;
next.run(req).await
} else {
(StatusCode::TOO_MANY_REQUESTS, "Too many requests").into_response()
}
}
10 changes: 9 additions & 1 deletion src/networking.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use {ipnet::IpNet, std::net::IpAddr};
use {axum::http::HeaderMap, ipnet::IpNet, std::net::IpAddr};

#[derive(thiserror::Error, Debug)]
pub enum NetworkInterfaceError {
Expand Down Expand Up @@ -65,3 +65,11 @@ fn is_public_ip_addr(addr: IpAddr) -> bool {

RESERVED_NETWORKS.iter().all(|range| !range.contains(&addr))
}

pub fn get_forwarded_ip(headers: HeaderMap) -> Option<IpAddr> {
headers
.get("X-Forwarded-For")
.and_then(|header| header.to_str().ok())
.and_then(|header| header.split(',').next())
.and_then(|client_ip| client_ip.trim().parse::<IpAddr>().ok())
}
7 changes: 7 additions & 0 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use {
crate::{
config::Config,
metrics::Metrics,
middleware::rate_limit,
networking,
providers::Provider,
relay::RelayClient,
Expand All @@ -10,6 +11,7 @@ use {
build_info::BuildInfo,
moka::future::Cache,
std::{net::IpAddr, sync::Arc},
tokio::time::Duration,
wc::geoip::{block::middleware::GeoBlockLayer, MaxMindResolver},
};

Expand Down Expand Up @@ -55,6 +57,7 @@ pub struct AppState {
pub uptime: std::time::Instant,
pub http_client: reqwest::Client,
pub provider_cache: Cache<String, Provider>,
pub rate_limit: rate_limit::RateLimiter,
}

build_info::build_info!(fn build_info);
Expand Down Expand Up @@ -101,6 +104,10 @@ pub fn new_state(
uptime: std::time::Instant::now(),
http_client: reqwest::Client::new(),
provider_cache: Cache::new(100),
rate_limit: rate_limit::RateLimiter::new(
rate_limit::MAX_REQUESTS_PER_SEC,
Duration::from_secs(1),
),
})
}

Expand Down
Loading