Skip to content

Commit

Permalink
feat: move to_radix to a blackbox (#6294)
Browse files Browse the repository at this point in the history
This PR moves to_radix to a Brillig-specific blackbox. The AVM won't
easily support field integer division, and the only usecase for field
integer division in regular noir code is to radix / to bits. We extract
to radix to a bb func so it can be directly integrated as a gadget in
the avm.
  • Loading branch information
sirasistant committed May 9, 2024
1 parent 95b499b commit ac27376
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 73 deletions.
74 changes: 69 additions & 5 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,6 @@ struct BlackBoxOp {
Program::HeapVector inputs;
Program::HeapArray iv;
Program::HeapArray key;
Program::MemoryAddress length;
Program::HeapVector outputs;

friend bool operator==(const AES128Encrypt&, const AES128Encrypt&);
Expand Down Expand Up @@ -896,6 +895,16 @@ struct BlackBoxOp {
static Sha256Compression bincodeDeserialize(std::vector<uint8_t>);
};

struct ToRadix {
Program::MemoryAddress input;
uint32_t radix;
Program::HeapArray output;

friend bool operator==(const ToRadix&, const ToRadix&);
std::vector<uint8_t> bincodeSerialize() const;
static ToRadix bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt,
Sha256,
Blake2s,
Expand All @@ -916,7 +925,8 @@ struct BlackBoxOp {
BigIntFromLeBytes,
BigIntToLeBytes,
Poseidon2Permutation,
Sha256Compression>
Sha256Compression,
ToRadix>
value;

friend bool operator==(const BlackBoxOp&, const BlackBoxOp&);
Expand Down Expand Up @@ -3939,9 +3949,6 @@ inline bool operator==(const BlackBoxOp::AES128Encrypt& lhs, const BlackBoxOp::A
if (!(lhs.key == rhs.key)) {
return false;
}
if (!(lhs.length == rhs.length)) {
return false;
}
if (!(lhs.outputs == rhs.outputs)) {
return false;
}
Expand Down Expand Up @@ -5141,6 +5148,63 @@ Program::BlackBoxOp::Sha256Compression serde::Deserializable<Program::BlackBoxOp

namespace Program {

inline bool operator==(const BlackBoxOp::ToRadix& lhs, const BlackBoxOp::ToRadix& rhs)
{
if (!(lhs.input == rhs.input)) {
return false;
}
if (!(lhs.radix == rhs.radix)) {
return false;
}
if (!(lhs.output == rhs.output)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BlackBoxOp::ToRadix::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxOp::ToRadix>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxOp::ToRadix BlackBoxOp::ToRadix::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxOp::ToRadix>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BlackBoxOp::ToRadix>::serialize(const Program::BlackBoxOp::ToRadix& obj,
Serializer& serializer)
{
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.radix)>::serialize(obj.radix, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
}

template <>
template <typename Deserializer>
Program::BlackBoxOp::ToRadix serde::Deserializable<Program::BlackBoxOp::ToRadix>::deserialize(
Deserializer& deserializer)
{
Program::BlackBoxOp::ToRadix obj;
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.radix = serde::Deserializable<decltype(obj.radix)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BlockId& lhs, const BlockId& rhs)
{
if (!(lhs.value == rhs.value)) {
Expand Down
56 changes: 55 additions & 1 deletion noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,17 @@ namespace Program {
static Sha256Compression bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt, Sha256, Blake2s, Blake3, Keccak256, Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, PedersenCommitment, PedersenHash, MultiScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntSub, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation, Sha256Compression> value;
struct ToRadix {
Program::MemoryAddress input;
uint32_t radix;
Program::HeapArray output;

friend bool operator==(const ToRadix&, const ToRadix&);
std::vector<uint8_t> bincodeSerialize() const;
static ToRadix bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt, Sha256, Blake2s, Blake3, Keccak256, Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, PedersenCommitment, PedersenHash, MultiScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntSub, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation, Sha256Compression, ToRadix> value;

friend bool operator==(const BlackBoxOp&, const BlackBoxOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4293,6 +4303,50 @@ Program::BlackBoxOp::Sha256Compression serde::Deserializable<Program::BlackBoxOp
return obj;
}

namespace Program {

inline bool operator==(const BlackBoxOp::ToRadix &lhs, const BlackBoxOp::ToRadix &rhs) {
if (!(lhs.input == rhs.input)) { return false; }
if (!(lhs.radix == rhs.radix)) { return false; }
if (!(lhs.output == rhs.output)) { return false; }
return true;
}

inline std::vector<uint8_t> BlackBoxOp::ToRadix::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxOp::ToRadix>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxOp::ToRadix BlackBoxOp::ToRadix::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxOp::ToRadix>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BlackBoxOp::ToRadix>::serialize(const Program::BlackBoxOp::ToRadix &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.radix)>::serialize(obj.radix, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
}

template <>
template <typename Deserializer>
Program::BlackBoxOp::ToRadix serde::Deserializable<Program::BlackBoxOp::ToRadix>::deserialize(Deserializer &deserializer) {
Program::BlackBoxOp::ToRadix obj;
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.radix = serde::Deserializable<decltype(obj.radix)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BlockId &lhs, const BlockId &rhs) {
Expand Down
5 changes: 5 additions & 0 deletions noir/noir-repo/acvm-repo/brillig/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,9 @@ pub enum BlackBoxOp {
hash_values: HeapVector,
output: HeapArray,
},
ToRadix {
input: MemoryAddress,
radix: u32,
output: HeapArray,
},
}
21 changes: 21 additions & 0 deletions noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use acvm_blackbox_solver::{
aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256,
keccakf1600, sha256, sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError,
};
use num_bigint::BigUint;

use crate::memory::MemoryValue;
use crate::Memory;
Expand Down Expand Up @@ -295,6 +296,25 @@ pub(crate) fn evaluate_black_box<Solver: BlackBoxFunctionSolver>(
memory.write_slice(memory.read_ref(output.pointer), &state);
Ok(())
}
BlackBoxOp::ToRadix { input, radix, output } => {
let input: FieldElement =
memory.read(*input).try_into().expect("ToRadix input not a field");

let mut input = BigUint::from_bytes_be(&input.to_be_bytes());
let radix = BigUint::from(*radix);

let mut limbs: Vec<MemoryValue> = Vec::with_capacity(output.size);

for _ in 0..output.size {
let limb = &input % &radix;
limbs.push(FieldElement::from_be_bytes_reduce(&limb.to_bytes_be()).into());
input /= &radix;
}

memory.write_slice(memory.read_ref(output.pointer), &limbs);

Ok(())
}
}
}

Expand All @@ -321,6 +341,7 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc {
BlackBoxOp::BigIntToLeBytes { .. } => BlackBoxFunc::BigIntToLeBytes,
BlackBoxOp::Poseidon2Permutation { .. } => BlackBoxFunc::Poseidon2Permutation,
BlackBoxOp::Sha256Compression { .. } => BlackBoxFunc::Sha256Compression,
BlackBoxOp::ToRadix { .. } => unreachable!("ToRadix is not an ACIR BlackBoxFunc"),
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,22 @@ impl<'block> BrilligBlock<'block> {
}
Value::Intrinsic(Intrinsic::ToRadix(endianness)) => {
let source = self.convert_ssa_single_addr_value(arguments[0], dfg);
let radix = self.convert_ssa_single_addr_value(arguments[1], dfg);
let limb_count = self.convert_ssa_single_addr_value(arguments[2], dfg);

let radix: u32 = dfg
.get_numeric_constant(arguments[1])
.expect("Radix should be known")
.try_to_u64()
.expect("Radix should fit in u64")
.try_into()
.expect("Radix should be u32");

let limb_count: usize = dfg
.get_numeric_constant(arguments[2])
.expect("Limb count should be known")
.try_to_u64()
.expect("Limb count should fit in u64")
.try_into()
.expect("Limb count should fit in usize");

let results = dfg.instruction_results(instruction_id);

Expand All @@ -511,7 +525,8 @@ impl<'block> BrilligBlock<'block> {
.extract_vector();

// Update the user-facing slice length
self.brillig_context.cast_instruction(target_len, limb_count);
self.brillig_context
.usize_const_instruction(target_len.address, limb_count.into());

self.brillig_context.codegen_to_radix(
source,
Expand All @@ -524,7 +539,13 @@ impl<'block> BrilligBlock<'block> {
}
Value::Intrinsic(Intrinsic::ToBits(endianness)) => {
let source = self.convert_ssa_single_addr_value(arguments[0], dfg);
let limb_count = self.convert_ssa_single_addr_value(arguments[1], dfg);
let limb_count: usize = dfg
.get_numeric_constant(arguments[1])
.expect("Limb count should be known")
.try_to_u64()
.expect("Limb count should fit in u64")
.try_into()
.expect("Limb count should fit in usize");

let results = dfg.instruction_results(instruction_id);

Expand All @@ -549,21 +570,18 @@ impl<'block> BrilligBlock<'block> {
BrilligVariable::SingleAddr(..) => unreachable!("ICE: ToBits on non-array"),
};

let radix = self.brillig_context.make_constant_instruction(2_usize.into(), 32);

// Update the user-facing slice length
self.brillig_context.cast_instruction(target_len, limb_count);
self.brillig_context
.usize_const_instruction(target_len.address, limb_count.into());

self.brillig_context.codegen_to_radix(
source,
target_vector,
radix,
2,
limb_count,
matches!(endianness, Endian::Big),
1,
);

self.brillig_context.deallocate_single_addr(radix);
}
_ => {
unreachable!("unsupported function call type {:?}", dfg[*func])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use acvm::FieldElement;

use crate::brillig::brillig_ir::BrilligBinaryOp;
use acvm::{
acir::brillig::{BlackBoxOp, HeapArray},
FieldElement,
};

use super::{
brillig_variable::{BrilligVector, SingleAddrVariable},
Expand Down Expand Up @@ -36,57 +37,46 @@ impl BrilligContext {
&mut self,
source_field: SingleAddrVariable,
target_vector: BrilligVector,
radix: SingleAddrVariable,
limb_count: SingleAddrVariable,
radix: u32,
limb_count: usize,
big_endian: bool,
limb_bit_size: u32,
) {
assert!(source_field.bit_size == FieldElement::max_num_bits());
assert!(radix.bit_size == 32);
assert!(limb_count.bit_size == 32);
let radix_as_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());
self.cast_instruction(radix_as_field, radix);

self.cast_instruction(SingleAddrVariable::new_usize(target_vector.size), limb_count);
self.usize_const_instruction(target_vector.size, limb_count.into());
self.usize_const_instruction(target_vector.rc, 1_usize.into());
self.codegen_allocate_array(target_vector.pointer, target_vector.size);

let shifted_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());
self.mov_instruction(shifted_field.address, source_field.address);
self.black_box_op_instruction(BlackBoxOp::ToRadix {
input: source_field.address,
radix,
output: HeapArray { pointer: target_vector.pointer, size: limb_count },
});

let limb_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());

let limb_casted = SingleAddrVariable::new(self.allocate_register(), limb_bit_size);

self.codegen_loop(target_vector.size, |ctx, iterator_register| {
// Compute the modulus
ctx.binary_instruction(
shifted_field,
radix_as_field,
limb_field,
BrilligBinaryOp::Modulo,
);
// Cast it
ctx.cast_instruction(limb_casted, limb_field);
// Write it
ctx.codegen_array_set(target_vector.pointer, iterator_register, limb_casted.address);
// Integer div the field
ctx.binary_instruction(
shifted_field,
radix_as_field,
shifted_field,
BrilligBinaryOp::UnsignedDiv,
);
});
if limb_bit_size != FieldElement::max_num_bits() {
self.codegen_loop(target_vector.size, |ctx, iterator_register| {
// Read the limb
ctx.codegen_array_get(target_vector.pointer, iterator_register, limb_field.address);
// Cast it
ctx.cast_instruction(limb_casted, limb_field);
// Write it
ctx.codegen_array_set(
target_vector.pointer,
iterator_register,
limb_casted.address,
);
});
}

// Deallocate our temporary registers
self.deallocate_single_addr(shifted_field);
self.deallocate_single_addr(limb_field);
self.deallocate_single_addr(limb_casted);
self.deallocate_single_addr(radix_as_field);

if big_endian {
self.codegen_reverse_vector_in_place(target_vector);
Expand Down
Loading

0 comments on commit ac27376

Please sign in to comment.