Skip to content

Commit

Permalink
Poseidon hash function (#3455)
Browse files Browse the repository at this point in the history
After implementing Rescue, I benchmarked it and found it was really slow

```
┌─────────────────────────────────────────────┬────────────┬──────────┬──────────┬──────────┬────────────┐
│ Name                                        │   Time/Run │  mWd/Run │ mjWd/Run │ Prom/Run │ Percentage │
├─────────────────────────────────────────────┼────────────┼──────────┼──────────┼──────────┼────────────┤
│ [lib/snark_params/snark_params.ml] pedersen │   152.88us │  22.56kw │   1.32kw │   1.32kw │      5.10% │
│ [lib/snark_params/snark_params.ml] rescue   │ 2_995.88us │ 120.05kw │ 118.56kw │ 118.56kw │    100.00% │
│ [lib/snark_params/snark_params.ml] poseidon │   132.82us │   6.03kw │   4.71kw │   4.71kw │      4.43% │
└─────────────────────────────────────────────┴────────────┴──────────┴──────────┴──────────┴────────────┘
```

but as you can see from that table, the [Poseidon](https://eprint.iacr.org/2019/458.pdf) implementation in this PR is way faster! I'm not 100% sure about the number of rounds I chose. It may need to be a bit higher.
  • Loading branch information
imeckler committed Sep 19, 2019
1 parent 29a4841 commit 171be4a
Show file tree
Hide file tree
Showing 12 changed files with 571 additions and 72 deletions.
378 changes: 378 additions & 0 deletions src/lib/crypto_params/rescue_params.ml

Large diffs are not rendered by default.

15 changes: 0 additions & 15 deletions src/lib/rescue/inputs.ml

This file was deleted.

2 changes: 1 addition & 1 deletion src/lib/snark_params/dune
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
(public_name snark_params)
(library_flags -linkall)
(inline_tests)
(libraries rescue group_map fold_lib o1trace coda_digestif tuple_lib bitstring_lib
(libraries sponge group_map fold_lib o1trace coda_digestif tuple_lib bitstring_lib
snarky_group_map
snarky_bowe_gabizon_hash
core_kernel snarky snarky_verifier snarky_field_extensions snarky_curves
Expand Down
19 changes: 18 additions & 1 deletion src/lib/snark_params/snark_params.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Tick_backend = Crypto_params.Tick_backend
module Tock_backend = Crypto_params.Tock_backend
module Snarkette_tick = Crypto_params.Snarkette_tick
module Snarkette_tock = Crypto_params.Snarkette_tock
module Rescue = Rescue_inst
module Sponge = Sponge_inst

module Make_snarkable (Impl : Snarky.Snark_intf.S) = struct
open Impl
Expand Down Expand Up @@ -793,3 +793,20 @@ let pending_coinbase_depth =
let target_bit_length = Tick.Field.size_in_bits - 8

module type Snark_intf = Snark_intf.S

let%bench_fun "pedersen" =
let open Tick in
let x = Field.random () |> Field.unpack in
fun () ->
Pedersen.digest_fold (Pedersen.State.create ())
Fold_lib.Fold.(group3 ~default:false (of_list x))

let%bench_fun "rescue" =
let open Tick in
let x = Field.random () in
fun () -> Sponge.hash [|x|]

let%bench_fun "poseidon" =
let open Tick in
let x = Field.random () in
fun () -> Sponge.Poseidon.hash Sponge.params [|x|]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
open Core

let params : _ Rescue.Params.t =
let params : _ Sponge.Params.t =
let open Crypto_params.Rescue_params in
{mds; round_constants}

Expand All @@ -9,10 +9,34 @@ module Inputs = struct

let to_the_alpha x =
let open Field in
let zero = square in
let one a = square a * x in
let one' = x in
one' |> zero |> one |> one
let res = x + zero in
res *= res ;
(* x^2 *)
res *= res ;
(* x^4 *)
res *= x ;
(* x^5 *)
res *= res ;
(* x^10 *)
res *= x ;
res

module Operations = struct
let apply_matrix rows v =
Array.map rows ~f:(fun row ->
let open Field in
let res = zero + zero in
Array.iteri row ~f:(fun i r -> res += (r * v.(i))) ;
res )

let add_block ~state block =
Array.iteri block ~f:(fun i b ->
let open Field in
state.(i) += b )

(* TODO: Have an explicit function for making a copy of a field element. *)
let copy a = Array.map a ~f:(fun x -> Field.(x + zero))
end

let alphath_root =
let inv_alpha =
Expand Down Expand Up @@ -50,8 +74,8 @@ module Inputs = struct
[%test_eq: Field.t] (to_the_alpha root) x
end

include Rescue.Make (Inputs)
module State = Rescue.State
include Sponge.Make (Sponge.Rescue (Inputs))
module State = Sponge.State

let update ~state = update ~state params

Expand All @@ -78,11 +102,15 @@ module Checked = struct
in
let y10 = y |> square |> square |> ( * ) y |> square in
assert_r1cs y10 y x ; y

let apply_matrix = None

module Operations = Sponge.Make_operations (Field)
end

include Rescue.Make (Inputs)
include Sponge.Make (Sponge.Rescue (Inputs))

let hash = hash (Rescue.Params.map ~f:Field.constant params)
let hash = hash (Sponge.Params.map ~f:Field.constant params)
end

let%test_unit "iterativeness" =
Expand All @@ -109,3 +137,5 @@ let%test_unit "rescue" =
)
(fun (x, y) -> hash [|x; y|])
(x, y)

module Poseidon = Sponge.Make (Sponge.Poseidon (Inputs))
4 changes: 2 additions & 2 deletions src/lib/rescue/dune → src/lib/sponge/dune
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(library
(name rescue)
(public_name rescue)
(name sponge)
(public_name sponge)
(preprocess (pps ppx_jane ppx_deriving.eq))
(inline_tests)
(libraries
Expand Down
45 changes: 45 additions & 0 deletions src/lib/sponge/intf.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
module type Field = sig
type t

val zero : t

val ( * ) : t -> t -> t

val ( + ) : t -> t -> t
end

module type Operations = sig
module Field : Field

val add_block : state:Field.t array -> Field.t array -> unit

val apply_matrix : Field.t array array -> Field.t array -> Field.t array

val copy : Field.t array -> Field.t array
end

module Inputs = struct
module type Common = sig
module Field : Field

val to_the_alpha : Field.t -> Field.t

module Operations : Operations with module Field := Field
end

module type Rescue = sig
include Common

val alphath_root : Field.t -> Field.t
end
end

module type Permutation = sig
module Field : Field

val add_block : state:Field.t array -> Field.t array -> unit

val copy : Field.t array -> Field.t array

val block_cipher : Field.t Params.t -> Field.t array -> Field.t array
end
8 changes: 8 additions & 0 deletions src/lib/sponge/params.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
open Core_kernel

type 'a t = {mds: 'a array array; round_constants: 'a array array}
[@@deriving bin_io]

let map {mds; round_constants} ~f =
let f = Array.map ~f:(Array.map ~f) in
{mds= f mds; round_constants= f round_constants}
4 changes: 2 additions & 2 deletions src/lib/rescue/params.sage → src/lib/sponge/params.sage
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ def random_value(F, prefix, i):
return F(int(hashlib.sha256('%s%d' % (prefix, i)).hexdigest(), 16))

m = 3
rounds = 11
rounds = 50

prefix = 'CodaRescue'

def round_constants(F):
name = prefix + 'RoundConstants'
return [ [ random_value(F, name, r * m + i) for i in xrange(m) ]
for r in xrange(2 * rounds + 1) ]
for r in xrange( rounds ) ]

def matrix_str(rows):
return '[|' + ';'.join('[|' + ';'.join('Field.of_string "{}"'.format(str(x)) for x in row) + '|]' for row in rows) + '|]'
Expand Down
100 changes: 66 additions & 34 deletions src/lib/rescue/rescue.ml → src/lib/sponge/sponge.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
open Core_kernel
module Params = Params
module State = Array

let for_ n ~init ~f =
let rec go i acc = if Int.(i = n) then acc else go (i + 1) (f i acc) in
go 0 init

(*
module Make_operations (Field : Intf.Field) = struct
let add_block ~state block =
Array.iteri block ~f:(fun i bi -> state.(i) <- Field.( + ) state.(i) bi)

let apply_matrix matrix v =
let dotv row =
Array.reduce_exn (Array.map2_exn v row ~f:Field.( * )) ~f:Field.( + )
in
Array.map matrix ~f:dotv

let copy = Array.copy
end

let m = 3

module Rescue (Inputs : Intf.Inputs.Rescue) = struct
(*
We refer below to this paper: https://eprint.iacr.org/2019/426.pdf.
I arrived at this value for the number of rounds in the following way.
Expand All @@ -24,51 +46,61 @@ NB: As you can see from the analysis this is really specialized to alpha = 11 an
should be higher for smaller alpha.
*)

let rounds = 11
let rounds = 11

let m = 3
open Inputs
include Operations
module Field = Field

module Params = struct
type 'a t = {mds: 'a array array; round_constants: 'a array array}
[@@deriving bin_io]
let sbox0, sbox1 = (alphath_root, to_the_alpha)

let map {mds; round_constants} ~f =
let f = Array.map ~f:(Array.map ~f) in
{mds= f mds; round_constants= f round_constants}
let block_cipher {Params.round_constants; mds} state =
add_block ~state round_constants.(0) ;
for_ (2 * rounds) ~init:state ~f:(fun r state ->
let sbox = if Int.(r mod 2 = 0) then sbox0 else sbox1 in
Array.map_inplace state ~f:sbox ;
let state = apply_matrix mds state in
add_block ~state round_constants.(r + 1) ;
state )
end

module State = Array

module Make (Inputs : Inputs.S) = struct
module Poseidon (Inputs : Intf.Inputs.Common) = struct
open Inputs
include Operations
module Field = Field

let add_block ~state block =
Array.iteri block ~f:(fun i bi -> state.(i) <- Field.( + ) state.(i) bi)
let rounds_full = 8

let sponge perm blocks ~state =
Array.fold ~init:state blocks ~f:(fun state block ->
add_block ~state block ; perm state )
let rounds_partial = 33

let sbox0, sbox1 = (alphath_root, to_the_alpha)
let half_rounds_full = rounds_full / 2

let for_ n ~init ~f =
let rec go i acc = if Int.(i = n) then acc else go (i + 1) (f i acc) in
go 0 init
let%test "rounds_full" = half_rounds_full * 2 = rounds_full

let apply matrix v =
let dotv row =
Array.reduce_exn (Array.map2_exn v row ~f:Field.( * )) ~f:Field.( + )
in
Array.map matrix ~f:dotv
let for_ n init ~f = for_ n ~init ~f

let block_cipher {Params.round_constants; mds} state =
add_block ~state round_constants.(0) ;
for_ (2 * rounds) ~init:state ~f:(fun r state ->
let sbox = if Int.(r mod 2 = 0) then sbox0 else sbox1 in
Array.map_inplace state ~f:sbox ;
let state = apply mds state in
add_block ~state round_constants.(r + 1) ;
state )
let sbox = to_the_alpha in
let full_half start =
for_ half_rounds_full ~f:(fun r state ->
add_block ~state round_constants.(start + r) ;
Array.map_inplace state ~f:sbox ;
apply_matrix mds state )
in
full_half 0 state
|> for_ rounds_partial ~f:(fun r state ->
add_block ~state round_constants.(half_rounds_full + r) ;
state.(0) <- sbox state.(0) ;
apply_matrix mds state )
|> full_half (half_rounds_full + rounds_partial)
end

module Make (P : Intf.Permutation) = struct
open P

let sponge perm blocks ~state =
Array.fold ~init:state blocks ~f:(fun state block ->
add_block ~state block ; perm state )

let to_blocks r a =
let n = Array.length a in
Expand All @@ -88,7 +120,7 @@ module Make (Inputs : Inputs.S) = struct
let r = m - 1

let update params ~state inputs =
let state = Array.copy state in
let state = copy state in
sponge (block_cipher params) (to_blocks r inputs) ~state

let digest state = state.(0)
Expand Down
20 changes: 12 additions & 8 deletions src/lib/rescue/rescue.mli → src/lib/sponge/sponge.mli
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
module Params : sig
type 'a t = {mds: 'a array array; round_constants: 'a array array}
[@@deriving bin_io]

val map : 'a t -> f:('a -> 'b) -> 'b t
end
module Params = Params

module State : sig
type 'a t = 'a array

val map : 'a t -> f:('a -> 'b) -> 'b t
end

module Make (Inputs : Inputs.S) : sig
open Inputs
module Rescue (Inputs : Intf.Inputs.Rescue) :
Intf.Permutation with module Field = Inputs.Field

module Poseidon (Inputs : Intf.Inputs.Common) :
Intf.Permutation with module Field = Inputs.Field

module Make_operations (Field : Intf.Field) :
Intf.Operations with module Field := Field

module Make (P : Intf.Permutation) : sig
open P

val update :
Field.t Params.t
Expand Down
File renamed without changes.

0 comments on commit 171be4a

Please sign in to comment.