Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/gha/tests/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ rec {
};

vmTests = {
inherit (nixosTests) curl-s3-binary-cache-store;
inherit (nixosTests) s3-binary-cache-store;
}
// lib.optionalAttrs (!withSanitizers && !withCoverage) {
# evalNixpkgs uses non-instrumented components from hydraJobs, so only run it
Expand Down
93 changes: 39 additions & 54 deletions src/libstore/aws-creds.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,51 +22,13 @@

namespace nix {

namespace {

// Global credential provider cache using boost's concurrent map
// Key: profile name (empty string for default profile)
using CredentialProviderCache =
boost::concurrent_flat_map<std::string, std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider>>;

static CredentialProviderCache credentialProviderCache;

/**
* Clear all cached credential providers.
* Called automatically by CrtWrapper destructor during static destruction.
*/
static void clearAwsCredentialsCache()
AwsAuthError::AwsAuthError(int errorCode)
: Error("AWS authentication error: '%s' (%d)", aws_error_str(errorCode), errorCode)
, errorCode(errorCode)
{
credentialProviderCache.clear();
}

static void initAwsCrt()
{
struct CrtWrapper
{
Aws::Crt::ApiHandle apiHandle;

CrtWrapper()
{
apiHandle.InitializeLogging(Aws::Crt::LogLevel::Warn, static_cast<FILE *>(nullptr));
}

~CrtWrapper()
{
try {
// CRITICAL: Clear credential provider cache BEFORE AWS CRT shuts down
// This ensures all providers (which hold references to ClientBootstrap)
// are destroyed while AWS CRT is still valid
clearAwsCredentialsCache();
// Now it's safe for ApiHandle destructor to run
} catch (...) {
ignoreExceptionInDestructor();
}
}
};

static CrtWrapper crt;
}
namespace {

static AwsCredentials getCredentialsFromProvider(std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider> provider)
{
Expand All @@ -79,8 +41,7 @@ static AwsCredentials getCredentialsFromProvider(std::shared_ptr<Aws::Crt::Auth:

provider->GetCredentials([prom](std::shared_ptr<Aws::Crt::Auth::Credentials> credentials, int errorCode) {
if (errorCode != 0 || !credentials) {
prom->set_exception(
std::make_exception_ptr(AwsAuthError("Failed to resolve AWS credentials: error code %d", errorCode)));
prom->set_exception(std::make_exception_ptr(AwsAuthError(errorCode)));
} else {
auto accessKeyId = Aws::Crt::ByteCursorToStringView(credentials->GetAccessKeyId());
auto secretAccessKey = Aws::Crt::ByteCursorToStringView(credentials->GetSecretAccessKey());
Expand Down Expand Up @@ -113,7 +74,35 @@ static AwsCredentials getCredentialsFromProvider(std::shared_ptr<Aws::Crt::Auth:

} // anonymous namespace

AwsCredentials getAwsCredentials(const std::string & profile)
class AwsCredentialProviderImpl : public AwsCredentialProvider
{
public:
AwsCredentialProviderImpl()
{
apiHandle.InitializeLogging(Aws::Crt::LogLevel::Warn, static_cast<FILE *>(nullptr));
}

AwsCredentials getCredentialsRaw(const std::string & profile);

AwsCredentials getCredentials(const ParsedS3URL & url) override
{
auto profile = url.profile.value_or("");
try {
return getCredentialsRaw(profile);
} catch (AwsAuthError & e) {
warn("AWS authentication failed for S3 request %s: %s", url.toHttpsUrl(), e.message());
credentialProviderCache.erase(profile);
throw;
}
}

private:
Aws::Crt::ApiHandle apiHandle;
boost::concurrent_flat_map<std::string, std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider>>
credentialProviderCache;
};

AwsCredentials AwsCredentialProviderImpl::getCredentialsRaw(const std::string & profile)
{
// Get or create credential provider with caching
std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider> provider;
Expand All @@ -132,8 +121,6 @@ AwsCredentials getAwsCredentials(const std::string & profile)
profile.empty() ? "(default)" : profile.c_str());

try {
initAwsCrt();

if (profile.empty()) {
Aws::Crt::Auth::CredentialsProviderChainDefaultConfig config;
config.Bootstrap = Aws::Crt::ApiHandle::GetOrCreateStaticDefaultClientBootstrap();
Expand Down Expand Up @@ -173,17 +160,15 @@ AwsCredentials getAwsCredentials(const std::string & profile)
return getCredentialsFromProvider(provider);
}

void invalidateAwsCredentials(const std::string & profile)
ref<AwsCredentialProvider> makeAwsCredentialsProvider()
{
credentialProviderCache.erase(profile);
return make_ref<AwsCredentialProviderImpl>();
}

AwsCredentials preResolveAwsCredentials(const ParsedS3URL & s3Url)
ref<AwsCredentialProvider> getAwsCredentialsProvider()
{
std::string profile = s3Url.profile.value_or("");

// Get credentials (automatically cached)
return getAwsCredentials(profile);
static auto instance = makeAwsCredentialsProvider();
return instance;
}

} // namespace nix
Expand Down
22 changes: 6 additions & 16 deletions src/libstore/filetransfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -883,22 +883,12 @@ void FileTransferRequest::setupForS3()
if (usernameAuth) {
debug("Using pre-resolved AWS credentials from parent process");
sessionToken = preResolvedAwsSessionToken;
} else {
std::string profile = parsedS3.profile.value_or("");
try {
auto creds = getAwsCredentials(profile);
usernameAuth = UsernameAuth{
.username = creds.accessKeyId,
.password = creds.secretAccessKey,
};
sessionToken = creds.sessionToken;
} catch (const AwsAuthError & e) {
warn("AWS authentication failed for S3 request %s: %s", uri, e.what());
// Invalidate the cached credentials so next request will retry
invalidateAwsCredentials(profile);
// Continue without authentication - might be a public bucket
return;
}
} else if (auto creds = getAwsCredentialsProvider()->maybeGetCredentials(parsedS3)) {
usernameAuth = UsernameAuth{
.username = creds->accessKeyId,
.password = creds->secretAccessKey,
};
sessionToken = creds->sessionToken;
}
if (sessionToken)
headers.emplace_back("x-amz-security-token", *sessionToken);
Expand Down
63 changes: 41 additions & 22 deletions src/libstore/include/nix/store/aws-creds.hh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#if NIX_WITH_AWS_AUTH

# include "nix/store/s3-url.hh"
# include "nix/util/ref.hh"
# include "nix/util/error.hh"

# include <memory>
Expand Down Expand Up @@ -33,35 +34,53 @@ struct AwsCredentials
}
};

/**
* Exception thrown when AWS authentication fails
*/
MakeError(AwsAuthError, Error);
class AwsAuthError : public Error
{
std::optional<int> errorCode;

/**
* Get AWS credentials for the given profile.
* This function automatically caches credential providers to avoid
* creating multiple providers for the same profile.
*
* @param profile The AWS profile name (empty string for default profile)
* @return AWS credentials
* @throws AwsAuthError if credentials cannot be resolved
*/
AwsCredentials getAwsCredentials(const std::string & profile = "");
public:
using Error::Error;
AwsAuthError(int errorCode);

std::optional<int> getErrorCode() const
{
return errorCode;
}
};

class AwsCredentialProvider
{
public:
/**
* Get AWS credentials for the given URL.
*
* @param url The S3 url to get the credentials for
* @return AWS credentials
* @throws AwsAuthError if credentials cannot be resolved
*/
virtual AwsCredentials getCredentials(const ParsedS3URL & url) = 0;

std::optional<AwsCredentials> maybeGetCredentials(const ParsedS3URL & url)
{
try {
return getCredentials(url);
} catch (AwsAuthError & e) {
return std::nullopt;
}
}

virtual ~AwsCredentialProvider() {}
};

/**
* Invalidate cached credentials for a profile (e.g., on authentication failure).
* The next request for this profile will create a new provider.
*
* @param profile The AWS profile name to invalidate
* Create a new instancee of AwsCredentialProvider.
*/
void invalidateAwsCredentials(const std::string & profile);
ref<AwsCredentialProvider> makeAwsCredentialsProvider();

/**
* Pre-resolve AWS credentials for S3 URLs.
* Used to cache credentials in parent process before forking.
* Get a reference to the global AwsCredentialProvider.
*/
AwsCredentials preResolveAwsCredentials(const ParsedS3URL & s3Url);
ref<AwsCredentialProvider> getAwsCredentialsProvider();

} // namespace nix
#endif
2 changes: 2 additions & 0 deletions src/libstore/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ curl_s3_store_opt = get_option('curl-s3-store').require(

if curl_s3_store_opt.enabled()
deps_other += aws_crt_cpp
aws_c_common = cxx.find_library('aws-c-common', required : true)
deps_other += aws_c_common
Comment on lines +161 to +162
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the aws_error_str in AwsAuthError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, very nice!

endif

configdata_pub.set('NIX_WITH_AWS_AUTH', curl_s3_store_opt.enabled().to_int())
Expand Down
2 changes: 1 addition & 1 deletion src/libstore/unix/build/derivation-builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ std::optional<AwsCredentials> DerivationBuilderImpl::preResolveAwsCredentials()
auto s3Url = ParsedS3URL::parse(parsedUrl);

// Use the preResolveAwsCredentials from aws-creds
auto credentials = nix::preResolveAwsCredentials(s3Url);
auto credentials = getAwsCredentialsProvider()->getCredentials(s3Url);
debug("Successfully pre-resolved AWS credentials in parent process");
return credentials;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/nixos/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ in

user-sandboxing = runNixOSTest ./user-sandboxing;

curl-s3-binary-cache-store = runNixOSTest ./curl-s3-binary-cache-store.nix;
s3-binary-cache-store = runNixOSTest ./s3-binary-cache-store.nix;

fsync = runNixOSTest ./fsync.nix;

Expand Down
Loading