From e2b856333dd879da4e8a39967e87897098d640ee Mon Sep 17 00:00:00 2001 From: Eragon Date: Tue, 30 Jan 2024 00:36:11 +0100 Subject: [PATCH] wip: At least it runs now --- Cargo.lock | 80 + Cargo.toml | 3 + Makefile | 14 +- tokenserver-db-common/Cargo.toml | 31 + .../src/error.rs | 22 +- tokenserver-db-common/src/lib.rs | 1 + tokenserver-db-mysql/Cargo.toml | 32 + .../2021-07-16-001122_init/down.sql | 0 .../migrations/2021-07-16-001122_init/up.sql | 0 .../down.sql | 0 .../up.sql | 0 .../down.sql | 0 .../up.sql | 0 .../down.sql | 0 .../up.sql | 0 .../2021-09-30-142746_add_indexes/down.sql | 0 .../2021-09-30-142746_add_indexes/up.sql | 0 .../down.sql | 0 .../up.sql | 0 .../down.sql | 0 .../up.sql | 0 .../down.sql | 0 .../2021-12-22-160451_remove_services/up.sql | 0 tokenserver-db-mysql/src/lib.rs | 6 + tokenserver-db-mysql/src/models.rs | 2092 +++++++++++++++++ tokenserver-db-mysql/src/pool.rs | 20 + tokenserver-db-sqlite/Cargo.toml | 32 + .../2024-01-28-211312_init/down.sql | 3 + .../migrations/2024-01-28-211312_init/up.sql | 34 + tokenserver-db-sqlite/src/lib.rs | 5 + tokenserver-db-sqlite/src/pool.rs | 17 + tokenserver-db/Cargo.toml | 7 + tokenserver-db/src/lib.rs | 15 +- tokenserver-db/src/mock.rs | 2 +- tokenserver-db/src/models.rs | 26 +- tokenserver-db/src/pool.rs | 32 +- 36 files changed, 2414 insertions(+), 60 deletions(-) create mode 100644 tokenserver-db-common/Cargo.toml rename {tokenserver-db => tokenserver-db-common}/src/error.rs (81%) create mode 100644 tokenserver-db-common/src/lib.rs create mode 100644 tokenserver-db-mysql/Cargo.toml rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-07-16-001122_init/down.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-07-16-001122_init/up.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-08-03-234845_populate_services/down.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-08-03-234845_populate_services/up.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-142643_remove_foreign_key_constraints/down.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-142643_remove_foreign_key_constraints/up.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-142654_remove_node_defaults/down.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-142654_remove_node_defaults/up.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-142746_add_indexes/down.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-142746_add_indexes/up.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-144043_remove_nodes_service_key/down.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-144043_remove_nodes_service_key/up.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-144225_remove_users_nodeid_key/down.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-09-30-144225_remove_users_nodeid_key/up.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-12-22-160451_remove_services/down.sql (100%) rename {tokenserver-db => tokenserver-db-mysql}/migrations/2021-12-22-160451_remove_services/up.sql (100%) create mode 100644 tokenserver-db-mysql/src/lib.rs create mode 100644 tokenserver-db-mysql/src/models.rs create mode 100644 tokenserver-db-mysql/src/pool.rs create mode 100644 tokenserver-db-sqlite/Cargo.toml create mode 100644 tokenserver-db-sqlite/migrations/2024-01-28-211312_init/down.sql create mode 100644 tokenserver-db-sqlite/migrations/2024-01-28-211312_init/up.sql create mode 100644 tokenserver-db-sqlite/src/lib.rs create mode 100644 tokenserver-db-sqlite/src/pool.rs diff --git a/Cargo.lock b/Cargo.lock index 0f5a7764d4..b7521db854 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3007,6 +3007,86 @@ dependencies = [ "syncserver-settings", "thiserror", "tokenserver-common", + "tokenserver-db-common", + "tokenserver-db-mysql", + "tokenserver-db-sqlite", + "tokenserver-settings", + "tokio", +] + +[[package]] +name = "tokenserver-db-common" +version = "0.14.4" +dependencies = [ + "async-trait", + "backtrace", + "diesel", + "diesel_logger", + "diesel_migrations", + "env_logger 0.10.2", + "futures 0.3.30", + "http", + "serde 1.0.196", + "serde_derive", + "serde_json", + "slog-scope", + "syncserver-common", + "syncserver-db-common", + "syncserver-settings", + "thiserror", + "tokenserver-common", + "tokenserver-settings", + "tokio", +] + +[[package]] +name = "tokenserver-db-mysql" +version = "0.14.4" +dependencies = [ + "async-trait", + "backtrace", + "diesel", + "diesel_logger", + "diesel_migrations", + "env_logger 0.10.2", + "futures 0.3.30", + "http", + "serde 1.0.196", + "serde_derive", + "serde_json", + "slog-scope", + "syncserver-common", + "syncserver-db-common", + "syncserver-settings", + "thiserror", + "tokenserver-common", + "tokenserver-db-common", + "tokenserver-settings", + "tokio", +] + +[[package]] +name = "tokenserver-db-sqlite" +version = "0.14.4" +dependencies = [ + "async-trait", + "backtrace", + "diesel", + "diesel_logger", + "diesel_migrations", + "env_logger 0.10.2", + "futures 0.3.30", + "http", + "serde 1.0.196", + "serde_derive", + "serde_json", + "slog-scope", + "syncserver-common", + "syncserver-db-common", + "syncserver-settings", + "thiserror", + "tokenserver-common", + "tokenserver-db-common", "tokenserver-settings", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 8739d2539f..1e70ab6984 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,9 @@ members = [ "tokenserver-auth", "tokenserver-common", "tokenserver-db", + "tokenserver-db-common", + "tokenserver-db-mysql", + "tokenserver-db-sqlite", "tokenserver-settings", "syncserver", ] diff --git a/Makefile b/Makefile index 9c15a50320..69643e5921 100644 --- a/Makefile +++ b/Makefile @@ -15,15 +15,15 @@ PYTHON_SITE_PACKGES = $(shell $(SRC_ROOT)/venv/bin/python -c "from distutils.sys clippy_sqlite: # Matches what's run in circleci - cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/sqlite -- -D warnings + cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/sqlite,tokenserver-db/sqlite --features=py_verifier -- -D warnings clippy_mysql: # Matches what's run in circleci - cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- -D warnings + cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql,tokenserver-db/mysql --features=py_verifier -- -D warnings clippy_spanner: # Matches what's run in circleci - cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- -D warnings + cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner,tokenserver-db/mysql --features=py_verifier -- -D warnings clean: cargo clean @@ -57,9 +57,9 @@ run_mysql: python # See https://github.com/PyO3/pyo3/issues/1741 for discussion re: why we need to set the # below env var PYTHONPATH=$(PYTHON_SITE_PACKGES) \ - RUST_LOG=debug \ + RUST_LOG=debug \ RUST_BACKTRACE=full \ - cargo run --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- --config config/local.toml + cargo run --no-default-features --features=syncstorage-db/mysql,tokenserver-db/mysql --features=py_verifier -- --config config/local.toml run_sqlite: python PATH="./venv/bin:$(PATH)" \ @@ -68,7 +68,7 @@ run_sqlite: python PYTHONPATH=$(PYTHON_SITE_PACKGES) \ RUST_LOG=debug \ RUST_BACKTRACE=full \ - cargo run --no-default-features --features=syncstorage-db/sqlite -- --config config/local.toml + cargo run --no-default-features --features=syncstorage-db/sqlite,tokenserver-db/sqlite --features=py_verifier -- --config config/local.toml run_spanner: python GOOGLE_APPLICATION_CREDENTIALS=$(PATH_TO_SYNC_SPANNER_KEYS) \ @@ -79,7 +79,7 @@ run_spanner: python PATH="./venv/bin:$(PATH)" \ RUST_LOG=debug \ RUST_BACKTRACE=full \ - cargo run --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- --config config/local.toml + cargo run --no-default-features --features=syncstorage-db/spanner,tokenserver-db/mysql --features=py_verifier -- --config config/local.toml test_mysql: SYNC_SYNCSTORAGE__DATABASE_URL=mysql://sample_user:sample_password@localhost/syncstorage_rs \ diff --git a/tokenserver-db-common/Cargo.toml b/tokenserver-db-common/Cargo.toml new file mode 100644 index 0000000000..47ea6b0e8b --- /dev/null +++ b/tokenserver-db-common/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "tokenserver-db-common" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +backtrace.workspace = true +futures.workspace = true +http.workspace = true +serde.workspace = true +serde_derive.workspace = true +serde_json.workspace = true +slog-scope.workspace = true + +async-trait = "0.1.40" +diesel = { version = "1.4", features = ["mysql", "r2d2"] } +diesel_logger = "0.1.1" +diesel_migrations = { version = "1.4.0", features = ["mysql"] } +syncserver-common = { path = "../syncserver-common" } +syncserver-db-common = { path = "../syncserver-db-common" } +thiserror = "1.0.26" +tokenserver-common = { path = "../tokenserver-common" } +tokenserver-settings = { path = "../tokenserver-settings" } +tokio = { workspace = true, features = ["macros", "sync"] } + +[dev-dependencies] +env_logger.workspace = true + +syncserver-settings = { path = "../syncserver-settings" } diff --git a/tokenserver-db/src/error.rs b/tokenserver-db-common/src/error.rs similarity index 81% rename from tokenserver-db/src/error.rs rename to tokenserver-db-common/src/error.rs index b0c78c433d..0110cf964e 100644 --- a/tokenserver-db/src/error.rs +++ b/tokenserver-db-common/src/error.rs @@ -7,8 +7,8 @@ use syncserver_db_common::error::SqlError; use thiserror::Error; use tokenserver_common::TokenserverError; -pub(crate) type DbFuture<'a, T> = syncserver_db_common::DbFuture<'a, T, DbError>; -pub(crate) type DbResult = Result; +pub type DbFuture<'a, T> = syncserver_db_common::DbFuture<'a, T, DbError>; +pub type DbResult = Result; /// An error type that represents any database-related errors that may occur while processing a /// tokenserver request. @@ -20,7 +20,7 @@ pub struct DbError { } impl DbError { - pub(crate) fn internal(msg: String) -> Self { + pub fn internal(msg: String) -> Self { DbErrorKind::Internal(msg).into() } } @@ -28,7 +28,7 @@ impl DbError { #[derive(Debug, Error)] enum DbErrorKind { #[error("{}", _0)] - Mysql(SqlError), + SqlError(SqlError), #[error("Unexpected error: {}", _0)] Internal(String), @@ -37,9 +37,9 @@ enum DbErrorKind { impl From for DbError { fn from(kind: DbErrorKind) -> Self { match kind { - DbErrorKind::Mysql(ref mysql_error) => Self { - status: mysql_error.status, - backtrace: Box::new(mysql_error.backtrace.clone()), + DbErrorKind::SqlError(ref sql_error) => Self { + status: sql_error.status, + backtrace: Box::new(sql_error.backtrace.clone()), kind, }, DbErrorKind::Internal(_) => Self { @@ -81,24 +81,24 @@ impl_fmt_display!(DbError, DbErrorKind); from_error!( diesel::result::Error, DbError, - |error: diesel::result::Error| DbError::from(DbErrorKind::Mysql(SqlError::from(error))) + |error: diesel::result::Error| DbError::from(DbErrorKind::SqlError(SqlError::from(error))) ); from_error!( diesel::result::ConnectionError, DbError, - |error: diesel::result::ConnectionError| DbError::from(DbErrorKind::Mysql(SqlError::from( + |error: diesel::result::ConnectionError| DbError::from(DbErrorKind::SqlError(SqlError::from( error ))) ); from_error!( diesel::r2d2::PoolError, DbError, - |error: diesel::r2d2::PoolError| DbError::from(DbErrorKind::Mysql(SqlError::from(error))) + |error: diesel::r2d2::PoolError| DbError::from(DbErrorKind::SqlError(SqlError::from(error))) ); from_error!( diesel_migrations::RunMigrationsError, DbError, - |error: diesel_migrations::RunMigrationsError| DbError::from(DbErrorKind::Mysql( + |error: diesel_migrations::RunMigrationsError| DbError::from(DbErrorKind::SqlError( SqlError::from(error) )) ); diff --git a/tokenserver-db-common/src/lib.rs b/tokenserver-db-common/src/lib.rs new file mode 100644 index 0000000000..a91e735174 --- /dev/null +++ b/tokenserver-db-common/src/lib.rs @@ -0,0 +1 @@ +pub mod error; diff --git a/tokenserver-db-mysql/Cargo.toml b/tokenserver-db-mysql/Cargo.toml new file mode 100644 index 0000000000..8a1054cb07 --- /dev/null +++ b/tokenserver-db-mysql/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "tokenserver-db-mysql" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +backtrace.workspace = true +futures.workspace = true +http.workspace = true +serde.workspace = true +serde_derive.workspace = true +serde_json.workspace = true +slog-scope.workspace = true + +async-trait = "0.1.40" +diesel = { version = "1.4", features = ["mysql", "r2d2"] } +diesel_logger = "0.1.1" +diesel_migrations = { version = "1.4.0", features = ["mysql"] } +syncserver-common = { path = "../syncserver-common" } +syncserver-db-common = { path = "../syncserver-db-common" } +thiserror = "1.0.26" +tokenserver-common = { path = "../tokenserver-common" } +tokenserver-db-common = { path = "../tokenserver-db-common" } +tokenserver-settings = { path = "../tokenserver-settings" } +tokio = { workspace = true, features = ["macros", "sync"] } + +[dev-dependencies] +env_logger.workspace = true + +syncserver-settings = { path = "../syncserver-settings" } diff --git a/tokenserver-db/migrations/2021-07-16-001122_init/down.sql b/tokenserver-db-mysql/migrations/2021-07-16-001122_init/down.sql similarity index 100% rename from tokenserver-db/migrations/2021-07-16-001122_init/down.sql rename to tokenserver-db-mysql/migrations/2021-07-16-001122_init/down.sql diff --git a/tokenserver-db/migrations/2021-07-16-001122_init/up.sql b/tokenserver-db-mysql/migrations/2021-07-16-001122_init/up.sql similarity index 100% rename from tokenserver-db/migrations/2021-07-16-001122_init/up.sql rename to tokenserver-db-mysql/migrations/2021-07-16-001122_init/up.sql diff --git a/tokenserver-db/migrations/2021-08-03-234845_populate_services/down.sql b/tokenserver-db-mysql/migrations/2021-08-03-234845_populate_services/down.sql similarity index 100% rename from tokenserver-db/migrations/2021-08-03-234845_populate_services/down.sql rename to tokenserver-db-mysql/migrations/2021-08-03-234845_populate_services/down.sql diff --git a/tokenserver-db/migrations/2021-08-03-234845_populate_services/up.sql b/tokenserver-db-mysql/migrations/2021-08-03-234845_populate_services/up.sql similarity index 100% rename from tokenserver-db/migrations/2021-08-03-234845_populate_services/up.sql rename to tokenserver-db-mysql/migrations/2021-08-03-234845_populate_services/up.sql diff --git a/tokenserver-db/migrations/2021-09-30-142643_remove_foreign_key_constraints/down.sql b/tokenserver-db-mysql/migrations/2021-09-30-142643_remove_foreign_key_constraints/down.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-142643_remove_foreign_key_constraints/down.sql rename to tokenserver-db-mysql/migrations/2021-09-30-142643_remove_foreign_key_constraints/down.sql diff --git a/tokenserver-db/migrations/2021-09-30-142643_remove_foreign_key_constraints/up.sql b/tokenserver-db-mysql/migrations/2021-09-30-142643_remove_foreign_key_constraints/up.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-142643_remove_foreign_key_constraints/up.sql rename to tokenserver-db-mysql/migrations/2021-09-30-142643_remove_foreign_key_constraints/up.sql diff --git a/tokenserver-db/migrations/2021-09-30-142654_remove_node_defaults/down.sql b/tokenserver-db-mysql/migrations/2021-09-30-142654_remove_node_defaults/down.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-142654_remove_node_defaults/down.sql rename to tokenserver-db-mysql/migrations/2021-09-30-142654_remove_node_defaults/down.sql diff --git a/tokenserver-db/migrations/2021-09-30-142654_remove_node_defaults/up.sql b/tokenserver-db-mysql/migrations/2021-09-30-142654_remove_node_defaults/up.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-142654_remove_node_defaults/up.sql rename to tokenserver-db-mysql/migrations/2021-09-30-142654_remove_node_defaults/up.sql diff --git a/tokenserver-db/migrations/2021-09-30-142746_add_indexes/down.sql b/tokenserver-db-mysql/migrations/2021-09-30-142746_add_indexes/down.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-142746_add_indexes/down.sql rename to tokenserver-db-mysql/migrations/2021-09-30-142746_add_indexes/down.sql diff --git a/tokenserver-db/migrations/2021-09-30-142746_add_indexes/up.sql b/tokenserver-db-mysql/migrations/2021-09-30-142746_add_indexes/up.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-142746_add_indexes/up.sql rename to tokenserver-db-mysql/migrations/2021-09-30-142746_add_indexes/up.sql diff --git a/tokenserver-db/migrations/2021-09-30-144043_remove_nodes_service_key/down.sql b/tokenserver-db-mysql/migrations/2021-09-30-144043_remove_nodes_service_key/down.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-144043_remove_nodes_service_key/down.sql rename to tokenserver-db-mysql/migrations/2021-09-30-144043_remove_nodes_service_key/down.sql diff --git a/tokenserver-db/migrations/2021-09-30-144043_remove_nodes_service_key/up.sql b/tokenserver-db-mysql/migrations/2021-09-30-144043_remove_nodes_service_key/up.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-144043_remove_nodes_service_key/up.sql rename to tokenserver-db-mysql/migrations/2021-09-30-144043_remove_nodes_service_key/up.sql diff --git a/tokenserver-db/migrations/2021-09-30-144225_remove_users_nodeid_key/down.sql b/tokenserver-db-mysql/migrations/2021-09-30-144225_remove_users_nodeid_key/down.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-144225_remove_users_nodeid_key/down.sql rename to tokenserver-db-mysql/migrations/2021-09-30-144225_remove_users_nodeid_key/down.sql diff --git a/tokenserver-db/migrations/2021-09-30-144225_remove_users_nodeid_key/up.sql b/tokenserver-db-mysql/migrations/2021-09-30-144225_remove_users_nodeid_key/up.sql similarity index 100% rename from tokenserver-db/migrations/2021-09-30-144225_remove_users_nodeid_key/up.sql rename to tokenserver-db-mysql/migrations/2021-09-30-144225_remove_users_nodeid_key/up.sql diff --git a/tokenserver-db/migrations/2021-12-22-160451_remove_services/down.sql b/tokenserver-db-mysql/migrations/2021-12-22-160451_remove_services/down.sql similarity index 100% rename from tokenserver-db/migrations/2021-12-22-160451_remove_services/down.sql rename to tokenserver-db-mysql/migrations/2021-12-22-160451_remove_services/down.sql diff --git a/tokenserver-db/migrations/2021-12-22-160451_remove_services/up.sql b/tokenserver-db-mysql/migrations/2021-12-22-160451_remove_services/up.sql similarity index 100% rename from tokenserver-db/migrations/2021-12-22-160451_remove_services/up.sql rename to tokenserver-db-mysql/migrations/2021-12-22-160451_remove_services/up.sql diff --git a/tokenserver-db-mysql/src/lib.rs b/tokenserver-db-mysql/src/lib.rs new file mode 100644 index 0000000000..df255d7a77 --- /dev/null +++ b/tokenserver-db-mysql/src/lib.rs @@ -0,0 +1,6 @@ +extern crate diesel; +#[macro_use] +extern crate diesel_migrations; + +//pub mod models; +pub mod pool; diff --git a/tokenserver-db-mysql/src/models.rs b/tokenserver-db-mysql/src/models.rs new file mode 100644 index 0000000000..1e51862882 --- /dev/null +++ b/tokenserver-db-mysql/src/models.rs @@ -0,0 +1,2092 @@ +use diesel::{ + mysql::MysqlConnection, + r2d2::{ConnectionManager, PooledConnection}, + sql_types::{Bigint, Float, Integer, Nullable, Text}, + RunQueryDsl, +}; +#[cfg(test)] +use diesel_logger::LoggingConnection; +use http::StatusCode; +use syncserver_common::{BlockingThreadpool, Metrics}; +use syncserver_db_common::{sync_db_method, DbFuture}; + +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; + +use super::{ + error::{DbError, DbResult}, + params, + results, +}; + +/// The maximum possible generation number. Used as a tombstone to mark users that have been +/// "retired" from the db. +const MAX_GENERATION: i64 = i64::MAX; + +type Conn = PooledConnection>; + +#[derive(Clone)] +pub struct TokenserverDb { + /// Synchronous Diesel calls are executed on a blocking threadpool to satisfy + /// the Db trait's asynchronous interface. + /// + /// Arc provides a Clone impl utilized for safely moving to + /// the thread pool but does not provide Send as the underlying db + /// conn. structs are !Sync (Arc requires both for Send). See the Send impl + /// below. + inner: Arc, + metrics: Metrics, + service_id: Option, + spanner_node_id: Option, + blocking_threadpool: Arc, +} + +/// Despite the db conn structs being !Sync (see Arc above) we +/// don't spawn multiple MysqlDb calls at a time in the thread pool. Calls are +/// queued to the thread pool via Futures, naturally serialized. +unsafe impl Send for TokenserverDb {} + +struct DbInner { + #[cfg(not(test))] + pub(super) conn: Conn, + #[cfg(test)] + pub(super) conn: LoggingConnection, // display SQL when RUST_LOG="diesel_logger=trace" +} + +impl TokenserverDb { + // Note that this only works because an instance of `TokenserverDb` has *exclusive access* to + // a connection from the r2d2 pool for its lifetime. `LAST_INSERT_ID()` returns the ID of the + // most recently-inserted record *for a given connection*. If connections were shared across + // requests, using this function would introduce a race condition, as we could potentially + // get IDs from records created during other requests. + const LAST_INSERT_ID_QUERY: &'static str = "SELECT LAST_INSERT_ID() AS id"; + + pub fn new( + conn: Conn, + metrics: &Metrics, + service_id: Option, + spanner_node_id: Option, + blocking_threadpool: Arc, + ) -> Self { + let inner = DbInner { + #[cfg(not(test))] + conn, + #[cfg(test)] + conn: LoggingConnection::new(conn), + }; + + // https://github.com/mozilla-services/syncstorage-rs/issues/1480 + #[allow(clippy::arc_with_non_send_sync)] + Self { + inner: Arc::new(inner), + metrics: metrics.clone(), + service_id, + spanner_node_id, + blocking_threadpool, + } + } + + fn get_node_id_sync(&self, params: params::GetNodeId) -> DbResult { + const QUERY: &str = r#" + SELECT id + FROM nodes + WHERE service = ? + AND node = ? + "#; + + if let Some(id) = self.spanner_node_id { + Ok(results::GetNodeId { id: id as i64 }) + } else { + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.get_node_id", None); + + diesel::sql_query(QUERY) + .bind::(params.service_id) + .bind::(¶ms.node) + .get_result(&self.inner.conn) + .map_err(Into::into) + } + } + + /// Mark users matching the given email and service ID as replaced. + fn replace_users_sync(&self, params: params::ReplaceUsers) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = ? + WHERE service = ? + AND email = ? + AND replaced_at IS NULL + AND created_at < ? + "#; + + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.replace_users", None); + + diesel::sql_query(QUERY) + .bind::(params.replaced_at) + .bind::(¶ms.service_id) + .bind::(¶ms.email) + .bind::(params.replaced_at) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + /// Mark the user with the given uid and service ID as being replaced. + fn replace_user_sync(&self, params: params::ReplaceUser) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = ? + WHERE service = ? + AND uid = ? + "#; + + diesel::sql_query(QUERY) + .bind::(params.replaced_at) + .bind::(params.service_id) + .bind::(params.uid) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + /// Update the user with the given email and service ID with the given `generation` and + /// `keys_changed_at`. + fn put_user_sync(&self, params: params::PutUser) -> DbResult { + // The `where` clause on this statement is designed as an extra layer of + // protection, to ensure that concurrent updates don't accidentally move + // timestamp fields backwards in time. The handling of `keys_changed_at` + // is additionally weird because we want to treat the default `NULL` value + // as zero. + const QUERY: &str = r#" + UPDATE users + SET generation = ?, + keys_changed_at = ? + WHERE service = ? + AND email = ? + AND generation <= ? + AND COALESCE(keys_changed_at, 0) <= COALESCE(?, keys_changed_at, 0) + AND replaced_at IS NULL + "#; + + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.put_user", None); + + diesel::sql_query(QUERY) + .bind::(params.generation) + .bind::, _>(params.keys_changed_at) + .bind::(¶ms.service_id) + .bind::(¶ms.email) + .bind::(params.generation) + .bind::, _>(params.keys_changed_at) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + /// Create a new user. + fn post_user_sync(&self, user: params::PostUser) -> DbResult { + const QUERY: &str = r#" + INSERT INTO users (service, email, generation, client_state, created_at, nodeid, keys_changed_at, replaced_at) + VALUES (?, ?, ?, ?, ?, ?, ?, NULL); + "#; + + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.post_user", None); + + diesel::sql_query(QUERY) + .bind::(user.service_id) + .bind::(&user.email) + .bind::(user.generation) + .bind::(&user.client_state) + .bind::(user.created_at) + .bind::(user.node_id) + .bind::, _>(user.keys_changed_at) + .execute(&self.inner.conn)?; + + diesel::sql_query(Self::LAST_INSERT_ID_QUERY) + .bind::(&user.email) + .get_result::(&self.inner.conn) + .map_err(Into::into) + } + + fn check_sync(&self) -> DbResult { + // has the database been up for more than 0 seconds? + let result = diesel::sql_query("SHOW STATUS LIKE \"Uptime\"").execute(&self.inner.conn)?; + Ok(result as u64 > 0) + } + + /// Gets the least-loaded node that has available slots. + fn get_best_node_sync(&self, params: params::GetBestNode) -> DbResult { + const DEFAULT_CAPACITY_RELEASE_RATE: f32 = 0.1; + const GET_BEST_NODE_QUERY: &str = r#" + SELECT id, node + FROM nodes + WHERE service = ? + AND available > 0 + AND capacity > current_load + AND downed = 0 + AND backoff = 0 + ORDER BY LOG(current_load) / LOG(capacity) + LIMIT 1 + "#; + const RELEASE_CAPACITY_QUERY: &str = r#" + UPDATE nodes + SET available = LEAST(capacity * ?, capacity - current_load) + WHERE service = ? + AND available <= 0 + AND capacity > current_load + AND downed = 0 + "#; + const SPANNER_QUERY: &str = r#" + SELECT id, node + FROM nodes + WHERE id = ? + LIMIT 1 + "#; + + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.get_best_node", None); + + if let Some(spanner_node_id) = self.spanner_node_id { + diesel::sql_query(SPANNER_QUERY) + .bind::(spanner_node_id) + .get_result::(&self.inner.conn) + .map_err(|e| { + let mut db_error = + DbError::internal(format!("unable to get Spanner node: {}", e)); + db_error.status = StatusCode::SERVICE_UNAVAILABLE; + db_error + }) + } else { + // We may have to retry the query if we need to release more capacity. This loop allows + // a maximum of five retries before bailing out. + for _ in 0..5 { + let maybe_result = diesel::sql_query(GET_BEST_NODE_QUERY) + .bind::(params.service_id) + .get_result::(&self.inner.conn) + .optional()?; + + if let Some(result) = maybe_result { + return Ok(result); + } + + // There were no available nodes. Try to release additional capacity from any nodes + // that are not fully occupied. + let affected_rows = diesel::sql_query(RELEASE_CAPACITY_QUERY) + .bind::( + params + .capacity_release_rate + .unwrap_or(DEFAULT_CAPACITY_RELEASE_RATE), + ) + .bind::(params.service_id) + .execute(&self.inner.conn)?; + + // If no nodes were affected by the last query, give up. + if affected_rows == 0 { + break; + } + } + + let mut db_error = DbError::internal("unable to get a node".to_owned()); + db_error.status = StatusCode::SERVICE_UNAVAILABLE; + Err(db_error) + } + } + + fn add_user_to_node_sync( + &self, + params: params::AddUserToNode, + ) -> DbResult { + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.add_user_to_node", None); + + const QUERY: &str = r#" + UPDATE nodes + SET current_load = current_load + 1, + available = GREATEST(available - 1, 0) + WHERE service = ? + AND node = ? + "#; + const SPANNER_QUERY: &str = r#" + UPDATE nodes + SET current_load = current_load + 1 + WHERE service = ? + AND node = ? + "#; + + let query = if self.spanner_node_id.is_some() { + SPANNER_QUERY + } else { + QUERY + }; + + diesel::sql_query(query) + .bind::(params.service_id) + .bind::(¶ms.node) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + fn get_users_sync(&self, params: params::GetUsers) -> DbResult { + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.get_users", None); + + const QUERY: &str = r#" + SELECT uid, nodes.node, generation, keys_changed_at, client_state, created_at, + replaced_at + FROM users + LEFT OUTER JOIN nodes ON users.nodeid = nodes.id + WHERE email = ? + AND users.service = ? + ORDER BY created_at DESC, uid DESC + LIMIT 20 + "#; + + diesel::sql_query(QUERY) + .bind::(¶ms.email) + .bind::(params.service_id) + .load::(&self.inner.conn) + .map_err(Into::into) + } + + /// Gets the user with the given email and service ID, or if one doesn't exist, allocates a new + /// user. + fn get_or_create_user_sync( + &self, + params: params::GetOrCreateUser, + ) -> DbResult { + let mut raw_users = self.get_users_sync(params::GetUsers { + service_id: params.service_id, + email: params.email.clone(), + })?; + + if raw_users.is_empty() { + // There are no users in the database with the given email and service ID, so + // allocate a new one. + let allocate_user_result = + self.allocate_user_sync(params.clone() as params::AllocateUser)?; + + Ok(results::GetOrCreateUser { + uid: allocate_user_result.uid, + email: params.email, + client_state: params.client_state, + generation: params.generation, + node: allocate_user_result.node, + keys_changed_at: params.keys_changed_at, + created_at: allocate_user_result.created_at, + replaced_at: None, + first_seen_at: allocate_user_result.created_at, + old_client_states: vec![], + }) + } else { + raw_users.sort_by_key(|raw_user| (raw_user.generation, raw_user.created_at)); + raw_users.reverse(); + + // The user with the greatest `generation` and `created_at` is the current user + let raw_user = raw_users[0].clone(); + + // Collect any old client states that differ from the current client state + let old_client_states = { + raw_users[1..] + .iter() + .map(|user| user.client_state.clone()) + .filter(|client_state| client_state != &raw_user.client_state) + .collect() + }; + + // Make sure every old row is marked as replaced. They might not be, due to races in row + // creation. + for old_user in &raw_users[1..] { + if old_user.replaced_at.is_none() { + let params = params::ReplaceUser { + uid: old_user.uid, + service_id: params.service_id, + replaced_at: raw_user.created_at, + }; + + self.replace_user_sync(params)?; + } + } + + let first_seen_at = raw_users[raw_users.len() - 1].created_at; + + match (raw_user.replaced_at, raw_user.node) { + // If the most up-to-date user is marked as replaced or does not have a node + // assignment, allocate a new user. Note that, if the current user is marked + // as replaced, we do not want to create a new user with the account metadata + // in the parameters to this method. Rather, we want to create a duplicate of + // the replaced user assigned to a new node. This distinction is important + // because the account metadata in the parameters to this method may not match + // that currently stored on the most up-to-date user and may be invalid. + (Some(_), _) | (_, None) if raw_user.generation < MAX_GENERATION => { + let allocate_user_result = { + self.allocate_user_sync(params::AllocateUser { + service_id: params.service_id, + email: params.email.clone(), + generation: raw_user.generation, + client_state: raw_user.client_state.clone(), + keys_changed_at: raw_user.keys_changed_at, + capacity_release_rate: params.capacity_release_rate, + })? + }; + + Ok(results::GetOrCreateUser { + uid: allocate_user_result.uid, + email: params.email, + client_state: raw_user.client_state, + generation: raw_user.generation, + node: allocate_user_result.node, + keys_changed_at: raw_user.keys_changed_at, + created_at: allocate_user_result.created_at, + replaced_at: None, + first_seen_at, + old_client_states, + }) + } + // The most up-to-date user has a node. Note that this user may be retired or + // replaced. + (_, Some(node)) => Ok(results::GetOrCreateUser { + uid: raw_user.uid, + email: params.email, + client_state: raw_user.client_state, + generation: raw_user.generation, + node, + keys_changed_at: raw_user.keys_changed_at, + created_at: raw_user.created_at, + replaced_at: None, + first_seen_at, + old_client_states, + }), + // The most up-to-date user doesn't have a node and is retired. This is an internal + // service error for compatibility reasons (the legacy Tokenserver returned an + // internal service error in this situation). + (_, None) => Err(DbError::internal("Tokenserver user retired".to_owned())), + } + } + } + + /// Creates a new user and assigns them to a node. + fn allocate_user_sync(&self, params: params::AllocateUser) -> DbResult { + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.allocate_user", None); + + // Get the least-loaded node + let node = self.get_best_node_sync(params::GetBestNode { + service_id: params.service_id, + capacity_release_rate: params.capacity_release_rate, + })?; + + // Decrement `available` and increment `current_load` on the node assigned to the user. + self.add_user_to_node_sync(params::AddUserToNode { + service_id: params.service_id, + node: node.node.clone(), + })?; + + let created_at = { + let start = SystemTime::now(); + start.duration_since(UNIX_EPOCH).unwrap().as_millis() as i64 + }; + let uid = self + .post_user_sync(params::PostUser { + service_id: params.service_id, + email: params.email.clone(), + generation: params.generation, + client_state: params.client_state.clone(), + created_at, + node_id: node.id, + keys_changed_at: params.keys_changed_at, + })? + .id; + + Ok(results::AllocateUser { + uid, + node: node.node, + created_at, + }) + } + + pub fn get_service_id_sync( + &self, + params: params::GetServiceId, + ) -> DbResult { + const QUERY: &str = r#" + SELECT id + FROM services + WHERE service = ? + "#; + + if let Some(id) = self.service_id { + Ok(results::GetServiceId { id }) + } else { + diesel::sql_query(QUERY) + .bind::(params.service) + .get_result::(&self.inner.conn) + .map_err(Into::into) + } + } + + #[cfg(test)] + fn set_user_created_at_sync( + &self, + params: params::SetUserCreatedAt, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET created_at = ? + WHERE uid = ? + "#; + diesel::sql_query(QUERY) + .bind::(params.created_at) + .bind::(¶ms.uid) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + #[cfg(test)] + fn set_user_replaced_at_sync( + &self, + params: params::SetUserReplacedAt, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = ? + WHERE uid = ? + "#; + diesel::sql_query(QUERY) + .bind::(params.replaced_at) + .bind::(¶ms.uid) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + #[cfg(test)] + fn get_user_sync(&self, params: params::GetUser) -> DbResult { + const QUERY: &str = r#" + SELECT service, email, generation, client_state, replaced_at, nodeid, keys_changed_at + FROM users + WHERE uid = ? + "#; + + diesel::sql_query(QUERY) + .bind::(params.id) + .get_result::(&self.inner.conn) + .map_err(Into::into) + } + + #[cfg(test)] + fn post_node_sync(&self, params: params::PostNode) -> DbResult { + const QUERY: &str = r#" + INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff) + VALUES (?, ?, ?, ?, ?, ?, ?) + "#; + diesel::sql_query(QUERY) + .bind::(params.service_id) + .bind::(¶ms.node) + .bind::(params.available) + .bind::(params.current_load) + .bind::(params.capacity) + .bind::(params.downed) + .bind::(params.backoff) + .execute(&self.inner.conn)?; + + diesel::sql_query(Self::LAST_INSERT_ID_QUERY) + .get_result::(&self.inner.conn) + .map_err(Into::into) + } + + #[cfg(test)] + fn get_node_sync(&self, params: params::GetNode) -> DbResult { + const QUERY: &str = r#" + SELECT * + FROM nodes + WHERE id = ? + "#; + + diesel::sql_query(QUERY) + .bind::(params.id) + .get_result::(&self.inner.conn) + .map_err(Into::into) + } + + #[cfg(test)] + fn unassign_node_sync(&self, params: params::UnassignNode) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = ? + WHERE nodeid = ? + "#; + + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64; + + diesel::sql_query(QUERY) + .bind::(current_time) + .bind::(params.node_id) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + #[cfg(test)] + fn remove_node_sync(&self, params: params::RemoveNode) -> DbResult { + const QUERY: &str = "DELETE FROM nodes WHERE id = ?"; + + diesel::sql_query(QUERY) + .bind::(params.node_id) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + #[cfg(test)] + fn post_service_sync(&self, params: params::PostService) -> DbResult { + const INSERT_SERVICE_QUERY: &str = r#" + INSERT INTO services (service, pattern) + VALUES (?, ?) + "#; + + diesel::sql_query(INSERT_SERVICE_QUERY) + .bind::(¶ms.service) + .bind::(¶ms.pattern) + .execute(&self.inner.conn)?; + + diesel::sql_query(Self::LAST_INSERT_ID_QUERY) + .get_result::(&self.inner.conn) + .map(|result| results::PostService { + id: result.id as i32, + }) + .map_err(Into::into) + } +} + +impl Db for TokenserverDb { + sync_db_method!(replace_user, replace_user_sync, ReplaceUser); + sync_db_method!(replace_users, replace_users_sync, ReplaceUsers); + sync_db_method!(post_user, post_user_sync, PostUser); + + sync_db_method!(put_user, put_user_sync, PutUser); + sync_db_method!(get_node_id, get_node_id_sync, GetNodeId); + sync_db_method!(get_best_node, get_best_node_sync, GetBestNode); + sync_db_method!(add_user_to_node, add_user_to_node_sync, AddUserToNode); + sync_db_method!(get_users, get_users_sync, GetUsers); + sync_db_method!(get_or_create_user, get_or_create_user_sync, GetOrCreateUser); + sync_db_method!(get_service_id, get_service_id_sync, GetServiceId); + + #[cfg(test)] + sync_db_method!(get_user, get_user_sync, GetUser); + + fn check(&self) -> DbFuture<'_, results::Check, DbError> { + let db = self.clone(); + Box::pin(self.blocking_threadpool.spawn(move || db.check_sync())) + } + + #[cfg(test)] + sync_db_method!( + set_user_created_at, + set_user_created_at_sync, + SetUserCreatedAt + ); + + #[cfg(test)] + sync_db_method!( + set_user_replaced_at, + set_user_replaced_at_sync, + SetUserReplacedAt + ); + + #[cfg(test)] + sync_db_method!(post_node, post_node_sync, PostNode); + + #[cfg(test)] + sync_db_method!(get_node, get_node_sync, GetNode); + + #[cfg(test)] + sync_db_method!(unassign_node, unassign_node_sync, UnassignNode); + + #[cfg(test)] + sync_db_method!(remove_node, remove_node_sync, RemoveNode); + + #[cfg(test)] + sync_db_method!(post_service, post_service_sync, PostService); +} + +pub trait Db { + fn replace_user( + &self, + params: params::ReplaceUser, + ) -> DbFuture<'_, results::ReplaceUser, DbError>; + + fn replace_users( + &self, + params: params::ReplaceUsers, + ) -> DbFuture<'_, results::ReplaceUsers, DbError>; + + fn post_user(&self, params: params::PostUser) -> DbFuture<'_, results::PostUser, DbError>; + + fn put_user(&self, params: params::PutUser) -> DbFuture<'_, results::PutUser, DbError>; + + fn check(&self) -> DbFuture<'_, results::Check, DbError>; + + fn get_node_id(&self, params: params::GetNodeId) -> DbFuture<'_, results::GetNodeId, DbError>; + + fn get_best_node( + &self, + params: params::GetBestNode, + ) -> DbFuture<'_, results::GetBestNode, DbError>; + + fn add_user_to_node( + &self, + params: params::AddUserToNode, + ) -> DbFuture<'_, results::AddUserToNode, DbError>; + + fn get_users(&self, params: params::GetUsers) -> DbFuture<'_, results::GetUsers, DbError>; + + fn get_or_create_user( + &self, + params: params::GetOrCreateUser, + ) -> DbFuture<'_, results::GetOrCreateUser, DbError>; + + fn get_service_id( + &self, + params: params::GetServiceId, + ) -> DbFuture<'_, results::GetServiceId, DbError>; + + #[cfg(test)] + fn set_user_created_at( + &self, + params: params::SetUserCreatedAt, + ) -> DbFuture<'_, results::SetUserCreatedAt, DbError>; + + #[cfg(test)] + fn set_user_replaced_at( + &self, + params: params::SetUserReplacedAt, + ) -> DbFuture<'_, results::SetUserReplacedAt, DbError>; + + #[cfg(test)] + fn get_user(&self, params: params::GetUser) -> DbFuture<'_, results::GetUser, DbError>; + + #[cfg(test)] + fn post_node(&self, params: params::PostNode) -> DbFuture<'_, results::PostNode, DbError>; + + #[cfg(test)] + fn get_node(&self, params: params::GetNode) -> DbFuture<'_, results::GetNode, DbError>; + + #[cfg(test)] + fn unassign_node( + &self, + params: params::UnassignNode, + ) -> DbFuture<'_, results::UnassignNode, DbError>; + + #[cfg(test)] + fn remove_node(&self, params: params::RemoveNode) + -> DbFuture<'_, results::RemoveNode, DbError>; + + #[cfg(test)] + fn post_service( + &self, + params: params::PostService, + ) -> DbFuture<'_, results::PostService, DbError>; +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use syncserver_settings::Settings; + + use crate::pool::{DbPool, TokenserverPool}; + + #[tokio::test] + async fn test_update_generation() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + let node_id = db + .post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + ..Default::default() + }) + .await? + .id; + + // Add a user + let email = "test_user"; + let uid = db + .post_user(params::PostUser { + service_id, + node_id, + email: email.to_owned(), + ..Default::default() + }) + .await? + .id; + + let user = db.get_user(params::GetUser { id: uid }).await?; + + assert_eq!(user.generation, 0); + assert_eq!(user.client_state, ""); + + // Changing generation should leave other properties unchanged. + db.put_user(params::PutUser { + email: email.to_owned(), + service_id, + generation: 42, + keys_changed_at: user.keys_changed_at, + }) + .await?; + + let user = db.get_user(params::GetUser { id: uid }).await?; + + assert_eq!(user.node_id, node_id); + assert_eq!(user.generation, 42); + assert_eq!(user.client_state, ""); + + // It's not possible to move the generation number backwards. + db.put_user(params::PutUser { + email: email.to_owned(), + service_id, + generation: 17, + keys_changed_at: user.keys_changed_at, + }) + .await?; + + let user = db.get_user(params::GetUser { id: uid }).await?; + + assert_eq!(user.node_id, node_id); + assert_eq!(user.generation, 42); + assert_eq!(user.client_state, ""); + + Ok(()) + } + + #[tokio::test] + async fn test_update_keys_changed_at() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + let node_id = db + .post_node(params::PostNode { + service_id, + node: "https://node".to_owned(), + ..Default::default() + }) + .await? + .id; + + // Add a user + let email = "test_user"; + let uid = db + .post_user(params::PostUser { + service_id, + node_id, + email: email.to_owned(), + ..Default::default() + }) + .await? + .id; + + let user = db.get_user(params::GetUser { id: uid }).await?; + + assert_eq!(user.keys_changed_at, None); + assert_eq!(user.client_state, ""); + + // Changing keys_changed_at should leave other properties unchanged. + db.put_user(params::PutUser { + email: email.to_owned(), + service_id, + generation: user.generation, + keys_changed_at: Some(42), + }) + .await?; + + let user = db.get_user(params::GetUser { id: uid }).await?; + + assert_eq!(user.node_id, node_id); + assert_eq!(user.keys_changed_at, Some(42)); + assert_eq!(user.client_state, ""); + + // It's not possible to move keys_changed_at backwards. + db.put_user(params::PutUser { + email: email.to_owned(), + service_id, + generation: user.generation, + keys_changed_at: Some(17), + }) + .await?; + + let user = db.get_user(params::GetUser { id: uid }).await?; + + assert_eq!(user.node_id, node_id); + assert_eq!(user.keys_changed_at, Some(42)); + assert_eq!(user.client_state, ""); + + Ok(()) + } + + #[tokio::test] + async fn replace_users() -> DbResult<()> { + const MILLISECONDS_IN_A_MINUTE: i64 = 60 * 1000; + const MILLISECONDS_IN_AN_HOUR: i64 = MILLISECONDS_IN_A_MINUTE * 60; + + let pool = db_pool().await?; + let db = pool.get().await?; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64; + let an_hour_ago = now - MILLISECONDS_IN_AN_HOUR; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + let node_id = db + .post_node(params::PostNode { + service_id, + ..Default::default() + }) + .await?; + + // Add a user to be updated + let email1 = "test_user_1"; + let uid1 = { + // Set created_at to be an hour ago + let uid = db + .post_user(params::PostUser { + service_id, + node_id: node_id.id, + email: email1.to_owned(), + ..Default::default() + }) + .await? + .id; + + db.set_user_created_at(params::SetUserCreatedAt { + created_at: an_hour_ago, + uid, + }) + .await?; + + uid + }; + + // Add a user that has already been replaced + let uid2 = { + // Set created_at to be an hour ago + let uid = db + .post_user(params::PostUser { + service_id, + node_id: node_id.id, + email: email1.to_owned(), + ..Default::default() + }) + .await? + .id; + + db.set_user_replaced_at(params::SetUserReplacedAt { + replaced_at: an_hour_ago + MILLISECONDS_IN_A_MINUTE, + uid, + }) + .await?; + + db.set_user_created_at(params::SetUserCreatedAt { + created_at: an_hour_ago, + uid, + }) + .await?; + + uid + }; + + // Add a user created too recently + { + let uid = db + .post_user(params::PostUser { + service_id, + node_id: node_id.id, + email: email1.to_owned(), + ..Default::default() + }) + .await? + .id; + + db.set_user_created_at(params::SetUserCreatedAt { + created_at: now + MILLISECONDS_IN_AN_HOUR, + uid, + }) + .await?; + } + + // Add a user with the wrong email address + let email2 = "test_user_2"; + { + // Set created_at to be an hour ago + let uid = db + .post_user(params::PostUser { + service_id, + node_id: node_id.id, + email: email2.to_owned(), + ..Default::default() + }) + .await? + .id; + + db.set_user_created_at(params::SetUserCreatedAt { + created_at: an_hour_ago, + uid, + }) + .await?; + } + + // Add a user with the wrong service + { + let uid = db + .post_user(params::PostUser { + service_id: service_id + 1, + node_id: node_id.id, + email: email1.to_owned(), + ..Default::default() + }) + .await? + .id; + + // Set created_at to be an hour ago + db.set_user_created_at(params::SetUserCreatedAt { + created_at: an_hour_ago, + uid, + }) + .await?; + } + + // Perform the bulk update + db.replace_users(params::ReplaceUsers { + service_id, + email: email1.to_owned(), + replaced_at: now, + }) + .await?; + + // Get all of the users + let users = { + let mut users1 = db + .get_users(params::GetUsers { + email: email1.to_owned(), + service_id, + }) + .await?; + let mut users2 = db + .get_users(params::GetUsers { + email: email2.to_owned(), + service_id, + }) + .await?; + users1.append(&mut users2); + + users1 + }; + + let mut users_with_replaced_at_uids: Vec = users + .iter() + .filter(|user| user.replaced_at.is_some()) + .map(|user| user.uid) + .collect(); + + users_with_replaced_at_uids.sort_unstable(); + + // The users with replaced_at timestamps should have the expected uids + let mut expected_user_uids = vec![uid1, uid2]; + expected_user_uids.sort_unstable(); + assert_eq!(users_with_replaced_at_uids, expected_user_uids); + + Ok(()) + } + + #[tokio::test] + async fn post_user() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + let post_node_params = params::PostNode { + service_id, + ..Default::default() + }; + let node_id = db.post_node(post_node_params.clone()).await?.id; + + // Add a user + let email1 = "test_user_1"; + let post_user_params1 = params::PostUser { + service_id, + email: email1.to_owned(), + generation: 1, + client_state: "aaaa".to_owned(), + created_at: 2, + node_id, + keys_changed_at: Some(3), + }; + let uid1 = db.post_user(post_user_params1.clone()).await?.id; + + // Add another user + let email2 = "test_user_2"; + let post_user_params2 = params::PostUser { + service_id, + node_id, + email: email2.to_owned(), + ..Default::default() + }; + let uid2 = db.post_user(post_user_params2).await?.id; + + // Ensure that two separate users were created + assert_ne!(uid1, uid2); + + // Get a user + let user = db.get_user(params::GetUser { id: uid1 }).await?; + + // Ensure the user has the expected values + let expected_get_user = results::GetUser { + service_id, + email: email1.to_owned(), + generation: 1, + client_state: "aaaa".to_owned(), + replaced_at: None, + node_id, + keys_changed_at: Some(3), + }; + + assert_eq!(user, expected_get_user); + + Ok(()) + } + + #[tokio::test] + async fn get_node_id() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + let node_id1 = db + .post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + ..Default::default() + }) + .await? + .id; + + // Add another node + db.post_node(params::PostNode { + service_id, + node: "https://node2".to_owned(), + ..Default::default() + }) + .await?; + + // Get the ID of the first node + let id = db + .get_node_id(params::GetNodeId { + service_id, + node: "https://node1".to_owned(), + }) + .await? + .id; + + // The ID should match that of the first node + assert_eq!(node_id1, id); + + Ok(()) + } + + #[tokio::test] + async fn test_node_allocation() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get_tokenserver_db().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + let node_id = db + .post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await? + .id; + + // Allocating a user assigns it to the node + let user = db.allocate_user_sync(params::AllocateUser { + service_id, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + })?; + assert_eq!(user.node, "https://node1"); + + // Getting the user from the database does not affect node assignment + let user = db.get_user(params::GetUser { id: user.uid }).await?; + assert_eq!(user.node_id, node_id); + + Ok(()) + } + + #[tokio::test] + async fn test_allocation_to_least_loaded_node() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get_tokenserver_db().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add two nodes + db.post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await?; + + db.post_node(params::PostNode { + service_id, + node: "https://node2".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await?; + + // Allocate two users + let user1 = db.allocate_user_sync(params::AllocateUser { + service_id, + generation: 1234, + email: "test1@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + })?; + + let user2 = db.allocate_user_sync(params::AllocateUser { + service_id, + generation: 1234, + email: "test2@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + })?; + + // Because users are always assigned to the least-loaded node, the users should have been + // assigned to different nodes + assert_ne!(user1.node, user2.node); + + Ok(()) + } + + #[tokio::test] + async fn test_allocation_is_not_allowed_to_downed_nodes() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get_tokenserver_db().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a downed node + db.post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + downed: 1, + ..Default::default() + }) + .await?; + + // User allocation fails because allocation is not allowed to downed nodes + let result = db.allocate_user_sync(params::AllocateUser { + service_id, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }); + let error = result.unwrap_err(); + assert_eq!(error.to_string(), "Unexpected error: unable to get a node"); + + Ok(()) + } + + #[tokio::test] + async fn test_allocation_is_not_allowed_to_backoff_nodes() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get_tokenserver_db().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a backoff node + db.post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + backoff: 1, + ..Default::default() + }) + .await?; + + // User allocation fails because allocation is not allowed to backoff nodes + let result = db.allocate_user_sync(params::AllocateUser { + service_id, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }); + let error = result.unwrap_err(); + assert_eq!(error.to_string(), "Unexpected error: unable to get a node"); + + Ok(()) + } + + #[tokio::test] + async fn test_node_reassignment_when_records_are_replaced() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get_tokenserver_db().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + db.post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await?; + + // Allocate a user + let allocate_user_result = db.allocate_user_sync(params::AllocateUser { + service_id, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + })?; + let user1 = db + .get_user(params::GetUser { + id: allocate_user_result.uid, + }) + .await?; + + // Mark the user as replaced + db.replace_user(params::ReplaceUser { + uid: allocate_user_result.uid, + service_id, + replaced_at: 1234, + }) + .await?; + + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + email: "test@test.com".to_owned(), + service_id, + generation: 1235, + client_state: "bbbb".to_owned(), + keys_changed_at: Some(1235), + capacity_release_rate: None, + }) + .await?; + + // Calling get_or_create_user() results in the creation of a new user record, since the + // previous record was marked as replaced + assert_ne!(allocate_user_result.uid, user2.uid); + + // The account metadata should match that of the original user and *not* that in the + // method parameters + assert_eq!(user1.generation, user2.generation); + assert_eq!(user1.keys_changed_at, user2.keys_changed_at); + assert_eq!(user1.client_state, user2.client_state); + + Ok(()) + } + + #[tokio::test] + async fn test_node_reassignment_not_done_for_retired_users() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + db.post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await?; + + // Add a retired user + let user1 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: MAX_GENERATION, + email: "test@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // Calling get_or_create_user() does not update the user's node + assert_eq!(user1.uid, user2.uid); + assert_eq!(user2.generation, MAX_GENERATION); + assert_eq!(user1.client_state, user2.client_state); + + Ok(()) + } + + #[tokio::test] + async fn test_node_reassignment_and_removal() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add two nodes + let node1_id = db + .post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await? + .id; + + let node2_id = db + .post_node(params::PostNode { + service_id, + node: "https://node2".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await? + .id; + + // Create four users. We should get two on each node. + let user1 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test1@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test2@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user3 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test3@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user4 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let node1_count = [&user1, &user2, &user3, &user4] + .iter() + .filter(|user| user.node == "https://node1") + .count(); + assert_eq!(node1_count, 2); + let node2_count = [&user1, &user2, &user3, &user4] + .iter() + .filter(|user| user.node == "https://node2") + .count(); + assert_eq!(node2_count, 2); + + // Clear the assignments on the first node. + db.unassign_node(params::UnassignNode { node_id: node1_id }) + .await?; + + // The users previously on the first node should balance across both nodes, + // giving 1 on the first node and 3 on the second node. + let mut node1_count = 0; + let mut node2_count = 0; + + for user in [&user1, &user2, &user3, &user4] { + let new_user = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + email: user.email.clone(), + generation: user.generation, + client_state: user.client_state.clone(), + keys_changed_at: user.keys_changed_at, + capacity_release_rate: None, + }) + .await?; + + if new_user.node == "https://node1" { + node1_count += 1; + } else { + assert_eq!(new_user.node, "https://node2"); + + node2_count += 1; + } + } + + assert_eq!(node1_count, 1); + assert_eq!(node2_count, 3); + + // Remove the second node. Everyone should end up on the first node. + db.remove_node(params::RemoveNode { node_id: node2_id }) + .await?; + + // Every user should be on the first node now. + for user in [&user1, &user2, &user3, &user4] { + let new_user = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + email: user.email.clone(), + generation: user.generation, + client_state: user.client_state.clone(), + keys_changed_at: user.keys_changed_at, + capacity_release_rate: None, + }) + .await?; + + assert_eq!(new_user.node, "https://node1"); + } + + Ok(()) + } + + #[tokio::test] + async fn test_gradual_release_of_node_capacity() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add two nodes + let node1_id = db + .post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 4, + capacity: 8, + available: 1, + ..Default::default() + }) + .await? + .id; + + let node2_id = db + .post_node(params::PostNode { + service_id, + node: "https://node2".to_owned(), + current_load: 4, + capacity: 6, + available: 1, + ..Default::default() + }) + .await? + .id; + + // Two user creations should succeed without releasing capacity on either of the nodes. + // The users should be assigned to different nodes. + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test1@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node1"); + let node = db.get_node(params::GetNode { id: node1_id }).await?; + assert_eq!(node.current_load, 5); + assert_eq!(node.capacity, 8); + assert_eq!(node.available, 0); + + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test2@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node2"); + let node = db.get_node(params::GetNode { id: node2_id }).await?; + assert_eq!(node.current_load, 5); + assert_eq!(node.capacity, 6); + assert_eq!(node.available, 0); + + // The next allocation attempt will release 10% more capacity, which is one more slot for + // each node. + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test3@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node1"); + let node = db.get_node(params::GetNode { id: node1_id }).await?; + assert_eq!(node.current_load, 6); + assert_eq!(node.capacity, 8); + assert_eq!(node.available, 0); + + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node2"); + let node = db.get_node(params::GetNode { id: node2_id }).await?; + assert_eq!(node.current_load, 6); + assert_eq!(node.capacity, 6); + assert_eq!(node.available, 0); + + // Now that node2 is full, further allocations will go to node1. + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test5@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node1"); + let node = db.get_node(params::GetNode { id: node1_id }).await?; + assert_eq!(node.current_load, 7); + assert_eq!(node.capacity, 8); + assert_eq!(node.available, 0); + + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test6@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node1"); + let node = db.get_node(params::GetNode { id: node1_id }).await?; + assert_eq!(node.current_load, 8); + assert_eq!(node.capacity, 8); + assert_eq!(node.available, 0); + + // Once the capacity is reached, further user allocations will result in an error. + let result = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test7@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await; + + assert_eq!( + result.unwrap_err().to_string(), + "Unexpected error: unable to get a node" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_correct_created_at_used_during_node_reassignment() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + let node_id = db + .post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 4, + capacity: 8, + available: 1, + ..Default::default() + }) + .await? + .id; + + // Create a user + let user1 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // Clear the user's node + db.unassign_node(params::UnassignNode { node_id }).await?; + + // Sleep very briefly to ensure the timestamp created during node reassignment is greater + // than the timestamp created during user creation + thread::sleep(Duration::from_millis(5)); + + // Get the user, prompting the user's reassignment to the same node + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // The user's timestamp should be updated since a new user record was created. + assert!(user2.created_at > user1.created_at); + + Ok(()) + } + + #[tokio::test] + async fn test_correct_created_at_used_during_user_retrieval() -> DbResult<()> { + let pool = db_pool().await?; + let db = pool.get().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node + db.post_node(params::PostNode { + service_id, + node: "https://node1".to_owned(), + current_load: 4, + capacity: 8, + available: 1, + ..Default::default() + }) + .await?; + + // Create a user + let user1 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // Sleep very briefly to ensure that any timestamp that might be created below is greater + // than the timestamp created during user creation + thread::sleep(Duration::from_millis(5)); + + // Get the user + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + service_id, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "aaaa".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // The user's timestamp should be equal to the one generated when the user was created + assert_eq!(user1.created_at, user2.created_at); + + Ok(()) + } + + #[tokio::test] + async fn test_get_spanner_node() -> DbResult<()> { + let pool = db_pool().await?; + let mut db = pool.get_tokenserver_db().await?; + + // Add a service + let service_id = db + .post_service(params::PostService { + service: "sync-1.5".to_owned(), + pattern: "{node}/1.5/{uid}".to_owned(), + }) + .await? + .id; + + // Add a node with capacity and available set to 0 + let spanner_node_id = db + .post_node(params::PostNode { + service_id, + node: "https://spanner_node".to_owned(), + current_load: 1000, + capacity: 0, + available: 0, + ..Default::default() + }) + .await? + .id; + + // Add another node with available capacity + db.post_node(params::PostNode { + service_id, + node: "https://another_node".to_owned(), + current_load: 0, + capacity: 1000, + available: 1000, + ..Default::default() + }) + .await?; + + // Ensure the node with available capacity is selected if the Spanner node ID is not + // cached + assert_ne!( + db.get_best_node(params::GetBestNode { + service_id, + capacity_release_rate: None, + }) + .await? + .id, + spanner_node_id + ); + + // Ensure the Spanner node is selected if the Spanner node ID is cached + db.spanner_node_id = Some(spanner_node_id as i32); + + assert_eq!( + db.get_best_node(params::GetBestNode { + service_id, + capacity_release_rate: None, + }) + .await? + .id, + spanner_node_id + ); + + Ok(()) + } + + async fn db_pool() -> DbResult { + let _ = env_logger::try_init(); + + let mut settings = Settings::test_settings().tokenserver; + settings.run_migrations = true; + let use_test_transactions = true; + + TokenserverPool::new( + &settings, + &Metrics::noop(), + Arc::new(BlockingThreadpool::default()), + use_test_transactions, + ) + } +} diff --git a/tokenserver-db-mysql/src/pool.rs b/tokenserver-db-mysql/src/pool.rs new file mode 100644 index 0000000000..1729c617ce --- /dev/null +++ b/tokenserver-db-mysql/src/pool.rs @@ -0,0 +1,20 @@ +use diesel::{ + mysql::MysqlConnection, + Connection, +}; +use diesel_logger::LoggingConnection; +use tokenserver_db_common::error::DbResult; + +embed_migrations!(); + +/// Run the diesel embedded migrations +/// +/// Mysql DDL statements implicitly commit which could disrupt MysqlPool's +/// begin_test_transaction during tests. So this runs on its own separate conn. +pub fn run_embedded_migrations(database_url: &str) -> DbResult<()> { + let conn = MysqlConnection::establish(database_url)?; + + embedded_migrations::run(&LoggingConnection::new(conn))?; + + Ok(()) +} diff --git a/tokenserver-db-sqlite/Cargo.toml b/tokenserver-db-sqlite/Cargo.toml new file mode 100644 index 0000000000..7293b14aa9 --- /dev/null +++ b/tokenserver-db-sqlite/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "tokenserver-db-sqlite" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +backtrace.workspace = true +futures.workspace = true +http.workspace = true +serde.workspace = true +serde_derive.workspace = true +serde_json.workspace = true +slog-scope.workspace = true + +async-trait = "0.1.40" +diesel = { version = "1.4", features = ["sqlite", "r2d2"] } +diesel_logger = "0.1.1" +diesel_migrations = { version = "1.4.0", features = ["sqlite"] } +syncserver-common = { path = "../syncserver-common" } +syncserver-db-common = { path = "../syncserver-db-common" } +thiserror = "1.0.26" +tokenserver-common = { path = "../tokenserver-common" } +tokenserver-db-common = { path = "../tokenserver-db-common"} +tokenserver-settings = { path = "../tokenserver-settings" } +tokio = { workspace = true, features = ["macros", "sync"] } + +[dev-dependencies] +env_logger.workspace = true + +syncserver-settings = { path = "../syncserver-settings" } diff --git a/tokenserver-db-sqlite/migrations/2024-01-28-211312_init/down.sql b/tokenserver-db-sqlite/migrations/2024-01-28-211312_init/down.sql new file mode 100644 index 0000000000..da49bf74a9 --- /dev/null +++ b/tokenserver-db-sqlite/migrations/2024-01-28-211312_init/down.sql @@ -0,0 +1,3 @@ +DROP TABLE IF EXISTS `users`; +DROP TABLE IF EXISTS `nodes`; +DROP TABLE IF EXISTS `services`; diff --git a/tokenserver-db-sqlite/migrations/2024-01-28-211312_init/up.sql b/tokenserver-db-sqlite/migrations/2024-01-28-211312_init/up.sql new file mode 100644 index 0000000000..7ab7c97b8b --- /dev/null +++ b/tokenserver-db-sqlite/migrations/2024-01-28-211312_init/up.sql @@ -0,0 +1,34 @@ +CREATE TABLE IF NOT EXISTS `services` ( + `id` int PRIMARY KEY, + `service` varchar(30) DEFAULT NULL UNIQUE, + `pattern` varchar(128) DEFAULT NULL +); + +CREATE TABLE IF NOT EXISTS `nodes` ( + `id` bigint PRIMARY KEY, + `service` int NOT NULL, + `node` varchar(64) NOT NULL, + `available` int NOT NULL, + `current_load` int NOT NULL, + `capacity` int NOT NULL, + `downed` int NOT NULL, + `backoff` int NOT NULL +); + +CREATE UNIQUE INDEX `unique_idx` ON `nodes` (`service`, `node`); + +CREATE TABLE IF NOT EXISTS `users` ( + `uid` PRIMARY KEY, + `service` int NOT NULL, + `email` varchar(255) NOT NULL, + `generation` bigint NOT NULL, + `client_state` varchar(32) NOT NULL, + `created_at` bigint NOT NULL, + `replaced_at` bigint DEFAULT NULL, + `nodeid` bigint NOT NULL, + `keys_changed_at` bigint DEFAULT NULL +); + +CREATE INDEX `lookup_idx` ON `users` (`email`, `service`, `created_at`); +CREATE INDEX `replaced_at_idx` ON `users` (`service`, `replaced_at`); +CREATE INDEX `node_idx` ON `users` (`nodeid`); diff --git a/tokenserver-db-sqlite/src/lib.rs b/tokenserver-db-sqlite/src/lib.rs new file mode 100644 index 0000000000..30caa23e1a --- /dev/null +++ b/tokenserver-db-sqlite/src/lib.rs @@ -0,0 +1,5 @@ +extern crate diesel; +#[macro_use] +extern crate diesel_migrations; + +pub mod pool; diff --git a/tokenserver-db-sqlite/src/pool.rs b/tokenserver-db-sqlite/src/pool.rs new file mode 100644 index 0000000000..2edd1f38ab --- /dev/null +++ b/tokenserver-db-sqlite/src/pool.rs @@ -0,0 +1,17 @@ +use diesel::{ + sqlite::SqliteConnection, + Connection, +}; +use diesel_logger::LoggingConnection; +use tokenserver_db_common::error::DbResult; + +embed_migrations!(); + +/// Run the diesel embedded migrations +pub fn run_embedded_migrations(database_url: &str) -> DbResult<()> { + let conn = SqliteConnection::establish(database_url)?; + + embedded_migrations::run(&LoggingConnection::new(conn))?; + + Ok(()) +} diff --git a/tokenserver-db/Cargo.toml b/tokenserver-db/Cargo.toml index 98d9fe3804..eebf91fcbb 100644 --- a/tokenserver-db/Cargo.toml +++ b/tokenserver-db/Cargo.toml @@ -23,9 +23,16 @@ syncserver-common = { path = "../syncserver-common" } syncserver-db-common = { path = "../syncserver-db-common" } tokenserver-common = { path = "../tokenserver-common" } tokenserver-settings = { path = "../tokenserver-settings" } +tokenserver-db-common = { path = "../tokenserver-db-common" } +tokenserver-db-mysql = { path = "../tokenserver-db-mysql", optional = true} +tokenserver-db-sqlite = { path = "../tokenserver-db-sqlite", optional = true} tokio = { workspace = true, features = ["macros", "sync"] } [dev-dependencies] env_logger.workspace = true syncserver-settings = { path = "../syncserver-settings" } + +[features] +mysql = ["tokenserver-db-mysql"] +sqlite = ["tokenserver-db-sqlite"] diff --git a/tokenserver-db/src/lib.rs b/tokenserver-db/src/lib.rs index 1b9f86c623..6d1e4169eb 100644 --- a/tokenserver-db/src/lib.rs +++ b/tokenserver-db/src/lib.rs @@ -1,8 +1,9 @@ -extern crate diesel; -#[macro_use] -extern crate diesel_migrations; +use diesel::r2d2::{ConnectionManager, PooledConnection}; +#[cfg(feature = "mysql")] +use diesel::MysqlConnection; +#[cfg(feature = "sqlite")] +use diesel::SqliteConnection; -mod error; pub mod mock; mod models; pub mod params; @@ -11,3 +12,9 @@ pub mod results; pub use models::{Db, TokenserverDb}; pub use pool::{DbPool, TokenserverPool}; + +#[cfg(feature = "mysql")] +type Conn = MysqlConnection; +#[cfg(feature = "sqlite")] +type Conn = SqliteConnection; +type PooledConn = PooledConnection>; diff --git a/tokenserver-db/src/mock.rs b/tokenserver-db/src/mock.rs index 29041091d7..1e2b783f3e 100644 --- a/tokenserver-db/src/mock.rs +++ b/tokenserver-db/src/mock.rs @@ -3,8 +3,8 @@ use async_trait::async_trait; use futures::future; use syncserver_db_common::{GetPoolState, PoolState}; +use tokenserver_db_common::error::{DbError, DbFuture}; -use super::error::{DbError, DbFuture}; use super::models::Db; use super::params; use super::pool::DbPool; diff --git a/tokenserver-db/src/models.rs b/tokenserver-db/src/models.rs index e78328319a..a026afb90e 100644 --- a/tokenserver-db/src/models.rs +++ b/tokenserver-db/src/models.rs @@ -1,6 +1,9 @@ +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; + use diesel::{ - mysql::MysqlConnection, - r2d2::{ConnectionManager, PooledConnection}, sql_types::{Bigint, Float, Integer, Nullable, Text}, OptionalExtension, RunQueryDsl, }; @@ -9,23 +12,18 @@ use diesel_logger::LoggingConnection; use http::StatusCode; use syncserver_common::{BlockingThreadpool, Metrics}; use syncserver_db_common::{sync_db_method, DbFuture}; - -use std::{ - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, -}; +use tokenserver_db_common::error::{DbError, DbResult}; use super::{ - error::{DbError, DbResult}, - params, results, + params, + results, + PooledConn, }; /// The maximum possible generation number. Used as a tombstone to mark users that have been /// "retired" from the db. const MAX_GENERATION: i64 = i64::MAX; -type Conn = PooledConnection>; - #[derive(Clone)] pub struct TokenserverDb { /// Synchronous Diesel calls are executed on a blocking threadpool to satisfy @@ -49,9 +47,9 @@ unsafe impl Send for TokenserverDb {} struct DbInner { #[cfg(not(test))] - pub(super) conn: Conn, + pub(super) conn: PooledConn, #[cfg(test)] - pub(super) conn: LoggingConnection, // display SQL when RUST_LOG="diesel_logger=trace" + pub(super) conn: LoggingConnection, // display SQL when RUST_LOG="diesel_logger=trace" } impl TokenserverDb { @@ -63,7 +61,7 @@ impl TokenserverDb { const LAST_INSERT_ID_QUERY: &'static str = "SELECT LAST_INSERT_ID() AS id"; pub fn new( - conn: Conn, + conn: PooledConn, metrics: &Metrics, service_id: Option, spanner_node_id: Option, diff --git a/tokenserver-db/src/pool.rs b/tokenserver-db/src/pool.rs index dd100abb4c..9311993cbe 100644 --- a/tokenserver-db/src/pool.rs +++ b/tokenserver-db/src/pool.rs @@ -1,41 +1,27 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use diesel::{ - mysql::MysqlConnection, - r2d2::{ConnectionManager, Pool}, - Connection, -}; -use diesel_logger::LoggingConnection; +use diesel::r2d2::{ConnectionManager, Pool}; use syncserver_common::{BlockingThreadpool, Metrics}; #[cfg(debug_assertions)] use syncserver_db_common::test::TestTransactionCustomizer; use syncserver_db_common::{GetPoolState, PoolState}; use tokenserver_settings::Settings; +use tokenserver_db_common::error::{DbError, DbResult}; +#[cfg(feature = "mysql")] +use tokenserver_db_mysql::pool::run_embedded_migrations; +#[cfg(feature = "sqlite")] +use tokenserver_db_sqlite::pool::run_embedded_migrations; use super::{ - error::{DbError, DbResult}, models::{Db, TokenserverDb}, + Conn, }; -embed_migrations!(); - -/// Run the diesel embedded migrations -/// -/// Mysql DDL statements implicitly commit which could disrupt MysqlPool's -/// begin_test_transaction during tests. So this runs on its own separate conn. -fn run_embedded_migrations(database_url: &str) -> DbResult<()> { - let conn = MysqlConnection::establish(database_url)?; - - embedded_migrations::run(&LoggingConnection::new(conn))?; - - Ok(()) -} - #[derive(Clone)] pub struct TokenserverPool { /// Pool of db connections - inner: Pool>, + inner: Pool>, metrics: Metrics, // This field is public so the service ID can be set after the pool is created pub service_id: Option, @@ -54,7 +40,7 @@ impl TokenserverPool { run_embedded_migrations(&settings.database_url)?; } - let manager = ConnectionManager::::new(settings.database_url.clone()); + let manager = ConnectionManager::::new(settings.database_url.clone()); let builder = Pool::builder() .max_size(settings.database_pool_max_size) .connection_timeout(Duration::from_secs(