Permalink
Browse files

Initialize ciphers in their constructors for all platforms

1. In order to ensure every cipher is initialized even though it's never used.
2. Fix a bug in BCrypto/CryptoImpl.cpp for VS2013: initialize pointer m_keyHandle as nullptr.
3. Use const reference to extend lifetime of temporaries.
  • Loading branch information...
wps132230 committed Apr 25, 2018
1 parent 21741fd commit ded84836cd7bf15aa2375a6c1f7143f34d985df1
@@ -103,6 +103,55 @@ class MockSymmetricCipher : public SymmetricCipher
size_t m_finalizeDecryptionCalledCount;
};

class RealSymmetricCipher : public SymmetricCipher
{
public:
RealSymmetricCipher(std::shared_ptr<SymmetricCipher>&& cipher) : SymmetricCipher(),
m_resetCalledCount(0), m_encryptCalledCount(0), m_decryptCalledCount(0),
m_finalizeEncryptionCalledCount(0), m_finalizeDecryptionCalledCount(0), m_cipher(std::move(cipher))
{
}

CryptoBuffer EncryptBuffer(const CryptoBuffer& unEncryptedData) override
{
m_encryptCalledCount++;
return m_cipher->EncryptBuffer(unEncryptedData);
}

CryptoBuffer FinalizeEncryption() override
{
m_finalizeEncryptionCalledCount++;
return m_cipher->FinalizeEncryption();
}

CryptoBuffer DecryptBuffer(const CryptoBuffer& encryptedData) override
{
m_decryptCalledCount++;
return m_cipher->DecryptBuffer(encryptedData);
}

CryptoBuffer FinalizeDecryption() override
{
m_finalizeDecryptionCalledCount++;
return m_cipher->FinalizeDecryption();
}

void Reset() override
{
m_resetCalledCount++;
m_cipher->Reset();
}

size_t m_resetCalledCount;
size_t m_encryptCalledCount;
size_t m_decryptCalledCount;
size_t m_finalizeEncryptionCalledCount;
size_t m_finalizeDecryptionCalledCount;

private:
std::shared_ptr<SymmetricCipher> m_cipher;
};

static const char* TEST_RESPONSE_1 = "BLAH_1_BLAH_1_BLAH_1_BLAH_1_BLAH_1_B";
static const char* TEST_RESPONSE_2 = "BLAH_2_BLAH_2_BLAH_2_BLAH_2_BLAH_2_B";
static const char* TEST_RESPONSE_FINAL = "BLAH_FIN_BLAH_FIN_BLAH_FIN_BLAH_FIN";
@@ -134,6 +183,119 @@ static Aws::String ComputePartialOutput()
return str;
}

using CipherCreateImplementationFunction = std::shared_ptr<SymmetricCipher>(*)(const CryptoBuffer&);

static void TestCiphersNeverUsedSrc(const CipherCreateImplementationFunction& createCipherFunction, const CryptoBuffer& key, CipherMode cipherMode)
{
std::istringstream is;

auto cipher = RealSymmetricCipher(createCipherFunction(key));

SymmetricCryptoStream stream(is, cipherMode, cipher, Aws::Utils::Crypto::DEFAULT_BUF_SIZE);

ASSERT_EQ(0u, cipher.m_encryptCalledCount);
ASSERT_EQ(0u, cipher.m_decryptCalledCount);
ASSERT_EQ(0u, cipher.m_finalizeEncryptionCalledCount);
ASSERT_EQ(0u, cipher.m_finalizeDecryptionCalledCount);
}

static void TestCiphersNeverUsedSinkDestructorFinalizes(const CipherCreateImplementationFunction& createCipherFunction, const CryptoBuffer& key, CipherMode cipherMode)
{
std::ostringstream os;

auto cipher = RealSymmetricCipher(createCipherFunction(key));

{
SymmetricCryptoStream stream(os, cipherMode, cipher, Aws::Utils::Crypto::DEFAULT_BUF_SIZE);
}

if (cipherMode == CipherMode::Encrypt)
{
ASSERT_EQ(1u, cipher.m_finalizeEncryptionCalledCount);
ASSERT_EQ(0u, cipher.m_finalizeDecryptionCalledCount);
}
else
{
ASSERT_EQ(0u, cipher.m_finalizeEncryptionCalledCount);
ASSERT_EQ(1u, cipher.m_finalizeDecryptionCalledCount);
}
}

static void TestCiphersNeverUsedSinkExplicitFinalize(const CipherCreateImplementationFunction& createCipherFunction, const CryptoBuffer& key, CipherMode cipherMode)
{
std::ostringstream os;

auto cipher = RealSymmetricCipher(createCipherFunction(key));

SymmetricCryptoStream stream(os, cipherMode, cipher, Aws::Utils::Crypto::DEFAULT_BUF_SIZE);
stream.Finalize();

if (cipherMode == CipherMode::Encrypt)
{
ASSERT_EQ(1u, cipher.m_finalizeEncryptionCalledCount);
ASSERT_EQ(0u, cipher.m_finalizeDecryptionCalledCount);
}
else
{
ASSERT_EQ(0u, cipher.m_finalizeEncryptionCalledCount);
ASSERT_EQ(1u, cipher.m_finalizeDecryptionCalledCount);
}
}

TEST(CryptoStreamsTest, TestCiphersNeverUsedSrc)
{
CryptoBuffer key = SymmetricCipher::GenerateKey();

TestCiphersNeverUsedSrc(CipherCreateImplementationFunction(CreateAES_CBCImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSrc(CipherCreateImplementationFunction(CreateAES_CBCImplementation), key, CipherMode::Decrypt);
TestCiphersNeverUsedSrc(CipherCreateImplementationFunction(CreateAES_CTRImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSrc(CipherCreateImplementationFunction(CreateAES_CTRImplementation), key, CipherMode::Decrypt);
#ifndef ENABLE_COMMONCRYPTO_ENCRYPTION
TestCiphersNeverUsedSrc(CipherCreateImplementationFunction(CreateAES_GCMImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSrc(CipherCreateImplementationFunction(CreateAES_GCMImplementation), key, CipherMode::Decrypt);
#endif
Aws::String kek = "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F";
CryptoBuffer kek_raw = HashingUtils::HexDecode(kek);
TestCiphersNeverUsedSrc(CipherCreateImplementationFunction(CreateAES_KeyWrapImplementation), kek_raw, CipherMode::Encrypt);
TestCiphersNeverUsedSrc(CipherCreateImplementationFunction(CreateAES_KeyWrapImplementation), kek_raw, CipherMode::Decrypt);
}

TEST(CryptoStreamsTest, TestCiphersNeverUsedSinkDestructorFinalizes)
{
CryptoBuffer key = SymmetricCipher::GenerateKey();

TestCiphersNeverUsedSinkDestructorFinalizes(CipherCreateImplementationFunction(CreateAES_CBCImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSinkDestructorFinalizes(CipherCreateImplementationFunction(CreateAES_CBCImplementation), key, CipherMode::Decrypt);
TestCiphersNeverUsedSinkDestructorFinalizes(CipherCreateImplementationFunction(CreateAES_CTRImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSinkDestructorFinalizes(CipherCreateImplementationFunction(CreateAES_CTRImplementation), key, CipherMode::Decrypt);
#ifndef ENABLE_COMMONCRYPTO_ENCRYPTION
TestCiphersNeverUsedSinkDestructorFinalizes(CipherCreateImplementationFunction(CreateAES_GCMImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSinkDestructorFinalizes(CipherCreateImplementationFunction(CreateAES_GCMImplementation), key, CipherMode::Decrypt);
#endif
Aws::String kek = "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F";
CryptoBuffer kek_raw = HashingUtils::HexDecode(kek);
TestCiphersNeverUsedSinkDestructorFinalizes(CipherCreateImplementationFunction(CreateAES_KeyWrapImplementation), kek_raw, CipherMode::Encrypt);
TestCiphersNeverUsedSinkDestructorFinalizes(CipherCreateImplementationFunction(CreateAES_KeyWrapImplementation), kek_raw, CipherMode::Decrypt);
}

TEST(CryptoStreamsTest, TestUninitializedCiphersSinkExplicitFinalize)
{
CryptoBuffer key = SymmetricCipher::GenerateKey();

TestCiphersNeverUsedSinkExplicitFinalize(CipherCreateImplementationFunction(CreateAES_CBCImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSinkExplicitFinalize(CipherCreateImplementationFunction(CreateAES_CBCImplementation), key, CipherMode::Decrypt);
TestCiphersNeverUsedSinkExplicitFinalize(CipherCreateImplementationFunction(CreateAES_CTRImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSinkExplicitFinalize(CipherCreateImplementationFunction(CreateAES_CTRImplementation), key, CipherMode::Decrypt);
#ifndef ENABLE_COMMONCRYPTO_ENCRYPTION
TestCiphersNeverUsedSinkExplicitFinalize(CipherCreateImplementationFunction(CreateAES_GCMImplementation), key, CipherMode::Encrypt);
TestCiphersNeverUsedSinkExplicitFinalize(CipherCreateImplementationFunction(CreateAES_GCMImplementation), key, CipherMode::Decrypt);
#endif
Aws::String kek = "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F";
CryptoBuffer kek_raw = HashingUtils::HexDecode(kek);
TestCiphersNeverUsedSinkExplicitFinalize(CipherCreateImplementationFunction(CreateAES_KeyWrapImplementation), kek_raw, CipherMode::Encrypt);
TestCiphersNeverUsedSinkExplicitFinalize(CipherCreateImplementationFunction(CreateAES_KeyWrapImplementation), kek_raw, CipherMode::Decrypt);
}

TEST(CryptoStreamsTest, TestEncryptSrcStreamEvenBoundaries)
{
std::istringstream is(ORIGINAL_SRC);
@@ -42,19 +42,36 @@ namespace Aws
* ivGenerationInCtrMode, if true, initializes the iv with a 4 byte counter at the end.
*/
SymmetricCipher(const CryptoBuffer& key, size_t ivSize, bool ivGenerationInCtrMode = false) :
m_key(key), m_initializationVector(ivSize > 0 ? GenerateIV(ivSize, ivGenerationInCtrMode) : 0), m_failure(false) { Validate(); }
m_key(key),
m_initializationVector(ivSize > 0 ? GenerateIV(ivSize, ivGenerationInCtrMode) : 0),
m_failure(false)
{
Validate();
}

/**
* Initialize with key and initializationVector, set tag for decryption of authenticated modes (makes copies of the buffers)
*/
SymmetricCipher(const CryptoBuffer& key, const CryptoBuffer& initializationVector, const CryptoBuffer& tag = CryptoBuffer(0)) :
m_key(key), m_initializationVector(initializationVector), m_tag(tag), m_failure(false) { Validate(); }
m_key(key),
m_initializationVector(initializationVector),
m_tag(tag),
m_failure(false)
{
Validate();
}

/**
* Initialize with key and initializationVector, set tag for decryption of authenticated modes (move the buffers)
*/
SymmetricCipher(CryptoBuffer&& key, CryptoBuffer&& initializationVector, CryptoBuffer&& tag = CryptoBuffer(0)) :
m_key(std::move(key)), m_initializationVector(std::move(initializationVector)), m_tag(std::move(tag)), m_failure(false) { Validate(); }
m_key(std::move(key)),
m_initializationVector(std::move(initializationVector)),
m_tag(std::move(tag)),
m_failure(false)
{
Validate();
}

SymmetricCipher(const SymmetricCipher& other) = delete;
SymmetricCipher& operator=(const SymmetricCipher& other) = delete;
@@ -232,31 +232,20 @@ namespace Aws
void Reset() override;

protected:
/**
* Algorithm/Mode level config for the BCRYPT_ALG_HANDLE and BCRYPT_KEY_HANDLE
*/
virtual void InitEncryptor_Internal() = 0;
virtual void InitDecryptor_Internal() = 0;
void InitKey();
virtual size_t GetBlockSizeBytes() const = 0;
virtual size_t GetKeyLengthBits() const = 0;

void CheckInitEncryptor();
void CheckInitDecryptor();

BCRYPT_ALG_HANDLE m_algHandle;
BCRYPT_KEY_HANDLE m_keyHandle;
DWORD m_flags;
CryptoBuffer m_workingIv;
PBCRYPT_AUTHENTICATED_CIPHER_MODE_INFO m_authInfoPtr;
bool m_encDecInitialized;
bool m_encryptionMode;
bool m_decryptionMode;

static BCRYPT_KEY_HANDLE ImportKeyBlob(BCRYPT_ALG_HANDLE handle, CryptoBuffer& key);

private:
void Init();
void InitKey();
void Cleanup();
};

@@ -295,12 +284,11 @@ namespace Aws
void Reset() override;

protected:
void InitEncryptor_Internal() override;
void InitDecryptor_Internal() override;
size_t GetBlockSizeBytes() const override;
size_t GetKeyLengthBits() const override;

private:
void InitCipher();
CryptoBuffer FillInOverflow(const CryptoBuffer& buffer);

CryptoBuffer m_blockOverflow;
@@ -345,13 +333,12 @@ namespace Aws
void Reset() override;

protected:
void InitEncryptor_Internal() override;
void InitDecryptor_Internal() override;

size_t GetBlockSizeBytes() const override;
size_t GetKeyLengthBits() const override;

private:
void InitCipher();

static void InitBuffersToNull(Aws::Vector<ByteBuffer*>& initBuffers);
static void CleanupBuffers(Aws::Vector<ByteBuffer*>& cleanupBuffers);

@@ -400,9 +387,6 @@ namespace Aws
void Reset() override;

protected:
void InitEncryptor_Internal() override;
void InitDecryptor_Internal() override;

size_t GetBlockSizeBytes() const override;
size_t GetKeyLengthBits() const override;
size_t GetTagLengthBytes() const;
@@ -447,13 +431,12 @@ namespace Aws
void Reset() override;

protected:
void InitEncryptor_Internal() override;
void InitDecryptor_Internal() override;

size_t GetBlockSizeBytes() const override;
size_t GetKeyLengthBits() const override;

private:
void InitCipher();

static size_t BlockSizeBytes;
static size_t KeyLengthBits;

@@ -164,32 +164,15 @@ namespace Aws
void Reset() override;

protected:
/**
* Algorithm/Mode level config for the EVP_CIPHER_CTX
*/
virtual void InitEncryptor_Internal() = 0;

/**
* Algorithm/Mode level config for the EVP_CIPHER_CTX
*/
virtual void InitDecryptor_Internal() = 0;

virtual size_t GetBlockSizeBytes() const = 0;

virtual size_t GetKeyLengthBits() const = 0;

void CheckInitEncryptor();

void CheckInitDecryptor();

_CCCryptor* m_cryptoHandle;
_CCCryptor* m_encryptorHandle;
_CCCryptor* m_decryptorHandle;

private:
void Init();

bool m_encDecInitialized;
bool m_encryptionMode;
bool m_decryptionMode;
};

/**
@@ -220,15 +203,13 @@ namespace Aws
AES_CBC_Cipher_CommonCrypto(AES_CBC_Cipher_CommonCrypto&& toMove) = default;

protected:
void InitEncryptor_Internal() override;

void InitDecryptor_Internal() override;

size_t GetBlockSizeBytes() const override;

size_t GetKeyLengthBits() const override;

private:
void InitCipher();

static size_t BlockSizeBytes;
static size_t KeyLengthBits;
};
@@ -262,15 +243,13 @@ namespace Aws
AES_CTR_Cipher_CommonCrypto(AES_CTR_Cipher_CommonCrypto&& toMove) = default;

protected:
void InitEncryptor_Internal() override;

void InitDecryptor_Internal() override;

size_t GetBlockSizeBytes() const override;

size_t GetKeyLengthBits() const override;

private:
void InitCipher();

static size_t BlockSizeBytes;
static size_t KeyLengthBits;
};
@@ -304,10 +283,6 @@ namespace Aws
void Reset() override;

protected:
void InitEncryptor_Internal() override {};

void InitDecryptor_Internal() override {};

inline size_t GetBlockSizeBytes() const override { return BlockSizeBytes; }

inline size_t GetKeyLengthBits() const override { return KeyLengthBits; }
Oops, something went wrong.

0 comments on commit ded8483

Please sign in to comment.