Skip to content

Commit

Permalink
feat: allow exporting memory file and trace file in non proof mode (#476
Browse files Browse the repository at this point in the history
)

* allow exporting memory file and trace file in non proof mode

* fix

* fmt

* fix2

* add comment to RelocateTrace public function

* add collectMemory and collectTrace flags

* rename collectMemory to BuildMemory

* fmt

* fmt

---------

Co-authored-by: Shourya Goel <shouryagoel10000@gmail.com>
  • Loading branch information
TAdev0 and Sh0g0-1758 committed Jul 11, 2024
1 parent c740c91 commit 08d12e6
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 33 deletions.
34 changes: 32 additions & 2 deletions cmd/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (

func main() {
var proofmode bool
var buildMemory bool
var collectTrace bool
var maxsteps uint64
var entrypointOffset uint64
var traceLocation string
Expand Down Expand Up @@ -49,12 +51,24 @@ func main() {
Value: 0,
Destination: &entrypointOffset,
},
&cli.BoolFlag{
Name: "collect_trace",
Usage: "collects the trace and builds the relocated trace after execution",
Required: false,
Destination: &collectTrace,
},
&cli.StringFlag{
Name: "tracefile",
Usage: "location to store the relocated trace",
Required: false,
Destination: &traceLocation,
},
&cli.BoolFlag{
Name: "build_memory",
Usage: "builds the relocated memory after execution",
Required: false,
Destination: &buildMemory,
},
&cli.StringFlag{
Name: "memoryfile",
Usage: "location to store the relocated memory",
Expand Down Expand Up @@ -82,10 +96,12 @@ func main() {
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
}

cairoZeroJson, err := zero.ZeroProgramFromJSON(content)
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
}

program, err := runnerzero.LoadCairoZeroProgram(cairoZeroJson)
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
Expand All @@ -95,6 +111,7 @@ func main() {
if err != nil {
return fmt.Errorf("cannot create hints: %w", err)
}

fmt.Println("Running....")
runner, err := runnerzero.NewRunner(program, hints, proofmode, maxsteps, layoutName)
if err != nil {
Expand All @@ -117,18 +134,31 @@ func main() {

if proofmode {
runner.EndRun()

if err := runner.FinalizeSegments(); err != nil {
return fmt.Errorf("cannot finalize segments: %w", err)
}
trace, memory, err := runner.BuildProof()
}

if proofmode || collectTrace {
trace, err := runner.BuildTrace()
if err != nil {
return fmt.Errorf("cannot build proof: %w", err)
return fmt.Errorf("cannot build trace: %w", err)
}

if traceLocation != "" {
if err := os.WriteFile(traceLocation, trace, 0644); err != nil {
return fmt.Errorf("cannot write relocated trace: %w", err)
}
}
}

if proofmode || buildMemory {
memory, err := runner.BuildMemory()
if err != nil {
return fmt.Errorf("cannot build memory: %w", err)
}

if memoryLocation != "" {
if err := os.WriteFile(memoryLocation, memory, 0644); err != nil {
return fmt.Errorf("cannot write relocated memory: %w", err)
Expand Down
35 changes: 18 additions & 17 deletions pkg/runners/zero/zero.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins"
mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

type ZeroRunner struct {
Expand All @@ -21,8 +20,9 @@ type ZeroRunner struct {
vm *vm.VirtualMachine
hintrunner hintrunner.HintRunner
// config
proofmode bool
maxsteps uint64
proofmode bool
collectTrace bool
maxsteps uint64
// auxiliar
runFinished bool
layout builtins.Layout
Expand Down Expand Up @@ -162,7 +162,7 @@ func (runner *ZeroRunner) InitializeMainEntrypoint() (mem.MemoryAddress, error)
}

func (runner *ZeroRunner) initializeEntrypoint(
initialPCOffset uint64, arguments []*f.Element, returnFp *mem.MemoryValue, memory *mem.Memory,
initialPCOffset uint64, arguments []*fp.Element, returnFp *mem.MemoryValue, memory *mem.Memory,
) (mem.MemoryAddress, error) {
stack, err := runner.initializeBuiltins(memory)
if err != nil {
Expand Down Expand Up @@ -223,7 +223,7 @@ func (runner *ZeroRunner) initializeVm(
Pc: *initialPC,
Ap: offset + uint64(len(stack)),
Fp: offset + uint64(len(stack)),
}, memory, vm.VirtualMachineConfig{ProofMode: runner.proofmode})
}, memory, vm.VirtualMachineConfig{ProofMode: runner.proofmode, CollectTrace: runner.collectTrace})
return err
}

Expand Down Expand Up @@ -273,12 +273,10 @@ func (runner *ZeroRunner) RunFor(steps uint64) error {
// Since this vm always finishes the run of the program at the number of steps that is a power of two in the proof mode,
// there is no need to run additional steps before the loop.
func (runner *ZeroRunner) EndRun() {
if runner.proofmode {
for runner.checkUsedCells() != nil {
pow2Steps := utils.NextPowerOfTwo(runner.vm.Step + 1)
if err := runner.RunFor(pow2Steps); err != nil {
panic(err)
}
for runner.checkUsedCells() != nil {
pow2Steps := utils.NextPowerOfTwo(runner.vm.Step + 1)
if err := runner.RunFor(pow2Steps); err != nil {
panic(err)
}
}
}
Expand Down Expand Up @@ -371,13 +369,16 @@ func (runner *ZeroRunner) FinalizeSegments() error {
return nil
}

func (runner *ZeroRunner) BuildProof() ([]byte, []byte, error) {
relocatedTrace, err := runner.vm.ExecutionTrace()
if err != nil {
return nil, nil, err
}
// BuildMemory relocates the memory and returns it
func (runner *ZeroRunner) BuildMemory() ([]byte, error) {
relocatedMemory := runner.vm.RelocateMemory()
return vm.EncodeMemory(relocatedMemory), nil
}

return vm.EncodeTrace(relocatedTrace), vm.EncodeMemory(runner.vm.RelocateMemory()), nil
// BuildMemory relocates the trace and returns it
func (runner *ZeroRunner) BuildTrace() ([]byte, error) {
relocatedTrace := runner.vm.RelocateTrace()
return vm.EncodeTrace(relocatedTrace), nil
}

func (runner *ZeroRunner) pc() mem.MemoryAddress {
Expand Down
25 changes: 11 additions & 14 deletions pkg/vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ type Trace struct {

// This type represents the current execution context of the vm
type VirtualMachineConfig struct {
// If true, the vm outputs the trace and the relocated memory at the end of execution
// If true, the vm outputs the trace and the relocated memory at the end of execution and finalize segments
// in order for the prover to create a proof
ProofMode bool
// If true, the vm collects the relocated trace at the end of execution, without finalizing segments
CollectTrace bool
}

type VirtualMachine struct {
Expand All @@ -91,9 +94,10 @@ type VirtualMachine struct {
func NewVirtualMachine(
initialContext Context, memory *mem.Memory, config VirtualMachineConfig,
) (*VirtualMachine, error) {

// Initialize the trace if necesary
var trace []Context
if config.ProofMode {
if config.ProofMode || config.CollectTrace {
trace = make([]Context, 0)
}

Expand Down Expand Up @@ -136,7 +140,7 @@ func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error {
}

// store the trace before state change
if vm.config.ProofMode {
if vm.config.ProofMode || vm.config.CollectTrace {
vm.Trace = append(vm.Trace, vm.Context)
}

Expand Down Expand Up @@ -216,15 +220,6 @@ func (vm *VirtualMachine) RunInstruction(instruction *a.Instruction) error {
return nil
}

// It returns the current trace entry, the public memory, and the occurrence of an error
func (vm *VirtualMachine) ExecutionTrace() ([]Trace, error) {
if !vm.config.ProofMode {
return nil, fmt.Errorf("proof mode is off")
}

return vm.relocateTrace(), nil
}

func (vm *VirtualMachine) getDstAddr(instruction *a.Instruction) (mem.MemoryAddress, error) {
var dstRegister uint64
if instruction.DstRegister == a.Ap {
Expand Down Expand Up @@ -537,8 +532,10 @@ func (vm *VirtualMachine) updateFp(instruction *a.Instruction, dstAddr *mem.Memo
}
}

func (vm *VirtualMachine) relocateTrace() []Trace {
// one is added, because prover expect that the first element to be on
// It returns the trace after relocation, i.e, relocates pc, ap and fp for each step
// to be their real address value
func (vm *VirtualMachine) RelocateTrace() []Trace {
// one is added, because prover expect that the first element to be
// indexed on 1 instead of 0
relocatedTrace := make([]Trace, len(vm.Trace))
totalBytecode := vm.Memory.Segments[ProgramSegment].Len() + 1
Expand Down

0 comments on commit 08d12e6

Please sign in to comment.