Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

axum6 with typesafe state #674

Merged
merged 1 commit into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 33 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions atuin-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ config = { version = "0.13", default-features = false, features = ["toml"] }
serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.86"
sodiumoxide = "0.2.6"
base64 = "0.20.0"
base64 = "0.21.0"
rand = "0.8.4"
tokio = { version = "1", features = ["full"] }
sqlx = { version = "0.6", features = [
Expand All @@ -29,7 +29,7 @@ sqlx = { version = "0.6", features = [
"postgres",
] }
async-trait = "0.1.58"
axum = "0.5"
axum = "0.6.4"
http = "0.2"
fs-err = "2.7"
chronoutil = "0.2.3"
Expand Down
21 changes: 13 additions & 8 deletions atuin-server/src/handlers/history.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
use std::collections::HashMap;

use axum::{
extract::{Path, Query},
Extension, Json,
extract::{Path, Query, State},
Json,
};
use http::StatusCode;
use tracing::{debug, error, instrument};

use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::{
calendar::{TimePeriod, TimePeriodInfo},
database::{Database, Postgres},
database::Database,
models::{NewHistory, User},
router::AppState,
};

use atuin_common::api::*;

#[instrument(skip_all, fields(user.id = user.id))]
pub async fn count(
user: User,
db: Extension<Postgres>,
state: State<AppState>,
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres;
match db.count_history_cached(&user).await {
// By default read out the cached value
Ok(count) => Ok(Json(CountResponse { count })),
Expand All @@ -39,8 +41,9 @@ pub async fn count(
pub async fn list(
req: Query<SyncHistoryRequest>,
user: User,
db: Extension<Postgres>,
state: State<AppState>,
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres;
let history = db
.list_history(
&user,
Expand Down Expand Up @@ -73,9 +76,9 @@ pub async fn list(

#[instrument(skip_all, fields(user.id = user.id))]
pub async fn add(
Json(req): Json<Vec<AddHistoryRequest>>,
user: User,
db: Extension<Postgres>,
state: State<AppState>,
Json(req): Json<Vec<AddHistoryRequest>>,
) -> Result<(), ErrorResponseStatus<'static>> {
debug!("request to add {} history items", req.len());

Expand All @@ -90,6 +93,7 @@ pub async fn add(
})
.collect();

let db = &state.0.postgres;
if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e);

Expand All @@ -105,13 +109,14 @@ pub async fn calendar(
Path(focus): Path<String>,
Query(params): Query<HashMap<String, u64>>,
user: User,
db: Extension<Postgres>,
state: State<AppState>,
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
let focus = focus.as_str();

let year = params.get("year").unwrap_or(&0);
let month = params.get("month").unwrap_or(&1);

let db = &state.0.postgres;
let focus = match focus {
"year" => db
.calendar(&user, TimePeriod::YEAR, *year, *month)
Expand Down
19 changes: 13 additions & 6 deletions atuin-server/src/handlers/user.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use std::borrow::Borrow;

use axum::{extract::Path, Extension, Json};
use axum::{
extract::{Path, State},
Extension, Json,
};
use http::StatusCode;
use sodiumoxide::crypto::pwhash::argon2id13;
use tracing::{debug, error, instrument};
use uuid::Uuid;

use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::{
database::{Database, Postgres},
database::Database,
models::{NewSession, NewUser},
router::AppState,
settings::Settings,
};

Expand All @@ -32,8 +36,9 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
#[instrument(skip_all, fields(user.username = username.as_str()))]
pub async fn get(
Path(username): Path<String>,
db: Extension<Postgres>,
state: State<AppState>,
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres;
let user = match db.get_user(username.as_ref()).await {
Ok(user) => user,
Err(sqlx::Error::RowNotFound) => {
Expand All @@ -54,9 +59,9 @@ pub async fn get(

#[instrument(skip_all)]
pub async fn register(
Json(register): Json<RegisterRequest>,
settings: Extension<Settings>,
db: Extension<Postgres>,
state: State<AppState>,
Json(register): Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
if !settings.open_registration {
return Err(
Expand All @@ -73,6 +78,7 @@ pub async fn register(
password: hashed,
};

let db = &state.0.postgres;
let user_id = match db.add_user(&new_user).await {
Ok(id) => id,
Err(e) => {
Expand Down Expand Up @@ -102,9 +108,10 @@ pub async fn register(

#[instrument(skip_all, fields(user.username = login.username.as_str()))]
pub async fn login(
state: State<AppState>,
login: Json<LoginRequest>,
db: Extension<Postgres>,
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres;
let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u,
Err(sqlx::Error::RowNotFound) => {
Expand Down
43 changes: 21 additions & 22 deletions atuin-server/src/router.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use async_trait::async_trait;
use axum::{
extract::{FromRequest, RequestParts},
handler::Handler,
extract::FromRequestParts,
response::IntoResponse,
routing::{get, post},
Extension, Router,
Router,
};
use eyre::Result;
use http::request::Parts;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;

Expand All @@ -17,20 +17,15 @@ use super::{
use crate::{models::User, settings::Settings};

#[async_trait]
impl<B> FromRequest<B> for User
where
B: Send,
{
impl FromRequestParts<AppState> for User {
type Rejection = http::StatusCode;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let postgres = req
.extensions()
.get::<Postgres>()
.ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?;

async fn from_request_parts(
req: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let auth_header = req
.headers()
.headers
.get(http::header::AUTHORIZATION)
.ok_or(http::StatusCode::FORBIDDEN)?;
let auth_header = auth_header
Expand All @@ -44,7 +39,8 @@ where
return Err(http::StatusCode::FORBIDDEN);
}

let user = postgres
let user = state
.postgres
.get_session_user(token)
.await
.map_err(|_| http::StatusCode::FORBIDDEN)?;
Expand All @@ -56,6 +52,13 @@ where
async fn teapot() -> impl IntoResponse {
(http::StatusCode::IM_A_TEAPOT, "☕")
}

#[derive(Clone)]
pub struct AppState {
pub postgres: Postgres,
pub settings: Settings,
}

pub fn router(postgres: Postgres, settings: Settings) -> Router {
let routes = Router::new()
.route("/", get(handlers::index))
Expand All @@ -73,11 +76,7 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router {
} else {
Router::new().nest(path, routes)
}
.fallback(teapot.into_service())
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(Extension(postgres))
.layer(Extension(settings)),
)
.fallback(teapot)
.with_state(AppState { postgres, settings })
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
}