diff --git a/crypto/s2n_hmac.c b/crypto/s2n_hmac.c index 7c432d4ab4f..7bfa1d7ef5e 100644 --- a/crypto/s2n_hmac.c +++ b/crypto/s2n_hmac.c @@ -84,8 +84,10 @@ static int s2n_sslv3_mac_digest(struct s2n_hmac_state *state, void *out, uint32_ int s2n_hmac_init(struct s2n_hmac_state *state, s2n_hmac_algorithm alg, const void *key, uint32_t klen) { s2n_hash_algorithm hash_alg = S2N_HASH_NONE; + state->currently_in_hash_block = 0; state->digest_size = 0; state->block_size = 64; + state->hash_block_size = 64; switch (alg) { case S2N_HMAC_NONE: @@ -116,11 +118,13 @@ int s2n_hmac_init(struct s2n_hmac_state *state, s2n_hmac_algorithm alg, const vo hash_alg = S2N_HASH_SHA384; state->digest_size = SHA384_DIGEST_LENGTH; state->block_size = 128; + state->hash_block_size = 128; break; case S2N_HMAC_SHA512: hash_alg = S2N_HASH_SHA512; state->digest_size = SHA512_DIGEST_LENGTH; state->block_size = 128; + state->hash_block_size = 128; break; default: S2N_ERROR(S2N_ERR_HMAC_INVALID_ALGORITHM); @@ -168,6 +172,10 @@ int s2n_hmac_init(struct s2n_hmac_state *state, s2n_hmac_algorithm alg, const vo int s2n_hmac_update(struct s2n_hmac_state *state, const void *in, uint32_t size) { + /* Keep track of how much of the current hash block is full */ + state->currently_in_hash_block += (128000 + size) % state->hash_block_size; + state->currently_in_hash_block %= state->block_size; + return s2n_hash_update(&state->inner, in, size); } @@ -185,6 +193,24 @@ int s2n_hmac_digest(struct s2n_hmac_state *state, void *out, uint32_t size) return s2n_hash_digest(&state->outer, out, size); } +int s2n_hmac_digest_two_compression_rounds(struct s2n_hmac_state *state, void *out, uint32_t size) +{ + GUARD(s2n_hmac_digest(state, out, size)); + + /* If there were 8 or more bytes of space left in the current hash block + * then the serialized length will have fit in that block. If there were + * fewer than 8 then adding the length will have caused an extra compression + * block round. This digest function always does two compression rounds, + * even if there is no need for the second. + */ + if (state->currently_in_hash_block > (state->hash_block_size - 8)) + { + return 0; + } + + return s2n_hash_update(&state->inner, state->xor_pad, state->hash_block_size); +} + int s2n_hmac_reset(struct s2n_hmac_state *state) { memcpy_check(&state->inner, &state->inner_just_key, sizeof(state->inner)); diff --git a/crypto/s2n_hmac.h b/crypto/s2n_hmac.h index 450a93ec47b..f38f912ac16 100644 --- a/crypto/s2n_hmac.h +++ b/crypto/s2n_hmac.h @@ -26,6 +26,8 @@ typedef enum { S2N_HMAC_NONE, S2N_HMAC_MD5, S2N_HMAC_SHA1, S2N_HMAC_SHA224, S2N_ struct s2n_hmac_state { s2n_hmac_algorithm alg; + uint16_t hash_block_size; + uint32_t currently_in_hash_block; uint16_t block_size; uint8_t digest_size; @@ -45,6 +47,7 @@ extern int s2n_hmac_digest_size(s2n_hmac_algorithm alg); extern int s2n_hmac_init(struct s2n_hmac_state *state, s2n_hmac_algorithm alg, const void *key, uint32_t klen); extern int s2n_hmac_update(struct s2n_hmac_state *state, const void *in, uint32_t size); extern int s2n_hmac_digest(struct s2n_hmac_state *state, void *out, uint32_t size); +extern int s2n_hmac_digest_two_compression_rounds(struct s2n_hmac_state *state, void *out, uint32_t size); extern int s2n_hmac_digest_verify(const void *a, uint32_t alen, const void *b, uint32_t blen); extern int s2n_hmac_reset(struct s2n_hmac_state *state); extern int s2n_hmac_copy(struct s2n_hmac_state *to, struct s2n_hmac_state *from); diff --git a/tls/s2n_cbc.c b/tls/s2n_cbc.c index c705c05fcf5..0ed7ad5e3ae 100644 --- a/tls/s2n_cbc.c +++ b/tls/s2n_cbc.c @@ -25,22 +25,6 @@ #include "tls/s2n_record.h" #include "tls/s2n_prf.h" -static uint8_t masks[256][255]; - -int s2n_cbc_masks_init() -{ - /* We have 256 different 255-byte sized masks for checking padding. 0's indicate where we would expect - * payload or MAC data to be. 0xff's indicate where we expected padding bytes, or the padding length - * byte to be. - */ - for (int i = 0; i < 256; i++) { - memset_check(&masks[i][0], 0, 255 - i); - memset_check(&masks[i][255 - i], 0xFF, i); - } - - return 0; -} - /* A TLS CBC record looks like .. * * [ Payload data ] [ HMAC ] [ Padding ] [ Padding length byte ] @@ -81,7 +65,7 @@ int s2n_verify_cbc(struct s2n_connection *conn, struct s2n_hmac_state *hmac, str /* Check the MAC */ uint8_t check_digest[S2N_MAX_DIGEST_LEN]; lte_check(mac_digest_size, sizeof(check_digest)); - GUARD(s2n_hmac_digest(hmac, check_digest, mac_digest_size)); + GUARD(s2n_hmac_digest_two_compression_rounds(hmac, check_digest, mac_digest_size)); int mismatches = s2n_constant_time_equals(decrypted->data + payload_length, check_digest, mac_digest_size) ^ 1; @@ -94,15 +78,15 @@ int s2n_verify_cbc(struct s2n_connection *conn, struct s2n_hmac_state *hmac, str } /* Check the padding */ - uint8_t *mask = masks[ padding_length ]; - int check = 255; if (check > payload_and_padding_size) { check = payload_and_padding_size; } - for (int i = 255 - check, j = decrypted->size - check; i < 255 && j < decrypted->size; i++, j++) { - mismatches |= (decrypted->data[j] ^ padding_length) & mask[i]; + int cutoff = check - padding_length; + for (int i = 0, j = decrypted->size - check; i < check && j < decrypted->size; i++, j++) { + uint8_t mask = ~(0xff << ((i >= cutoff) * 8)); + mismatches |= (decrypted->data[j] ^ padding_length) & mask; } if (mismatches) { diff --git a/tls/s2n_connection.c b/tls/s2n_connection.c index 365516a6a66..0bd2d7d188b 100644 --- a/tls/s2n_connection.c +++ b/tls/s2n_connection.c @@ -220,6 +220,8 @@ int s2n_connection_wipe(struct s2n_connection *conn) GUARD(s2n_hash_init(&conn->handshake.server_md5, S2N_HASH_MD5)); GUARD(s2n_hash_init(&conn->handshake.server_sha1, S2N_HASH_SHA1)); GUARD(s2n_hash_init(&conn->handshake.server_sha256, S2N_HASH_SHA256)); + GUARD(s2n_hmac_init(&conn->client->client_record_mac, S2N_HMAC_NONE, NULL, 0)); + GUARD(s2n_hmac_init(&conn->server->server_record_mac, S2N_HMAC_NONE, NULL, 0)); memcpy_check(&conn->alert_in, &alert_in, sizeof(struct s2n_stuffer)); memcpy_check(&conn->reader_alert_out, &reader_alert_out, sizeof(struct s2n_stuffer)); diff --git a/tls/s2n_record.h b/tls/s2n_record.h index 8c06001f724..8a65da26349 100644 --- a/tls/s2n_record.h +++ b/tls/s2n_record.h @@ -24,6 +24,5 @@ extern int s2n_record_write(struct s2n_connection *conn, uint8_t content_type, s extern int s2n_record_parse(struct s2n_connection *conn); extern int s2n_record_header_parse(struct s2n_connection *conn, uint8_t *content_type, uint16_t *fragment_length); extern int s2n_sslv2_record_header_parse(struct s2n_connection *conn, uint8_t *record_type, uint8_t *client_protocol_version, uint16_t *fragment_length); -extern int s2n_cbc_masks_init(); extern int s2n_verify_cbc(struct s2n_connection *conn, struct s2n_hmac_state *hmac, struct s2n_blob *decrypted); extern int s2n_aead_aad_init(const struct s2n_connection *conn, uint8_t *sequence_number, uint8_t content_type, uint16_t record_length, struct s2n_stuffer *ad); diff --git a/utils/s2n_random.c b/utils/s2n_random.c index 411215072d6..fae92851e77 100644 --- a/utils/s2n_random.c +++ b/utils/s2n_random.c @@ -31,8 +31,6 @@ #include "error/s2n_errno.h" -#include "tls/s2n_record.h" - #include "utils/s2n_safety.h" #include "utils/s2n_random.h" @@ -197,9 +195,6 @@ int s2n_init(void) S2N_ERROR(S2N_ERR_OPEN_RANDOM); } - /* Create the CBC masks */ - GUARD(s2n_cbc_masks_init()); - #if defined(MAP_INHERIT_ZERO) if ((zero_if_forked_ptr = mmap(NULL, sizeof(int), PROT_READ|PROT_WRITE, MAP_ANON|MAP_PRIVATE, -1, 0)) == MAP_FAILED) {