From 775a4344402e393235d9043f87b8c36cbe8d1dbf Mon Sep 17 00:00:00 2001 From: Matthew Ryan Date: Fri, 28 Dec 2018 15:18:40 +0000 Subject: [PATCH] Add a ppx for snarky with_label (#1362) * Add extremely basic with_label ppx This expands [%with_label expr] to with_label (expr ^ ": " ^ Pervasives.__LOC__) * Implement snarkydef ppx * Use let%snarkydef in place of toplevel let _ = with_label __LOC__ _ * Revert let%snarkydef change for let (type _) and let (module _) * Add support for let%snarkydef x (type y) ... = ... * Use let%snarkydef in a few more places --- src/lib/blockchain_snark/blockchain_state.ml | 123 +++--- src/lib/blockchain_snark/jbuild | 2 +- src/lib/coda_base/blockchain_state.ml | 24 +- src/lib/coda_base/data_hash.ml | 15 +- src/lib/coda_base/jbuild | 2 +- src/lib/coda_base/ledger_hash.ml | 28 +- src/lib/coda_base/payment_payload.ml | 9 +- src/lib/coda_base/schnorr.ml | 52 ++- src/lib/coda_base/transition_system.ml | 145 ++++--- src/lib/consensus/jbuild | 2 +- src/lib/consensus/proof_of_work.ml.ignore | 3 +- src/lib/non_zero_curve_point/jbuild | 2 +- .../non_zero_curve_point.ml | 7 +- src/lib/ppx_snarky/jbuild | 8 + src/lib/ppx_snarky/ppx_snarky.ml | 1 + src/lib/ppx_snarky/snarkydef.ml | 62 +++ src/lib/signature_lib/checked.ml | 52 ++- src/lib/signature_lib/jbuild | 2 +- src/lib/snark_bits/bits.ml | 27 +- src/lib/snark_bits/jbuild | 2 +- src/lib/snarky/src/curves.ml | 394 +++++++++--------- src/lib/snarky/src/jbuild | 2 +- src/lib/snarky/src/snark0.ml | 26 +- src/lib/transaction_snark/jbuild | 2 +- .../transaction_snark/transaction_snark.ml | 297 +++++++------ src/ppx_snarky.opam | 6 + 26 files changed, 659 insertions(+), 636 deletions(-) create mode 100644 src/lib/ppx_snarky/jbuild create mode 100644 src/lib/ppx_snarky/ppx_snarky.ml create mode 100644 src/lib/ppx_snarky/snarkydef.ml create mode 100644 src/ppx_snarky.opam diff --git a/src/lib/blockchain_snark/blockchain_state.ml b/src/lib/blockchain_snark/blockchain_state.ml index fc771b7fd20..3a1b34fdff0 100644 --- a/src/lib/blockchain_snark/blockchain_state.ml +++ b/src/lib/blockchain_snark/blockchain_state.ml @@ -46,80 +46,75 @@ module Make (Consensus_mechanism : Consensus.Mechanism.S) : transition consensus data is valid new consensus state is a function of the old consensus state *) - let update + let%snarkydef update ((previous_state_hash, previous_state) : State_hash.var * Protocol_state.var) (transition : Snark_transition.var) : ( State_hash.var * Protocol_state.var * [`Success of Boolean.var] , _ ) Tick.Checked.t = - with_label __LOC__ - (let supply_increase = Snark_transition.supply_increase transition in - let%bind `Success updated_consensus_state, consensus_state = - Consensus_mechanism.next_state_checked ~prev_state:previous_state - ~prev_state_hash:previous_state_hash transition supply_increase - in - let%bind success = - let%bind correct_transaction_snark = - T.verify_complete_merge - (Snark_transition.sok_digest transition) - ( previous_state |> Protocol_state.blockchain_state - |> Blockchain_state.ledger_hash ) - ( transition |> Snark_transition.blockchain_state - |> Blockchain_state.ledger_hash ) - supply_increase - (As_prover.return - (Option.value ~default:Tock.Proof.dummy - (Snark_transition.ledger_proof transition))) - and ledger_hash_didn't_change = - Frozen_ledger_hash.equal_var - ( previous_state |> Protocol_state.blockchain_state - |> Blockchain_state.ledger_hash ) - ( transition |> Snark_transition.blockchain_state - |> Blockchain_state.ledger_hash ) - in - let%bind correct_snark = - Boolean.(correct_transaction_snark || ledger_hash_didn't_change) - in - Boolean.(correct_snark && updated_consensus_state) - in - let new_state = - Protocol_state.create_var ~previous_state_hash - ~blockchain_state:(Snark_transition.blockchain_state transition) - ~consensus_state - in - let%bind state_triples = Protocol_state.var_to_triples new_state in - let%bind state_partial = - Pedersen.Checked.Section.extend Pedersen.Checked.Section.empty - ~start:Hash_prefix.length_in_triples state_triples - in - let%map state_hash = - Pedersen.Checked.Section.create - ~acc:(`Value Hash_prefix.protocol_state.acc) - ~support: - (Interval_union.of_interval (0, Hash_prefix.length_in_triples)) - |> Pedersen.Checked.Section.disjoint_union_exn state_partial - >>| Pedersen.Checked.Section.to_initial_segment_digest_exn >>| fst - in - ( State_hash.var_of_hash_packed state_hash - , new_state - , `Success success )) + let supply_increase = Snark_transition.supply_increase transition in + let%bind `Success updated_consensus_state, consensus_state = + Consensus_mechanism.next_state_checked ~prev_state:previous_state + ~prev_state_hash:previous_state_hash transition supply_increase + in + let%bind success = + let%bind correct_transaction_snark = + T.verify_complete_merge + (Snark_transition.sok_digest transition) + ( previous_state |> Protocol_state.blockchain_state + |> Blockchain_state.ledger_hash ) + ( transition |> Snark_transition.blockchain_state + |> Blockchain_state.ledger_hash ) + supply_increase + (As_prover.return + (Option.value ~default:Tock.Proof.dummy + (Snark_transition.ledger_proof transition))) + and ledger_hash_didn't_change = + Frozen_ledger_hash.equal_var + ( previous_state |> Protocol_state.blockchain_state + |> Blockchain_state.ledger_hash ) + ( transition |> Snark_transition.blockchain_state + |> Blockchain_state.ledger_hash ) + in + let%bind correct_snark = + Boolean.(correct_transaction_snark || ledger_hash_didn't_change) + in + Boolean.(correct_snark && updated_consensus_state) + in + let new_state = + Protocol_state.create_var ~previous_state_hash + ~blockchain_state:(Snark_transition.blockchain_state transition) + ~consensus_state + in + let%bind state_triples = Protocol_state.var_to_triples new_state in + let%bind state_partial = + Pedersen.Checked.Section.extend Pedersen.Checked.Section.empty + ~start:Hash_prefix.length_in_triples state_triples + in + let%map state_hash = + Pedersen.Checked.Section.create + ~acc:(`Value Hash_prefix.protocol_state.acc) + ~support: + (Interval_union.of_interval (0, Hash_prefix.length_in_triples)) + |> Pedersen.Checked.Section.disjoint_union_exn state_partial + >>| Pedersen.Checked.Section.to_initial_segment_digest_exn >>| fst + in + (State_hash.var_of_hash_packed state_hash, new_state, `Success success) end end module Checked = struct - let is_base_hash h = - with_label __LOC__ - (Field.Checked.equal - (Field.Checked.constant - ( Protocol_state.hash Consensus_mechanism.genesis_protocol_state - :> Field.t )) - (State_hash.var_to_hash_packed h)) + let%snarkydef is_base_hash h = + Field.Checked.equal + (Field.Checked.constant + ( Protocol_state.hash Consensus_mechanism.genesis_protocol_state + :> Field.t )) + (State_hash.var_to_hash_packed h) - let hash (t : Protocol_state.var) = - with_label __LOC__ - ( Protocol_state.var_to_triples t - >>= Pedersen.Checked.digest_triples ~init:Hash_prefix.protocol_state - >>| State_hash.var_of_hash_packed ) + let%snarkydef hash (t : Protocol_state.var) = + Protocol_state.var_to_triples t + >>= Pedersen.Checked.digest_triples ~init:Hash_prefix.protocol_state + >>| State_hash.var_of_hash_packed end end diff --git a/src/lib/blockchain_snark/jbuild b/src/lib/blockchain_snark/jbuild index 26333fc35b6..121d5874763 100644 --- a/src/lib/blockchain_snark/jbuild +++ b/src/lib/blockchain_snark/jbuild @@ -7,5 +7,5 @@ (library_flags (-linkall)) (libraries (core cached cache_dir protocols snarky snark_params coda_base transaction_snark bignum_bigint consensus)) (inline_tests) - (preprocess (pps (ppx_jane ppx_deriving.eq bisect_ppx -conditional))) + (preprocess (pps (ppx_snarky ppx_jane ppx_deriving.eq bisect_ppx -conditional))) (synopsis "blockchain state transition snarking library"))) diff --git a/src/lib/coda_base/blockchain_state.ml b/src/lib/coda_base/blockchain_state.ml index 24ae942835e..3bccafec40c 100644 --- a/src/lib/coda_base/blockchain_state.ml +++ b/src/lib/coda_base/blockchain_state.ml @@ -185,21 +185,17 @@ end) : S = struct let () = assert Insecure.signature_hash_function - let hash_checked t ~nonce = + let%snarkydef hash_checked t ~nonce = let open Let_syntax in - with_label __LOC__ - (let%bind trips = var_to_triples t in - let%bind hash = - Pedersen.Checked.digest_triples ~init:Hash_prefix.signature - ( trips - @ Fold.(to_list (group3 ~default:Boolean.false_ (of_list nonce))) - ) - in - let%map bs = Pedersen.Checked.Digest.choose_preimage hash in - Bitstring.Lsb_first.of_list - (List.take - (bs :> Boolean.var list) - Inner_curve.Scalar.length_in_bits)) + let%bind trips = var_to_triples t in + let%bind hash = + Pedersen.Checked.digest_triples ~init:Hash_prefix.signature + ( trips + @ Fold.(to_list (group3 ~default:Boolean.false_ (of_list nonce))) ) + in + let%map bs = Pedersen.Checked.Digest.choose_preimage hash in + Bitstring.Lsb_first.of_list + (List.take (bs :> Boolean.var list) Inner_curve.Scalar.length_in_bits) end module Signature = diff --git a/src/lib/coda_base/data_hash.ml b/src/lib/coda_base/data_hash.ml index 26b050c1a34..b2b6795904b 100644 --- a/src/lib/coda_base/data_hash.ml +++ b/src/lib/coda_base/data_hash.ml @@ -136,14 +136,13 @@ struct >>| fun x -> (x :> Boolean.var list) else Field.Checked.unpack ~length:length_in_bits - let var_to_bits t = - with_label __LOC__ - ( match t.bits with - | Some bits -> return (bits :> Boolean.var list) - | None -> - let%map bits = unpack t.digest in - t.bits <- Some (Bitstring.Lsb_first.of_list bits) ; - bits ) + let%snarkydef var_to_bits t = + match t.bits with + | Some bits -> return (bits :> Boolean.var list) + | None -> + let%map bits = unpack t.digest in + t.bits <- Some (Bitstring.Lsb_first.of_list bits) ; + bits let var_to_triples t = var_to_bits t >>| Bitstring.pad_to_triple_list ~default:Boolean.false_ diff --git a/src/lib/coda_base/jbuild b/src/lib/coda_base/jbuild index 516e860201f..98c6ae3c17c 100644 --- a/src/lib/coda_base/jbuild +++ b/src/lib/coda_base/jbuild @@ -37,7 +37,7 @@ yojson codable)) (preprocessor_deps ("../../config.mlh")) - (preprocess (pps (ppx_jane ppx_deriving.eq ppx_deriving.enum ppx_deriving.ord ppx_deriving_yojson bisect_ppx -conditional))) + (preprocess (pps (ppx_snarky ppx_jane ppx_deriving.eq ppx_deriving.enum ppx_deriving.ord ppx_deriving_yojson bisect_ppx -conditional))) (synopsis "Snarks and friends necessary for keypair generation"))) (rule diff --git a/src/lib/coda_base/ledger_hash.ml b/src/lib/coda_base/ledger_hash.ml index f168a0b2012..403dc2687b2 100644 --- a/src/lib/coda_base/ledger_hash.ml +++ b/src/lib/coda_base/ledger_hash.ml @@ -86,20 +86,20 @@ let get t addr = Merkle_tree.get_req ~depth (var_to_hash_packed t) addr - returns a root [t'] of a tree of depth [depth] which is [t] but with the account [f account] at path [addr]. *) -let modify_account t pk ~(filter : Account.var -> ('a, _) Checked.t) ~f = - with_label __LOC__ - (let%bind addr = - request_witness Account.Index.Unpacked.typ - As_prover.( - map (read Public_key.Compressed.typ pk) ~f:(fun s -> Find_index s)) - in - handle - (Merkle_tree.modify_req ~depth (var_to_hash_packed t) addr - ~f:(fun account -> - let%bind x = filter account in - f x account )) - reraise_merkle_requests - >>| var_of_hash_packed) +let%snarkydef modify_account t pk ~(filter : Account.var -> ('a, _) Checked.t) + ~f = + let%bind addr = + request_witness Account.Index.Unpacked.typ + As_prover.( + map (read Public_key.Compressed.typ pk) ~f:(fun s -> Find_index s)) + in + handle + (Merkle_tree.modify_req ~depth (var_to_hash_packed t) addr + ~f:(fun account -> + let%bind x = filter account in + f x account )) + reraise_merkle_requests + >>| var_of_hash_packed (* [modify_account_send t pk ~f] implements the following spec: diff --git a/src/lib/coda_base/payment_payload.ml b/src/lib/coda_base/payment_payload.ml index 693de7854b7..c55d94fbe16 100644 --- a/src/lib/coda_base/payment_payload.ml +++ b/src/lib/coda_base/payment_payload.ml @@ -44,11 +44,10 @@ let fold {receiver; amount} = (* TODO: This could be a bit more efficient by packing across triples, but I think the added confusion-possibility is not worth it. *) -let var_to_triples {receiver; amount} = - with_label __LOC__ - (let%map receiver = Public_key.Compressed.var_to_triples receiver in - let amount = Amount.var_to_triples amount in - receiver @ amount) +let%snarkydef var_to_triples {receiver; amount} = + let%map receiver = Public_key.Compressed.var_to_triples receiver in + let amount = Amount.var_to_triples amount in + receiver @ amount let length_in_triples = Public_key.Compressed.length_in_triples + Amount.length_in_triples diff --git a/src/lib/coda_base/schnorr.ml b/src/lib/coda_base/schnorr.ml index e88f7417fc6..db02c96b3c8 100644 --- a/src/lib/coda_base/schnorr.ml +++ b/src/lib/coda_base/schnorr.ml @@ -29,34 +29,32 @@ module Message = struct let () = assert Insecure.signature_hash_function - let hash_checked t ~nonce = + let%snarkydef hash_checked t ~nonce = let open Let_syntax in - with_label __LOC__ - (let init = - Pedersen.Checked.Section.create - ~acc:(`Value Hash_prefix.signature.acc) - ~support: - (Interval_union.of_interval (0, Hash_prefix.length_in_triples)) - in - let%bind with_t = Pedersen.Checked.Section.disjoint_union_exn init t in - let%bind digest = - let%map final = - Pedersen.Checked.Section.extend with_t - (Bitstring_lib.Bitstring.pad_to_triple_list - ~default:Boolean.false_ nonce) - ~start: - ( Hash_prefix.length_in_triples - + User_command_payload.length_in_triples ) - in - let d, _ = - Pedersen.Checked.Section.to_initial_segment_digest final - |> Or_error.ok_exn - in - d - in - let%bind bs = Pedersen.Checked.Digest.choose_preimage digest in - let%map d = Sha256_lib.Sha256.Checked.digest (bs :> Boolean.var list) in - Bitstring.Lsb_first.of_list (d :> Boolean.var list)) + let init = + Pedersen.Checked.Section.create ~acc:(`Value Hash_prefix.signature.acc) + ~support: + (Interval_union.of_interval (0, Hash_prefix.length_in_triples)) + in + let%bind with_t = Pedersen.Checked.Section.disjoint_union_exn init t in + let%bind digest = + let%map final = + Pedersen.Checked.Section.extend with_t + (Bitstring_lib.Bitstring.pad_to_triple_list ~default:Boolean.false_ + nonce) + ~start: + ( Hash_prefix.length_in_triples + + User_command_payload.length_in_triples ) + in + let d, _ = + Pedersen.Checked.Section.to_initial_segment_digest final + |> Or_error.ok_exn + in + d + in + let%bind bs = Pedersen.Checked.Digest.choose_preimage digest in + let%map d = Sha256_lib.Sha256.Checked.digest (bs :> Boolean.var list) in + Bitstring.Lsb_first.of_list (d :> Boolean.var list) end include Signature_lib.Checked.Schnorr (Tick) (Snark_params.Tick.Inner_curve) diff --git a/src/lib/coda_base/transition_system.ml b/src/lib/coda_base/transition_system.ml index 1e7a3f0fbe0..028a690711a 100644 --- a/src/lib/coda_base/transition_system.ml +++ b/src/lib/coda_base/transition_system.ml @@ -109,69 +109,67 @@ struct >>| Tick.Pedersen.Checked.Section.to_initial_segment_digest >>| Or_error.ok_exn >>| fst - let prev_state_valid wrap_vk_section wrap_vk wrap_vk_data prev_state_hash = + let%snarkydef prev_state_valid wrap_vk_section wrap_vk wrap_vk_data + prev_state_hash = let open Let_syntax in - with_label __LOC__ - (* TODO: Should build compositionally on the prev_state hash (instead of converting to bits) *) - (let%bind prev_state_hash_trips = - State.Hash.var_to_triples prev_state_hash - in - let%bind prev_top_hash = - compute_top_hash wrap_vk_section prev_state_hash_trips - >>= Wrap_input.Checked.tick_field_to_scalars - in - let%bind other_wrap_vk_data, result = - Verifier.All_in_one.check_proof wrap_vk - ~get_vk:As_prover.(map get_state ~f:Prover_state.wrap_vk) - ~get_proof:As_prover.(map get_state ~f:Prover_state.prev_proof) - prev_top_hash - in - let%map () = - Verifier.Verification_key_data.Checked.Assert.equal wrap_vk_data - other_wrap_vk_data - in - result) + (* TODO: Should build compositionally on the prev_state hash (instead of converting to bits) *) + let%bind prev_state_hash_trips = + State.Hash.var_to_triples prev_state_hash + in + let%bind prev_top_hash = + compute_top_hash wrap_vk_section prev_state_hash_trips + >>= Wrap_input.Checked.tick_field_to_scalars + in + let%bind other_wrap_vk_data, result = + Verifier.All_in_one.check_proof wrap_vk + ~get_vk:As_prover.(map get_state ~f:Prover_state.wrap_vk) + ~get_proof:As_prover.(map get_state ~f:Prover_state.prev_proof) + prev_top_hash + in + let%map () = + Verifier.Verification_key_data.Checked.Assert.equal wrap_vk_data + other_wrap_vk_data + in + result let provide_witness' typ ~f = provide_witness typ As_prover.(map get_state ~f) - let main (top_hash : Digest.Tick.Packed.var) = - with_label __LOC__ - (let%bind prev_state = - provide_witness' State.typ ~f:Prover_state.prev_state - and update = provide_witness' Update.typ ~f:Prover_state.update in - let%bind prev_state_hash = State.Checked.hash prev_state in - let%bind next_state_hash, _next_state, `Success success = - with_label __LOC__ - (State.Checked.update (prev_state_hash, prev_state) update) - in - let%bind wrap_vk = - provide_witness' Verifier.Verification_key.typ - ~f:(fun {Prover_state.wrap_vk; _} -> - Verifier.Verification_key.of_verification_key wrap_vk ) - in - let wrap_vk_data = - Verifier.Verification_key.Checked.to_full_data wrap_vk - in - let%bind wrap_vk_section = hash_vk_data wrap_vk_data in - let%bind () = - with_label __LOC__ - (let%bind sh = State.Hash.var_to_triples next_state_hash in - (* We could be reusing the intermediate state of the hash on sh here instead of + let%snarkydef main (top_hash : Digest.Tick.Packed.var) = + let%bind prev_state = + provide_witness' State.typ ~f:Prover_state.prev_state + and update = provide_witness' Update.typ ~f:Prover_state.update in + let%bind prev_state_hash = State.Checked.hash prev_state in + let%bind next_state_hash, _next_state, `Success success = + with_label __LOC__ + (State.Checked.update (prev_state_hash, prev_state) update) + in + let%bind wrap_vk = + provide_witness' Verifier.Verification_key.typ + ~f:(fun {Prover_state.wrap_vk; _} -> + Verifier.Verification_key.of_verification_key wrap_vk ) + in + let wrap_vk_data = + Verifier.Verification_key.Checked.to_full_data wrap_vk + in + let%bind wrap_vk_section = hash_vk_data wrap_vk_data in + let%bind () = + with_label __LOC__ + (let%bind sh = State.Hash.var_to_triples next_state_hash in + (* We could be reusing the intermediate state of the hash on sh here instead of hashing anew *) - compute_top_hash wrap_vk_section sh - >>= Field.Checked.Assert.equal top_hash) - in - let%bind prev_state_valid = - prev_state_valid wrap_vk_section wrap_vk wrap_vk_data - prev_state_hash - in - let%bind inductive_case_passed = - with_label __LOC__ Boolean.(prev_state_valid && success) - in - let%bind is_base_case = State.Checked.is_base_hash next_state_hash in - with_label __LOC__ - (Boolean.Assert.any [is_base_case; inductive_case_passed])) + compute_top_hash wrap_vk_section sh + >>= Field.Checked.Assert.equal top_hash) + in + let%bind prev_state_valid = + prev_state_valid wrap_vk_section wrap_vk wrap_vk_data prev_state_hash + in + let%bind inductive_case_passed = + with_label __LOC__ Boolean.(prev_state_valid && success) + in + let%bind is_base_case = State.Checked.is_base_hash next_state_hash in + with_label __LOC__ + (Boolean.Assert.any [is_base_case; inductive_case_passed]) end module Step (Tick_keypair : Tick_keypair_intf) = struct @@ -201,25 +199,24 @@ struct let step_vk_bits = Verifier.Verification_key_data.to_bits step_vk_data (* TODO: Use an online verifier here *) - let main (input : Wrap_input.var) = + let%snarkydef main (input : Wrap_input.var) = let open Let_syntax in - with_label __LOC__ - (let%bind vk_data, result = - (* The use of choose_preimage here is justified since we feed it to the verifier, which doesn't + let%bind vk_data, result = + (* The use of choose_preimage here is justified since we feed it to the verifier, which doesn't depend on which unpacking is provided. *) - let%bind input = Wrap_input.Checked.to_scalar input in - Verifier.All_in_one.check_proof - Verifier.Verification_key.( - Checked.constant (of_verification_key Step_vk.verification_key)) - ~get_vk:(As_prover.return Step_vk.verification_key) - ~get_proof:As_prover.(map get_state ~f:Prover_state.proof) - [input] - in - let%bind () = - let open Verifier.Verification_key_data.Checked in - Assert.equal vk_data (constant step_vk_data) - in - with_label __LOC__ (Boolean.Assert.is_true result)) + let%bind input = Wrap_input.Checked.to_scalar input in + Verifier.All_in_one.check_proof + Verifier.Verification_key.( + Checked.constant (of_verification_key Step_vk.verification_key)) + ~get_vk:(As_prover.return Step_vk.verification_key) + ~get_proof:As_prover.(map get_state ~f:Prover_state.proof) + [input] + in + let%bind () = + let open Verifier.Verification_key_data.Checked in + Assert.equal vk_data (constant step_vk_data) + in + with_label __LOC__ (Boolean.Assert.is_true result) end module Wrap (Step_vk : Step_vk_intf) (Tock_keypair : Tock_keypair_intf) = diff --git a/src/lib/consensus/jbuild b/src/lib/consensus/jbuild index d1b8d71eb7b..0dd744357c5 100644 --- a/src/lib/consensus/jbuild +++ b/src/lib/consensus/jbuild @@ -18,5 +18,5 @@ global_signer_private_key non_zero_curve_point )) (preprocessor_deps ("../../config.mlh")) - (preprocess (pps (ppx_jane ppx_deriving.eq bisect_ppx -conditional))) + (preprocess (pps (ppx_snarky ppx_jane ppx_deriving.eq bisect_ppx -conditional))) (synopsis "Consensus mechanisms"))) diff --git a/src/lib/consensus/proof_of_work.ml.ignore b/src/lib/consensus/proof_of_work.ml.ignore index a8cc1303354..ee25977191f 100644 --- a/src/lib/consensus/proof_of_work.ml.ignore +++ b/src/lib/consensus/proof_of_work.ml.ignore @@ -134,8 +134,7 @@ module Strength = struct z (* floor(two_to_the bit_length / y) *) - let of_target (y: Target.Packed.var) (y_unpacked: Target.Unpacked.var) = - Tick.with_label __LOC__ + let%snarkydef of_target (y: Target.Packed.var) (y_unpacked: Target.Unpacked.var) = ( if Insecure.strength_calculation then Tick.provide_witness Tick.Typ.field Tick.As_prover.( diff --git a/src/lib/non_zero_curve_point/jbuild b/src/lib/non_zero_curve_point/jbuild index a0fca0cf06a..eebe14bf3e3 100644 --- a/src/lib/non_zero_curve_point/jbuild +++ b/src/lib/non_zero_curve_point/jbuild @@ -6,4 +6,4 @@ (flags (:standard -short-paths)) (library_flags (-linkall)) (libraries (core_kernel snark_params fold_lib base64 codable)) - (preprocess (pps (ppx_jane ppx_deriving.eq))))) + (preprocess (pps (ppx_snarky ppx_jane ppx_deriving.eq))))) diff --git a/src/lib/non_zero_curve_point/non_zero_curve_point.ml b/src/lib/non_zero_curve_point/non_zero_curve_point.ml index 2c039a3b929..feffe038014 100644 --- a/src/lib/non_zero_curve_point/non_zero_curve_point.ml +++ b/src/lib/non_zero_curve_point/non_zero_curve_point.ml @@ -170,10 +170,9 @@ let decompress_var ({x; is_odd} as c : Compressed.var) = let compress : t -> Compressed.t = Compressed.compress -let compress_var ((x, y) : var) : (Compressed.var, _) Checked.t = - with_label __LOC__ - (let%map is_odd = parity_var y in - {Compressed.x; is_odd}) +let%snarkydef compress_var ((x, y) : var) : (Compressed.var, _) Checked.t = + let%map is_odd = parity_var y in + {Compressed.x; is_odd} let of_bigstring bs = let open Or_error.Let_syntax in diff --git a/src/lib/ppx_snarky/jbuild b/src/lib/ppx_snarky/jbuild new file mode 100644 index 00000000000..b0184cc0d9a --- /dev/null +++ b/src/lib/ppx_snarky/jbuild @@ -0,0 +1,8 @@ +(jbuild_version 1) + +(library + ((name ppx_snarky) + (public_name ppx_snarky) + (kind ppx_rewriter) + (libraries (ppxlib)) + (preprocess (pps (ppxlib.metaquot))))) diff --git a/src/lib/ppx_snarky/ppx_snarky.ml b/src/lib/ppx_snarky/ppx_snarky.ml new file mode 100644 index 00000000000..01f7f2f6dcc --- /dev/null +++ b/src/lib/ppx_snarky/ppx_snarky.ml @@ -0,0 +1 @@ +let () = Snarkydef.main () diff --git a/src/lib/ppx_snarky/snarkydef.ml b/src/lib/ppx_snarky/snarkydef.ml new file mode 100644 index 00000000000..d6af103428a --- /dev/null +++ b/src/lib/ppx_snarky/snarkydef.ml @@ -0,0 +1,62 @@ +open Ppxlib +open Ast_helper +open Ast_builder.Default +open Asttypes + +let name = "snarkydef" + +let located_label_expr expr = + let loc = expr.pexp_loc in + [%expr Pervasives.( ^ ) [%e expr] (Pervasives.( ^ ) ": " Pervasives.__LOC__)] + +let located_label_string ~loc str = + [%expr + Pervasives.( ^ ) + [%e Exp.constant ~loc (Const.string (str ^ ": "))] + Pervasives.__LOC__] + +let with_label ~loc exprs = Exp.apply ~loc [%expr with_label] exprs + +let with_label_one ~loc ~path:_ expr = + with_label ~loc [(Nolabel, located_label_expr expr)] + +let rec snarkydef_inject ~loc ~name expr = + match expr.pexp_desc with + | Pexp_fun (lbl, default, pat, body) -> + { expr with + pexp_desc= + Pexp_fun (lbl, default, pat, snarkydef_inject ~loc ~name body) } + | Pexp_newtype (typname, body) -> + { expr with + pexp_desc= Pexp_newtype (typname, snarkydef_inject ~loc ~name body) } + | Pexp_function _ -> + Location.raise_errorf ~loc:expr.pexp_loc + "%%snarkydef currently doesn't support 'function'" + | _ -> + with_label ~loc + [(Nolabel, located_label_string ~loc name); (Nolabel, expr)] + +let snarkydef ~loc ~path:_ name expr = + [%stri + let [%p Pat.var ~loc (Located.mk ~loc name)] = + [%e snarkydef_inject ~loc ~name expr]] + +let with_label_ext = + Extension.declare "with_label" Extension.Context.expression + Ast_pattern.(single_expr_payload __) + with_label_one + +let snarkydef_ext = + Extension.declare "snarkydef" Extension.Context.structure_item + Ast_pattern.( + pstr + ( pstr_value nonrecursive + (value_binding ~pat:(ppat_var __) ~expr:__ ^:: nil) + ^:: nil )) + snarkydef + +let main () = + Driver.register_transformation name + ~rules: + [ Context_free.Rule.extension with_label_ext + ; Context_free.Rule.extension snarkydef_ext ] diff --git a/src/lib/signature_lib/checked.ml b/src/lib/signature_lib/checked.ml index 060eb4f1153..c1eb4b9a6dd 100644 --- a/src/lib/signature_lib/checked.ml +++ b/src/lib/signature_lib/checked.ml @@ -257,37 +257,33 @@ module Schnorr open Impl.Let_syntax - let verification_hash (type s) + let%snarkydef verification_hash (type s) ((module Shifted) as shifted : (module Curve.Checked.Shifted.S with type t = s)) ((s, h) : Signature.var) (public_key : Public_key.var) (m : Message.var) = - with_label __LOC__ - (let%bind pre_r = - (* s * g + h * public_key *) - let%bind s_g = - Curve.Checked.scale_known shifted Curve.one - (Curve.Scalar.Checked.to_bits s) - ~init:Shifted.zero - in - let%bind s_g_h_pk = - Curve.Checked.scale shifted public_key - (Curve.Scalar.Checked.to_bits h) - ~init:s_g - in - Shifted.unshift_nonzero s_g_h_pk - in - let%bind r = compress pre_r in - Message.hash_checked m ~nonce:r) - - let verifies shifted ((_, h) as signature) pk m = - with_label __LOC__ - ( verification_hash shifted signature pk m - >>= Curve.Scalar.Checked.equal h ) - - let assert_verifies shifted ((_, h) as signature) pk m = - with_label __LOC__ - ( verification_hash shifted signature pk m - >>= Curve.Scalar.Checked.Assert.equal h ) + let%bind pre_r = + (* s * g + h * public_key *) + let%bind s_g = + Curve.Checked.scale_known shifted Curve.one + (Curve.Scalar.Checked.to_bits s) + ~init:Shifted.zero + in + let%bind s_g_h_pk = + Curve.Checked.scale shifted public_key + (Curve.Scalar.Checked.to_bits h) + ~init:s_g + in + Shifted.unshift_nonzero s_g_h_pk + in + let%bind r = compress pre_r in + Message.hash_checked m ~nonce:r + + let%snarkydef verifies shifted ((_, h) as signature) pk m = + verification_hash shifted signature pk m >>= Curve.Scalar.Checked.equal h + + let%snarkydef assert_verifies shifted ((_, h) as signature) pk m = + verification_hash shifted signature pk m + >>= Curve.Scalar.Checked.Assert.equal h end end diff --git a/src/lib/signature_lib/jbuild b/src/lib/signature_lib/jbuild index 5b81c04656b..84bd2d19d35 100644 --- a/src/lib/signature_lib/jbuild +++ b/src/lib/signature_lib/jbuild @@ -8,5 +8,5 @@ (libraries ( snarky base64 snark_params core non_zero_curve_point yojson )) (preprocessor_deps ("../../config.mlh")) - (preprocess (pps (ppx_jane ppx_deriving.eq ppx_deriving_yojson))) + (preprocess (pps (ppx_snarky ppx_jane ppx_deriving.eq ppx_deriving_yojson))) (synopsis "Schnorr signatures using the tick and tock curves"))) diff --git a/src/lib/snark_bits/bits.ml b/src/lib/snark_bits/bits.ml index 84f66c18214..5d6ca2fbf47 100644 --- a/src/lib/snark_bits/bits.ml +++ b/src/lib/snark_bits/bits.ml @@ -246,26 +246,23 @@ module Snarkable = struct let compare_var x y = Impl.Field.Checked.compare ~bit_length:V.length (pack_var x) (pack_var y) - let increment_if_var bs (b : Boolean.var) = + let%snarkydef increment_if_var bs (b : Boolean.var) = let open Impl in - with_label __LOC__ - (let v = Field.Checked.pack bs in - let v' = Field.Checked.add v (b :> Field.Checked.t) in - Field.Checked.unpack v' ~length:V.length) + let v = Field.Checked.pack bs in + let v' = Field.Checked.add v (b :> Field.Checked.t) in + Field.Checked.unpack v' ~length:V.length - let increment_var bs = + let%snarkydef increment_var bs = let open Impl in - with_label __LOC__ - (let v = Field.Checked.pack bs in - let v' = Field.Checked.add v (Field.Checked.constant Field.one) in - Field.Checked.unpack v' ~length:V.length) + let v = Field.Checked.pack bs in + let v' = Field.Checked.add v (Field.Checked.constant Field.one) in + Field.Checked.unpack v' ~length:V.length - let equal_var (n : Unpacked.var) (n' : Unpacked.var) = - with_label __LOC__ (Field.Checked.equal (pack_var n) (pack_var n')) + let%snarkydef equal_var (n : Unpacked.var) (n' : Unpacked.var) = + Field.Checked.equal (pack_var n) (pack_var n') - let assert_equal_var (n : Unpacked.var) (n' : Unpacked.var) = - with_label __LOC__ - (Field.Checked.Assert.equal (pack_var n) (pack_var n')) + let%snarkydef assert_equal_var (n : Unpacked.var) (n' : Unpacked.var) = + Field.Checked.Assert.equal (pack_var n) (pack_var n') let if_ (cond : Boolean.var) ~(then_ : Unpacked.var) ~(else_ : Unpacked.var) : (Unpacked.var, _) Checked.t = diff --git a/src/lib/snark_bits/jbuild b/src/lib/snark_bits/jbuild index f60ffa6daff..1e34dd19fee 100644 --- a/src/lib/snark_bits/jbuild +++ b/src/lib/snark_bits/jbuild @@ -8,6 +8,6 @@ (inline_tests) (libraries ( fold_lib core_kernel snarky )) - (preprocess (pps (ppx_jane ppx_deriving.eq bisect_ppx -conditional))) + (preprocess (pps (ppx_snarky ppx_jane ppx_deriving.eq bisect_ppx -conditional))) (synopsis "Snark parameters"))) diff --git a/src/lib/snarky/src/curves.ml b/src/lib/snarky/src/curves.ml index 8ebaeae8ae9..8416734b73f 100644 --- a/src/lib/snarky/src/curves.ml +++ b/src/lib/snarky/src/curves.ml @@ -312,55 +312,48 @@ module Edwards = struct () end - let add_known (x1, y1) (x2, y2) = - with_label __LOC__ - (let x1x2 = Field.Checked.scale x1 x2 - and y1y2 = Field.Checked.scale y1 y2 - and x1y2 = Field.Checked.scale x1 y2 - and y1x2 = Field.Checked.scale y1 x2 in - let%bind p = Field.Checked.mul x1x2 y1y2 in - let open Field.Checked.Infix in - let p = Params.d * p in - let%map a = - Field.Checked.div (x1y2 + y1x2) - (Field.Checked.constant Field.one + p) - and b = - Field.Checked.div (y1y2 - x1x2) - (Field.Checked.constant Field.one - p) - in - (a, b)) + let%snarkydef add_known (x1, y1) (x2, y2) = + let x1x2 = Field.Checked.scale x1 x2 + and y1y2 = Field.Checked.scale y1 y2 + and x1y2 = Field.Checked.scale x1 y2 + and y1x2 = Field.Checked.scale y1 x2 in + let%bind p = Field.Checked.mul x1x2 y1y2 in + let open Field.Checked.Infix in + let p = Params.d * p in + let%map a = + Field.Checked.div (x1y2 + y1x2) (Field.Checked.constant Field.one + p) + and b = + Field.Checked.div (y1y2 - x1x2) (Field.Checked.constant Field.one - p) + in + (a, b) (* TODO: Optimize -- could probably shave off one constraint. *) - let add (x1, y1) (x2, y2) = - with_label __LOC__ - (let%bind x1x2 = Field.Checked.mul x1 x2 - and y1y2 = Field.Checked.mul y1 y2 - and x1y2 = Field.Checked.mul x1 y2 - and x2y1 = Field.Checked.mul x2 y1 in - let%bind p = Field.Checked.mul x1x2 y1y2 in - let open Field.Checked.Infix in - let p = Params.d * p in - let%map a = - Field.Checked.div (x1y2 + x2y1) - (Field.Checked.constant Field.one + p) - and b = - Field.Checked.div (y1y2 - x1x2) - (Field.Checked.constant Field.one - p) - in - (a, b)) + let%snarkydef add (x1, y1) (x2, y2) = + let%bind x1x2 = Field.Checked.mul x1 x2 + and y1y2 = Field.Checked.mul y1 y2 + and x1y2 = Field.Checked.mul x1 y2 + and x2y1 = Field.Checked.mul x2 y1 in + let%bind p = Field.Checked.mul x1x2 y1y2 in + let open Field.Checked.Infix in + let p = Params.d * p in + let%map a = + Field.Checked.div (x1y2 + x2y1) (Field.Checked.constant Field.one + p) + and b = + Field.Checked.div (y1y2 - x1x2) (Field.Checked.constant Field.one - p) + in + (a, b) - let double (x, y) = - with_label __LOC__ - (let%bind xy = Field.Checked.mul x y - and xx = Field.Checked.mul x x - and yy = Field.Checked.mul y y in - let open Field.Checked.Infix in - let two = Field.of_int 2 in - let%map a = Field.Checked.div (two * xy) (xx + yy) - and b = - Field.Checked.div (yy - xx) (Field.Checked.constant two - xx - yy) - in - (a, b)) + let%snarkydef double (x, y) = + let%bind xy = Field.Checked.mul x y + and xx = Field.Checked.mul x x + and yy = Field.Checked.mul y y in + let open Field.Checked.Infix in + let two = Field.of_int 2 in + let%map a = Field.Checked.div (two * xy) (xx + yy) + and b = + Field.Checked.div (yy - xx) (Field.Checked.constant two - xx - yy) + in + (a, b) let if_value (b : Boolean.var) ~then_:(x1, y1) ~else_:(x2, y2) = let not_b = (Boolean.not b :> Field.Checked.t) in @@ -405,79 +398,73 @@ module Edwards = struct in r) - let scale t (c : Scalar.var) = - with_label __LOC__ - (let rec go i acc pt = function - | [] -> return acc - | b :: bs -> - let%bind acc' = - with_label (sprintf "acc_%d" i) - (let%bind add_pt = add acc pt in - let don't_add_pt = acc in - if_ b ~then_:add_pt ~else_:don't_add_pt) - and pt' = double pt in - go (i + 1) acc' pt' bs - in - match c with - | [] -> failwith "Edwards.Checked.scale: Empty bits" - | b :: bs -> - let%bind acc = if_ b ~then_:t ~else_:identity - and pt = double t in - go 1 acc pt bs) + let%snarkydef scale t (c : Scalar.var) = + let rec go i acc pt = function + | [] -> return acc + | b :: bs -> + let%bind acc' = + with_label (sprintf "acc_%d" i) + (let%bind add_pt = add acc pt in + let don't_add_pt = acc in + if_ b ~then_:add_pt ~else_:don't_add_pt) + and pt' = double pt in + go (i + 1) acc' pt' bs + in + match c with + | [] -> failwith "Edwards.Checked.scale: Empty bits" + | b :: bs -> + let%bind acc = if_ b ~then_:t ~else_:identity and pt = double t in + go 1 acc pt bs (* TODO: Unit test *) - let cond_add ((x2, y2) : value) ~to_:((x1, y1) : var) + let%snarkydef cond_add ((x2, y2) : value) ~to_:((x1, y1) : var) ~if_:(b : Boolean.var) : (var, _) Checked.t = - with_label __LOC__ - (let one = Field.Checked.constant Field.one in - let b = (b :> Field.Checked.t) in - let open Let_syntax in - let open Field.Checked.Infix in - let res a1 a3 = - let%bind a = - provide_witness Typ.field - (let open As_prover in - let open As_prover.Let_syntax in - let open Field.Infix in - let%map b = read_var b - and a3 = read_var a3 - and a1 = read_var a1 in - a1 + (b * (a3 - a1))) - in - let%map () = assert_r1cs b (a3 - a1) (a - a1) in - a - in - let%bind beta = Field.Checked.mul x1 y1 in - let p = Field.Infix.(Params.d * x2 * y2) * beta in - let%bind x3 = Field.Checked.div ((y2 * x1) + (x2 * y1)) (one + p) - and y3 = Field.Checked.div ((y2 * y1) - (x2 * x1)) (one - p) in - let%map x_res = res x1 x3 and y_res = res y1 y3 in - (x_res, y_res)) - - let scale_known (t : value) (c : Scalar.var) = - with_label __LOC__ - (let rec go i acc pt = function - | b :: bs -> - let%bind acc' = - with_label (sprintf "acc_%d" i) - (cond_add pt ~to_:acc ~if_:b) - in - go (i + 1) acc' (double_value pt) bs - | [] -> return acc - in - match c with - | [] -> failwith "scale_known: Empty bits" - | b :: bs -> - let acc = - let b = (b :> Field.Checked.t) in - let x_id, y_id = identity_value in - let x_t, y_t = t in - let open Field.Checked.Infix in - ( (Field.Infix.(x_t - x_id) * b) + Field.Checked.constant x_id - , (Field.Infix.(y_t - y_id) * b) + Field.Checked.constant y_id - ) - in - go 1 acc (double_value t) bs) + let one = Field.Checked.constant Field.one in + let b = (b :> Field.Checked.t) in + let open Let_syntax in + let open Field.Checked.Infix in + let res a1 a3 = + let%bind a = + provide_witness Typ.field + (let open As_prover in + let open As_prover.Let_syntax in + let open Field.Infix in + let%map b = read_var b + and a3 = read_var a3 + and a1 = read_var a1 in + a1 + (b * (a3 - a1))) + in + let%map () = assert_r1cs b (a3 - a1) (a - a1) in + a + in + let%bind beta = Field.Checked.mul x1 y1 in + let p = Field.Infix.(Params.d * x2 * y2) * beta in + let%bind x3 = Field.Checked.div ((y2 * x1) + (x2 * y1)) (one + p) + and y3 = Field.Checked.div ((y2 * y1) - (x2 * x1)) (one - p) in + let%map x_res = res x1 x3 and y_res = res y1 y3 in + (x_res, y_res) + + let%snarkydef scale_known (t : value) (c : Scalar.var) = + let rec go i acc pt = function + | b :: bs -> + let%bind acc' = + with_label (sprintf "acc_%d" i) (cond_add pt ~to_:acc ~if_:b) + in + go (i + 1) acc' (double_value pt) bs + | [] -> return acc + in + match c with + | [] -> failwith "scale_known: Empty bits" + | b :: bs -> + let acc = + let b = (b :> Field.Checked.t) in + let x_id, y_id = identity_value in + let x_t, y_t = t in + let open Field.Checked.Infix in + ( (Field.Infix.(x_t - x_id) * b) + Field.Checked.constant x_id + , (Field.Infix.(y_t - y_id) * b) + Field.Checked.constant y_id ) + in + go 1 acc (double_value t) bs end end @@ -625,44 +612,42 @@ module Make_weierstrass_checked let equal = assert_equal end - let add' ~div (ax, ay) (bx, by) = - with_label __LOC__ - (let open Let_syntax in - let%bind lambda = - div (Field.Checked.sub by ay) (Field.Checked.sub bx ax) - in - let%bind cx = - provide_witness Typ.field - (let open As_prover in - let open Let_syntax in - let%map ax = read_var ax - and bx = read_var bx - and lambda = read_var lambda in - Field.(sub (square lambda) (add ax bx))) - in - let%bind () = - (* lambda^2 = cx + ax + bx + let%snarkydef add' ~div (ax, ay) (bx, by) = + let open Let_syntax in + let%bind lambda = + div (Field.Checked.sub by ay) (Field.Checked.sub bx ax) + in + let%bind cx = + provide_witness Typ.field + (let open As_prover in + let open Let_syntax in + let%map ax = read_var ax + and bx = read_var bx + and lambda = read_var lambda in + Field.(sub (square lambda) (add ax bx))) + in + let%bind () = + (* lambda^2 = cx + ax + bx cx = lambda^2 - (ax + bc) *) - assert_ - (Constraint.square ~label:"c1" lambda - Field.Checked.Infix.(cx + ax + bx)) - in - let%bind cy = - provide_witness Typ.field - (let open As_prover in - let open Let_syntax in - let%map ax = read_var ax - and ay = read_var ay - and cx = read_var cx - and lambda = read_var lambda in - Field.(sub (mul lambda (sub ax cx)) ay)) - in - let%map () = - Field.Checked.Infix.( - assert_r1cs ~label:"c2" lambda (ax - cx) (cy + ay)) - in - (cx, cy)) + assert_ + (Constraint.square ~label:"c1" lambda + Field.Checked.Infix.(cx + ax + bx)) + in + let%bind cy = + provide_witness Typ.field + (let open As_prover in + let open Let_syntax in + let%map ax = read_var ax + and ay = read_var ay + and cx = read_var cx + and lambda = read_var lambda in + Field.(sub (mul lambda (sub ax cx)) ay)) + in + let%map () = + Field.Checked.Infix.(assert_r1cs ~label:"c2" lambda (ax - cx) (cy + ay)) + in + (cx, cy) (* This function MUST NOT be called UNLESS you are certain the two points on which it is called are not equal. If it is called on equal points, @@ -767,43 +752,42 @@ module Make_weierstrass_checked (module M : S) end - let double (ax, ay) = - with_label __LOC__ - (let open Let_syntax in - let%bind x_squared = Field.Checked.square ax in - let%bind lambda = - provide_witness Typ.field - As_prover.( - map2 (read_var x_squared) (read_var ay) ~f:(fun x_squared ay -> - let open Field in - let open Infix in - ((of_int 3 * x_squared) + Params.a) * inv (of_int 2 * ay) )) - in - let%bind bx = - provide_witness Typ.field - As_prover.( - map2 (read_var lambda) (read_var ax) ~f:(fun lambda ax -> - let open Field in - Infix.(square lambda - (of_int 2 * ax)) )) - in - let%bind by = - provide_witness Typ.field - (let open As_prover in - let open Let_syntax in - let%map lambda = read_var lambda - and ax = read_var ax - and ay = read_var ay - and bx = read_var bx in - Field.Infix.((lambda * (ax - bx)) - ay)) - in - let two = Field.of_int 2 in - let open Field.Checked.Infix in - let%map () = - assert_r1cs (two * lambda) ay - ((Field.of_int 3 * x_squared) + Field.Checked.constant Params.a) - and () = assert_square lambda (bx + (two * ax)) - and () = assert_r1cs lambda (ax - bx) (by + ay) in - (bx, by)) + let%snarkydef double (ax, ay) = + let open Let_syntax in + let%bind x_squared = Field.Checked.square ax in + let%bind lambda = + provide_witness Typ.field + As_prover.( + map2 (read_var x_squared) (read_var ay) ~f:(fun x_squared ay -> + let open Field in + let open Infix in + ((of_int 3 * x_squared) + Params.a) * inv (of_int 2 * ay) )) + in + let%bind bx = + provide_witness Typ.field + As_prover.( + map2 (read_var lambda) (read_var ax) ~f:(fun lambda ax -> + let open Field in + Infix.(square lambda - (of_int 2 * ax)) )) + in + let%bind by = + provide_witness Typ.field + (let open As_prover in + let open Let_syntax in + let%map lambda = read_var lambda + and ax = read_var ax + and ay = read_var ay + and bx = read_var bx in + Field.Infix.((lambda * (ax - bx)) - ay)) + in + let two = Field.of_int 2 in + let open Field.Checked.Infix in + let%map () = + assert_r1cs (two * lambda) ay + ((Field.of_int 3 * x_squared) + Field.Checked.constant Params.a) + and () = assert_square lambda (bx + (two * ax)) + and () = assert_r1cs lambda (ax - bx) (by + ay) in + (bx, by) let if_value (cond : Boolean.var) ~then_ ~else_ = let x1, y1 = Curve.to_coords then_ in @@ -815,25 +799,25 @@ module Make_weierstrass_checked in (choose x1 x2, choose y1 y2) - let scale (type shifted) (module Shifted : Shifted.S with type t = shifted) t + let%snarkydef scale (type shifted) + (module Shifted : Shifted.S with type t = shifted) t (c : Boolean.var Bitstring_lib.Bitstring.Lsb_first.t) ~(init : shifted) : (shifted, _) Checked.t = let c = Bitstring_lib.Bitstring.Lsb_first.to_list c in - with_label __LOC__ - (let open Let_syntax in - let rec go i bs0 acc pt = - match bs0 with - | [] -> return acc - | b :: bs -> - let%bind acc' = - with_label (sprintf "acc_%d" i) - (let%bind add_pt = Shifted.add acc pt in - let don't_add_pt = acc in - Shifted.if_ b ~then_:add_pt ~else_:don't_add_pt) - and pt' = double pt in - go (i + 1) bs acc' pt' - in - go 0 c init t) + let open Let_syntax in + let rec go i bs0 acc pt = + match bs0 with + | [] -> return acc + | b :: bs -> + let%bind acc' = + with_label (sprintf "acc_%d" i) + (let%bind add_pt = Shifted.add acc pt in + let don't_add_pt = acc in + Shifted.if_ b ~then_:add_pt ~else_:don't_add_pt) + and pt' = double pt in + go (i + 1) bs acc' pt' + in + go 0 c init t (* This 'looks up' a field element from a lookup table of size 2^2 = 4 with a 2 bit index. See https://github.com/zcash/zcash/issues/2234#issuecomment-383736266 for diff --git a/src/lib/snarky/src/jbuild b/src/lib/snarky/src/jbuild index ffdf8681153..78e8c6bb389 100644 --- a/src/lib/snarky/src/jbuild +++ b/src/lib/snarky/src/jbuild @@ -13,4 +13,4 @@ (-I re2_c/libre2) )) (preprocessor_deps ("../../../config.mlh")) - (preprocess (pps (ppx_jane ppx_deriving.enum ppx_deriving.eq bisect_ppx -conditional))))) + (preprocess (pps (ppx_snarky ppx_jane ppx_deriving.enum ppx_deriving.eq bisect_ppx -conditional))))) diff --git a/src/lib/snarky/src/snark0.ml b/src/lib/snarky/src/snark0.ml index 5bf7961ed29..a54383c0f0a 100644 --- a/src/lib/snarky/src/snark0.ml +++ b/src/lib/snarky/src/snark0.ml @@ -1112,11 +1112,10 @@ module Make_basic (Backend : Backend_intf.S) = struct let%bind y_inv = inv y in mul x y_inv) - let assert_non_zero (v : Cvar.t) = - with_label __LOC__ - (let open Let_syntax in - let%map _ = inv v in - ()) + let%snarkydef assert_non_zero (v : Cvar.t) = + let open Let_syntax in + let%map _ = inv v in + () module Boolean = struct type var = Cvar.t @@ -1232,17 +1231,14 @@ module Make_basic (Backend : Backend_intf.S) = struct let is_true (v : var) = assert_equal v true_ - let any (bs : var list) = - with_label __LOC__ (assert_non_zero (Cvar.sum bs)) + let%snarkydef any (bs : var list) = assert_non_zero (Cvar.sum bs) - let all (bs : var list) = - with_label __LOC__ - (assert_equal (Cvar.sum bs) - (Cvar.constant (Field.of_int (List.length bs)))) + let%snarkydef all (bs : var list) = + assert_equal (Cvar.sum bs) + (Cvar.constant (Field.of_int (List.length bs))) - let exactly_one (bs : var list) = - with_label __LOC__ - (assert_equal (Cvar.sum bs) (Cvar.constant Field.one)) + let%snarkydef exactly_one (bs : var list) = + assert_equal (Cvar.sum bs) (Cvar.constant Field.one) end module Expr = struct @@ -1503,7 +1499,7 @@ module Make_basic (Backend : Backend_intf.S) = struct let compare ~bit_length a b = let open Checked in let open Let_syntax in - with_label __LOC__ + [%with_label "compare"] (let alpha_packed = Cvar.Infix.(Cvar.constant (two_to_the bit_length) + b - a) in diff --git a/src/lib/transaction_snark/jbuild b/src/lib/transaction_snark/jbuild index 22fe8f6f31c..b3265f643ed 100644 --- a/src/lib/transaction_snark/jbuild +++ b/src/lib/transaction_snark/jbuild @@ -7,6 +7,6 @@ (library_flags (-linkall)) (inline_tests) (libraries (core cache_dir cached snarky coda_base bignum)) - (preprocess (pps (ppx_jane ppx_deriving.std bisect_ppx -conditional))) + (preprocess (pps (ppx_snarky ppx_jane ppx_deriving.std bisect_ppx -conditional))) (synopsis "Transaction state transition snarking library"))) diff --git a/src/lib/transaction_snark/transaction_snark.ml b/src/lib/transaction_snark/transaction_snark.ml index 0d579e9a251..13ba62e67e8 100644 --- a/src/lib/transaction_snark/transaction_snark.ml +++ b/src/lib/transaction_snark/transaction_snark.ml @@ -172,13 +172,12 @@ module Base = struct open Tick open Let_syntax - let check_signature shifted ~payload_section ~is_user_command ~sender - ~signature = - with_label __LOC__ - (let%bind verifies = - Schnorr.Checked.verifies shifted signature sender payload_section - in - Boolean.Assert.any [Boolean.not is_user_command; verifies]) + let%snarkydef check_signature shifted ~payload_section ~is_user_command + ~sender ~signature = + let%bind verifies = + Schnorr.Checked.verifies shifted signature sender payload_section + in + Boolean.Assert.any [Boolean.not is_user_command; verifies] let chain if_ b ~then_ ~else_ = let%bind then_ = then_ and else_ = else_ in @@ -200,92 +199,90 @@ module Base = struct - fee excess = -(amount + fee) *) (* Nonce should only be incremented if it is a "Normal" transaction. *) - let apply_tagged_transaction (type shifted) + let%snarkydef apply_tagged_transaction (type shifted) (shifted : (module Inner_curve.Checked.Shifted.S with type t = shifted)) root ({sender; signature; payload} : Transaction_union.var) = - with_label __LOC__ - (let nonce = payload.common.nonce in - let tag = payload.body.tag in - let%bind payload_section = Schnorr.Message.var_of_payload payload in - let%bind is_user_command = - Transaction_union.Tag.Checked.is_user_command tag - in - let%bind () = - check_signature shifted ~payload_section ~is_user_command ~sender - ~signature - in - let%bind {excess; sender_delta; supply_increase; receiver_increase} = - Transaction_union_payload.Changes.Checked.of_payload payload - in - let%bind is_stake_delegation = - Transaction_union.Tag.Checked.is_stake_delegation tag - in - let%bind sender_compressed = Public_key.compress_var sender in - let%bind root = - let%bind is_fee_transfer = - Transaction_union.Tag.Checked.is_fee_transfer tag - in - Frozen_ledger_hash.modify_account_send root ~is_fee_transfer - sender_compressed ~f:(fun ~is_empty_and_writeable account -> - with_label __LOC__ - (let%bind next_nonce = - Account.Nonce.increment_if_var account.nonce is_user_command - in - let%bind () = - with_label __LOC__ - (let%bind nonce_matches = - Account.Nonce.equal_var nonce account.nonce - in - Boolean.Assert.any - [Boolean.not is_user_command; nonce_matches]) - in - let%bind receipt_chain_hash = - let current = account.receipt_chain_hash in - let%bind r = - Receipt.Chain_hash.Checked.cons ~payload:payload_section - current + let nonce = payload.common.nonce in + let tag = payload.body.tag in + let%bind payload_section = Schnorr.Message.var_of_payload payload in + let%bind is_user_command = + Transaction_union.Tag.Checked.is_user_command tag + in + let%bind () = + check_signature shifted ~payload_section ~is_user_command ~sender + ~signature + in + let%bind {excess; sender_delta; supply_increase; receiver_increase} = + Transaction_union_payload.Changes.Checked.of_payload payload + in + let%bind is_stake_delegation = + Transaction_union.Tag.Checked.is_stake_delegation tag + in + let%bind sender_compressed = Public_key.compress_var sender in + let%bind root = + let%bind is_fee_transfer = + Transaction_union.Tag.Checked.is_fee_transfer tag + in + Frozen_ledger_hash.modify_account_send root ~is_fee_transfer + sender_compressed ~f:(fun ~is_empty_and_writeable account -> + with_label __LOC__ + (let%bind next_nonce = + Account.Nonce.increment_if_var account.nonce is_user_command + in + let%bind () = + with_label __LOC__ + (let%bind nonce_matches = + Account.Nonce.equal_var nonce account.nonce in - Receipt.Chain_hash.Checked.if_ is_user_command ~then_:r - ~else_:current - in - let%bind delegate = - let if_ = chain Public_key.Compressed.Checked.if_ in - if_ is_empty_and_writeable ~then_:(return sender_compressed) - ~else_: - (if_ is_stake_delegation - ~then_:(return payload.body.public_key) - ~else_:(return account.delegate)) - in - let%map balance = - Balance.Checked.add_signed_amount account.balance - sender_delta - in - { Account.balance - ; public_key= sender_compressed - ; nonce= next_nonce - ; receipt_chain_hash - ; delegate }) ) - in - let%bind receiver = - (* A stake delegation only uses the sender *) - Public_key.Compressed.Checked.if_ is_stake_delegation - ~then_:sender_compressed ~else_:payload.body.public_key - in - (* we explicitly set the public_key because it could be zero if the account is new *) - let%map root = - (* This update should be a no-op in the stake delegation case *) - Frozen_ledger_hash.modify_account_recv root receiver - ~f:(fun ~is_empty_and_writeable account -> + Boolean.Assert.any + [Boolean.not is_user_command; nonce_matches]) + in + let%bind receipt_chain_hash = + let current = account.receipt_chain_hash in + let%bind r = + Receipt.Chain_hash.Checked.cons ~payload:payload_section + current + in + Receipt.Chain_hash.Checked.if_ is_user_command ~then_:r + ~else_:current + in + let%bind delegate = + let if_ = chain Public_key.Compressed.Checked.if_ in + if_ is_empty_and_writeable ~then_:(return sender_compressed) + ~else_: + (if_ is_stake_delegation + ~then_:(return payload.body.public_key) + ~else_:(return account.delegate)) + in let%map balance = - (* receiver_increase will be zero in the stake delegation case *) - Balance.Checked.(account.balance + receiver_increase) - and delegate = - Public_key.Compressed.Checked.if_ is_empty_and_writeable - ~then_:receiver ~else_:account.delegate + Balance.Checked.add_signed_amount account.balance sender_delta in - {account with balance; delegate; public_key= receiver} ) - in - (root, excess, supply_increase)) + { Account.balance + ; public_key= sender_compressed + ; nonce= next_nonce + ; receipt_chain_hash + ; delegate }) ) + in + let%bind receiver = + (* A stake delegation only uses the sender *) + Public_key.Compressed.Checked.if_ is_stake_delegation + ~then_:sender_compressed ~else_:payload.body.public_key + in + (* we explicitly set the public_key because it could be zero if the account is new *) + let%map root = + (* This update should be a no-op in the stake delegation case *) + Frozen_ledger_hash.modify_account_recv root receiver + ~f:(fun ~is_empty_and_writeable account -> + let%map balance = + (* receiver_increase will be zero in the stake delegation case *) + Balance.Checked.(account.balance + receiver_increase) + and delegate = + Public_key.Compressed.Checked.if_ is_empty_and_writeable + ~then_:receiver ~else_:account.delegate + in + {account with balance; delegate; public_key= receiver} ) + in + (root, excess, supply_increase) (* Someday: write the following soundness tests: @@ -314,40 +311,35 @@ module Base = struct such that H(l1, l2, fee_excess, supply_increase) = top_hash, applying [t] to ledger with merkle hash [l1] results in ledger with merkle hash [l2]. *) - let main top_hash = - with_label __LOC__ - (let%bind (module Shifted) = - Tick.Inner_curve.Checked.Shifted.create () - in - let%bind root_before = - provide_witness' Frozen_ledger_hash.typ ~f:Prover_state.state1 - in - let%bind t = - with_label __LOC__ - (provide_witness' Transaction_union.typ ~f:Prover_state.transaction) - in - let%bind root_after, fee_excess, supply_increase = - apply_tagged_transaction (module Shifted) root_before t - in - let%map () = - with_label __LOC__ - (let%bind b1 = Frozen_ledger_hash.var_to_triples root_before - and b2 = Frozen_ledger_hash.var_to_triples root_after - and sok_digest = - provide_witness' Sok_message.Digest.typ - ~f:Prover_state.sok_digest - in - let fee_excess = Amount.Signed.Checked.to_triples fee_excess in - let supply_increase = Amount.var_to_triples supply_increase in - let triples = - Sok_message.Digest.Checked.to_triples sok_digest - @ b1 @ b2 @ supply_increase @ fee_excess - in - Pedersen.Checked.digest_triples ~init:Hash_prefix.base_snark - triples - >>= Field.Checked.Assert.equal top_hash) - in - ()) + let%snarkydef main top_hash = + let%bind (module Shifted) = Tick.Inner_curve.Checked.Shifted.create () in + let%bind root_before = + provide_witness' Frozen_ledger_hash.typ ~f:Prover_state.state1 + in + let%bind t = + with_label __LOC__ + (provide_witness' Transaction_union.typ ~f:Prover_state.transaction) + in + let%bind root_after, fee_excess, supply_increase = + apply_tagged_transaction (module Shifted) root_before t + in + let%map () = + with_label __LOC__ + (let%bind b1 = Frozen_ledger_hash.var_to_triples root_before + and b2 = Frozen_ledger_hash.var_to_triples root_after + and sok_digest = + provide_witness' Sok_message.Digest.typ ~f:Prover_state.sok_digest + in + let fee_excess = Amount.Signed.Checked.to_triples fee_excess in + let supply_increase = Amount.var_to_triples supply_increase in + let triples = + Sok_message.Digest.Checked.to_triples sok_digest + @ b1 @ b2 @ supply_increase @ fee_excess + in + Pedersen.Checked.digest_triples ~init:Hash_prefix.base_snark triples + >>= Field.Checked.Assert.equal top_hash) + in + () let create_keys () = generate_keypair main ~exposing:(tick_input ()) @@ -828,39 +820,38 @@ struct constraints pass iff (b1, b2, .., bn) = unpack input, there is a proof making one of [ base_vk; merge_vk ] accept (b1, b2, .., bn) *) - let main (input : Wrap_input.var) = + let%snarkydef main (input : Wrap_input.var) = let open Let_syntax in - with_label __LOC__ - (let%bind input = Wrap_input.Checked.to_scalar input in - let%bind is_base = - provide_witness' Boolean.typ ~f:(fun {Prover_state.proof_type; _} -> - Proof_type.is_base proof_type ) - in - let verification_key = - Verifier.Verification_key.Checked.if_value is_base ~then_:base_vk - ~else_:merge_vk - in - let%bind vk_data, result = - (* someday: Probably an opportunity for optimization here since + let%bind input = Wrap_input.Checked.to_scalar input in + let%bind is_base = + provide_witness' Boolean.typ ~f:(fun {Prover_state.proof_type; _} -> + Proof_type.is_base proof_type ) + in + let verification_key = + Verifier.Verification_key.Checked.if_value is_base ~then_:base_vk + ~else_:merge_vk + in + let%bind vk_data, result = + (* someday: Probably an opportunity for optimization here since we are passing in one of two known verification keys. *) - with_label __LOC__ - (Verifier.All_in_one.check_proof verification_key - ~get_vk: - As_prover.( - map get_state ~f:(fun {Prover_state.proof_type; _} -> - match proof_type with - | `Base -> Vk.base - | `Merge -> Vk.merge )) - ~get_proof:As_prover.(map get_state ~f:Prover_state.proof) - [input]) - in - let%bind () = - with_label __LOC__ - (Verifier.Verification_key_data.Checked.Assert.equal - (Verifier.Verification_key.Checked.to_full_data verification_key) - vk_data) - in - Boolean.Assert.is_true result) + with_label __LOC__ + (Verifier.All_in_one.check_proof verification_key + ~get_vk: + As_prover.( + map get_state ~f:(fun {Prover_state.proof_type; _} -> + match proof_type with + | `Base -> Vk.base + | `Merge -> Vk.merge )) + ~get_proof:As_prover.(map get_state ~f:Prover_state.proof) + [input]) + in + let%bind () = + with_label __LOC__ + (Verifier.Verification_key_data.Checked.Assert.equal + (Verifier.Verification_key.Checked.to_full_data verification_key) + vk_data) + in + Boolean.Assert.is_true result let create_keys () = generate_keypair ~exposing:wrap_input main diff --git a/src/ppx_snarky.opam b/src/ppx_snarky.opam new file mode 100644 index 00000000000..3f309a68f5b --- /dev/null +++ b/src/ppx_snarky.opam @@ -0,0 +1,6 @@ +opam-version: "1.2" +version: "0.1" +build: [ + ["dune" "build" "--only" "src" "--root" "." "-j" jobs "@install"] +] +