Skip to content

Commit

Permalink
Fix memory read for stack allocated buffers.
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanHenson committed Feb 8, 2024
1 parent 95bd82e commit 5983351
Showing 1 changed file with 118 additions and 149 deletions.
267 changes: 118 additions & 149 deletions source/arm/crc64_arm.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,176 +49,145 @@
# endif // defined(__ARM_FEATURE_SHA3)

/** Compute CRC64XZ using ARMv8 NEON +crypto/pmull64 instructions. */
uint64_t aws_checksums_crc64xz_arm_pmull(const uint8_t *input, int length, const uint64_t previousCrc64) {
uint64_t aws_checksums_crc64xz_arm_pmull(const uint8_t *input, int length, const uint64_t previous_crc64) {
if (!input || length <= 0) {
return previousCrc64;
return previous_crc64;
}

// the amount of complexity required to handle vector instructions on
// memory regions smaller than an xmm register does not justify the very negligible performance gains
// we would get for using it on an input this small.
if (length < 16) {
return aws_checksums_crc64xz_sw(input, length, previous_crc64);
}

// Invert the previous crc bits and load into the lower half of a neon register
poly64x2_t a1 = vreinterpretq_p64_u64(vcombine_u64(vcreate_u64(~previousCrc64), vcreate_u64(0)));
poly64x2_t a1 = vreinterpretq_p64_u64(vcombine_u64(vcreate_u64(~previous_crc64), vcreate_u64(0)));

// Load the x^128 and x^192 constants - they'll (very likely) be needed
const poly64x2_t x128 = load_p64(aws_checksums_crc64xz_constants.x128);

if (length < 16) {
// Neon register loads are 16 bytes at once, so for lengths less than 16 we need to
// carefully load from memory to prevent reading beyond the end of the input buffer
int alignment = (intptr_t)input & 15;
if (alignment + length <= 16) {
// The input falls in a single 16 byte segment so we load from a 16 byte aligned address
// The input data will be loaded "into the middle" of the neon register
// Right shift the input data register to eliminate any leading bytes and move the data to the least
// significant bytes, then mask out the most significant bytes that may contain garbage
uint8x16_t load = mask_low_u8(right_shift_u8(load_u8(input - alignment), alignment), length);
// XOR the masked input data with the previous crc
a1 = xor_p64(a1, vreinterpretq_p64_u8(load));
} else {
// The input spans two 16 byte segments so it's safe to load the input from its actual starting address
// The input data will be in the least significant bytes of the neon register
// Mask out the most significant bytes that may contain garbage
uint8x16_t load = mask_low_u8(load_u8(input), length);
// XOR the masked input data with the previous crc
a1 = xor_p64(a1, vreinterpretq_p64_u8(load));
}

if (length <= 8) {
// For 8 or less bytes of input just left shift to effectively multiply by x^64
a1 = left_shift_p64(a1, 8 - length);
} else {
// For 8-15 bytes of input we can't left shift without losing the most significant bytes
// We need to fold the lower half of the crc register into the upper half
a1 = left_shift_p64(a1, 16 - length);
// Multiply the lower half of the crc register by x^128 (swapping upper and lower halves)
// XOR the result with the right shifted upper half of the crc
a1 = xor_p64(right_shift_imm_p64(a1, 8), pmull_lo(a1, vextq_p64(x128, x128, 1)));
}
// Load the next 16 bytes of input and XOR with the previous crc
a1 = xor_p64(a1, load_p64_u8(input));
input += 16;
length -= 16;

// Fall through to Barrett modular reduction
if (length < 112) {

} else { // There are 16 or more bytes of input
const poly64x2_t x256 = load_p64(aws_checksums_crc64xz_constants.x256);

// Load the next 16 bytes of input and XOR with the previous crc
a1 = xor_p64(a1, load_p64_u8(input));
input += 16;
length -= 16;

if (length < 112) {

const poly64x2_t x256 = load_p64(aws_checksums_crc64xz_constants.x256);

if (length & 64) {
// Fold the current crc register with 64 bytes of input by multiplying 64-bit chunks by x^576 through
// x^128
const poly64x2_t x512 = load_p64(aws_checksums_crc64xz_constants.x512);
const poly64x2_t x384 = load_p64(aws_checksums_crc64xz_constants.x384);
poly64x2_t b1 = load_p64_u8(input + 0);
poly64x2_t c1 = load_p64_u8(input + 16);
poly64x2_t d1 = load_p64_u8(input + 32);
poly64x2_t e1 = load_p64_u8(input + 48);
a1 = xor3_p64(pmull_lo(x512, a1), pmull_hi(x512, a1), pmull_lo(x384, b1));
b1 = xor3_p64(pmull_hi(x384, b1), pmull_lo(x256, c1), pmull_hi(x256, c1));
c1 = xor3_p64(pmull_lo(x128, d1), pmull_hi(x128, d1), e1);
a1 = xor3_p64(a1, b1, c1);
input += 64;
}

if (length & 32) {
// Fold the current running value with 32 bytes of input by multiplying 64-bit chunks by x^320 through
// x^128
poly64x2_t b1 = load_p64_u8(input + 0);
poly64x2_t c1 = load_p64_u8(input + 16);
a1 = xor3_p64(c1, pmull_lo(x256, a1), pmull_hi(x256, a1));
a1 = xor3_p64(a1, pmull_lo(x128, b1), pmull_hi(x128, b1));
input += 32;
}
} else { // There are 112 or more bytes of input

const poly64x2_t x1024 = load_p64(aws_checksums_crc64xz_constants.x1024);

// Load another 112 bytes of input
if (length & 64) {
// Fold the current crc register with 64 bytes of input by multiplying 64-bit chunks by x^576 through
// x^128
const poly64x2_t x512 = load_p64(aws_checksums_crc64xz_constants.x512);
const poly64x2_t x384 = load_p64(aws_checksums_crc64xz_constants.x384);
poly64x2_t b1 = load_p64_u8(input + 0);
poly64x2_t c1 = load_p64_u8(input + 16);
poly64x2_t d1 = load_p64_u8(input + 32);
poly64x2_t e1 = load_p64_u8(input + 48);
poly64x2_t f1 = load_p64_u8(input + 64);
poly64x2_t g1 = load_p64_u8(input + 80);
poly64x2_t h1 = load_p64_u8(input + 96);
input += 112;
length -= 112;

// Spin through additional chunks of 128 bytes, if any
int loops = length / 128;
while (loops--) {
// Fold input values in parallel by multiplying by x^1088 and x^1024 constants
a1 = xor3_p64(pmull_lo(x1024, a1), pmull_hi(x1024, a1), load_p64_u8(input + 0));
b1 = xor3_p64(pmull_lo(x1024, b1), pmull_hi(x1024, b1), load_p64_u8(input + 16));
c1 = xor3_p64(pmull_lo(x1024, c1), pmull_hi(x1024, c1), load_p64_u8(input + 32));
d1 = xor3_p64(pmull_lo(x1024, d1), pmull_hi(x1024, d1), load_p64_u8(input + 48));
e1 = xor3_p64(pmull_lo(x1024, e1), pmull_hi(x1024, e1), load_p64_u8(input + 64));
f1 = xor3_p64(pmull_lo(x1024, f1), pmull_hi(x1024, f1), load_p64_u8(input + 80));
g1 = xor3_p64(pmull_lo(x1024, g1), pmull_hi(x1024, g1), load_p64_u8(input + 96));
h1 = xor3_p64(pmull_lo(x1024, h1), pmull_hi(x1024, h1), load_p64_u8(input + 112));
input += 128;
}

// Fold 128 bytes down to 64 bytes by multiplying by the x^576 and x^512 constants
const poly64x2_t x512 = load_p64(aws_checksums_crc64xz_constants.x512);
a1 = xor3_p64(e1, pmull_lo(x512, a1), pmull_hi(x512, a1));
b1 = xor3_p64(f1, pmull_lo(x512, b1), pmull_hi(x512, b1));
c1 = xor3_p64(g1, pmull_lo(x512, c1), pmull_hi(x512, c1));
d1 = xor3_p64(h1, pmull_lo(x512, d1), pmull_hi(x512, d1));

if (length & 64) {
// Fold the current 64 bytes with 64 bytes of input by multiplying by x^576 and x^512 constants
a1 = xor3_p64(pmull_lo(x512, a1), pmull_hi(x512, a1), load_p64_u8(input + 0));
b1 = xor3_p64(pmull_lo(x512, b1), pmull_hi(x512, b1), load_p64_u8(input + 16));
c1 = xor3_p64(pmull_lo(x512, c1), pmull_hi(x512, c1), load_p64_u8(input + 32));
d1 = xor3_p64(pmull_lo(x512, d1), pmull_hi(x512, d1), load_p64_u8(input + 48));
input += 64;
}

// Fold 64 bytes down to 32 bytes by multiplying by the x^320 and x^256 constants
const poly64x2_t x256 = load_p64(aws_checksums_crc64xz_constants.x256);
a1 = xor3_p64(c1, pmull_lo(x256, a1), pmull_hi(x256, a1));
b1 = xor3_p64(d1, pmull_lo(x256, b1), pmull_hi(x256, b1));

if (length & 32) {
// Fold the current running value with 32 bytes of input by multiplying by x^320 and x^256 constants
a1 = xor3_p64(pmull_lo(x256, a1), pmull_hi(x256, a1), load_p64_u8(input + 0));
b1 = xor3_p64(pmull_lo(x256, b1), pmull_hi(x256, b1), load_p64_u8(input + 16));
input += 32;
}
a1 = xor3_p64(pmull_lo(x512, a1), pmull_hi(x512, a1), pmull_lo(x384, b1));
b1 = xor3_p64(pmull_hi(x384, b1), pmull_lo(x256, c1), pmull_hi(x256, c1));
c1 = xor3_p64(pmull_lo(x128, d1), pmull_hi(x128, d1), e1);
a1 = xor3_p64(a1, b1, c1);
input += 64;
}

// Fold 32 bytes down to 16 bytes by multiplying by x^192 and x^128 constants
a1 = xor3_p64(b1, pmull_lo(x128, a1), pmull_hi(x128, a1));
if (length & 32) {
// Fold the current running value with 32 bytes of input by multiplying 64-bit chunks by x^320 through
// x^128
poly64x2_t b1 = load_p64_u8(input + 0);
poly64x2_t c1 = load_p64_u8(input + 16);
a1 = xor3_p64(c1, pmull_lo(x256, a1), pmull_hi(x256, a1));
a1 = xor3_p64(a1, pmull_lo(x128, b1), pmull_hi(x128, b1));
input += 32;
}
} else { // There are 112 or more bytes of input

const poly64x2_t x1024 = load_p64(aws_checksums_crc64xz_constants.x1024);

// Load another 112 bytes of input
poly64x2_t b1 = load_p64_u8(input + 0);
poly64x2_t c1 = load_p64_u8(input + 16);
poly64x2_t d1 = load_p64_u8(input + 32);
poly64x2_t e1 = load_p64_u8(input + 48);
poly64x2_t f1 = load_p64_u8(input + 64);
poly64x2_t g1 = load_p64_u8(input + 80);
poly64x2_t h1 = load_p64_u8(input + 96);
input += 112;
length -= 112;

// Spin through additional chunks of 128 bytes, if any
int loops = length / 128;
while (loops--) {
// Fold input values in parallel by multiplying by x^1088 and x^1024 constants
a1 = xor3_p64(pmull_lo(x1024, a1), pmull_hi(x1024, a1), load_p64_u8(input + 0));
b1 = xor3_p64(pmull_lo(x1024, b1), pmull_hi(x1024, b1), load_p64_u8(input + 16));
c1 = xor3_p64(pmull_lo(x1024, c1), pmull_hi(x1024, c1), load_p64_u8(input + 32));
d1 = xor3_p64(pmull_lo(x1024, d1), pmull_hi(x1024, d1), load_p64_u8(input + 48));
e1 = xor3_p64(pmull_lo(x1024, e1), pmull_hi(x1024, e1), load_p64_u8(input + 64));
f1 = xor3_p64(pmull_lo(x1024, f1), pmull_hi(x1024, f1), load_p64_u8(input + 80));
g1 = xor3_p64(pmull_lo(x1024, g1), pmull_hi(x1024, g1), load_p64_u8(input + 96));
h1 = xor3_p64(pmull_lo(x1024, h1), pmull_hi(x1024, h1), load_p64_u8(input + 112));
input += 128;
}

if (length & 16) {
// Fold the current 16 bytes with 16 bytes of input by multiplying by x^192 and x^128 constants
a1 = xor3_p64(pmull_lo(x128, a1), pmull_hi(x128, a1), load_p64_u8(input + 0));
input += 16;
// Fold 128 bytes down to 64 bytes by multiplying by the x^576 and x^512 constants
const poly64x2_t x512 = load_p64(aws_checksums_crc64xz_constants.x512);
a1 = xor3_p64(e1, pmull_lo(x512, a1), pmull_hi(x512, a1));
b1 = xor3_p64(f1, pmull_lo(x512, b1), pmull_hi(x512, b1));
c1 = xor3_p64(g1, pmull_lo(x512, c1), pmull_hi(x512, c1));
d1 = xor3_p64(h1, pmull_lo(x512, d1), pmull_hi(x512, d1));

if (length & 64) {
// Fold the current 64 bytes with 64 bytes of input by multiplying by x^576 and x^512 constants
a1 = xor3_p64(pmull_lo(x512, a1), pmull_hi(x512, a1), load_p64_u8(input + 0));
b1 = xor3_p64(pmull_lo(x512, b1), pmull_hi(x512, b1), load_p64_u8(input + 16));
c1 = xor3_p64(pmull_lo(x512, c1), pmull_hi(x512, c1), load_p64_u8(input + 32));
d1 = xor3_p64(pmull_lo(x512, d1), pmull_hi(x512, d1), load_p64_u8(input + 48));
input += 64;
}

// There must only be 0-15 bytes of input left
length &= 15;

if (length == 0) {
// Multiply the lower half of the crc register by x^128 (swapping upper and lower halves)
poly64x2_t mul_by_x128 = pmull_lo(a1, vextq_p64(x128, x128, 1));
// XOR the result with the right shifted upper half of the crc
a1 = xor_p64(right_shift_imm_p64(a1, 8), mul_by_x128);
} else {
// Handle any trailing input from 1-15 bytes
const poly64x2_t trailing_constants = load_p64(aws_checksums_crc64xz_constants.trailing[length - 1]);
// Multiply the crc by a pair of trailing length constants in order to fold it into the trailing input
a1 = xor_p64(pmull_lo(a1, trailing_constants), pmull_hi(a1, trailing_constants));
// Safely load ending at the last byte of trailing input and mask out any leading garbage
poly64x2_t trailing_input = mask_high_p64(load_p64_u8(input + length - 16), length);
// Multiply the lower half of the trailing input register by x^128 (swapping x^192 and x^128 halves)
poly64x2_t mul_by_x128 = pmull_lo(trailing_input, vextq_p64(x128, x128, 1));
// XOR the results with the right shifted upper half of the trailing input
a1 = xor3_p64(a1, right_shift_imm_p64(trailing_input, 8), mul_by_x128);
// Fold 64 bytes down to 32 bytes by multiplying by the x^320 and x^256 constants
const poly64x2_t x256 = load_p64(aws_checksums_crc64xz_constants.x256);
a1 = xor3_p64(c1, pmull_lo(x256, a1), pmull_hi(x256, a1));
b1 = xor3_p64(d1, pmull_lo(x256, b1), pmull_hi(x256, b1));

if (length & 32) {
// Fold the current running value with 32 bytes of input by multiplying by x^320 and x^256 constants
a1 = xor3_p64(pmull_lo(x256, a1), pmull_hi(x256, a1), load_p64_u8(input + 0));
b1 = xor3_p64(pmull_lo(x256, b1), pmull_hi(x256, b1), load_p64_u8(input + 16));
input += 32;
}

// Fold 32 bytes down to 16 bytes by multiplying by x^192 and x^128 constants
a1 = xor3_p64(b1, pmull_lo(x128, a1), pmull_hi(x128, a1));
}

if (length & 16) {
// Fold the current 16 bytes with 16 bytes of input by multiplying by x^192 and x^128 constants
a1 = xor3_p64(pmull_lo(x128, a1), pmull_hi(x128, a1), load_p64_u8(input + 0));
input += 16;
}

// There must only be 0-15 bytes of input left
length &= 15;

if (length == 0) {
// Multiply the lower half of the crc register by x^128 (swapping upper and lower halves)
poly64x2_t mul_by_x128 = pmull_lo(a1, vextq_p64(x128, x128, 1));
// XOR the result with the right shifted upper half of the crc
a1 = xor_p64(right_shift_imm_p64(a1, 8), mul_by_x128);
} else {
// Handle any trailing input from 1-15 bytes
const poly64x2_t trailing_constants = load_p64(aws_checksums_crc64xz_constants.trailing[length - 1]);
// Multiply the crc by a pair of trailing length constants in order to fold it into the trailing input
a1 = xor_p64(pmull_lo(a1, trailing_constants), pmull_hi(a1, trailing_constants));
// Safely load ending at the last byte of trailing input and mask out any leading garbage
poly64x2_t trailing_input = mask_high_p64(load_p64_u8(input + length - 16), length);
// Multiply the lower half of the trailing input register by x^128 (swapping x^192 and x^128 halves)
poly64x2_t mul_by_x128 = pmull_lo(trailing_input, vextq_p64(x128, x128, 1));
// XOR the results with the right shifted upper half of the trailing input
a1 = xor3_p64(a1, right_shift_imm_p64(trailing_input, 8), mul_by_x128);
}

// Barrett modular reduction
Expand Down

0 comments on commit 5983351

Please sign in to comment.