Skip to content

Commit

Permalink
Merge pull request #380 from omerfirmak/elim-pedersen-alloc
Browse files Browse the repository at this point in the history
Precompute point multiplication results in pedersen
  • Loading branch information
yelhousni committed Apr 14, 2023
2 parents ebbf692 + 9c9d107 commit e500f2f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 51 deletions.
119 changes: 71 additions & 48 deletions ecc/stark-curve/pedersen-hash/pedersen_hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

const nibbleCount = fp.Bits / 4

var (
shiftPoint starkcurve.G1Jac
p0 starkcurve.G1Jac
p1 starkcurve.G1Jac
p2 starkcurve.G1Jac
p3 starkcurve.G1Jac
shiftPoint starkcurve.G1Jac
pointIndexed [4][nibbleCount][16]*starkcurve.G1Jac
p [4]starkcurve.G1Jac
)

func init() {
Expand All @@ -24,66 +24,89 @@ func init() {
shiftPoint.Y.SetString("1713931329540660377023406109199410414810705867260802078187082345529207694986")
shiftPoint.Z.SetOne()

p0.X.SetString("996781205833008774514500082376783249102396023663454813447423147977397232763")
p0.Y.SetString("1668503676786377725805489344771023921079126552019160156920634619255970485781")
p0.Z.SetOne()

p1.X.SetString("2251563274489750535117886426533222435294046428347329203627021249169616184184")
p1.Y.SetString("1798716007562728905295480679789526322175868328062420237419143593021674992973")
p1.Z.SetOne()

p2.X.SetString("2138414695194151160943305727036575959195309218611738193261179310511854807447")
p2.Y.SetString("113410276730064486255102093846540133784865286929052426931474106396135072156")
p2.Z.SetOne()

p3.X.SetString("2379962749567351885752724891227938183011949129833673362440656643086021394946")
p3.Y.SetString("776496453633298175483985398648758586525933812536653089401905292063708816422")
p3.Z.SetOne()
p[0].X.SetString("996781205833008774514500082376783249102396023663454813447423147977397232763")
p[0].Y.SetString("1668503676786377725805489344771023921079126552019160156920634619255970485781")
p[0].Z.SetOne()

p[1].X.SetString("2251563274489750535117886426533222435294046428347329203627021249169616184184")
p[1].Y.SetString("1798716007562728905295480679789526322175868328062420237419143593021674992973")
p[1].Z.SetOne()

p[2].X.SetString("2138414695194151160943305727036575959195309218611738193261179310511854807447")
p[2].Y.SetString("113410276730064486255102093846540133784865286929052426931474106396135072156")
p[2].Z.SetOne()

p[3].X.SetString("2379962749567351885752724891227938183011949129833673362440656643086021394946")
p[3].Y.SetString("776496453633298175483985398648758586525933812536653089401905292063708816422")
p[3].Z.SetOne()

var multiplier big.Int
for pointIndex, point := range p {
var nibbleIndexed [nibbleCount][16]*starkcurve.G1Jac
for nibIndex := uint(0); nibIndex < nibbleCount; nibIndex++ {
var selectorIndexed [16]*starkcurve.G1Jac
for selector := 0; selector < 16; selector++ {
multiplier.SetUint64(uint64(selector))
multiplier.Lsh(&multiplier, nibIndex*4)

res := point
res.ScalarMultiplication(&res, &multiplier)
selectorIndexed[selector] = &res
}
nibbleIndexed[nibIndex] = selectorIndexed
}
pointIndexed[pointIndex] = nibbleIndexed
}
}

// PedersenArray implements [Pedersen array hashing].
//
// [Pedersen array hashing]: https://docs.starknet.io/documentation/develop/Hashing/hash-functions/#array_hashing
func PedersenArray(elems ...*fp.Element) *fp.Element {
d := new(fp.Element)
func PedersenArray(elems ...*fp.Element) fp.Element {
var d fp.Element
for _, e := range elems {
d = Pedersen(d, e)
d = Pedersen(&d, e)
}
return Pedersen(d, new(fp.Element).SetUint64(uint64(len(elems))))
return Pedersen(&d, new(fp.Element).SetUint64(uint64(len(elems))))
}

// Pedersen implements the [Pedersen hash] based on the [reference implementation].
//
// [Pedersen hash]: https://docs.starknet.io/documentation/develop/Hashing/hash-functions/#pedersen_hash
// [reference implementation]: https://github.com/starkware-libs/cairo-lang/blob/de741b92657f245a50caab99cfaef093152fd8be/src/starkware/crypto/signature/fast_pedersen_hash.py
func Pedersen(a *fp.Element, b *fp.Element) *fp.Element {

result := new(starkcurve.G1Jac).Set(&shiftPoint)
func Pedersen(a *fp.Element, b *fp.Element) fp.Element {
acc := shiftPoint
accumulate := func(bytes []byte, nibbleIndexed [nibbleCount][16]*starkcurve.G1Jac) {
for i, val := range bytes {
lowNibble := val & 0x0F
index := len(bytes) - i - 1

if lowNibble > 0 {
lowNibbleIndex := 2 * index
acc.AddAssign(nibbleIndexed[lowNibbleIndex][lowNibble])
}

highNibble := (val & 0xF0) >> 4

if highNibble > 0 {
highNibbleIndex := (2 * index) + 1
acc.AddAssign(nibbleIndexed[highNibbleIndex][highNibble])
}
}
}

var point starkcurve.G1Jac
result.AddAssign(processElement(a, &p0, &p1, &point))
result.AddAssign(processElement(b, &p2, &p3, &point))
aBytes := a.Bytes()
accumulate(aBytes[1:], pointIndexed[0])
accumulate(aBytes[:1], pointIndexed[1])
bBytes := b.Bytes()
accumulate(bBytes[1:], pointIndexed[2])
accumulate(bBytes[:1], pointIndexed[3])

// recover the affine x coordinate
var x fp.Element
x.Inverse(&result.Z).
x.Inverse(&acc.Z).
Square(&x)
x.Mul(&result.X, &x)

return &x
}

func processElement(a *fp.Element, p1 *starkcurve.G1Jac, p2 *starkcurve.G1Jac, res *starkcurve.G1Jac) *starkcurve.G1Jac {
var bigInt big.Int
var aBytes [32]byte
a.BigInt(&bigInt).FillBytes(aBytes[:])

highPart := bigInt.SetUint64(uint64(aBytes[0])) // The top nibble (bits 249-252)
lowPart := aBytes[1:] // Zero-out the top nibble (bits 249-252)

res.ScalarMultiplication(p2, highPart)
x.Mul(&acc.X, &x)

var n starkcurve.G1Jac
n.ScalarMultiplication(p1, bigInt.SetBytes(lowPart))
return res.AddAssign(&n)
return x
}
6 changes: 3 additions & 3 deletions ecc/stark-curve/pedersen-hash/pedersen_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestPedersenArray(t *testing.T) {
}
}

var feltBench *fp.Element
var feltBench fp.Element

// go test -bench=. -run=^# -cpu=1,2,4,8,16
func BenchmarkPedersenArray(b *testing.B) {
Expand All @@ -133,7 +133,7 @@ func BenchmarkPedersenArray(b *testing.B) {

for _, i := range numOfElems {
b.Run(fmt.Sprintf("Number of felts: %d", i), func(b *testing.B) {
var f *fp.Element
var f fp.Element
randomFelts := createRandomFelts(i)
for n := 0; n < b.N; n++ {
f = PedersenArray(randomFelts...)
Expand All @@ -154,7 +154,7 @@ func BenchmarkPedersen(b *testing.B) {
b.Errorf("Error occured %s", err)
}

var f *fp.Element
var f fp.Element
for n := 0; n < b.N; n++ {
f = Pedersen(e0, e1)
}
Expand Down

0 comments on commit e500f2f

Please sign in to comment.