diff --git a/contrib/codegen/src/database.rs b/contrib/codegen/src/database.rs index be5a2c1359..67b9facf80 100644 --- a/contrib/codegen/src/database.rs +++ b/contrib/codegen/src/database.rs @@ -155,3 +155,132 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result Result { + let invocation = parse_invocation(attr, input)?; + + let conn_type = &invocation.connection_type; + let name = &invocation.db_name; + let guard_type = &invocation.type_name; + let vis = &invocation.visibility; + let pool_type = Ident::new(&format!("{}Pool", guard_type), guard_type.span()); + let fairing_name = format!("'{}' Database Pool", name); + let span = conn_type.span().into(); + + let databases = quote_spanned!(span => ::rocket_contrib::databases); + let Poolable = quote_spanned!(span => #databases::Poolable); + let r2d2 = quote_spanned!(span => #databases::r2d2); + let request = quote!(::rocket::request); + + let generated_types = quote_spanned! { span => + /// The request guard type. + #vis struct #guard_type(pub #r2d2::PooledConnection<<#conn_type as #Poolable>::Manager>); + + /// The pool type. + #vis struct #pool_type { + #vis pool: #r2d2::Pool<<#conn_type as #Poolable>::Manager>, + #vis callback: Box, + } + }; + + Ok(quote! { + #generated_types + + impl #guard_type { + /// Returns a fairing that initializes the associated database + /// connection pool. + pub fn fairing() -> impl ::rocket::fairing::Fairing { + use #databases::Poolable; + + ::rocket::fairing::AdHoc::on_attach(#fairing_name, |rocket| { + let pool = #databases::database_config(#name, rocket.config()) + .map(<#conn_type>::pool); + + match pool { + Ok(Ok(p)) => Ok(rocket.manage(#pool_type { pool: p, callback: Box::new(|_| ()) })), + Err(config_error) => { + ::rocket::logger::error( + &format!("Database configuration failure: '{}'", #name)); + ::rocket::logger::error_(&format!("{}", config_error)); + Err(rocket) + }, + Ok(Err(pool_error)) => { + ::rocket::logger::error( + &format!("Failed to initialize pool for '{}'", #name)); + ::rocket::logger::error_(&format!("{:?}", pool_error)); + Err(rocket) + }, + } + }) + } + + pub fn fairing_with_callback(callback: F) -> impl ::rocket::fairing::Fairing + where F: Fn(&mut #conn_type) + Send + Sync + 'static { + use #databases::Poolable; + + ::rocket::fairing::AdHoc::on_attach(#fairing_name, |rocket| { + let pool = #databases::database_config(#name, rocket.config()) + .map(<#conn_type>::pool); + + match pool { + Ok(Ok(p)) => Ok(rocket.manage(#pool_type { pool: p, callback: Box::new(callback) })), + Err(config_error) => { + ::rocket::logger::error( + &format!("Database configuration failure: '{}'", #name)); + ::rocket::logger::error_(&format!("{}", config_error)); + Err(rocket) + }, + Ok(Err(pool_error)) => { + ::rocket::logger::error( + &format!("Failed to initialize pool for '{}'", #name)); + ::rocket::logger::error_(&format!("{:?}", pool_error)); + Err(rocket) + }, + } + }) + } + + /// Retrieves a connection of type `Self` from the `rocket` + /// instance. Returns `Some` as long as `Self::fairing()` has been + /// attached and there is at least one connection in the pool. + pub fn get_one(rocket: &::rocket::Rocket) -> Option { + rocket.state::<#pool_type>() + .and_then(|pool| pool.pool.get().ok()) + .map(#guard_type) + } + } + + impl ::std::ops::Deref for #guard_type { + type Target = #conn_type; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl ::std::ops::DerefMut for #guard_type { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + impl<'a, 'r> #request::FromRequest<'a, 'r> for #guard_type { + type Error = (); + + fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome { + use ::rocket::{Outcome, http::Status}; + let pool = ::rocket::try_outcome!(request.guard::<::rocket::State<#pool_type>>()); + + match pool.inner().pool.get() { + Ok(mut conn) => { + (pool.inner().callback)(&mut *conn); + Outcome::Success(#guard_type(conn)) + }, + Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())), + } + } + } + }.into()) +} diff --git a/contrib/codegen/src/lib.rs b/contrib/codegen/src/lib.rs index 512267b985..703f31c122 100644 --- a/contrib/codegen/src/lib.rs +++ b/contrib/codegen/src/lib.rs @@ -49,3 +49,12 @@ pub fn database(attr: TokenStream, input: TokenStream) -> TokenStream { TokenStream::new() }) } + +#[cfg(feature = "database_attribute")] +#[proc_macro_attribute] +pub fn test_database(attr: TokenStream, input: TokenStream) -> TokenStream { + crate::database::test_database_attr(attr, input).unwrap_or_else(|diag| { + diag.emit(); + TokenStream::new() + }) +} diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index 8e17a21dbf..2193571953 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -15,7 +15,7 @@ use rocket::fairing::AdHoc; use rocket::request::{Form, FlashMessage}; use rocket::response::{Flash, Redirect}; use rocket_contrib::{templates::Template, serve::StaticFiles}; -use diesel::SqliteConnection; +use diesel::{SqliteConnection, Connection}; use crate::task::{Task, Todo}; @@ -24,6 +24,11 @@ use crate::task::{Task, Todo}; // tested without any outside setup of the database. embed_migrations!(); +#[cfg(test)] +#[test_database("sqlite_database")] +pub struct DbConn(SqliteConnection); + +#[cfg(not(test))] #[database("sqlite_database")] pub struct DbConn(SqliteConnection); @@ -90,13 +95,25 @@ fn run_db_migrations(rocket: Rocket) -> Result { } fn rocket() -> Rocket { - rocket::ignite() + #[cfg(test)] + let rocket = rocket::ignite() + .attach(DbConn::fairing_with_callback(|conn: &mut diesel::SqliteConnection| {conn.begin_test_transaction();})) + .attach(AdHoc::on_attach("Database Migrations", run_db_migrations)) + .mount("/", StaticFiles::from("static/")) + .mount("/", routes![index]) + .mount("/todo", routes![new, toggle, delete]) + .attach(Template::fairing()); + + #[cfg(not(test))] + let rocket = rocket::ignite() .attach(DbConn::fairing()) .attach(AdHoc::on_attach("Database Migrations", run_db_migrations)) .mount("/", StaticFiles::from("static/")) .mount("/", routes![index]) .mount("/todo", routes![new, toggle, delete]) - .attach(Template::fairing()) + .attach(Template::fairing()); + + rocket } fn main() {