diff --git a/Cargo.lock b/Cargo.lock index 7b8166411..73572e3f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1707,6 +1707,7 @@ dependencies = [ name = "nativelink-scheduler" version = "0.2.0" dependencies = [ + "async-lock", "async-trait", "blake3", "futures", diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index 149efe483..abc24cf5f 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -1,4 +1,4 @@ -// Copyright 2023 The Native Link Authors. All rights reserved. +// Copyright 2023-2024 The Native Link Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -134,6 +134,12 @@ pub struct GrpcScheduler { /// Retry configuration to use when a network request fails. #[serde(default)] pub retry: Retry, + + /// Limit the number of simultaneous upstream requests to this many. A + /// value of zero is treated as unlimited. If the limit is reached the + /// request is queued. + #[serde(default)] + pub max_concurrent_requests: usize, } #[derive(Deserialize, Debug)] diff --git a/nativelink-config/src/stores.rs b/nativelink-config/src/stores.rs index eb3f43be1..7c550334f 100644 --- a/nativelink-config/src/stores.rs +++ b/nativelink-config/src/stores.rs @@ -1,4 +1,4 @@ -// Copyright 2023 The Native Link Authors. All rights reserved. +// Copyright 2023-2024 The Native Link Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -546,6 +546,12 @@ pub struct GrpcStore { /// Retry configuration to use when a network request fails. #[serde(default)] pub retry: Retry, + + /// Limit the number of simultaneous upstream requests to this many. A + /// value of zero is treated as unlimited. If the limit is reached the + /// request is queued. + #[serde(default)] + pub max_concurrent_requests: usize, } /// The possible error codes that might occur on an upstream request. diff --git a/nativelink-scheduler/BUILD.bazel b/nativelink-scheduler/BUILD.bazel index 95034e00c..46ca0b964 100644 --- a/nativelink-scheduler/BUILD.bazel +++ b/nativelink-scheduler/BUILD.bazel @@ -30,6 +30,7 @@ rust_library( "//nativelink-proto", "//nativelink-store", "//nativelink-util", + "@crates//:async-lock", "@crates//:blake3", "@crates//:futures", "@crates//:hashbrown", diff --git a/nativelink-scheduler/Cargo.toml b/nativelink-scheduler/Cargo.toml index ed781c9c5..0d57786d1 100644 --- a/nativelink-scheduler/Cargo.toml +++ b/nativelink-scheduler/Cargo.toml @@ -13,6 +13,7 @@ nativelink-proto = { path = "../nativelink-proto" } # files somewhere else. nativelink-store = { path = "../nativelink-store" } +async-lock = "3.2.0" async-trait = "0.1.74" blake3 = "1.5.0" prost = "0.12.3" diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 466d68aec..2ae981cfc 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -1,4 +1,4 @@ -// Copyright 2023 The Native Link Authors. All rights reserved. +// Copyright 2023-2024 The Native Link Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ use nativelink_proto::build::bazel::remote::execution::v2::{ }; use nativelink_proto::google::longrunning::Operation; use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionState, DEFAULT_EXECUTION_PRIORITY}; +use nativelink_util::grpc_utils::ConnectionManager; use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::tls_utils; use parking_lot::Mutex; @@ -36,17 +37,16 @@ use rand::Rng; use tokio::select; use tokio::sync::watch; use tokio::time::sleep; -use tonic::{transport, Request, Streaming}; +use tonic::{Request, Streaming}; use tracing::{error, info, warn}; use crate::action_scheduler::ActionScheduler; use crate::platform_property_manager::PlatformPropertyManager; pub struct GrpcScheduler { - capabilities_client: CapabilitiesClient, - execution_client: ExecutionClient, platform_property_managers: Mutex>>, retrier: Retrier, + connection_manager: ConnectionManager, } impl GrpcScheduler { @@ -69,16 +69,15 @@ impl GrpcScheduler { config: &nativelink_config::schedulers::GrpcScheduler, jitter_fn: Box Duration + Send + Sync>, ) -> Result { - let channel = transport::Channel::balance_list(std::iter::once(tls_utils::endpoint(&config.endpoint)?)); + let endpoint = tls_utils::endpoint(&config.endpoint)?; Ok(Self { - capabilities_client: CapabilitiesClient::new(channel.clone()), - execution_client: ExecutionClient::new(channel), platform_property_managers: Mutex::new(HashMap::new()), retrier: Retrier::new( Arc::new(|duration| Box::pin(sleep(duration))), Arc::new(jitter_fn), config.retry.to_owned(), ), + connection_manager: ConnectionManager::new(std::iter::once(endpoint), config.max_concurrent_requests), }) } @@ -150,14 +149,17 @@ impl ActionScheduler for GrpcScheduler { self.perform_request(instance_name, |instance_name| async move { // Not in the cache, lookup the capabilities with the upstream. - let capabilities = self - .capabilities_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let capabilities_result = CapabilitiesClient::new(channel) .get_capabilities(GetCapabilitiesRequest { instance_name: instance_name.to_string(), }) - .await? - .into_inner(); + .await + .err_tip(|| "Retrieving upstream GrpcScheduler capabilities"); + if let Err(err) = &capabilities_result { + connection.on_error(err); + } + let capabilities = capabilities_result?.into_inner(); let platform_property_manager = Arc::new(PlatformPropertyManager::new( capabilities .execution_capabilities @@ -195,11 +197,15 @@ impl ActionScheduler for GrpcScheduler { }; let result_stream = self .perform_request(request, |request| async move { - self.execution_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ExecutionClient::new(channel) .execute(Request::new(request)) .await - .err_tip(|| "Sending action to upstream scheduler") + .err_tip(|| "Sending action to upstream scheduler"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .await? .into_inner(); @@ -215,11 +221,15 @@ impl ActionScheduler for GrpcScheduler { }; let result_stream = self .perform_request(request, |request| async move { - self.execution_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ExecutionClient::new(channel) .wait_execution(Request::new(request)) .await - .err_tip(|| "While getting wait_execution stream") + .err_tip(|| "While getting wait_execution stream"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .and_then(|result_stream| Self::stream_state(result_stream.into_inner())) .await; diff --git a/nativelink-store/src/grpc_store.rs b/nativelink-store/src/grpc_store.rs index e4a507509..6acd46642 100644 --- a/nativelink-store/src/grpc_store.rs +++ b/nativelink-store/src/grpc_store.rs @@ -36,6 +36,7 @@ use nativelink_proto::google::bytestream::{ }; use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf}; use nativelink_util::common::DigestInfo; +use nativelink_util::grpc_utils::ConnectionManager; use nativelink_util::resource_info::ResourceInfo; use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::store_trait::{Store, UploadSizeInfo}; @@ -46,7 +47,7 @@ use prost::Message; use rand::rngs::OsRng; use rand::Rng; use tokio::time::sleep; -use tonic::{transport, IntoRequest, Request, Response, Status, Streaming}; +use tonic::{IntoRequest, Request, Response, Status, Streaming}; use tracing::error; use uuid::Uuid; @@ -57,11 +58,9 @@ use crate::ac_utils::ESTIMATED_DIGEST_SIZE; // underlying data. This might cause issues if embedded in certain stores. pub struct GrpcStore { instance_name: String, - cas_client: ContentAddressableStorageClient, - bytestream_client: ByteStreamClient, - ac_client: ActionCacheClient, store_type: nativelink_config::stores::StoreType, retrier: Retrier, + connection_manager: ConnectionManager, } /// This provides a buffer for the first response from GrpcStore.read in order @@ -201,7 +200,7 @@ where Some(message) } Err(err) => { - error!("{err:?}"); + local_state.read_stream_error = Some(err); None } }, @@ -244,18 +243,15 @@ impl GrpcStore { endpoints.push(endpoint); } - let conn = transport::Channel::balance_list(endpoints.into_iter()); Ok(GrpcStore { instance_name: config.instance_name.clone(), - cas_client: ContentAddressableStorageClient::new(conn.clone()), - bytestream_client: ByteStreamClient::new(conn.clone()), - ac_client: ActionCacheClient::new(conn), store_type: config.store_type, retrier: Retrier::new( Arc::new(|duration| Box::pin(sleep(duration))), Arc::new(jitter_fn), config.retry.to_owned(), ), + connection_manager: ConnectionManager::new(endpoints.into_iter(), config.max_concurrent_requests), }) } @@ -291,11 +287,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name = self.instance_name.clone(); self.perform_request(request, |request| async move { - self.cas_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ContentAddressableStorageClient::new(channel) .find_missing_blobs(Request::new(request)) .await - .err_tip(|| "in GrpcStore::find_missing_blobs") + .err_tip(|| "in GrpcStore::find_missing_blobs"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .await } @@ -312,11 +312,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name = self.instance_name.clone(); self.perform_request(request, |request| async move { - self.cas_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ContentAddressableStorageClient::new(channel) .batch_update_blobs(Request::new(request)) .await - .err_tip(|| "in GrpcStore::batch_update_blobs") + .err_tip(|| "in GrpcStore::batch_update_blobs"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .await } @@ -333,11 +337,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name = self.instance_name.clone(); self.perform_request(request, |request| async move { - self.cas_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ContentAddressableStorageClient::new(channel) .batch_read_blobs(Request::new(request)) .await - .err_tip(|| "in GrpcStore::batch_read_blobs") + .err_tip(|| "in GrpcStore::batch_read_blobs"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .await } @@ -354,11 +362,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name = self.instance_name.clone(); self.perform_request(request, |request| async move { - self.cas_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ContentAddressableStorageClient::new(channel) .get_tree(Request::new(request)) .await - .err_tip(|| "in GrpcStore::get_tree") + .err_tip(|| "in GrpcStore::get_tree"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .await } @@ -377,13 +389,15 @@ impl GrpcStore { &self, request: ReadRequest, ) -> Result>, Error> { - let mut response = self - .bytestream_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ByteStreamClient::new(channel) .read(Request::new(request)) .await - .err_tip(|| "in GrpcStore::read")? - .into_inner(); + .err_tip(|| "in GrpcStore::read"); + if let Err(err) = &result { + connection.on_error(err); + } + let mut response = result?.into_inner(); let first_response = response .message() .await @@ -423,13 +437,13 @@ impl GrpcStore { let result = self .retrier .retry(unfold(local_state, move |local_state| async move { - let mut client = self.bytestream_client.clone(); + let (connection, channel) = self.connection_manager.get_connection().await; // The client write may occur on a separate thread and // therefore in order to share the state with it we have to // wrap it in a Mutex and retrieve it after the write // has completed. There is no way to get the value back // from the client. - let result = client + let result = ByteStreamClient::new(channel) .write(WriteStateWrapper { shared_state: local_state.clone(), }) @@ -447,6 +461,7 @@ impl GrpcStore { // On error determine whether it is possible to retry. match result.err_tip(|| "in GrpcStore::write") { Err(err) => { + connection.on_error(&err); if local_state_locked.can_resume() { local_state_locked.resume(); RetryResult::Retry(err) @@ -484,11 +499,15 @@ impl GrpcStore { } self.perform_request(request, |request| async move { - self.bytestream_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ByteStreamClient::new(channel) .query_write_status(Request::new(request)) .await - .err_tip(|| "in GrpcStore::query_write_status") + .err_tip(|| "in GrpcStore::query_write_status"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .await } @@ -500,11 +519,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name = self.instance_name.clone(); self.perform_request(request, |request| async move { - self.ac_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ActionCacheClient::new(channel) .get_action_result(Request::new(request)) .await - .err_tip(|| "in GrpcStore::get_action_result") + .err_tip(|| "in GrpcStore::get_action_result"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .await } @@ -516,11 +539,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name = self.instance_name.clone(); self.perform_request(request, |request| async move { - self.ac_client - .clone() + let (connection, channel) = self.connection_manager.get_connection().await; + let result = ActionCacheClient::new(channel) .update_action_result(Request::new(request)) .await - .err_tip(|| "in GrpcStore::update_action_result") + .err_tip(|| "in GrpcStore::update_action_result"); + if let Err(err) = &result { + connection.on_error(err); + } + result }) .await } diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index 7fdb17177..77d5302d0 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -16,6 +16,7 @@ rust_library( "src/evicting_map.rs", "src/fastcdc.rs", "src/fs.rs", + "src/grpc_utils.rs", "src/lib.rs", "src/metrics_utils.rs", "src/platform_properties.rs", diff --git a/nativelink-util/src/grpc_utils.rs b/nativelink-util/src/grpc_utils.rs new file mode 100644 index 000000000..0dbf2c65b --- /dev/null +++ b/nativelink-util/src/grpc_utils.rs @@ -0,0 +1,107 @@ +// Copyright 2024 The Native Link Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_lock::{Semaphore, SemaphoreGuard}; +use nativelink_error::{Code, Error}; +use parking_lot::Mutex; +use tonic::transport::{Channel, Endpoint}; + +/// A helper utility that enables management of a suite of connections to an +/// upstream gRPC endpoint using Tonic. +pub struct ConnectionManager { + /// The endpoints to establish Channels for. + endpoints: Vec, + /// A balance channel over the above endpoints which is kept with a + /// monotonic index to ensure we only re-create a channel on the first error + /// received on it. + channel: Mutex<(usize, Channel)>, + /// If a maximum number of upstream requests are allowed at a time, this + /// is a Semaphore to manage that. + request_semaphore: Option, +} + +impl ConnectionManager { + /// Create a connection manager that creates a balance list between a given + /// set of Endpoints. This will restrict the number of concurrent requests + /// assuming that the user of this connection manager uses the connection + /// only once and reports all errors. + pub fn new(endpoints: impl IntoIterator, max_concurrent_requests: usize) -> Self { + let endpoints = Vec::from_iter(endpoints); + let channel = Channel::balance_list(endpoints.iter().cloned()); + Self { + endpoints, + channel: Mutex::new((0, channel)), + request_semaphore: (max_concurrent_requests > 0).then_some(Semaphore::new(max_concurrent_requests)), + } + } + + /// Get a connection slot for an Endpoint, this contains a Channel which + /// should be used once and any errors should be reported back to the + /// on_error method to ensure that the Channel is re-connected on error. + pub async fn get_connection(&self) -> (Connection<'_>, Channel) { + let _permit = if let Some(semaphore) = &self.request_semaphore { + Some(semaphore.acquire().await) + } else { + None + }; + let channel_lock = self.channel.lock(); + ( + Connection { + channel_id: channel_lock.0, + parent: self, + _permit, + }, + channel_lock.1.clone(), + ) + } +} + +/// An instance of this is obtained for every communication with the gGRPC +/// service. This handles the permit for limiting concurrency, and also +/// re-connecting the underlying channel on error. It depends on users +/// reporting all errors. +pub struct Connection<'a> { + channel_id: usize, + parent: &'a ConnectionManager, + _permit: Option>, +} + +impl<'a> Connection<'a> { + pub fn on_error(self, err: &Error) { + // Usually Tonic reconnects on upstream errors (like Unavailable) but + // if there are protocol errors (such as GoAway) then it will not + // attempt to re-connect, and therefore we are forced to manually do + // that. + if err.code != Code::Internal { + return; + } + // Create a new channel for future requests to use upon a new request + // to ConnectionManager::get_connection(). In order to ensure we only + // do this for the first error on a cloned Channel we check the ID + // matches the current ID when we get the lock. + let mut channel_lock = self.parent.channel.lock(); + if channel_lock.0 != self.channel_id { + // The connection was already re-established by another user getting + // and error on a clone of this Channel, so don't make another one. + return; + } + // Create a new channel with a unique ID to track if it gets an error. + // This new Channel will be used when a new request comes into + // ConnectionManager::get_connection() as this request has been and gone + // with an error now and it's up to the user whether they retry by + // getting a new connection. + channel_lock.0 += 1; + channel_lock.1 = Channel::balance_list(self.parent.endpoints.iter().cloned()); + } +} diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 3dad898e8..ec56134e5 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -19,6 +19,7 @@ pub mod digest_hasher; pub mod evicting_map; pub mod fastcdc; pub mod fs; +pub mod grpc_utils; pub mod metrics_utils; pub mod platform_properties; pub mod resource_info; diff --git a/nativelink-util/src/retry.rs b/nativelink-util/src/retry.rs index 4752cb49e..c37171d55 100644 --- a/nativelink-util/src/retry.rs +++ b/nativelink-util/src/retry.rs @@ -20,7 +20,7 @@ use futures::future::Future; use futures::stream::StreamExt; use nativelink_config::stores::{ErrorCode, Retry}; use nativelink_error::{make_err, Code, Error}; -use tracing::debug; +use tracing::info; struct ExponentialBackoff { current: Duration, @@ -148,7 +148,7 @@ impl Retrier { Some(RetryResult::Err(e)) => return Err(e.append(format!("On attempt {attempt}"))), Some(RetryResult::Retry(e)) => { if !self.should_retry(&e.code) { - debug!("Not retrying permanent error on attempt {attempt}: {e:?}"); + info!("Not retrying permanent error on attempt {attempt}: {e:?}"); return Err(e); } (self.sleep_fn)(iter.next().ok_or(e.append(format!("On attempt {attempt}")))?).await