Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions backend/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
pub mod routing;
pub mod api;

pub use routing::{router, AppState};
pub use api::fetch_problem;
pub mod utils;
31 changes: 23 additions & 8 deletions backend/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use backend::{api, router, AppState};
use backend::utils::api;
use backend::utils::ratelimiter::RateLimiter;
use backend::utils::routing::{AppState, router};

use hyper::service::{make_service_fn, service_fn};
use hyper::Server;
use std::net::SocketAddr;
use std::convert::{From, Infallible};
use std::{println};
use std::result::Result::{Err, Ok};
use hyper::{Server, Body, Response, StatusCode};
use std::sync::Arc;
use std::net::SocketAddr;
use std::convert::Infallible;

#[tokio::main]
async fn main() {
Expand All @@ -14,6 +15,7 @@ async fn main() {
println!("Succeeded to fetch problems");

let state = Arc::new(AppState { problems, problem_models });
let limiter = RateLimiter::new();

// Fly.io 環境変数 PORT を使用
let port: u16 = std::env::var("PORT")
Expand All @@ -24,12 +26,25 @@ async fn main() {
// 0.0.0.0 でバインド
let addr = SocketAddr::from(([0, 0, 0, 0], port));

let make_svc = make_service_fn(move |_conn| {
let make_svc = make_service_fn(move |conn: &hyper::server::conn::AddrStream| {
let remote_addr = conn.remote_addr().ip();
let state = state.clone();
let limiter = limiter.clone();

async move {
Ok::<_, Infallible>(service_fn(move |req| {
let state = state.clone();
router(req, state)
let limiter = limiter.clone();
let ip = remote_addr;

async move {
if !limiter.check(ip).await {
let mut res = Response::new(Body::from("Too Many Requests"));
*res.status_mut() = StatusCode::TOO_MANY_REQUESTS;
return Ok::<_, Infallible>(res)
}
router(req, state).await
}
}))
}
});
Expand Down
File renamed without changes.
3 changes: 3 additions & 0 deletions backend/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod api;
pub mod ratelimiter;
pub mod routing;
37 changes: 37 additions & 0 deletions backend/src/utils/ratelimiter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use std::collections::HashMap;
use std::time::{Duration, Instant};
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;

#[derive(Clone)]
pub struct RateLimiter {
pub last_request: Arc<Mutex<HashMap<IpAddr, Instant>>>,
pub ttl: Duration,
}

impl RateLimiter {
pub fn new() -> Self {
Self {
last_request: Arc::new(Mutex::new(HashMap::new())),
ttl: Duration::from_secs(600),
}
}

pub async fn check(&self, ip: IpAddr) -> bool {
let mut map = self.last_request.lock().await;
let now = Instant::now();

map.retain(|_, &mut last| now.duration_since(last) < self.ttl);

match map.get(&ip) {
Some(&last) if now.duration_since(last) < Duration::from_secs(1) => {
false
},
_ => {
map.insert(ip, now);
true
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rand;
use rand::seq::IteratorRandom;
use serde::Serialize;

use crate::api::{Problem, ProblemModel};
use crate::utils::api::{Problem, ProblemModel};

#[derive(Clone)]
pub struct AppState {
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/api_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use backend::api::{Problem, ProblemModel};
use backend::utils::api::{Problem, ProblemModel};
use std::collections::HashMap;

/// fetch_problem のモック版
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/routing_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(unused)]

use backend::routing::{router, AppState};
use backend::api::{Problem, ProblemModel};
use backend::utils::routing::{router, AppState};
use backend::utils::api::{Problem, ProblemModel};
use hyper::{Body, Method, Request, StatusCode};
use std::assert;
use std::collections::HashMap;
Expand Down