diff --git a/Cargo.toml b/Cargo.toml index 1442da9..2afeb77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,15 +13,15 @@ keywords = ["testing", "database", "postgres"] maintenance = { status = "experimental" } [dependencies] +ctor = "0.2" glob = "0.3" nix = "0.26" +reflink-copy = "0.1" tempfile = "3" thiserror = "1.0" tokio = { version = "1.8", features = ["parking_lot", "rt", "sync", "io-util", "process", "macros", "fs"], default-features = false, optional = true } tracing = "0.1" which = "4.0" -once_cell = "1" -reflink-copy = "0.1" [dev-dependencies] test-log = { version = "0.2", default-features = false, features = ["trace"] } diff --git a/src/asynchronous.rs b/src/asynchronous.rs index 0b00b51..5adc9ef 100644 --- a/src/asynchronous.rs +++ b/src/asynchronous.rs @@ -1,4 +1,4 @@ -use nix::unistd::Uid; +use nix::unistd::{Uid, User}; use std::path::Path; use std::process::Stdio; use std::sync::Arc; @@ -153,7 +153,13 @@ pub(crate) async fn chown_to_non_root(dir: &Path) -> TmpPostgrustResult<()> { return Ok(()); } - let (uid, gid) = &*POSTGRES_UID_GID; + let (uid, gid) = POSTGRES_UID_GID.get_or_init(|| { + User::from_name("postgres") + .ok() + .flatten() + .map(|u| (u.uid, u.gid)) + .expect("no user `postgres` found is system") + }); let mut cmd = Command::new("chown"); cmd.arg("-R").arg(format!("{uid}:{gid}")).arg(dir); exec_process(&mut cmd, TmpPostgrustError::UpdatingPermissionsFailed).await @@ -196,7 +202,13 @@ fn cmd_as_non_root(command: &mut Command) { let current_uid = Uid::effective(); if current_uid.is_root() { // PostgreSQL cannot be run as root, so change to default user - let (user_id, group_id) = &*POSTGRES_UID_GID; - command.uid(user_id.as_raw()).gid(group_id.as_raw()); + let (uid, gid) = POSTGRES_UID_GID.get_or_init(|| { + User::from_name("postgres") + .ok() + .flatten() + .map(|u| (u.uid, u.gid)) + .expect("no user `postgres` found is system") + }); + command.uid(uid.as_raw()).gid(gid.as_raw()); } } diff --git a/src/lib.rs b/src/lib.rs index a4b6c79..f4ab3ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,23 +23,38 @@ use std::fs::{metadata, set_permissions}; use std::io::{BufRead, BufReader}; use std::path::Path; use std::sync::atomic::AtomicU32; -use std::sync::Arc; +use std::sync::{Arc, Mutex, OnceLock}; use std::{fs::File, io::Write}; -use nix::unistd::{Gid, Uid, User}; -use once_cell::sync::Lazy; +use ctor::dtor; +use nix::unistd::{Gid, Uid}; use tempfile::{Builder, TempDir}; use tracing::{debug, info, instrument}; use crate::errors::{TmpPostgrustError, TmpPostgrustResult}; -pub(crate) static POSTGRES_UID_GID: Lazy<(Uid, Gid)> = Lazy::new(|| { - User::from_name("postgres") - .ok() - .flatten() - .map(|u| (u.uid, u.gid)) - .expect("no user `postgres` found is system") -}); +pub(crate) static POSTGRES_UID_GID: OnceLock<(Uid, Gid)> = OnceLock::new(); + +/// As the static variables declared by this crate contain values that +/// need to be dropped at program exit to clean up resources, we use a +/// `#[dtor]` hack to drop the variables if they have been initialized. +#[dtor] +fn cleanup_static() { + #[cfg(feature = "tokio-process")] + if let Some(factory_mutex) = TOKIO_POSTGRES_FACTORY.get() { + let mut guard = factory_mutex.blocking_lock(); + drop(guard.take()); + } + + if let Some(factory_mutex) = DEFAULT_POSTGRES_FACTORY.get() { + let mut guard = factory_mutex + .lock() + .expect("Failed to lock default factory mutex."); + drop(guard.take()); + } +} + +static DEFAULT_POSTGRES_FACTORY: OnceLock>> = OnceLock::new(); /// Create a new default instance, initializing the `DEFAULT_POSTGRES_FACTORY` if it /// does not already exist. @@ -48,15 +63,25 @@ pub(crate) static POSTGRES_UID_GID: Lazy<(Uid, Gid)> = Lazy::new(|| { /// /// Will return `Err` if postgres is not installed on system pub fn new_default_process() -> TmpPostgrustResult { - static DEFAULT_POSTGRES_FACTORY: Lazy = - Lazy::new(|| TmpPostgrustFactory::try_new().unwrap()); - DEFAULT_POSTGRES_FACTORY.new_instance() + let factory_mutex = DEFAULT_POSTGRES_FACTORY.get_or_init(|| { + Mutex::new(Some( + TmpPostgrustFactory::try_new().expect("Failed to initialize default postgres factory."), + )) + }); + let guard = factory_mutex + .lock() + .expect("Failed to lock default factory mutex."); + let factory = guard + .as_ref() + .expect("Default factory is uninitialized or has been dropped."); + factory.new_instance() } /// Static factory that can be re-used between tests. #[cfg(feature = "tokio-process")] -static TOKIO_POSTGRES_FACTORY: tokio::sync::OnceCell = - tokio::sync::OnceCell::const_new(); +static TOKIO_POSTGRES_FACTORY: tokio::sync::OnceCell< + tokio::sync::Mutex>, +> = tokio::sync::OnceCell::const_new(); /// Create a new default instance, initializing the `TOKIO_POSTGRES_FACTORY` if it /// does not already exist. @@ -66,9 +91,17 @@ static TOKIO_POSTGRES_FACTORY: tokio::sync::OnceCell = /// Will return `Err` if postgres is not installed on system #[cfg(feature = "tokio-process")] pub async fn new_default_process_async() -> TmpPostgrustResult { - let factory = TOKIO_POSTGRES_FACTORY - .get_or_try_init(TmpPostgrustFactory::try_new_async) + let factory_mutex = TOKIO_POSTGRES_FACTORY + .get_or_try_init(|| async { + TmpPostgrustFactory::try_new_async() + .await + .map(|factory| tokio::sync::Mutex::new(Some(factory))) + }) .await?; + let guard = factory_mutex.lock().await; + let factory = guard + .as_ref() + .expect("Default tokio factory is uninitialized or has been dropped."); factory.new_instance_async().await } @@ -458,100 +491,6 @@ mod tests { client2.query("SELECT 1;", &[]).await.unwrap(); } - #[cfg(feature = "tokio-process")] - static FACTORY: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); - - #[cfg(feature = "tokio-process")] - #[test(tokio::test)] - async fn static_oncecell() { - let factory = FACTORY - .get_or_try_init(TmpPostgrustFactory::try_new_async) - .await - .unwrap(); - let proc1 = factory.new_instance_async().await.unwrap(); - - let (client1, conn1) = tokio_postgres::connect(&proc1.connection_string, NoTls) - .await - .unwrap(); - - tokio::spawn(async move { - if let Err(e) = conn1.await { - error!("connection error: {}", e); - } - }); - - let factory = FACTORY - .get_or_try_init(TmpPostgrustFactory::try_new_async) - .await - .unwrap(); - let proc2 = factory.new_instance_async().await.unwrap(); - - let (client2, conn2) = tokio_postgres::connect(&proc2.connection_string, NoTls) - .await - .unwrap(); - - tokio::spawn(async move { - if let Err(e) = conn2.await { - error!("connection error: {}", e); - } - }); - - // Shouldn't be able to do this if they are both the same database. - client1.execute("CREATE TABLE lock ();", &[]).await.unwrap(); - client2.execute("CREATE TABLE lock ();", &[]).await.unwrap(); - } - - // Test that a OnceCell can be used in two async tests. - #[cfg(feature = "tokio-process")] - static SHARED_FACTORY: tokio::sync::OnceCell = - tokio::sync::OnceCell::const_new(); - - #[cfg(feature = "tokio-process")] - #[test(tokio::test)] - async fn static_oncecell_shared_1() { - let factory = SHARED_FACTORY - .get_or_try_init(TmpPostgrustFactory::try_new_async) - .await - .unwrap(); - let proc = factory.new_instance_async().await.unwrap(); - - let (client, conn) = tokio_postgres::connect(&proc.connection_string, NoTls) - .await - .unwrap(); - - tokio::spawn(async move { - if let Err(e) = conn.await { - error!("connection error: {}", e); - } - }); - - // Chance to catch concurrent tests or database that have already been used. - client.execute("CREATE TABLE lock ();", &[]).await.unwrap(); - } - - #[cfg(feature = "tokio-process")] - #[test(tokio::test)] - async fn static_oncecell_shared_2() { - let factory = SHARED_FACTORY - .get_or_try_init(TmpPostgrustFactory::try_new_async) - .await - .unwrap(); - let proc = factory.new_instance_async().await.unwrap(); - - let (client, conn) = tokio_postgres::connect(&proc.connection_string, NoTls) - .await - .unwrap(); - - tokio::spawn(async move { - if let Err(e) = conn.await { - error!("connection error: {}", e); - } - }); - - // Chance to catch concurrent tests or database that have already been used. - client.execute("CREATE TABLE lock ();", &[]).await.unwrap(); - } - #[cfg(feature = "tokio-process")] #[test(tokio::test)] async fn default_process_factory_1() { diff --git a/src/synchronous.rs b/src/synchronous.rs index f329660..20d8793 100644 --- a/src/synchronous.rs +++ b/src/synchronous.rs @@ -13,6 +13,7 @@ use std::sync::Arc; use nix::sys::signal; use nix::sys::signal::Signal; +use nix::unistd::User; use nix::unistd::{Pid, Uid}; use tempfile::TempDir; use tracing::{debug, instrument}; @@ -109,7 +110,13 @@ pub(crate) fn chown_to_non_root(dir: &Path) -> TmpPostgrustResult<()> { return Ok(()); } - let (uid, gid) = &*POSTGRES_UID_GID; + let (uid, gid) = POSTGRES_UID_GID.get_or_init(|| { + User::from_name("postgres") + .ok() + .flatten() + .map(|u| (u.uid, u.gid)) + .expect("no user `postgres` found is system") + }); let mut cmd = Command::new("chown"); cmd.arg("-R").arg(format!("{uid}:{gid}")).arg(dir); exec_process(&mut cmd, TmpPostgrustError::UpdatingPermissionsFailed)?; @@ -191,8 +198,15 @@ impl Drop for ProcessGuard { fn cmd_as_non_root(command: &mut Command) { let current_uid = Uid::effective(); if current_uid.is_root() { - let (user_id, group_id) = &*POSTGRES_UID_GID; + let (uid, gid) = POSTGRES_UID_GID.get_or_init(|| { + User::from_name("postgres") + .ok() + .flatten() + .map(|u| (u.uid, u.gid)) + .expect("no user `postgres` found is system") + }); + command.uid(uid.as_raw()).gid(gid.as_raw()); // PostgreSQL cannot be run as root, so change to default user - command.uid(user_id.as_raw()).gid(group_id.as_raw()); + command.uid(uid.as_raw()).gid(gid.as_raw()); } }