Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow brillig to read arrays directly from memory #4460

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,15 @@ struct BrilligInputs {
static Array bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array> value;
struct MemoryArray {
Circuit::BlockId value;

friend bool operator==(const MemoryArray&, const MemoryArray&);
std::vector<uint8_t> bincodeSerialize() const;
static MemoryArray bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array, MemoryArray> value;

friend bool operator==(const BrilligInputs&, const BrilligInputs&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4923,6 +4931,53 @@ Circuit::BrilligInputs::Array serde::Deserializable<Circuit::BrilligInputs::Arra

namespace Circuit {

inline bool operator==(const BrilligInputs::MemoryArray& lhs, const BrilligInputs::MemoryArray& rhs)
{
if (!(lhs.value == rhs.value)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BrilligInputs::MemoryArray::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligInputs::MemoryArray>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligInputs::MemoryArray BrilligInputs::MemoryArray::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligInputs::MemoryArray>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BrilligInputs::MemoryArray>::serialize(const Circuit::BrilligInputs::MemoryArray& obj,
Serializer& serializer)
{
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Circuit::BrilligInputs::MemoryArray serde::Deserializable<Circuit::BrilligInputs::MemoryArray>::deserialize(
Deserializer& deserializer)
{
Circuit::BrilligInputs::MemoryArray obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode& lhs, const BrilligOpcode& rhs)
{
if (!(lhs.value == rhs.value)) {
Expand Down
48 changes: 47 additions & 1 deletion noir/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,15 @@ namespace Circuit {
static Array bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array> value;
struct MemoryArray {
Circuit::BlockId value;

friend bool operator==(const MemoryArray&, const MemoryArray&);
std::vector<uint8_t> bincodeSerialize() const;
static MemoryArray bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array, MemoryArray> value;

friend bool operator==(const BrilligInputs&, const BrilligInputs&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4090,6 +4098,44 @@ Circuit::BrilligInputs::Array serde::Deserializable<Circuit::BrilligInputs::Arra
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligInputs::MemoryArray &lhs, const BrilligInputs::MemoryArray &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> BrilligInputs::MemoryArray::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligInputs::MemoryArray>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligInputs::MemoryArray BrilligInputs::MemoryArray::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligInputs::MemoryArray>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BrilligInputs::MemoryArray>::serialize(const Circuit::BrilligInputs::MemoryArray &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Circuit::BrilligInputs::MemoryArray serde::Deserializable<Circuit::BrilligInputs::MemoryArray>::deserialize(Deserializer &deserializer) {
Circuit::BrilligInputs::MemoryArray obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode &lhs, const BrilligOpcode &rhs) {
Expand Down
2 changes: 2 additions & 0 deletions noir/acvm-repo/acir/src/circuit/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::native_types::{Expression, Witness};
use brillig::Opcode as BrilligOpcode;
use serde::{Deserialize, Serialize};
use super::opcodes::BlockId;

/// Inputs for the Brillig VM. These are the initial inputs
/// that the Brillig VM will use to start.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
pub enum BrilligInputs {
Single(Expression),
Array(Vec<Expression>),
MemoryArray(BlockId)
}

/// Outputs for the Brillig VM. Once the VM has completed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::native_types::{Expression, Witness};
use serde::{Deserialize, Serialize};

#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Copy, Default)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Copy, Default)]
pub struct BlockId(pub u32);

/// Operation on a block of memory
Expand Down
15 changes: 12 additions & 3 deletions noir/acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::HashMap;

use acir::{
brillig::{ForeignCallParam, ForeignCallResult, Value},
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
OpcodeLocation,
brillig::{Brillig, BrilligInputs, BrilligOutputs}, opcodes::BlockId, OpcodeLocation
},
native_types::WitnessMap,
FieldElement,
Expand All @@ -12,7 +13,7 @@ use brillig_vm::{VMStatus, VM};

use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError};

use super::{get_value, insert_value};
use super::{get_value, insert_value, memory_op::MemoryOpSolver};

#[derive(Debug)]
pub enum BrilligSolverStatus {
Expand Down Expand Up @@ -64,6 +65,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
/// witness.
pub(super) fn new(
initial_witness: &WitnessMap,
memory: &HashMap<BlockId, MemoryOpSolver>,
brillig: &'b Brillig,
bb_solver: &'b B,
acir_index: usize,
Expand Down Expand Up @@ -96,6 +98,13 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
}
}
}
},
BrilligInputs::MemoryArray(block_id) => {
let memory_block = memory.get(block_id).ok_or(OpcodeNotSolvable::MissingMemoryBlock(block_id.0))?;
for memory_index in 0..memory_block.block_len {
let memory_value = memory_block.block_value.get(&memory_index).expect("All memory is initialized on creation");
calldata.push((*memory_value).into());
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions noir/acvm-repo/acvm/src/pwg/memory_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ type MemoryIndex = u32;
/// Maintains the state for solving [`MemoryInit`][`acir::circuit::Opcode::MemoryInit`] and [`MemoryOp`][`acir::circuit::Opcode::MemoryOp`] opcodes.
#[derive(Default)]
pub(super) struct MemoryOpSolver {
block_value: HashMap<MemoryIndex, FieldElement>,
block_len: u32,
pub(super) block_value: HashMap<MemoryIndex, FieldElement>,
pub(super) block_len: u32,
}

impl MemoryOpSolver {
Expand Down
6 changes: 4 additions & 2 deletions noir/acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ pub enum StepResult<'a, B: BlackBoxFunctionSolver> {
pub enum OpcodeNotSolvable {
#[error("missing assignment for witness index {0}")]
MissingAssignment(u32),
#[error("Attempted to load uninitialized memory block")]
MissingMemoryBlock(u32),
#[error("expression has too many unknowns {0}")]
ExpressionHasTooManyUnknowns(Expression),
}
Expand Down Expand Up @@ -336,7 +338,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
// there will be a cached `BrilligSolver` to avoid recomputation.
let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() {
Some(solver) => solver,
None => BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)?,
None => BrilligSolver::new(witness, &self.block_solvers, brillig, self.backend, self.instruction_pointer)?,
};
match solver.solve()? {
BrilligSolverStatus::ForeignCallWait(foreign_call) => {
Expand Down Expand Up @@ -371,7 +373,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
return StepResult::Status(self.handle_opcode_resolution(resolution));
}

let solver = BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer);
let solver = BrilligSolver::new(witness, &self.block_solvers, brillig, self.backend, self.instruction_pointer);
match solver {
Ok(solver) => StepResult::IntoBrillig(solver),
Err(..) => StepResult::Status(self.handle_opcode_resolution(solver.map(|_| ()))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1453,10 +1453,8 @@ impl AcirContext {
}
Ok(BrilligInputs::Array(var_expressions))
}
AcirValue::DynamicArray(_) => {
let mut var_expressions = Vec::new();
self.brillig_array_input(&mut var_expressions, i)?;
Ok(BrilligInputs::Array(var_expressions))
AcirValue::DynamicArray(AcirDynamicArray { block_id,.. }) => {
Ok(BrilligInputs::MemoryArray(block_id))
}
}
})?;
Expand Down Expand Up @@ -1870,6 +1868,9 @@ fn execute_brillig(code: &[BrilligOpcode], inputs: &[BrilligInputs]) -> Option<V
calldata.push(expr.to_const()?.into());
}
}
BrilligInputs::MemoryArray(_) => {
return None;
}
}
}

Expand Down
Loading