Skip to content

Commit edab3a0

Browse files
committed
perf(FRI): Reduce c-values calculation inner-loop with 3 instructions
Reduces the number of instructions in the inner loop that runs num_rounds * num_collinearity check many times from 68 instructions to 65.
1 parent fa7a2d0 commit edab3a0

File tree

2 files changed

+113
-101
lines changed

2 files changed

+113
-101
lines changed

tasm-lib/benchmarks/tasm_recufier_fri_verify.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
[
22
{
33
"name": "tasm_recufier_fri_verify",
4-
"clock_cycle_count": 380008,
5-
"hash_table_height": 14562,
4+
"clock_cycle_count": 379768,
5+
"hash_table_height": 14556,
66
"u32_table_height": 40068,
77
"case": "CommonCase"
88
},
99
{
1010
"name": "tasm_recufier_fri_verify",
11-
"clock_cycle_count": 380008,
12-
"hash_table_height": 14562,
11+
"clock_cycle_count": 379768,
12+
"hash_table_height": 14556,
1313
"u32_table_height": 39865,
1414
"case": "WorstCase"
1515
}

tasm-lib/src/recufier/fri_verify.rs

Lines changed: 109 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ impl BasicSnippet for FriSnippet {
487487

488488
// Loop's end condition is determined by pointer values, so we don't need a loop counter value
489489
// All pointers are traversed from highest address to lowest
490-
// INVARIANT: _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elements *a_indices *b_elements *b_indices
490+
// INVARIANT: _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elements *a_indices *b_elements *b_indices
491491
{compute_c_values_loop}:
492492

493493
// evaluate termination criterion
@@ -498,128 +498,135 @@ impl BasicSnippet for FriSnippet {
498498
skiz return
499499

500500
// Strategy:
501-
// 1. Calculate `a_y`
501+
// 1. Read `a_y`
502502
// 2. Calculate `a_x`
503-
// 3. Calculate `-b_x`
504-
// 4. Calculate `1 / (a_x - b_x)` while preserving `a_x`
505-
// 5. Calculate `b_y`
506-
// 6. Calculate `a_y - b_y`, preserving `a_y`
507-
// 7. Calculate `a_y - b_y / (a_x - b_x)`
508-
// 8. Calculate `c_x - a_x`
509-
// 9. Calculate final `c_y`
503+
// 3. Calculate `[a_y- b_y]`, preserve a[y]
504+
// 4. Calculate `b_x`
505+
// 5. Calculate `-b_x`
506+
// 6. Calculate `1 / (a_x - b_x)` while preserving `a_x`
507+
// 7. Calculate `(a_y - b_y) / (a_x - b_x)`
508+
// 8: Read `[c_x]`
509+
// 9. Calculate `c_x - a_x`
510+
// 10. Calculate final `c_y`
511+
// 11. Write c_y to *c_elem
510512

511513
// _ *c_end_condition g offset r *c_elem *alphas[r] *a_elem *a_index *b_elem *b_index
512514

513-
// 1:
515+
// 1: Read `a_y`
514516
dup 3
515517
read_mem {EXTENSION_DEGREE}
516-
swap 7 pop 1
517-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index *b_elem *b_index [a_y]
518+
swap 7
519+
pop 1
520+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index *b_elem *b_index [a_y]
518521

519-
// 2:
520-
dup 9
522+
// 2: Calculate `a_x`
523+
dup 11
521524
dup 6
522525
read_mem 1
523526
swap 8
524527
pop 1
525-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] (1<<round) a_index
528+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] (1<<round) a_index
526529

527-
dup 13
530+
dup 12
528531
pow
529-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] (1<<round) (g^a_index)
532+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] (1<<round) (g^a_index)
530533

531-
dup 12
534+
dup 11
532535
mul
533-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] (1<<round) (g^a_index * offset)
536+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] (1<<round) (g^a_index * offset)
534537

535538
pow
536-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] (g^a_index * offset)^(1<<round)
537-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] a_x
539+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] (g^a_index * offset)^(1<<round)
540+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] a_x
538541

539-
// 3:
540-
dup 10
542+
// 3: Calculate `[a_y- b_y]`, preserve a[y]
541543
dup 5
544+
read_mem {EXTENSION_DEGREE}
545+
swap 9
546+
pop 1
547+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] a_x [b_y]
548+
549+
push -1
550+
xbmul
551+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] a_x [-b_y]
552+
553+
dup 6
554+
dup 6
555+
dup 6
556+
xxadd
557+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index [a_y] a_x [a_y-b_y]
558+
559+
// 4: Calculate `b_x`
560+
dup 15
561+
dup 15
562+
dup 9
542563
read_mem 1
543-
swap 7
564+
swap 11
544565
pop 1
545-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x (1<<round) b_index
566+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [a_y-b_y] (1<<round) g b_index
546567

547-
dup 14
568+
swap 1
548569
pow
549-
dup 13
570+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [a_y-b_y] (1<<round) (g^b_index)
571+
572+
dup 15
550573
mul
574+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [a_y-b_y] (1<<round) (g^b_index * offset)
575+
551576
pow
552-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x b_x
577+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [a_y-b_y] (g^b_index * offset)^((1<<round))
578+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [a_y-b_y] b_x
553579

580+
// 5: Calculate `-b_x`
554581
push -1
555582
mul
556-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x (-b_x)
583+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [a_y-b_y] (-b_x)
557584

558-
// 4:
559-
dup 1
585+
// 6: Calculate `1 / (a_x - b_x)` while preserving `a_x`
586+
dup 4
560587
add
561588
invert
562-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x (1 / (a_x-b_x))
589+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [a_y-b_y] (1/(a_x-b_x))
563590

564-
// 5:
565-
dup 6
566-
read_mem {EXTENSION_DEGREE}
567-
swap 10
568-
pop 1
569-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] a_x (1/(a_x-b_x)) [b_y]
591+
// 7: Calculate `(a_y - b_y) / (a_x - b_x)`
570592

571-
// 6:
572-
push -1
573593
xbmul
574-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] a_x (1/(a_x-b_x)) [-b_y]
575-
576-
dup 7
577-
dup 7
578-
dup 7
579-
xxadd
580-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] a_x (1/(a_x-b_x)) [a_y-b_y]
581-
582-
// 7:
583-
584-
swap 1 swap 2 swap 3
585-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] a_x [a_y-b_y] (1/(a_x-b_x))
586-
587-
xbmul
588-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] a_x [(a_y-b_y) / (a_x-b_x)]
589-
590-
swap 1 swap 2 swap 3
591-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] [(a_y-b_y) / (a_x-b_x)] a_x
592-
593-
push -1
594-
mul
595-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] [(a_y-b_y) / (a_x-b_x)] (-a_x)
594+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [(a_y-b_y)/(a_x-b_x)]
596595

596+
// 8: Read `[c_x]`
597597
dup 11
598598
read_mem {EXTENSION_DEGREE}
599599
pop 1
600-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] [(a_y-b_y) / (a_x-b_x)] (-a_x) [c_x]
601-
602-
swap 1 swap 2 swap 3
603-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] [(a_y-b_y) / (a_x-b_x)] [c_x] (-a_x)
600+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [(a_y-b_y)/(a_x-b_x)] [c_x]
604601

602+
// 9: Calculate `c_x - a_x`
603+
dup 6
604+
push -1
605+
mul
605606
add
606-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [a_y] [(a_y-b_y) / (a_x-b_x)] [c_x -a_x]
607+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [(a_y-b_y)/(a_x-b_x)] [c_x - a_x]
607608

609+
// 10: Calculate final `c_y`
608610
xxmul
611+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] a_x [(a_y-b_y)/(a_x-b_x) * (c_x -a_x)]
612+
613+
swap 1
614+
swap 2
615+
swap 3
616+
pop 1
617+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [a_y] [(a_y-b_y)/(a_x-b_x) * (c_x -a_x)]
618+
609619
xxadd
610-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev [c_value]
620+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [(a_y-b_y)/(a_x-b_x) * (c_x -a_x) + a_y]
621+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem *b_index_prev [c_y]
611622

623+
// 11. Write c_y to *c_elem
612624
dup 8
613625
write_mem {EXTENSION_DEGREE}
614-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev *c_elem_next
615-
616626
push {- 2 * EXTENSION_DEGREE as i32}
617627
add
618-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev *c_elem_prev
619-
620628
swap 6
621629
pop 1
622-
// _ *c_end_condition g offset (1<<round) *c_elem_prev *alphas[r] *a_elem_prev *a_index_prev *b_elem_prev *b_index_prev
623630

624631
recurse
625632

@@ -779,68 +786,68 @@ impl BasicSnippet for FriSnippet {
779786
push {-(EXTENSION_DEGREE as i32 - 1)}
780787
add // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition
781788

782-
dup 8
783-
{&domain_generator} // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g
784-
read_mem 1
785-
pop 1
789+
dup 12
790+
push 2
791+
pow // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round)
786792

787793
dup 9
788-
{&domain_offset} // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset
794+
{&domain_generator} // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g
789795
read_mem 1
790796
pop 1
791797

792-
dup 14 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset r
793-
push 2
794-
pow // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round)
798+
dup 10
799+
{&domain_offset} // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset
800+
read_mem 1
801+
pop 1
795802

796-
dup 7 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_elements
797-
dup 9 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_elements *c_indices
798-
read_mem 1 pop 1 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_elements c_indices_len
803+
dup 7 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_elements
804+
dup 9 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset*c_elements *c_indices
805+
read_mem 1 pop 1 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset*c_elements c_indices_len
799806

800807
// Write length to *c_elements
801-
dup 0 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_elements c_indices_len c_indices_len
802-
swap 2 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) c_indices_len c_indices_len *c_elements
808+
dup 0 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_elements c_indices_len c_indices_len
809+
swap 2 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset c_indices_len c_indices_len *c_elements
803810
write_mem 1
804-
swap 1 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) (*c_elements + 1) c_indices_len
811+
swap 1 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset (*c_elements + 1) c_indices_len
805812

806813
push -1
807814
add
808815
push {EXTENSION_DEGREE}
809816
mul
810817
add
811-
// _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_last_elem_first_word
818+
// _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_last_elem_first_word
812819

813820
dup 14
814-
// _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_last_elem_first_word *b_indices
821+
// _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_last_elem_first_word *b_indices
815822
dup 0
816823
read_mem 1
817824
pop 1
818-
add // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_last_elem_first_word *b_index_last
825+
add // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_last_elem_first_word *b_index_last
819826

820827
dup 14
821828
dup 10
822829
read_mem 1
823830
pop 1
824-
// _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_last_elem_first_word *b_index_last *b_elements c_len
831+
// _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_last_elem_first_word *b_index_last *b_elements c_len
825832

826833
push {EXTENSION_DEGREE}
827834
mul
828835
add
829-
// _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_last_elem_first_word *b_index_last *b_elem_last
836+
// _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_last_elem_first_word *b_index_last *b_elem_last
830837

831-
dup 9 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_last_elem_first_word *b_index_last *b_elem_last *alphas[r]
838+
dup 9 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_last_elem_first_word *b_index_last *b_elem_last *alphas[r]
832839

833840
swap 2
834-
swap 1 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_last_elem_first_word *alphas[r] *b_index_last *b_elem_last
841+
swap 1 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_last_elem_first_word *alphas[r] *b_index_last *b_elem_last
835842

836843
dup 9
837-
dup 9 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition g offset (1<<round) *c_last_elem_first_word *alphas[r] *b_indices *b_elements *a_elements *a_indices
844+
dup 9 // _ ... *alphas[r] *a_elements *a_indices *c_elements_end_condition (1<<round) g offset *c_last_elem_first_word *alphas[r] *b_indices *b_elements *a_elements *a_indices
838845

839846
swap 2
840847
swap 1
841848
swap 3
842849

843-
// _ *c_end_condition g offset (1<<round) *c_elem *alphas[r] *a_elements *a_indices *b_elements *b_indices
850+
// _ *c_end_condition (1<<round) g offset *c_elem *alphas[r] *a_elements *a_indices *b_elements *b_indices
844851
call {compute_c_values_loop}
845852

846853
pop 5
@@ -1636,10 +1643,15 @@ mod test {
16361643
}
16371644

16381645
#[proptest(cases = 3)]
1639-
fn test_shadow(test_case: TestCase) {
1646+
fn test_shadow_prop(test_case: TestCase) {
16401647
assert_behavioral_equivalence_of_fris(test_case);
16411648
}
16421649

1650+
#[test]
1651+
fn test_shadow_small() {
1652+
assert_behavioral_equivalence_of_fris(TestCase::small_case());
1653+
}
1654+
16431655
#[test]
16441656
fn modifying_any_element_in_vm_proof_iter_of_small_test_case_causes_verification_failure() {
16451657
let test_case = TestCase::small_case();

0 commit comments

Comments
 (0)