diff --git a/core/src/lib/tls_openssl_private.cc b/core/src/lib/tls_openssl_private.cc index ff72db7afc9..333eec709bd 100644 --- a/core/src/lib/tls_openssl_private.cc +++ b/core/src/lib/tls_openssl_private.cc @@ -37,6 +37,7 @@ /* static private */ std::map TlsOpenSslPrivate::psk_client_credentials_; +std::mutex TlsOpenSslPrivate::psk_client_credentials_mutex_; /* static private */ /* No anonymous ciphers, no <128 bit ciphers, no export ciphers, no MD5 ciphers */ @@ -70,7 +71,9 @@ TlsOpenSslPrivate::~TlsOpenSslPrivate() /* the openssl_ctx object is the factory that creates * openssl objects, so delete this at the end */ if (openssl_ctx_) { + psk_client_credentials_mutex_.lock(); psk_client_credentials_.erase(openssl_ctx_); + psk_client_credentials_mutex_.unlock(); SSL_CTX_free(openssl_ctx_); openssl_ctx_ = nullptr; } @@ -368,8 +371,10 @@ void TlsOpenSslPrivate::ClientContextInsertCredentials(const PskCredentials &cre if (!openssl_ctx_) { /* do not register nullptr */ Dmsg0(100, "Psk Server Callback: No SSL_CTX\n"); } else { + psk_client_credentials_mutex_.lock(); TlsOpenSslPrivate::psk_client_credentials_.insert( std::pair(openssl_ctx_, credentials)); + psk_client_credentials_mutex_.unlock(); } } @@ -431,26 +436,35 @@ unsigned int TlsOpenSslPrivate::psk_client_cb(SSL *ssl, return 0; } - if (psk_client_credentials_.find(openssl_ctx) == psk_client_credentials_.end()) { + PskCredentials credentials; + bool found = false; + + psk_client_credentials_mutex_.lock(); + if (psk_client_credentials_.find(openssl_ctx) != psk_client_credentials_.end()) { + credentials = TlsOpenSslPrivate::psk_client_credentials_.at(openssl_ctx); + found = true; + } + psk_client_credentials_mutex_.unlock(); + + if (!found) { Dmsg0(100, "Error, TLS-PSK CALLBACK not set because SSL_CTX is not registered.\n"); - } else { - const PskCredentials &credentials = TlsOpenSslPrivate::psk_client_credentials_.at(openssl_ctx); - int ret = Bsnprintf(identity, max_identity_len, "%s", credentials.get_identity().c_str()); + return 0; + } - if (ret < 0 || (unsigned int)ret > max_identity_len) { - Dmsg0(100, "Error, identify too long\n"); - return 0; - } - Dmsg1(100, "psk_client_cb. identity: %s.\n", identity); + int ret = Bsnprintf(identity, max_identity_len, "%s", credentials.get_identity().c_str()); - ret = Bsnprintf((char *)psk, max_psk_len, "%s", credentials.get_psk().c_str()); - if (ret < 0 || (unsigned int)ret > max_psk_len) { - Dmsg0(100, "Error, psk too long\n"); - return 0; - } - return ret; - } - return 0; + if (ret < 0 || (unsigned int)ret > max_identity_len) { + Dmsg0(100, "Error, identify too long\n"); + return 0; + } + Dmsg1(100, "psk_client_cb. identity: %s.\n", identity); + + ret = Bsnprintf((char *)psk, max_psk_len, "%s", credentials.get_psk().c_str()); + if (ret < 0 || (unsigned int)ret > max_psk_len) { + Dmsg0(100, "Error, psk too long\n"); + return 0; + } + return ret; } /* diff --git a/core/src/lib/tls_openssl_private.h b/core/src/lib/tls_openssl_private.h index 777c5f2e034..14aa22086d0 100644 --- a/core/src/lib/tls_openssl_private.h +++ b/core/src/lib/tls_openssl_private.h @@ -63,6 +63,7 @@ class TlsOpenSslPrivate { /* PskCredentials lookup map for all connections */ static std::map psk_client_credentials_; + static std::mutex psk_client_credentials_mutex_; /* tls_default_ciphers_ if no user ciphers given */ static const std::string tls_default_ciphers_; diff --git a/core/src/lib/tls_psk_credentials.h b/core/src/lib/tls_psk_credentials.h index 34f87ed54a4..702d9e44332 100644 --- a/core/src/lib/tls_psk_credentials.h +++ b/core/src/lib/tls_psk_credentials.h @@ -45,7 +45,7 @@ class PskCredentials PskCredentials &operator = (const PskCredentials &rhs) { identity_ = rhs.identity_; - psk_ = rhs.identity_; + psk_ = rhs.psk_; return *this; }