diff --git a/src/libAtomVM/bif.c b/src/libAtomVM/bif.c index 8b2334759..46e37b5d6 100644 --- a/src/libAtomVM/bif.c +++ b/src/libAtomVM/bif.c @@ -806,6 +806,7 @@ term bif_erlang_sub_2(Context *ctx, uint32_t fail_label, int live, term arg1, te } } +// this function assumes that bigres_len is always <= bigres buffer capacity static term make_bigint(Context *ctx, uint32_t fail_label, uint32_t live, const intn_digit_t bigres[], size_t bigres_len, intn_integer_sign_t sign) { @@ -1716,6 +1717,10 @@ term bif_erlang_bsl_2(Context *ctx, uint32_t fail_label, int live, term arg1, te intn_digit_t bigres[INTN_MAX_RES_LEN]; size_t bigres_len = intn_bsl(m, m_len, b, bigres); + // this check is required in order to avoid out-of-bounds read in make_bigint + if (UNLIKELY(bigres_len > INTN_MAX_RES_LEN)) { + RAISE_ERROR_BIF(fail_label, OVERFLOW_ATOM); + } return make_bigint(ctx, fail_label, live, bigres, bigres_len, m_sign); diff --git a/src/libAtomVM/externalterm.c b/src/libAtomVM/externalterm.c index b1bb45129..cbbf2b30a 100644 --- a/src/libAtomVM/externalterm.c +++ b/src/libAtomVM/externalterm.c @@ -1011,9 +1011,10 @@ static int calculate_heap_usage(const uint8_t *external_term_buf, size_t remaini } // num_bytes > 8 bytes || uint64_does_overflow_int64 + size_t required_digits = intn_required_digits_for_unsigned_integer(num_bytes); size_t data_size; size_t unused_rounded_len; - term_intn_to_term_size(num_bytes, &data_size, &unused_rounded_len); + term_intn_to_term_size(required_digits, &data_size, &unused_rounded_len); return BOXED_INTN_SIZE(data_size); } diff --git a/src/libAtomVM/intn.c b/src/libAtomVM/intn.c index 11c395623..ff58f2cc7 100644 --- a/src/libAtomVM/intn.c +++ b/src/libAtomVM/intn.c @@ -61,11 +61,6 @@ static inline size_t pad_uint16_to_digits(uint16_t n16[], size_t n16_len) return n16_len; } -static inline size_t size_round_to(size_t n, size_t round_to) -{ - return (n + (round_to - 1)) & ~(round_to - 1); -} - /* * Multiplication */ @@ -439,18 +434,19 @@ size_t intn_divmnu(const intn_digit_t m[], size_t m_len, const intn_digit_t n[], return padded_q_len / UINT16_IN_A_DIGIT; } -// This function assumes no leading zeros (lenght is used in comparison) -// Caller must ensure this precondition int intn_cmp(const intn_digit_t a[], size_t a_len, const intn_digit_t b[], size_t b_len) { - if (a_len > b_len) { + size_t normal_a_len = intn_count_digits(a, a_len); + size_t normal_b_len = intn_count_digits(b, b_len); + + if (normal_a_len > normal_b_len) { return 1; } - if (a_len < b_len) { + if (normal_a_len < normal_b_len) { return -1; } - for (size_t i = a_len; i > 0; i--) { + for (size_t i = normal_a_len; i > 0; i--) { if (a[i - 1] > b[i - 1]) { return 1; } @@ -791,23 +787,21 @@ size_t intn_bnot(const intn_digit_t m[], size_t m_len, intn_integer_sign_t m_sig size_t intn_bsl(const intn_digit_t num[], size_t len, size_t n, intn_digit_t *out) { - size_t digit_bit_size = sizeof(uint32_t) * 8; - size_t digit_left_bit_shift = n % 32; size_t right_shift_n = (32 - digit_left_bit_shift); size_t counted_digits = intn_count_digits(num, len); size_t ms_digit_bits = 32 - uint32_nlz(num[counted_digits - 1]); - size_t effective_bits_len = (counted_digits - 1) * digit_bit_size + ms_digit_bits; - size_t new_bits_len = size_round_to(effective_bits_len + n, digit_bit_size); + size_t effective_bits_len = (counted_digits - 1) * INTN_DIGIT_BITS + ms_digit_bits; + size_t new_bits_len = size_align_up_pow2(effective_bits_len + n, INTN_DIGIT_BITS); - size_t new_digits_count = new_bits_len / digit_bit_size; + size_t new_digits_count = new_bits_len / INTN_DIGIT_BITS; if (new_digits_count > INTN_BSL_MAX_RES_LEN) { return new_digits_count; } - size_t initial_zeros = MIN(n / digit_bit_size, INTN_BSL_MAX_RES_LEN); + size_t initial_zeros = MIN(n / INTN_DIGIT_BITS, INTN_BSL_MAX_RES_LEN); memset(out, 0, initial_zeros * sizeof(uint32_t)); if (right_shift_n == 32) { @@ -837,15 +831,14 @@ size_t intn_bsl(const intn_digit_t num[], size_t len, size_t n, intn_digit_t *ou void bsru( const uint32_t num[], size_t effective_bits_len, size_t n, uint32_t last_digit, uint32_t *out) { - size_t digit_bit_size = sizeof(uint32_t) * 8; // 32 - - size_t digit_right_bit_shift = n % digit_bit_size; - size_t left_shift_n = (digit_bit_size - digit_right_bit_shift); + size_t digit_right_bit_shift = n % INTN_DIGIT_BITS; + size_t left_shift_n = (INTN_DIGIT_BITS - digit_right_bit_shift); - size_t len_in_digits = size_round_to(effective_bits_len, digit_bit_size) / digit_bit_size; + size_t len_in_digits + = size_align_up_pow2(effective_bits_len, INTN_DIGIT_BITS) / INTN_DIGIT_BITS; // caller makes sure that discarded < len_in_digits - size_t discarded = n / digit_bit_size; + size_t discarded = n / INTN_DIGIT_BITS; if (left_shift_n == 32) { memcpy(out, num + discarded, (len_in_digits - discarded) * sizeof(uint32_t)); @@ -868,17 +861,17 @@ void bsru( size_t intn_bsr( const intn_digit_t num[], size_t len, intn_integer_sign_t num_sign, size_t n, intn_digit_t *out) { - size_t digit_bit_size = sizeof(uint32_t) * 8; size_t counted_digits = intn_count_digits(num, len); size_t ms_digit_bits = 32 - uint32_nlz(num[counted_digits - 1]); - size_t effective_bits_len = (counted_digits - 1) * digit_bit_size + ms_digit_bits; + size_t effective_bits_len = (counted_digits - 1) * INTN_DIGIT_BITS + ms_digit_bits; if (n >= effective_bits_len) { out[0] = (num_sign == IntNPositiveInteger) ? 0 : 1; return 1; } - size_t shifted_len = size_round_to(effective_bits_len - n, digit_bit_size) / digit_bit_size; + size_t shifted_len + = size_align_up_pow2(effective_bits_len - n, INTN_DIGIT_BITS) / INTN_DIGIT_BITS; if (num_sign == IntNPositiveInteger) { bsru(num, effective_bits_len, n, 0, out); diff --git a/src/libAtomVM/intn.h b/src/libAtomVM/intn.h index 1722342d2..34c0a73e1 100644 --- a/src/libAtomVM/intn.h +++ b/src/libAtomVM/intn.h @@ -52,6 +52,7 @@ #define INTN_DIV_OUT_LEN(m, n) ((m) - (n) + 1 + 1) #define INTN_ABS_OUT_LEN(m) ((m) + 1) +#define INTN_DIGIT_BITS 32 #define INTN_MAX_UNSIGNED_BYTES_SIZE 32 #define INTN_MAX_UNSIGNED_BITS_SIZE 256 @@ -160,6 +161,11 @@ int intn_to_integer_bytes(const intn_digit_t in[], size_t in_len, intn_integer_s size_t intn_required_unsigned_integer_bytes(const intn_digit_t in[], size_t in_len); +static inline size_t intn_required_digits_for_unsigned_integer(size_t size_in_bytes) +{ + return size_align_up_pow2(size_in_bytes, sizeof(intn_digit_t)) / sizeof(intn_digit_t); +} + static inline intn_integer_sign_t intn_negate_sign(intn_integer_sign_t sign) { return (sign == IntNPositiveInteger) ? IntNNegativeInteger : IntNPositiveInteger; diff --git a/src/libAtomVM/utils.h b/src/libAtomVM/utils.h index 0265f3e04..d81b0d1cb 100644 --- a/src/libAtomVM/utils.h +++ b/src/libAtomVM/utils.h @@ -359,6 +359,96 @@ static inline __attribute__((always_inline)) func_ptr_t cast_void_to_func_ptr(vo #define MAXI(A, B) ((A > B) ? (A) : (B)) #define MINI(A, B) ((A > B) ? (B) : (A)) +/** + * @brief Align size up to power-of-2 boundary + * + * Rounds up a size value to the next multiple of a power-of-2 alignment. + * This function uses bit manipulation for efficient alignment calculation + * and is faster than the general-purpose \c size_align_up(). + * + * @param n Size value to align + * @param align Power-of-2 alignment boundary + * @return Size rounded up to next multiple of align + * + * @pre align must be a power of 2 (e.g., 2, 4, 8, 16, 32, ...) + * @warning Undefined behavior if align is not a power of 2 + * @warning Undefined behavior if align is 0 + * + * @note Result is always >= n + * + * @code + * size_t aligned = size_align_up_pow2(17, 8); // Returns 24 + * size_t aligned = size_align_up_pow2(16, 8); // Returns 16 (already aligned) + * @endcode + * + * @see size_align_up() for arbitrary alignment values + */ +static inline size_t size_align_up_pow2(size_t n, size_t align) +{ + return (n + (align - 1)) & ~(align - 1); +} + +/** + * @brief Align size up to arbitrary boundary + * + * Rounds up a size value to the next multiple of an alignment boundary. + * Works with any alignment value, not just powers of 2. + * + * @param n Size value to align + * @param align Alignment boundary (any positive value, or 0) + * @return Size rounded up to next multiple of align, or n if align is 0 + * + * @note Returns n unchanged if align is 0 (no alignment) + * @note Result is always >= n + * @note For power-of-2 alignments, \c size_align_up_pow2() is more efficient + * + * @code + * size_t aligned = size_align_up(17, 10); // Returns 20 + * size_t aligned = size_align_up(20, 10); // Returns 20 (already aligned) + * size_t aligned = size_align_up(17, 0); // Returns 17 (no alignment) + * @endcode + * + * @see size_align_up_pow2() for optimized power-of-2 alignment + * @see size_align_down() for rounding down instead of up + */ +static inline size_t size_align_up(size_t n, size_t align) +{ + if (align == 0) { + return n; + } + return ((n + align - 1) / align) * align; +} + +/** + * @brief Align size down to arbitrary boundary + * + * Rounds down a size value to the previous multiple of an alignment boundary. + * Works with any alignment value, not just powers of 2. + * + * @param n Size value to align + * @param align Alignment boundary (any positive value, or 0) + * @return Size rounded down to previous multiple of align, or n if align is 0 + * + * @note Returns n unchanged if align is 0 (no alignment) + * @note Result is always <= n + * @note Commonly used for finding aligned base addresses within buffers + * + * @code + * size_t aligned = size_align_down(17, 10); // Returns 10 + * size_t aligned = size_align_down(20, 10); // Returns 20 (already aligned) + * size_t aligned = size_align_down(7, 10); // Returns 0 + * @endcode + * + * @see size_align_up() for rounding up instead of down + */ +static inline size_t size_align_down(size_t n, size_t align) +{ + if (align == 0) { + return n; + } + return (n / align) * align; +} + /** * @brief Negate unsigned 32-bit value (\c uint32_t) to signed integer (\c int32_t) *