Skip to content

Commit

Permalink
perf(2-chain): handle edge cases in varScalarMul
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Mar 15, 2024
1 parent 9bc2788 commit 0457871
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
12 changes: 10 additions & 2 deletions std/algebra/native/sw_bls12377/g1.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl
// the hints allow to decompose the scalar s into s1 and s2 such that
// s1 + λ * s2 == s mod r,
// where λ is third root of one in 𝔽_r.
sd, err := api.Compiler().NewHint(decomposeScalarG1, 2, s)
sd, err := api.Compiler().NewHint(decomposeScalarG1Simple, 2, s)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
Expand Down Expand Up @@ -304,7 +304,15 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl
// subtract [2^nbits]G since we added G at the beginning
B.X = points.G1m[nbits-1][0]
B.Y = api.Neg(points.G1m[nbits-1][1])
Acc.AddAssign(api, B)
if cfg.CompleteArithmetic {
Acc.AddUnified(api, B)
} else {
Acc.AddAssign(api, B)
}

if cfg.CompleteArithmetic {
Acc.Select(api, selector, G1Affine{X: 0, Y: 0}, Acc)
}

P.X = Acc.X
P.Y = Acc.Y
Expand Down
56 changes: 45 additions & 11 deletions std/algebra/native/sw_bls12377/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
func GetHints() []solver.Hint {
return []solver.Hint{
decomposeScalarG1,
decomposeScalarG1Simple,
decomposeScalarG2,
}
}
Expand All @@ -19,7 +20,7 @@ func init() {
solver.RegisterHint(GetHints()...)
}

func decomposeScalarG1(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error {
func decomposeScalarG1Simple(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 1 {
return fmt.Errorf("expecting one input")
}
Expand All @@ -34,23 +35,56 @@ func decomposeScalarG1(scalarField *big.Int, inputs []*big.Int, outputs []*big.I
return nil
}

func decomposeScalarG2(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error {
func decomposeScalarG1(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 1 {
return fmt.Errorf("expecting one input")
}
if len(outputs) != 3 {
return fmt.Errorf("expecting three outputs")
}
cc := getInnerCurveConfig(scalarField)
sp := ecc.SplitScalar(inputs[0], cc.glvBasis)
res[0].Set(&(sp[0]))
res[1].Set(&(sp[1]))
outputs[0].Set(&(sp[0]))
outputs[1].Set(&(sp[1]))
one := big.NewInt(1)
// add (lambda+1, lambda) until scalar compostion is over Fr to ensure that
// the high bits are set in decomposition.
for outputs[0].Cmp(cc.lambda) < 1 && outputs[1].Cmp(cc.lambda) < 1 {
outputs[0].Add(outputs[0], cc.lambda)
outputs[0].Add(outputs[0], one)
outputs[1].Add(outputs[1], cc.lambda)
}
// figure out how many times we have overflowed
outputs[2].Mul(outputs[1], cc.lambda).Add(outputs[2], outputs[0])
outputs[2].Sub(outputs[2], inputs[0])
outputs[2].Div(outputs[2], cc.fr)

return nil
}

func decomposeScalarG2(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 1 {
return fmt.Errorf("expecting one input")
}
if len(outputs) != 3 {
return fmt.Errorf("expecting three outputs")
}
cc := getInnerCurveConfig(scalarField)
sp := ecc.SplitScalar(inputs[0], cc.glvBasis)
outputs[0].Set(&(sp[0]))
outputs[1].Set(&(sp[1]))
one := big.NewInt(1)
// add (lambda+1, lambda) until scalar compostion is over Fr to ensure that
// the high bits are set in decomposition.
for res[0].Cmp(cc.lambda) < 1 && res[1].Cmp(cc.lambda) < 1 {
res[0].Add(res[0], cc.lambda)
res[0].Add(res[0], one)
res[1].Add(res[1], cc.lambda)
for outputs[0].Cmp(cc.lambda) < 1 && outputs[1].Cmp(cc.lambda) < 1 {
outputs[0].Add(outputs[0], cc.lambda)
outputs[0].Add(outputs[0], one)
outputs[1].Add(outputs[1], cc.lambda)
}
// figure out how many times we have overflowed
res[2].Mul(res[1], cc.lambda).Add(res[2], res[0])
res[2].Sub(res[2], inputs[0])
res[2].Div(res[2], cc.fr)
outputs[2].Mul(outputs[1], cc.lambda).Add(outputs[2], outputs[0])
outputs[2].Sub(outputs[2], inputs[0])
outputs[2].Div(outputs[2], cc.fr)

return nil
}

0 comments on commit 0457871

Please sign in to comment.