diff --git a/src/modules/surjection/main_impl.h b/src/modules/surjection/main_impl.h index f57ddba1d..c67d4c0d3 100644 --- a/src/modules/surjection/main_impl.h +++ b/src/modules/surjection/main_impl.h @@ -56,7 +56,7 @@ int secp256k1_surjectionproof_parse(const secp256k1_context* ctx, secp256k1_surj } signature_len = 32 * (1 + secp256k1_count_bits_set(&input[2], (n_inputs + 7) / 8)); - if (inputlen < 2 + (n_inputs + 7) / 8 + signature_len) { + if (inputlen != 2 + (n_inputs + 7) / 8 + signature_len) { return 0; } proof->n_inputs = n_inputs; diff --git a/src/modules/surjection/tests_impl.h b/src/modules/surjection/tests_impl.h index 08742e149..a0856e229 100644 --- a/src/modules/surjection/tests_impl.h +++ b/src/modules/surjection/tests_impl.h @@ -331,6 +331,7 @@ static void test_gen_verify(size_t n_inputs, size_t n_used) { unsigned char seed[32]; secp256k1_surjectionproof proof; unsigned char serialized_proof[SECP256K1_SURJECTIONPROOF_SERIALIZATION_BYTES_MAX]; + unsigned char serialized_proof_trailing[SECP256K1_SURJECTIONPROOF_SERIALIZATION_BYTES_MAX + 1]; size_t serialized_len = SECP256K1_SURJECTIONPROOF_SERIALIZATION_BYTES_MAX; secp256k1_fixed_asset_tag fixed_input_tags[1000]; secp256k1_generator ephemeral_input_tags[1000]; @@ -376,6 +377,12 @@ static void test_gen_verify(size_t n_inputs, size_t n_used) { CHECK(secp256k1_surjectionproof_serialize(ctx, serialized_proof, &serialized_len, &proof)); CHECK(serialized_len == secp256k1_surjectionproof_serialized_size(ctx, &proof)); CHECK(serialized_len == SECP256K1_SURJECTIONPROOF_SERIALIZATION_BYTES(n_inputs, n_used)); + + /* trailing garbage */ + memcpy(&serialized_proof_trailing, &serialized_proof, serialized_len); + serialized_proof_trailing[serialized_len] = seed[0]; + CHECK(secp256k1_surjectionproof_parse(ctx, &proof, serialized_proof, serialized_len + 1) == 0); + CHECK(secp256k1_surjectionproof_parse(ctx, &proof, serialized_proof, serialized_len)); result = secp256k1_surjectionproof_verify(ctx, &proof, ephemeral_input_tags, n_inputs, &ephemeral_input_tags[n_inputs]); CHECK(result == 1);