Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging develop into master #7

Merged
merged 9 commits into from
Mar 2, 2020
112 changes: 36 additions & 76 deletions cs/fft/fft.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package fft

import (
"math/bits"
"runtime"
"sync"

"github.com/consensys/gnark/cs/internal/curve"
Expand All @@ -28,120 +29,79 @@ import (
// The result is in bit-reversed order.
// len(a) must be a power of 2, and w must be a len(a)th root of unity in field F.
// The algorithm is recursive, decimation-in-frequency. [cite]
func FFT(a []curve.Element, w curve.Element) {
func FFT(a []curve.Element, w curve.Element, numCPU ...uint) {
var wg sync.WaitGroup
asyncFFT(a, w, &wg)
asyncFFT(a, w, &wg, 1)
wg.Wait()
bitReverse(a)
}

// Coset Evaluation on ker(X^n+1)
func Coset(a []curve.Element, w curve.Element, wSqrt curve.Element) {
wSqrtCopy := wSqrt
for i := 1; i < len(a); i++ {
a[i].MulAssign(&wSqrtCopy)
wSqrtCopy.MulAssign(&wSqrt)
}

FFT(a, w)
BitReverse(a)
}

// InvCoset Get back polynomial from its values on ker X^n+1
func InvCoset(a []curve.Element, w curve.Element, wSqrt curve.Element) {

var wInv, wSqrtInv curve.Element
wInv.Inverse(&w)
wSqrtInv.Inverse(&wSqrt)
wsqrtInvCpy := wSqrtInv

Inv(a, wInv)
BitReverse(a)

for i := 1; i < len(a); i++ {
a[i].MulAssign(&wSqrtInv)
wSqrtInv.MulAssign(&wsqrtInvCpy)
}
}

func asyncFFT(a []curve.Element, w curve.Element, wg *sync.WaitGroup) {
func asyncFFT(a []curve.Element, w curve.Element, wg *sync.WaitGroup, splits uint) {
n := len(a)
if n == 1 {
return
}
m := n / 2
m := n >> 1

// wPow == w^1
wPow := w

// i == 0
tmp := a[0]
t := a[0]
a[0].AddAssign(&a[m])
a[m].Sub(&tmp, &a[m])
a[m].Sub(&t, &a[m])

for i := 1; i < m; i++ {
tmp = a[i]
t = a[i]
a[i].AddAssign(&a[i+m])

a[i+m].
Sub(&tmp, &a[i+m]).
Sub(&t, &a[i+m]).
MulAssign(&wPow)

wPow.MulAssign(&w)
}

// if m == 1, then next iteration ends, no need to call 2 extra functions for that
if m == 1 {
return
}

// note: w is passed by value
w.Square(&w)

if m < 20 {
asyncFFT(a[0:m], w, nil)
asyncFFT(a[m:n], w, nil)
const parallelThreshold = 64
serial := splits > uint(runtime.NumCPU()) || m <= parallelThreshold

if serial {
asyncFFT(a[0:m], w, nil, splits)
asyncFFT(a[m:n], w, nil, splits)
} else {
wg.Add(2)
pool.Push(func() {
asyncFFT(a[0:m], w, wg)
wg.Done()
}, true)
splits <<= 1
wg.Add(1)
pool.Push(func() {
asyncFFT(a[m:n], w, wg)
asyncFFT(a[m:n], w, wg, splits)
wg.Done()
}, true)
// TODO fixme that seems risky behavior and could starve the thread pool
// we may want to push that as a taks in the pool too?.
asyncFFT(a[0:m], w, wg, splits)
}
}

// Inv computes the inverse discrete Fourier transform of a and stores the result in a.
// See FFT for more info.
func Inv(a []curve.Element, wInv curve.Element) {
var wg sync.WaitGroup
asyncFFT(a, wInv, &wg)
wg.Wait()

// scale by inverse of n
var nInv curve.Element
nInv.SetUint64(uint64(len(a)))
nInv.Inverse(&nInv)

for i := 0; i < len(a); i++ {
a[i].MulAssign(&nInv)
}
}

// BitReverse applies the bit-reversal permutation to a.
// bitReverse applies the bit-reversal permutation to a.
// len(a) must be a power of 2 (as in every single function in this file)
func BitReverse(a []curve.Element) {
l := uint(len(a))
n := uint(bits.UintSize - bits.TrailingZeros(l))
func bitReverse(a []curve.Element) {
n := uint(len(a))
nn := uint(bits.UintSize - bits.TrailingZeros(n))

var tmp curve.Element
for i := uint(0); i < l; i++ {
irev := bits.Reverse(i) >> n
var tReverse curve.Element
for i := uint(0); i < n; i++ {
irev := bits.Reverse(i) >> nn
if irev > i {
tmp = a[i]
tReverse = a[i]
a[i] = a[irev]
a[irev] = tmp
a[irev] = tReverse
}
}
}

func reverse(x, n int) int {
return int(bits.Reverse(uint(x)) >> (bits.UintSize - uint(n)))
}
66 changes: 1 addition & 65 deletions cs/fft/fft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,76 +49,12 @@ func TestFFT(t *testing.T) {
fftExpected[2].SetString("8444461749428370424248824938781546531375899335154063827935233455917408882477")
fftExpected[3].SetString("8444461749428193732362745809358310391547622204027831664851124434067521319365")
FFT(poly, w)
BitReverse(poly)

for i := 0; i < 4; i++ {
if !poly[i].Equal(&fftExpected[i]) {
t.Fatal("Error fft")
}
}

Inv(fftExpected, winv)
BitReverse(fftExpected)
for i := 0; i < 4; i++ {
if !polyCpy[i].Equal(&fftExpected[i]) {
t.Fatal("Error inv fft")
}
}

}

func TestFFTCoset(t *testing.T) {

var wsqrt, w Element
// primitive 8-th root of 1
wsqrt.SetString("3279917132858342911831074864712036382710139745724269329239664300762234227201")

// primitive 4-th root of 1
w.SetString("880904806456922042258150504921383618666682042621506879489")

poly := make([]Element, 4)
poly[0].SetString("1223")
poly[1].SetString("9283")
poly[2].SetString("2323")
poly[3].SetString("29832")
polyCpy := make([]Element, 4)
copy(polyCpy[:], poly[:])

polyCoset := make([]Element, 4)
polyCoset[0].SetString("6744231264996566884193988396561893970787357999391009292610442572606065589798")
polyCoset[1].SetString("117515726529979382411741906321656162865657092943595752906312217939318191217")
polyCoset[2].SetString("1700230484431807632738567341079460891955787200511346860729560902832305757583")
polyCoset[3].SetString("8326946022898386949153352233600082037142996377462175749624151218457128944376")

Coset(poly, w, wsqrt)

for i := 0; i < 4; i++ {
if !poly[i].Equal(&polyCoset[i]) {
t.Fatal("Error FFT coset")
}
}

InvCoset(polyCoset, w, wsqrt)

for i := 0; i < 4; i++ {
if !polyCoset[i].Equal(&polyCpy[i]) {
t.Fatal("Error Inv FFT coset")
}
}
}

func TestReverse(t *testing.T) {

got := [8]int{0, 1, 2, 3, 4, 5, 6, 7}
want := [8]int{0, 4, 2, 6, 1, 5, 3, 7}

for i := range got {
got[i] = reverse(got[i], 3)
}

if got != want {
t.Error("expected:", want, "received:", got)
}
}

func TestBitReverse(t *testing.T) {
Expand All @@ -133,7 +69,7 @@ func TestBitReverse(t *testing.T) {
got[6].SetUint64(7)
got[7].SetUint64(8)

BitReverse(got[:])
bitReverse(got[:])

var want [8]Element // not in Mongomery form
want[0].SetUint64(1)
Expand Down
88 changes: 88 additions & 0 deletions cs/groth16/computeh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// +build bls377 !bn256,!bls381

package groth16

import (
"testing"

"github.com/consensys/gnark/cs/internal/curve"
)

func TestComputeH(t *testing.T) {
const n = 10
A := make([]curve.Element, n)
B := make([]curve.Element, n)
C := make([]curve.Element, n)
expectedH := make([]curve.Element, 16)
A[0].SetString("3978970733340151541160460328870183152087563023891174578722607319103291602338")
B[0].SetString("5519280666570261192572043782616932099550224742860556770901989646757503432308")
C[0].SetString("6924800278067507252485509571380959703366298839796086977356452915404508079610")
A[1].SetString("7488318508387618945409062109743609505795627605956394772982737920747334223949")
B[1].SetString("8344843539136514282835507380553483054180871257261777237158217935405423515807")
C[1].SetString("3559085210450098678950834866439716591786887598440943937813201271498457354927")
A[2].SetString("2420147635162995286688788630254374502635517119709704618441221192289805138677")
B[2].SetString("2093851841774026535211048833225589090274546321368120036475536362766848069838")
C[2].SetString("8242202799862391990981745519417743168238212536863449080586424409246524448574")
A[3].SetString("6227086318545100046816294321804286844155281348233488397533092282104036301817")
B[3].SetString("7548613128257608189511512397331405541277652678737187577586327641991194308266")
C[3].SetString("4712036668525324077096858567740257232563542247254604171851057954830752600814")
A[4].SetString("1767010469588943385772232782098207436123315739970815871255308076402460522716")
B[4].SetString("5268185495508136628608211189303431587883163031828168456149730213611887667026")
C[4].SetString("6916385412012417303128337989972639505148468126118315980201834164352691671944")
A[5].SetString("7012618530441313385856791555752908929193146678873667252815093157281979030668")
B[5].SetString("4613857718522236696579576340633930511200370490657902545003758129208235025491")
C[5].SetString("5710843409139585140307434852998554460427461639665464769969561314658262741132")
A[6].SetString("5335789919045347874047089146053320712611529972859803732422946582463661809041")
B[6].SetString("3427309601388481707169300492532148271131969221969230339157463986353574082441")
C[6].SetString("4063889622970899690944300311723166625169306585034435683821734007120798971335")
A[7].SetString("2783394209260384458035044348517864917631460882587334312970627663656034678301")
B[7].SetString("2343136976545616164994616713640929188970757024547416192265132760761852966246")
C[7].SetString("483081301362591834919594860128879516660765040402111809154071487826063748557")
A[8].SetString("4156793256373737655400392381211896434026931710499794180622414734063356840123")
B[8].SetString("7432467422285966377084439669797524986662665362270033039778175315423267847940")
C[8].SetString("2874703728404464659469597888806870849452756237706576809296293810563730519986")
A[9].SetString("8199927226660795908332543744006148752900427840674135842753812812653758490806")
B[9].SetString("7066001929692942188517204166766548971985451139710210313151579366973338638446")
C[9].SetString("8321403640356687097026315142151258279240079298980314559771806416409465014477")
expectedH[0].SetString("7782398469555828294647752617114515444715099999479061815771163450140531726624")
expectedH[1].SetString("4701689288062574458155554348017820064085128967868041479068223690213808159153")
expectedH[2].SetString("5773240734661054821105058335364507879215455884736431041075930770153196827417")
expectedH[3].SetString("4430008663075340055570681260271475508744453011644659436594186985777875239909")
expectedH[4].SetString("5293159083690165590347068946093606540443898801048204205045300446831727744815")
expectedH[5].SetString("272177142499885579493716010452276371584654739076175567328412592813313406252")
expectedH[6].SetString("5995984606756942227830675909734452947413901427664788436135418074562426104275")
expectedH[7].SetString("6808098446036842299283970629392738739278889336880949126600163314411314787657")
expectedH[8].SetString("48857087950413550241731164205916036550846184126544125671045818548092037649")
expectedH[9].SetString("8229757359290435701536684615691652512640106506069606636148399560514233124480")
expectedH[10].SetString("7486536221068617665728847675314458702282316190531463791154512620664800577830")
expectedH[11].SetString("7799869307295829397869081481914591556840218077620501375141006871652243359988")
expectedH[12].SetString("5572487871835836445627620749639442731687958077058639237703284282228178065342")
expectedH[13].SetString("3031952431061688597913922713543624406818694349737393722449744612497314671271")
expectedH[14].SetString("5823956675647904867599193233987686189497459491984483713315208960253899989011")
expectedH[15].SetString("3920524502188845982638454764913867261210845354831386460279061209446868519983")
h := <-computeH(A, B, C, n+3)
for i := 0; i < len(h); i++ {
if !h[i].Equal(&expectedH[i]) {
t.Fatal("incorrect result")
}
}

}

func BenchmarkComputeH(b *testing.B) {
const n = 1000000
A := make([]curve.Element, n)
B := make([]curve.Element, n)
C := make([]curve.Element, n)
for i := 0; i < n; i++ {
A[i].SetRandom()
B[i].SetRandom()
C[i].SetRandom()
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
<-computeH(A, B, C, n+3)
}

}
Loading