Skip to content

Commit

Permalink
Reuse db connection in request
Browse files Browse the repository at this point in the history
  • Loading branch information
sdankel committed May 2, 2024
1 parent dc7d19d commit 7597850
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 48 deletions.
21 changes: 8 additions & 13 deletions src/db/api_token.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::error::DatabaseError;
use super::{models, schema};
use super::{models, schema, DbConn};
use super::{string_to_uuid, Database};
use diesel::prelude::*;

Expand All @@ -13,15 +13,13 @@ use uuid::Uuid;
const TOKEN_PREFIX: &str = "pub_";
const TOKEN_LENGTH: usize = 32;

impl Database {
impl DbConn {
/// Creates an API token for the user and returns the token.
pub fn new_token(
&self,
&mut self,
user_id: Uuid,
friendly_name: String,
) -> Result<(models::Token, String), DatabaseError> {
let connection = &mut self.connection();

let plain_token = generate_token();
let token = Sha256::digest(plain_token.as_bytes()).as_slice().to_vec();

Expand All @@ -36,36 +34,33 @@ impl Database {
let saved_token = diesel::insert_into(schema::api_tokens::table)
.values(&new_token)
.returning(models::Token::as_returning())
.get_result(connection)
.get_result(self.inner())
.map_err(|_| DatabaseError::InsertTokenFailed(user_id.to_string()))?;

Ok((saved_token, plain_token))
}

/// Deletes an API token for the user.
pub fn delete_token(&self, user_id: Uuid, token_id: String) -> Result<(), DatabaseError> {
let connection = &mut self.connection();

pub fn delete_token(&mut self, user_id: Uuid, token_id: String) -> Result<(), DatabaseError> {
let token_uuid = string_to_uuid(token_id.clone())?;

diesel::delete(
schema::api_tokens::table
.filter(schema::api_tokens::id.eq(token_uuid))
.filter(schema::api_tokens::user_id.eq(user_id)),
)
.execute(connection)
.execute(self.inner())
.map_err(|_| DatabaseError::NotFound(token_id))?;

Ok(())
}

/// Fetch all tokens for the given user ID.
pub fn get_tokens_for_user(&self, user_id: Uuid) -> Result<Vec<models::Token>, DatabaseError> {
let connection = &mut self.connection();
pub fn get_tokens_for_user(&mut self, user_id: Uuid) -> Result<Vec<models::Token>, DatabaseError> {
schema::api_tokens::table
.filter(schema::api_tokens::user_id.eq(user_id))
.select(models::Token::as_returning())
.load(connection)
.load(self.inner())
.map_err(|_| DatabaseError::NotFound(user_id.to_string()))
}
}
Expand Down
11 changes: 9 additions & 2 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ impl Default for Database {
}
}

pub struct DbConn(DbConnection);
impl DbConn {
pub fn inner(&mut self) -> &mut PgConnection {
&mut self.0
}
}

impl Database {
pub fn new() -> Self {
// Create a connection pool
Expand All @@ -44,8 +51,8 @@ impl Database {
}

/// Get a connection from the pool.
pub fn connection(&self) -> DbConnection {
self.pool.get().expect("db connection")
pub fn conn(&self) -> DbConn {
DbConn(self.pool.get().expect("db connection"))
}
}

Expand Down
32 changes: 13 additions & 19 deletions src/db/user_session.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
use super::error::DatabaseError;
use super::{api, models, schema};
use super::{api, models, schema, DbConn};
use super::{string_to_uuid, Database};
use diesel::prelude::*;
use diesel::upsert::excluded;
use std::time::{Duration, SystemTime};
use uuid::Uuid;

impl Database {
impl DbConn {
/// Insert a user session into the database and return the session ID.
/// If the user doesn't exist, insert the user as well.
/// If the user does exist, update the user's full name and avatar URL if they have changed.
pub fn insert_user_session(
&self,
&mut self,
user: &api::auth::User,
expires_in: u32,
) -> Result<models::Session, DatabaseError> {
let connection = &mut self.connection();

// Insert or update a user
let new_user = models::NewUser {
full_name: user.full_name.clone(),
Expand All @@ -36,7 +34,7 @@ impl Database {
schema::users::full_name.eq(excluded(schema::users::full_name)),
schema::users::avatar_url.eq(excluded(schema::users::avatar_url)),
))
.get_result(connection)
.get_result(self.inner())
.map_err(|_| DatabaseError::InsertUserFailed(user.github_login.clone()))?;

let new_session = models::NewSession {
Expand All @@ -48,49 +46,45 @@ impl Database {
let saved_session = diesel::insert_into(schema::sessions::table)
.values(&new_session)
.returning(models::Session::as_returning())
.get_result(connection)
.get_result(self.inner())
.map_err(|_| DatabaseError::InsertSessionFailed(user.github_login.clone()))?;

Ok(saved_session)
}

/// Fetch a user given the user ID.
pub fn get_user(&self, user_id: Uuid) -> Result<models::User, DatabaseError> {
let connection = &mut self.connection();
pub fn get_user(&mut self, user_id: Uuid) -> Result<models::User, DatabaseError> {
schema::users::table
.filter(schema::users::id.eq(user_id))
.select(models::User::as_returning())
.first::<models::User>(connection)
.first::<models::User>(self.inner())
.map_err(|_| DatabaseError::NotFound(user_id.to_string()))
}

/// Fetch a user given the user ID.
pub fn get_session(&self, session_id: Uuid) -> Result<models::Session, DatabaseError> {
let connection = &mut self.connection();
pub fn get_session(&mut self, session_id: Uuid) -> Result<models::Session, DatabaseError> {
schema::sessions::table
.filter(schema::sessions::id.eq(session_id))
.select(models::Session::as_returning())
.first::<models::Session>(connection)
.first::<models::Session>(self.inner())
.map_err(|_| DatabaseError::NotFound(session_id.to_string()))
}

/// Fetch a user from the database for a given session ID.
pub fn get_user_for_session(&self, session_id: String) -> Result<models::User, DatabaseError> {
pub fn get_user_for_session(&mut self, session_id: String) -> Result<models::User, DatabaseError> {
let session_uuid = string_to_uuid(session_id.clone())?;
let connection = &mut self.connection();
schema::sessions::table
.inner_join(schema::users::table)
.filter(schema::sessions::id.eq(session_uuid))
.select(models::User::as_returning())
.first::<models::User>(connection)
.first::<models::User>(self.inner())
.map_err(|_| DatabaseError::NotFound(session_id))
}

/// Delete a session given its ID.
pub fn delete_session(&self, session_id: Uuid) -> Result<(), DatabaseError> {
let connection = &mut self.connection();
pub fn delete_session(&mut self, session_id: Uuid) -> Result<(), DatabaseError> {
diesel::delete(schema::sessions::table.filter(schema::sessions::id.eq(session_id)))
.execute(connection)
.execute(self.inner())
.map_err(|_| DatabaseError::NotFound(session_id.to_string()))?;
Ok(())
}
Expand Down
10 changes: 5 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async fn login(
request: Json<LoginRequest>,
) -> ApiResult<LoginResponse> {
let (user, expires_in) = handle_login(request.code.clone()).await?;
let session = db.insert_user_session(&user, expires_in)?;
let session = db.conn().insert_user_session(&user, expires_in)?;
let session_id = session.id.to_string();
cookies.add(Cookie::build(SESSION_COOKIE_NAME, session_id.clone()).finish());
Ok(Json(LoginResponse { user, session_id }))
Expand All @@ -41,7 +41,7 @@ async fn login(
#[post("/logout")]
async fn logout(db: &State<Database>, auth: SessionAuth) -> ApiResult<EmptyResponse> {
let session_id = auth.session_id;
let _ = db.delete_session(session_id)?;
let _ = db.conn().delete_session(session_id)?;
Ok(Json(EmptyResponse))
}

Expand All @@ -60,7 +60,7 @@ fn new_token(
request: Json<CreateTokenRequest>,
) -> ApiResult<CreateTokenResponse> {
let user = auth.user;
let (token, plain_token) = db.new_token(user.id, request.name.clone())?;
let (token, plain_token) = db.conn().new_token(user.id, request.name.clone())?;
Ok(Json(CreateTokenResponse {
token: Token {
// The only time we return the plain token is when it's created.
Expand All @@ -73,14 +73,14 @@ fn new_token(
#[delete("/token/<id>")]
fn delete_token(db: &State<Database>, auth: SessionAuth, id: String) -> ApiResult<EmptyResponse> {
let user_id = auth.user.id;
let _ = db.delete_token(user_id, id.clone())?;
let _ = db.conn().delete_token(user_id, id.clone())?;
Ok(Json(EmptyResponse))
}

#[get("/tokens")]
fn tokens(db: &State<Database>, auth: SessionAuth) -> ApiResult<TokensResponse> {
let user_id = auth.user.id;
let tokens = db.get_tokens_for_user(user_id)?;
let tokens = db.conn().get_tokens_for_user(user_id)?;
Ok(Json(TokensResponse {
tokens: tokens.into_iter().map(|t| t.into()).collect(),
}))
Expand Down
10 changes: 8 additions & 2 deletions src/middleware/session_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,23 @@ pub struct SessionAuth {
pub enum SessionAuthError {
Missing,
Invalid,
DatabaseConnection,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for SessionAuth {
type Error = SessionAuthError;

async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
// TODO: use fairing?
// TODO: use fairing for db connection?
// let db = try_outcome!(request.guard::<Database>().await);

let db = request.rocket().state::<Database>().unwrap();
let mut db = match request.rocket().state::<Database>() {
Some(db) => {
db.conn()
},
None => return Outcome::Failure((Status::InternalServerError, SessionAuthError::DatabaseConnection)),
};
if let Some(Some(session_id)) = request
.cookies()
.get(SESSION_COOKIE_NAME)
Expand Down
16 changes: 9 additions & 7 deletions tests/db_integration.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use diesel::RunQueryDsl as _;
use forc_pub::api;
use forc_pub::db::Database;
use forc_pub::db::{Database, DbConn};

/// Note: Integration tests for the database module assume that the database is running and that the DATABASE_URL environment variable is set.
/// This should be done by running `./scripts/start_local_db.sh` before running the tests.
Expand All @@ -12,13 +12,15 @@ const TEST_URL_1: &str = "url1.url";
const TEST_URL_2: &str = "url2.url";
const TEST_LOGIN_2: &str = "foobar";

fn clear_tables(db: &Database) {
let connection = &mut db.connection();
fn clear_tables(db: &mut DbConn) {
diesel::delete(forc_pub::schema::api_tokens::table)
.execute(db.inner())
.expect("clear api_tokens table");
diesel::delete(forc_pub::schema::sessions::table)
.execute(connection)
.execute(db.inner())
.expect("clear sessions table");
diesel::delete(forc_pub::schema::users::table)
.execute(connection)
.execute(db.inner())
.expect("clear users table");
}

Expand All @@ -42,7 +44,7 @@ fn mock_user_2() -> api::auth::User {

#[test]
fn test_multiple_user_sessions() {
let db = Database::default();
let db = &mut Database::default().conn();

let user1 = mock_user_1();
let user2 = mock_user_2();
Expand Down Expand Up @@ -75,5 +77,5 @@ fn test_multiple_user_sessions() {
.expect("result is ok");
assert_eq!(result.github_login, TEST_LOGIN_2);

clear_tables(&db);
clear_tables(db);
}

0 comments on commit 7597850

Please sign in to comment.