diff --git a/backend/src/api/mod.rs b/backend/src/api/mod.rs index 35c8173..601d301 100644 --- a/backend/src/api/mod.rs +++ b/backend/src/api/mod.rs @@ -1,39 +1,59 @@ use serde::Deserialize; -use std::{collections::HashMap, error::Error, println, string::String, vec::Vec}; +use core::prelude::v1::derive; +use std::{collections::HashMap, error::Error, fs, iter::{IntoIterator, Iterator}, option::Option::{self, None, Some}, path::PathBuf, string::String, vec::Vec}; -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct Problem { pub id: String, pub contest_id: String, pub name: String, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] +pub struct ProblemModelRaw { + pub difficulty: Option, +} + +#[derive(Debug, Clone)] pub struct ProblemModel { - pub difficulty: Option, + pub difficulty: Option, +} + +fn adjust_difficulty(difficulty: Option) -> Option { + match difficulty { + Some(d) if d >= 400 => Some(d as f64), + Some(d) => Some(400.0 / f64::exp(1.0 - d as f64 / 400.0)), + None => None, + } } pub async fn fetch_problem() -> Result<(Vec, HashMap), Box> { - // let problems_url = "https://kenkoooo.com/atcoder/resources/problems.json"; - // let problem_models_url = "https://kenkoooo.com/atcoder/resources/problem-models.json"; - let problems_url = "https://github.com/Twil3akine/atcoder-random-picker/blob/master/.gitconfig"; - let problem_models_url = "https://github.com/Twil3akine/atcoder-random-picker/blob/master/.gitconfig"; + // カレントディレクトリ基準で src/data を指す + let mut problems_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + problems_path.push("src/data/problems.json"); - let client = reqwest::Client::builder() - .user_agent("atcoder-random-picker/0.1 (twil3; contact: twil3akine@gmail.com)") - .build()?; + let mut problem_models_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + problem_models_path.push("src/data/problem-models.json"); - // 問題一覧 - let problems_text = client.get(problems_url).send().await?.text().await?; + // ファイル読み込み + let problems_text = fs::read_to_string(problems_path)?; let problems: Vec = serde_json::from_str(&problems_text)?; - // 問題モデル - let problem_models_text = client.get(problem_models_url).send().await?.text().await?; - let problem_models: HashMap = serde_json::from_str(&problem_models_text)?; - - for (id, model) in &problem_models { - println!("{}: {:?}", id, model.difficulty); - } + let problem_models_text = fs::read_to_string(problem_models_path)?; + let raw_models: HashMap = serde_json::from_str(&problem_models_text)?; + + // 補正式を適用させる + let problem_models: HashMap = raw_models + .into_iter() + .map(|(id, raw)| { + ( + id, + ProblemModel { + difficulty: adjust_difficulty(raw.difficulty), + }, + ) + }) + .collect(); Ok((problems, problem_models)) -} \ No newline at end of file +} diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 0f0334b..6232d93 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -1,4 +1,5 @@ pub mod routing; pub mod api; -pub use routing::router; \ No newline at end of file +pub use routing::{router, AppState}; +pub use api::fetch_problem; \ No newline at end of file diff --git a/backend/src/main.rs b/backend/src/main.rs index 2f76d00..72bf1c3 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,32 +1,40 @@ +use backend::{api, router, AppState}; use hyper::service::{make_service_fn, service_fn}; -use hyper::{Server}; +use hyper::Server; use std::net::SocketAddr; -use std::convert::Infallible; +use std::convert::{From, Infallible}; use std::{println}; -use std::result::Result::Err; - -use backend::{api, router}; +use std::result::Result::{Err, Ok}; +use std::sync::Arc; #[tokio::main] async fn main() { match api::fetch_problem().await { - Ok(_ps) => println!("Successed to fetch problems"), - Err(e) => eprintln!("Failed to failed problems: {}", e), - } - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + Ok((problems, problem_models)) => { + println!("Succeeded to fetch problems"); - let make_svc = make_service_fn(|_conn| async { - Ok::<_, Infallible>(service_fn(router)) - }); + let state = Arc::new(AppState { problems, problem_models }); + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let make_svc = make_service_fn(move |_conn| { + let state = state.clone(); + async move { + Ok::<_, Infallible>(service_fn(move |req| { + let state = state.clone(); + router(req, state) + })) + } + }); - let server = Server::bind(&addr).serve(make_svc); + let server = Server::bind(&addr).serve(make_svc); - println!("Running server on : http://{}.", addr); + println!("Running server on http://{}.", addr); - if let Err(e) = server.await { - eprintln!("error on {}.", e); - } else { - println!("error was shut down."); + if let Err(e) = server.await { + eprintln!("error on {}.", e); + } else { + println!("server shut down."); + } + } + Err(e) => eprintln!("Failed to fetch problems: {}", e), } } \ No newline at end of file diff --git a/backend/src/routing/mod.rs b/backend/src/routing/mod.rs index 505452f..4ff14aa 100644 --- a/backend/src/routing/mod.rs +++ b/backend/src/routing/mod.rs @@ -1,25 +1,35 @@ use hyper::{Body, Request, Response, StatusCode, header}; -use std::convert::{From, Infallible}; -use std::collections::HashMap; -use std::format; use std::option::Option::None; use std::result::Result::Ok; -use url::form_urlencoded; +use std::{collections::HashMap}; +use std::convert::{From, Infallible}; +use std::sync::Arc; use chrono::{DateTime, Local}; -use rand::Rng; +use rand; +use rand::seq::IteratorRandom; +use serde::Serialize; -async fn get_parameter(req: Request) -> HashMap { - let query = req.uri().query().unwrap_or(""); - let params: HashMap = form_urlencoded::parse(query.as_bytes()) - .into_owned() - .collect(); +use crate::api::{Problem, ProblemModel}; - params +#[derive(Clone)] +pub struct AppState { + pub problems: Vec, + pub problem_models: HashMap, } -fn get_random_number(under: u32, over: u32) -> u32 { - let mut rng = rand::thread_rng(); - rng.gen_range(under..=over) +#[derive(Serialize)] +struct ProblemResponse { + id: String, + contest_id: String, + name: String, + difficulty: Option, +} + +async fn get_parameter(req: &Request) -> HashMap { + let query = req.uri().query().unwrap_or(""); + url::form_urlencoded::parse(query.as_bytes()) + .into_owned() + .collect() } fn log(now: DateTime, method: &str, path: &str, status: StatusCode) { @@ -42,9 +52,8 @@ fn with_cors_headers(mut res: Response) -> Response { res } -pub async fn router(req: Request) -> Result, Infallible> { +pub async fn router(req: Request, state: Arc) -> Result, Infallible> { let now= Local::now(); - let path = req.uri().path().to_string(); let method = req.method().to_string(); @@ -55,17 +64,10 @@ pub async fn router(req: Request) -> Result, Infallible> { } (&hyper::Method::GET, "/") => { - let params: HashMap = get_parameter(req).await; + let params: HashMap = get_parameter(&req).await; - let under: u32 = match params.get("under").and_then(|s| s.parse().ok()) { - Some(v) => v, - None => 0, - }; - - let over: u32 = match params.get("over").and_then(|s| s.parse().ok()) { - Some(v) => v, - None => 3854, - }; + let under: f64 = params.get("under").and_then(|s| s.parse().ok()).unwrap_or(0.0); + let over: f64 = params.get("over").and_then(|s| s.parse().ok()).unwrap_or(3854.0); if under > over { let mut bad_request = Response::new(Body::from("'under' cannot bt greater than 'over'.")); @@ -73,9 +75,35 @@ pub async fn router(req: Request) -> Result, Infallible> { return Ok(bad_request); } - let random_number = get_random_number(under, over); - let response_body = format!("{}", random_number); - Ok(with_cors_headers(Response::new(Body::from(response_body)))) + let mut rng = rand::thread_rng(); + let selected = state.problems.iter().filter_map(|p| { + state.problem_models.get(&p.id).and_then(|m| { + m.difficulty.and_then(|diff| { + if under <= diff && diff <= over { + Some(ProblemResponse { + id: p.id.clone(), + contest_id: p.contest_id.clone(), + name: p.name.clone(), + difficulty: Some(diff), + }) + } else { + None + } + }) + }) + }).choose(&mut rng); + + match selected { + Some(problem) => { + let body = serde_json::to_string(&problem).unwrap(); + Ok(with_cors_headers(Response::new(Body::from(body)))) + } + None => { + let mut not_found = Response::new(Body::from("No problem found in given range.")); + *not_found.status_mut() = StatusCode::NOT_FOUND; + Ok(with_cors_headers(not_found)) + } + } } _ => { diff --git a/backend/tests/api_test.rs b/backend/tests/api_test.rs new file mode 100644 index 0000000..a2bd3ee --- /dev/null +++ b/backend/tests/api_test.rs @@ -0,0 +1,22 @@ +#![allow(unused)] + +use backend::api::{fetch_problem, Problem, ProblemModel}; +use std::collections::HashMap; + +#[tokio::test] +async fn test_fetch_problem_returns_data() { + let result = fetch_problem().await; + + assert!(result.is_ok()); + let (problems, problem_models): (Vec, HashMap) = result.unwrap(); + + // 空でないことを確認 + assert!(!problems.is_empty()); + assert!(!problem_models.is_empty()); + + // Problem の中身確認 + let first_problem = &problems[0]; + assert!(!first_problem.id.is_empty()); + assert!(!first_problem.contest_id.is_empty()); + assert!(!first_problem.name.is_empty()); +} diff --git a/backend/tests/routing_test.rs b/backend/tests/routing_test.rs index 3a41713..f61bebd 100644 --- a/backend/tests/routing_test.rs +++ b/backend/tests/routing_test.rs @@ -1,76 +1,86 @@ -#![allow(unused_comparisons)] +#![allow(unused)] -use backend::router; +use backend::routing::{router, AppState}; +use backend::api::{Problem, ProblemModel}; use hyper::{Body, Method, Request, StatusCode}; +use std::assert; +use std::collections::HashMap; +use std::sync::Arc; + +fn build_test_state() -> Arc { + let mut problem_models = HashMap::new(); + problem_models.insert( + "abc001_a".to_string(), + ProblemModel { difficulty: Some(1000.0) }, + ); + + let problems = vec![ + Problem { + id: "abc001_a".to_string(), + contest_id: "abc001".to_string(), + name: "A - Test Problem".to_string(), + } + ]; + + Arc::new(AppState { + problems, + problem_models, + }) +} async fn build_and_send(method: Method, path: &str) -> (StatusCode, String) { - let req = Request::builder() - .method(method) - .uri(path) - .body(Body::empty()) - .unwrap(); + let req = Request::builder() + .method(method) + .uri(path) + .body(Body::empty()) + .unwrap(); - let res = router(req).await.unwrap(); + let state = build_test_state(); + let res = router(req, state).await.unwrap(); - let status = res.status(); - let body_bytes = hyper::body::to_bytes(res.into_body()).await.unwrap(); - let body_string = String::from_utf8(body_bytes.to_vec()).unwrap(); + let status = res.status(); + let body_bytes = hyper::body::to_bytes(res.into_body()).await.unwrap(); + let body_string = String::from_utf8(body_bytes.to_vec()).unwrap(); - (status, body_string) + (status, body_string) } - - - - #[tokio::test] async fn test_not_found_path() { - let (status, body) = build_and_send(Method::GET, "/test").await; - - assert_eq!(status, StatusCode::NOT_FOUND); - assert_eq!(body, "404 Not Found"); + let (status, body) = build_and_send(Method::GET, "/test").await; + assert_eq!(status, StatusCode::NOT_FOUND); + assert_eq!(body, "404 Not Found"); } #[tokio::test] -async fn test_calc_with_both_params() { - let (status, body) = build_and_send(Method::GET, "/?under=10&over=20").await; - - assert_eq!(status, StatusCode::OK); - let random_number: u32 = body.parse().unwrap(); - assert!(10 <= random_number && random_number <= 20); -} - -#[tokio::test] -async fn test_calc_without_under() { - let (status, body) = build_and_send(Method::GET, "/?over=5").await; - - assert_eq!(status, StatusCode::OK); - let random_number: u32 = body.parse().unwrap(); - assert!(0 <= random_number && random_number <= 5); +async fn test_not_found_problem() { + let (status, body) = build_and_send(Method::GET, "/?under=0&over=500").await; + assert_eq!(status, StatusCode::NOT_FOUND); + assert_eq!(body, "No problem found in given range."); } #[tokio::test] -async fn test_calc_without_over() { - let (status, body) = build_and_send(Method::GET, "/?under=100").await; - - assert_eq!(status, StatusCode::OK); - let random_number: u32 = body.parse().unwrap(); - assert!(100 <= random_number && random_number <= 3854); +async fn test_under_greater_than_over() { + let (status, body) = build_and_send(Method::GET, "/?under=1500&over=500").await; + assert_eq!(status, StatusCode::BAD_REQUEST); + assert_eq!(body, "'under' cannot bt greater than 'over'."); } #[tokio::test] -async fn test_calc_without_any_params() { - let (status, body) = build_and_send(Method::GET, "/").await; - +async fn test_random_range() { + let (status, body) = build_and_send(Method::GET, "/?under=500&over=1500").await; assert_eq!(status, StatusCode::OK); - let random_number: u32 = body.parse().unwrap(); - assert!(0 <= random_number && random_number <= 3854); -} - -#[tokio::test] -async fn test_calc_under_greater_than_over() { - let (status, body) = build_and_send(Method::GET, "/?under=50&over=10").await; - - assert_eq!(status, StatusCode::BAD_REQUEST); - assert_eq!(body, "'under' cannot bt greater than 'over'."); + + #[derive(serde::Deserialize)] + struct ProblemResponse { + id: String, + contest_id: String, + name: String, + difficulty: f64, + } + + let problem: ProblemResponse = serde_json::from_str(&body).unwrap(); + let diff = problem.difficulty; + + assert!(500.0 <= diff && diff <= 1500.0); } \ No newline at end of file