Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ keywords = ["testing", "database", "postgres"]
maintenance = { status = "experimental" }

[dependencies]
ctor = "0.2"
glob = "0.3"
nix = "0.26"
reflink-copy = "0.1"
tempfile = "3"
thiserror = "1.0"
tokio = { version = "1.8", features = ["parking_lot", "rt", "sync", "io-util", "process", "macros", "fs"], default-features = false, optional = true }
tracing = "0.1"
which = "4.0"
once_cell = "1"
reflink-copy = "0.1"

[dev-dependencies]
test-log = { version = "0.2", default-features = false, features = ["trace"] }
Expand Down
20 changes: 16 additions & 4 deletions src/asynchronous.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use nix::unistd::Uid;
use nix::unistd::{Uid, User};
use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
Expand Down Expand Up @@ -153,7 +153,13 @@ pub(crate) async fn chown_to_non_root(dir: &Path) -> TmpPostgrustResult<()> {
return Ok(());
}

let (uid, gid) = &*POSTGRES_UID_GID;
let (uid, gid) = POSTGRES_UID_GID.get_or_init(|| {
User::from_name("postgres")
.ok()
.flatten()
.map(|u| (u.uid, u.gid))
.expect("no user `postgres` found is system")
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.expect("no user `postgres` found is system")
.expect("no user `postgres` found in system")

Probably also want to factor this out into a new function

});
let mut cmd = Command::new("chown");
cmd.arg("-R").arg(format!("{uid}:{gid}")).arg(dir);
exec_process(&mut cmd, TmpPostgrustError::UpdatingPermissionsFailed).await
Expand Down Expand Up @@ -196,7 +202,13 @@ fn cmd_as_non_root(command: &mut Command) {
let current_uid = Uid::effective();
if current_uid.is_root() {
// PostgreSQL cannot be run as root, so change to default user
let (user_id, group_id) = &*POSTGRES_UID_GID;
command.uid(user_id.as_raw()).gid(group_id.as_raw());
let (uid, gid) = POSTGRES_UID_GID.get_or_init(|| {
User::from_name("postgres")
.ok()
.flatten()
.map(|u| (u.uid, u.gid))
.expect("no user `postgres` found is system")
});
command.uid(uid.as_raw()).gid(gid.as_raw());
}
}
161 changes: 50 additions & 111 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,38 @@ use std::fs::{metadata, set_permissions};
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
use std::sync::{Arc, Mutex, OnceLock};
use std::{fs::File, io::Write};

use nix::unistd::{Gid, Uid, User};
use once_cell::sync::Lazy;
use ctor::dtor;
use nix::unistd::{Gid, Uid};
use tempfile::{Builder, TempDir};
use tracing::{debug, info, instrument};

use crate::errors::{TmpPostgrustError, TmpPostgrustResult};

pub(crate) static POSTGRES_UID_GID: Lazy<(Uid, Gid)> = Lazy::new(|| {
User::from_name("postgres")
.ok()
.flatten()
.map(|u| (u.uid, u.gid))
.expect("no user `postgres` found is system")
});
pub(crate) static POSTGRES_UID_GID: OnceLock<(Uid, Gid)> = OnceLock::new();

/// As the static variables declared by this crate contain values that
/// need to be dropped at program exit to clean up resources, we use a
/// `#[dtor]` hack to drop the variables if they have been initialized.
#[dtor]
fn cleanup_static() {
#[cfg(feature = "tokio-process")]
if let Some(factory_mutex) = TOKIO_POSTGRES_FACTORY.get() {
let mut guard = factory_mutex.blocking_lock();
drop(guard.take());
}

if let Some(factory_mutex) = DEFAULT_POSTGRES_FACTORY.get() {
let mut guard = factory_mutex
.lock()
.expect("Failed to lock default factory mutex.");
drop(guard.take());
}
}

static DEFAULT_POSTGRES_FACTORY: OnceLock<Mutex<Option<TmpPostgrustFactory>>> = OnceLock::new();

/// Create a new default instance, initializing the `DEFAULT_POSTGRES_FACTORY` if it
/// does not already exist.
Expand All @@ -48,15 +63,25 @@ pub(crate) static POSTGRES_UID_GID: Lazy<(Uid, Gid)> = Lazy::new(|| {
///
/// Will return `Err` if postgres is not installed on system
pub fn new_default_process() -> TmpPostgrustResult<synchronous::ProcessGuard> {
static DEFAULT_POSTGRES_FACTORY: Lazy<TmpPostgrustFactory> =
Lazy::new(|| TmpPostgrustFactory::try_new().unwrap());
DEFAULT_POSTGRES_FACTORY.new_instance()
let factory_mutex = DEFAULT_POSTGRES_FACTORY.get_or_init(|| {
Mutex::new(Some(
TmpPostgrustFactory::try_new().expect("Failed to initialize default postgres factory."),
))
});
let guard = factory_mutex
.lock()
.expect("Failed to lock default factory mutex.");
let factory = guard
.as_ref()
.expect("Default factory is uninitialized or has been dropped.");
factory.new_instance()
}

/// Static factory that can be re-used between tests.
#[cfg(feature = "tokio-process")]
static TOKIO_POSTGRES_FACTORY: tokio::sync::OnceCell<TmpPostgrustFactory> =
tokio::sync::OnceCell::const_new();
static TOKIO_POSTGRES_FACTORY: tokio::sync::OnceCell<
tokio::sync::Mutex<Option<TmpPostgrustFactory>>,
> = tokio::sync::OnceCell::const_new();

/// Create a new default instance, initializing the `TOKIO_POSTGRES_FACTORY` if it
/// does not already exist.
Expand All @@ -66,9 +91,17 @@ static TOKIO_POSTGRES_FACTORY: tokio::sync::OnceCell<TmpPostgrustFactory> =
/// Will return `Err` if postgres is not installed on system
#[cfg(feature = "tokio-process")]
pub async fn new_default_process_async() -> TmpPostgrustResult<asynchronous::ProcessGuard> {
let factory = TOKIO_POSTGRES_FACTORY
.get_or_try_init(TmpPostgrustFactory::try_new_async)
let factory_mutex = TOKIO_POSTGRES_FACTORY
.get_or_try_init(|| async {
TmpPostgrustFactory::try_new_async()
.await
.map(|factory| tokio::sync::Mutex::new(Some(factory)))
})
.await?;
let guard = factory_mutex.lock().await;
let factory = guard
.as_ref()
.expect("Default tokio factory is uninitialized or has been dropped.");
factory.new_instance_async().await
}

Expand Down Expand Up @@ -458,100 +491,6 @@ mod tests {
client2.query("SELECT 1;", &[]).await.unwrap();
}

#[cfg(feature = "tokio-process")]
static FACTORY: tokio::sync::OnceCell<TmpPostgrustFactory> = tokio::sync::OnceCell::const_new();

#[cfg(feature = "tokio-process")]
#[test(tokio::test)]
async fn static_oncecell() {
let factory = FACTORY
.get_or_try_init(TmpPostgrustFactory::try_new_async)
.await
.unwrap();
let proc1 = factory.new_instance_async().await.unwrap();

let (client1, conn1) = tokio_postgres::connect(&proc1.connection_string, NoTls)
.await
.unwrap();

tokio::spawn(async move {
if let Err(e) = conn1.await {
error!("connection error: {}", e);
}
});

let factory = FACTORY
.get_or_try_init(TmpPostgrustFactory::try_new_async)
.await
.unwrap();
let proc2 = factory.new_instance_async().await.unwrap();

let (client2, conn2) = tokio_postgres::connect(&proc2.connection_string, NoTls)
.await
.unwrap();

tokio::spawn(async move {
if let Err(e) = conn2.await {
error!("connection error: {}", e);
}
});

// Shouldn't be able to do this if they are both the same database.
client1.execute("CREATE TABLE lock ();", &[]).await.unwrap();
client2.execute("CREATE TABLE lock ();", &[]).await.unwrap();
}

// Test that a OnceCell can be used in two async tests.
#[cfg(feature = "tokio-process")]
static SHARED_FACTORY: tokio::sync::OnceCell<TmpPostgrustFactory> =
tokio::sync::OnceCell::const_new();

#[cfg(feature = "tokio-process")]
#[test(tokio::test)]
async fn static_oncecell_shared_1() {
let factory = SHARED_FACTORY
.get_or_try_init(TmpPostgrustFactory::try_new_async)
.await
.unwrap();
let proc = factory.new_instance_async().await.unwrap();

let (client, conn) = tokio_postgres::connect(&proc.connection_string, NoTls)
.await
.unwrap();

tokio::spawn(async move {
if let Err(e) = conn.await {
error!("connection error: {}", e);
}
});

// Chance to catch concurrent tests or database that have already been used.
client.execute("CREATE TABLE lock ();", &[]).await.unwrap();
}

#[cfg(feature = "tokio-process")]
#[test(tokio::test)]
async fn static_oncecell_shared_2() {
let factory = SHARED_FACTORY
.get_or_try_init(TmpPostgrustFactory::try_new_async)
.await
.unwrap();
let proc = factory.new_instance_async().await.unwrap();

let (client, conn) = tokio_postgres::connect(&proc.connection_string, NoTls)
.await
.unwrap();

tokio::spawn(async move {
if let Err(e) = conn.await {
error!("connection error: {}", e);
}
});

// Chance to catch concurrent tests or database that have already been used.
client.execute("CREATE TABLE lock ();", &[]).await.unwrap();
}

#[cfg(feature = "tokio-process")]
#[test(tokio::test)]
async fn default_process_factory_1() {
Expand Down
20 changes: 17 additions & 3 deletions src/synchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::sync::Arc;

use nix::sys::signal;
use nix::sys::signal::Signal;
use nix::unistd::User;
use nix::unistd::{Pid, Uid};
use tempfile::TempDir;
use tracing::{debug, instrument};
Expand Down Expand Up @@ -109,7 +110,13 @@ pub(crate) fn chown_to_non_root(dir: &Path) -> TmpPostgrustResult<()> {
return Ok(());
}

let (uid, gid) = &*POSTGRES_UID_GID;
let (uid, gid) = POSTGRES_UID_GID.get_or_init(|| {
User::from_name("postgres")
.ok()
.flatten()
.map(|u| (u.uid, u.gid))
.expect("no user `postgres` found is system")
});
let mut cmd = Command::new("chown");
cmd.arg("-R").arg(format!("{uid}:{gid}")).arg(dir);
exec_process(&mut cmd, TmpPostgrustError::UpdatingPermissionsFailed)?;
Expand Down Expand Up @@ -191,8 +198,15 @@ impl Drop for ProcessGuard {
fn cmd_as_non_root(command: &mut Command) {
let current_uid = Uid::effective();
if current_uid.is_root() {
let (user_id, group_id) = &*POSTGRES_UID_GID;
let (uid, gid) = POSTGRES_UID_GID.get_or_init(|| {
User::from_name("postgres")
.ok()
.flatten()
.map(|u| (u.uid, u.gid))
.expect("no user `postgres` found is system")
});
command.uid(uid.as_raw()).gid(gid.as_raw());
// PostgreSQL cannot be run as root, so change to default user
command.uid(user_id.as_raw()).gid(group_id.as_raw());
command.uid(uid.as_raw()).gid(gid.as_raw());
}
}