Skip to content

Commit

Permalink
Cairo0 hint uint256 mul div mod (#314)
Browse files Browse the repository at this point in the history
* Implement Uint256MulDivMod method, add test case

* Implement test case, debug hint

* Debug hint method

* Added new test cases, refactored code

* Merge main and refactor code

* Lint code

* Added test case for high quotient for MulDivMod

* Fixed comment issue
  • Loading branch information
MaksymMalicki committed Mar 20, 2024
1 parent 9ddf3a6 commit 41bd195
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 1 deletion.
1 change: 0 additions & 1 deletion pkg/hintrunner/hinter/operand.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ func GetConsecutiveValues(vm *VM.VirtualMachine, ref ResOperander, size int16) (
}

func WriteToNthStructField(vm *VM.VirtualMachine, addr mem.MemoryAddress, value mem.MemoryValue, field int16) error {

nAddr, err := addr.AddOffset(field)
if err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ const (
uint256SignedNNCode string = "memory[ap] = 1 if 0 <= (ids.a.high % PRIME) < 2 ** 127 else 0"
uint256UnsignedDivRemCode string = "a = (ids.a.high << 128) + ids.a.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a, div)\nids.quotient.low = quotient & ((1 << 128) - 1)\nids.quotient.high = quotient >> 128\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128"
uint256SqrtCode string = "from starkware.python.math_utils import isqrt\nn = (ids.n.high << 128) + ids.n.low\nroot = isqrt(n)\nassert 0 <= root < 2 ** 128\nids.root.low = root\nids.root.high = 0"
uint256MulDivModCode string = "a = (ids.a.high << 128) + ids.a.low/n b = (ids.b.high << 128) + ids.b.low/n div = (ids.div.high << 128) + ids.div.low/n quotient, remainder = divmod(a * b, div)/n ids.quotient_low.low = quotient & ((1 << 128) - 1)/n ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)/n ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)/n ids.quotient_high.high = quotient >> 384/n ids.remainder.low = remainder & ((1 << 128) - 1)/n ids.remainder.high = remainder >> 128"

// ------ Usort hints related code ------

// ------ Elliptic Curve hints related code ------
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createUint256UnsignedDivRemHinter(resolver)
case uint256SqrtCode:
return createUint256SqrtHinter(resolver)
case uint256MulDivModCode:
return createUint256MulDivModHinter(resolver)
case sqrtCode:
return createSqrtHinter(resolver)
default:
Expand Down
106 changes: 106 additions & 0 deletions pkg/hintrunner/zero/zerohint_uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,109 @@ func createUint256UnsignedDivRemHinter(resolver hintReferenceResolver) (hinter.H
}
return newUint256UnsignedDivRemHint(a, div, quotient, remainder), nil
}

func newUint256MulDivModHint(a, b, div, quotientLow, quotientHigh, remainder hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "Uint256MulDivMod",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {

//> a = (ids.a.high << 128) + ids.a.low
// b = (ids.b.high << 128) + ids.b.low
// div = (ids.div.high << 128) + ids.div.low
// quotient, remainder = divmod(a * b, div)

// ids.quotient_low.low = quotient & ((1 << 128) - 1)
// ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)
// ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)
// ids.quotient_high.high = quotient >> 384
// ids.remainder.low = remainder & ((1 << 128) - 1)
// ids.remainder.high = remainder >> 128
aLow, aHigh, err := GetUint256AsFelts(vm, a)
if err != nil {
return err
}
var aLowBig big.Int
aLow.BigInt(&aLowBig)
var aHighBig big.Int
aHigh.BigInt(&aHighBig)
bLow, bHigh, err := GetUint256AsFelts(vm, b)
if err != nil {
return err
}
var bLowBig big.Int
bLow.BigInt(&bLowBig)
var bHighBig big.Int
bHigh.BigInt(&bHighBig)
divLow, divHigh, err := GetUint256AsFelts(vm, div)
if err != nil {
return err
}
var divLowBig big.Int
divLow.BigInt(&divLowBig)
var divHighBig big.Int
divHigh.BigInt(&divHighBig)
a := new(big.Int).Add(new(big.Int).Lsh(&aHighBig, 128), &aLowBig)
b := new(big.Int).Add(new(big.Int).Lsh(&bHighBig, 128), &bLowBig)
div := new(big.Int).Add(new(big.Int).Lsh(&divHighBig, 128), &divLowBig)
quot := new(big.Int).Div(new(big.Int).Mul(a, b), div)
rem := new(big.Int).Mod(new(big.Int).Mul(a, b), div)
mask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewInt(1))
lowQuotLow := new(fp.Element).SetBigInt(new(big.Int).And(quot, mask))
lowQuotHigh := new(fp.Element).SetBigInt(new(big.Int).And(new(big.Int).Rsh(quot, 128), mask))
highQuotLow := new(fp.Element).SetBigInt(new(big.Int).And(new(big.Int).Rsh(quot, 256), mask))
highQuotHigh := new(fp.Element).SetBigInt(new(big.Int).Rsh(quot, 384))
lowRem := new(fp.Element).SetBigInt(new(big.Int).And(rem, mask))
highRem := new(fp.Element).SetBigInt(new(big.Int).Rsh(rem, 128))
quotientLowAddr, err := quotientLow.GetAddress(vm)
if err != nil {
return err
}
err = hinter.WriteUint256ToAddress(vm, quotientLowAddr, lowQuotLow, lowQuotHigh)
if err != nil {
return err
}
quotientHighAddr, err := quotientHigh.GetAddress(vm)
if err != nil {
return err
}
err = hinter.WriteUint256ToAddress(vm, quotientHighAddr, highQuotLow, highQuotHigh)
if err != nil {
return err
}
remainderAddr, err := remainder.GetAddress(vm)
if err != nil {
return err
}
return hinter.WriteUint256ToAddress(vm, remainderAddr, lowRem, highRem)
},
}

}

func createUint256MulDivModHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
a, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}
b, err := resolver.GetResOperander("b")
if err != nil {
return nil, err
}
div, err := resolver.GetResOperander("div")
if err != nil {
return nil, err
}
quotientLow, err := resolver.GetResOperander("quotient_low")
if err != nil {
return nil, err
}
quotientHigh, err := resolver.GetResOperander("quotient_high")
if err != nil {
return nil, err
}
remainder, err := resolver.GetResOperander("remainder")
if err != nil {
return nil, err
}
return newUint256MulDivModHint(a, b, div, quotientLow, quotientHigh, remainder), nil
}
84 changes: 84 additions & 0 deletions pkg/hintrunner/zero/zerohint_uint256_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zero

import (
"math/big"
"testing"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
Expand Down Expand Up @@ -276,5 +277,88 @@ func TestZeroHintUint256(t *testing.T) {
}),
},
},
"Uint256MulDivMod": {
{
operanders: []*hintOperander{
{Name: "a.low", Kind: apRelative, Value: feltUint64(6)},
{Name: "a.high", Kind: apRelative, Value: feltUint64(0)},
{Name: "b.low", Kind: apRelative, Value: feltUint64(6)},
{Name: "b.high", Kind: apRelative, Value: feltUint64(0)},
{Name: "div.low", Kind: apRelative, Value: feltUint64(2)},
{Name: "div.high", Kind: apRelative, Value: feltUint64(0)},
{Name: "quotient_low.low", Kind: uninitialized},
{Name: "quotient_low.high", Kind: uninitialized},
{Name: "quotient_high.low", Kind: uninitialized},
{Name: "quotient_high.high", Kind: uninitialized},
{Name: "remainder.low", Kind: uninitialized},
{Name: "remainder.high", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newUint256MulDivModHint(ctx.operanders["a.low"], ctx.operanders["b.low"], ctx.operanders["div.low"], ctx.operanders["quotient_low.low"], ctx.operanders["quotient_high.low"], ctx.operanders["remainder.low"])
},
check: allVarValueEquals(map[string]*fp.Element{
"quotient_low.low": feltUint64(18),
"quotient_low.high": feltUint64(0),
"quotient_high.low": feltUint64(0),
"quotient_high.high": feltUint64(0),
"remainder.low": feltUint64(0),
"remainder.high": feltUint64(0),
}),
},
{
operanders: []*hintOperander{
{Name: "a.low", Kind: apRelative, Value: &utils.FeltZero},
{Name: "a.high", Kind: apRelative, Value: feltString("2")},
{Name: "b.low", Kind: apRelative, Value: &utils.FeltZero},
{Name: "b.high", Kind: apRelative, Value: feltString("3")},
{Name: "div.low", Kind: apRelative, Value: &utils.FeltZero},
{Name: "div.high", Kind: apRelative, Value: feltString("2")},
{Name: "quotient_low.low", Kind: uninitialized},
{Name: "quotient_low.high", Kind: uninitialized},
{Name: "quotient_high.low", Kind: uninitialized},
{Name: "quotient_high.high", Kind: uninitialized},
{Name: "remainder.low", Kind: uninitialized},
{Name: "remainder.high", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newUint256MulDivModHint(ctx.operanders["a.low"], ctx.operanders["b.low"], ctx.operanders["div.low"], ctx.operanders["quotient_low.low"], ctx.operanders["quotient_high.low"], ctx.operanders["remainder.low"])
},
check: allVarValueEquals(map[string]*fp.Element{
"quotient_low.low": &utils.FeltZero,
"quotient_low.high": feltUint64(3),
"quotient_high.low": &utils.FeltZero,
"quotient_high.high": &utils.FeltZero,
"remainder.low": &utils.FeltZero,
"remainder.high": &utils.FeltZero,
}),
},
{
operanders: []*hintOperander{
{Name: "a.low", Kind: apRelative, Value: &utils.FeltZero},
{Name: "a.high", Kind: apRelative, Value: new(fp.Element).SetBigInt(new(big.Int).Lsh(big.NewInt(1), 127))},
{Name: "b.low", Kind: apRelative, Value: &utils.FeltZero},
{Name: "b.high", Kind: apRelative, Value: new(fp.Element).SetBigInt(new(big.Int).Lsh(big.NewInt(1), 127))},
{Name: "div.low", Kind: apRelative, Value: new(fp.Element).SetBigInt(new(big.Int).Lsh(big.NewInt(1), 126))},
{Name: "div.high", Kind: apRelative, Value: &utils.FeltZero},
{Name: "quotient_low.low", Kind: uninitialized},
{Name: "quotient_low.high", Kind: uninitialized},
{Name: "quotient_high.low", Kind: uninitialized},
{Name: "quotient_high.high", Kind: uninitialized},
{Name: "remainder.low", Kind: uninitialized},
{Name: "remainder.high", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newUint256MulDivModHint(ctx.operanders["a.low"], ctx.operanders["b.low"], ctx.operanders["div.low"], ctx.operanders["quotient_low.low"], ctx.operanders["quotient_high.low"], ctx.operanders["remainder.low"])
},
check: allVarValueEquals(map[string]*fp.Element{
"quotient_low.low": &utils.FeltZero,
"quotient_low.high": &utils.FeltZero,
"quotient_high.low": &utils.FeltZero,
"quotient_high.high": feltInt64(1),
"remainder.low": &utils.FeltZero,
"remainder.high": &utils.FeltZero,
}),
},
},
})
}

0 comments on commit 41bd195

Please sign in to comment.