Skip to content

Commit

Permalink
database fairing callback
Browse files Browse the repository at this point in the history
This doesn't quite solve the issue of transactional testing :(
  • Loading branch information
ELD committed Jan 9, 2020
1 parent 2d8bdd4 commit f13a56b
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 3 deletions.
129 changes: 129 additions & 0 deletions contrib/codegen/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,132 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
}
}.into())
}

pub fn test_database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStream> {
let invocation = parse_invocation(attr, input)?;

let conn_type = &invocation.connection_type;
let name = &invocation.db_name;
let guard_type = &invocation.type_name;
let vis = &invocation.visibility;
let pool_type = Ident::new(&format!("{}Pool", guard_type), guard_type.span());
let fairing_name = format!("'{}' Database Pool", name);
let span = conn_type.span().into();

let databases = quote_spanned!(span => ::rocket_contrib::databases);
let Poolable = quote_spanned!(span => #databases::Poolable);
let r2d2 = quote_spanned!(span => #databases::r2d2);
let request = quote!(::rocket::request);

let generated_types = quote_spanned! { span =>
/// The request guard type.
#vis struct #guard_type(pub #r2d2::PooledConnection<<#conn_type as #Poolable>::Manager>);

/// The pool type.
#vis struct #pool_type {
#vis pool: #r2d2::Pool<<#conn_type as #Poolable>::Manager>,
#vis callback: Box<dyn Fn(&mut #conn_type) + Send + Sync + 'static>,
}
};

Ok(quote! {
#generated_types

impl #guard_type {
/// Returns a fairing that initializes the associated database
/// connection pool.
pub fn fairing() -> impl ::rocket::fairing::Fairing {
use #databases::Poolable;

::rocket::fairing::AdHoc::on_attach(#fairing_name, |rocket| {
let pool = #databases::database_config(#name, rocket.config())
.map(<#conn_type>::pool);

match pool {
Ok(Ok(p)) => Ok(rocket.manage(#pool_type { pool: p, callback: Box::new(|_| ()) })),
Err(config_error) => {
::rocket::logger::error(
&format!("Database configuration failure: '{}'", #name));
::rocket::logger::error_(&format!("{}", config_error));
Err(rocket)
},
Ok(Err(pool_error)) => {
::rocket::logger::error(
&format!("Failed to initialize pool for '{}'", #name));
::rocket::logger::error_(&format!("{:?}", pool_error));
Err(rocket)
},
}
})
}

pub fn fairing_with_callback<F>(callback: F) -> impl ::rocket::fairing::Fairing
where F: Fn(&mut #conn_type) + Send + Sync + 'static {
use #databases::Poolable;

::rocket::fairing::AdHoc::on_attach(#fairing_name, |rocket| {
let pool = #databases::database_config(#name, rocket.config())
.map(<#conn_type>::pool);

match pool {
Ok(Ok(p)) => Ok(rocket.manage(#pool_type { pool: p, callback: Box::new(callback) })),
Err(config_error) => {
::rocket::logger::error(
&format!("Database configuration failure: '{}'", #name));
::rocket::logger::error_(&format!("{}", config_error));
Err(rocket)
},
Ok(Err(pool_error)) => {
::rocket::logger::error(
&format!("Failed to initialize pool for '{}'", #name));
::rocket::logger::error_(&format!("{:?}", pool_error));
Err(rocket)
},
}
})
}

/// Retrieves a connection of type `Self` from the `rocket`
/// instance. Returns `Some` as long as `Self::fairing()` has been
/// attached and there is at least one connection in the pool.
pub fn get_one(rocket: &::rocket::Rocket) -> Option<Self> {
rocket.state::<#pool_type>()
.and_then(|pool| pool.pool.get().ok())
.map(#guard_type)
}
}

impl ::std::ops::Deref for #guard_type {
type Target = #conn_type;

#[inline(always)]
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl ::std::ops::DerefMut for #guard_type {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl<'a, 'r> #request::FromRequest<'a, 'r> for #guard_type {
type Error = ();

fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome<Self, ()> {
use ::rocket::{Outcome, http::Status};
let pool = ::rocket::try_outcome!(request.guard::<::rocket::State<#pool_type>>());

match pool.inner().pool.get() {
Ok(mut conn) => {
(pool.inner().callback)(&mut *conn);
Outcome::Success(#guard_type(conn))
},
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
}
}
}
}.into())
}
9 changes: 9 additions & 0 deletions contrib/codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ pub fn database(attr: TokenStream, input: TokenStream) -> TokenStream {
TokenStream::new()
})
}

#[cfg(feature = "database_attribute")]
#[proc_macro_attribute]
pub fn test_database(attr: TokenStream, input: TokenStream) -> TokenStream {
crate::database::test_database_attr(attr, input).unwrap_or_else(|diag| {
diag.emit();
TokenStream::new()
})
}
23 changes: 20 additions & 3 deletions examples/todo/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use rocket::fairing::AdHoc;
use rocket::request::{Form, FlashMessage};
use rocket::response::{Flash, Redirect};
use rocket_contrib::{templates::Template, serve::StaticFiles};
use diesel::SqliteConnection;
use diesel::{SqliteConnection, Connection};

use crate::task::{Task, Todo};

Expand All @@ -24,6 +24,11 @@ use crate::task::{Task, Todo};
// tested without any outside setup of the database.
embed_migrations!();

#[cfg(test)]
#[test_database("sqlite_database")]
pub struct DbConn(SqliteConnection);

#[cfg(not(test))]
#[database("sqlite_database")]
pub struct DbConn(SqliteConnection);

Expand Down Expand Up @@ -90,13 +95,25 @@ fn run_db_migrations(rocket: Rocket) -> Result<Rocket, Rocket> {
}

fn rocket() -> Rocket {
rocket::ignite()
#[cfg(test)]
let rocket = rocket::ignite()
.attach(DbConn::fairing_with_callback(|conn: &mut diesel::SqliteConnection| {conn.begin_test_transaction();}))
.attach(AdHoc::on_attach("Database Migrations", run_db_migrations))
.mount("/", StaticFiles::from("static/"))
.mount("/", routes![index])
.mount("/todo", routes![new, toggle, delete])
.attach(Template::fairing());

#[cfg(not(test))]
let rocket = rocket::ignite()
.attach(DbConn::fairing())
.attach(AdHoc::on_attach("Database Migrations", run_db_migrations))
.mount("/", StaticFiles::from("static/"))
.mount("/", routes![index])
.mount("/todo", routes![new, toggle, delete])
.attach(Template::fairing())
.attach(Template::fairing());

rocket
}

fn main() {
Expand Down

0 comments on commit f13a56b

Please sign in to comment.