## 引入库

In [1]:
import math
from merlin.merlin_transcript import MerlinTranscript
from merkle import MerkleTree
from typing import List
import numpy as np
from hashlib import sha256
# 导入SageMath库
from sage.all import *
from enum import Enum
from typing import Generic, TypeVar, List

## 定义类

### Domain 类

In [2]:
import numpy as np
from hashlib import sha256
# 导入SageMath库
from sage.all import *


# 定义域类
class Domain:
    def __init__(self, root_of_unity, root_of_unity_inv, offset, backing_domain):
        self.root_of_unity = root_of_unity
        self.root_of_unity_inv = root_of_unity_inv
        self.offset = offset
        self.backing_domain = backing_domain
    
    def __repr__(self):
        return (f"Domain(root_of_unity={self.root_of_unity}, "
                f"root_of_unity_inv={self.root_of_unity_inv}, "
                f"backing_domain={self.backing_domain})")
    
    def new(self, degree, log_rho_inv):
        pass

    def generate_elements(self):
        # 生成域的元素（这里简化为整数序列）
        return np.arange(self.size)
    
    def size(self):
        return len(self.backing_domain)
    
    # o^power * <w^power> 
    def scale_generator_by(self, power):
        root_of_unity = self.root_of_unity ** power
        root_of_unity_inv = self.root_of_unity_inv ** power
        offset = self.offset ** power
        size = len(self.backing_domain) // power
        backing_domain = [offset * (root_of_unity ** i) for i in range(size)]
        return Domain(root_of_unity, root_of_unity_inv, offset, backing_domain)
    
    # L_0 = o * <w> then L_1 = w * o^power * <w^power>
    def scale_with_offset(self, power):
        root_of_unity = self.root_of_unity ** power
        root_of_unity_inv = self.root_of_unity_inv ** power
        offset = self.root_of_unity * self.offset ** power
        size = len(self.backing_domain) // power
        backing_domain = [self.root_of_unity * (self.offset ** power) * (root_of_unity ** i) for i in range(size)]
        return Domain(root_of_unity, root_of_unity_inv, offset, backing_domain)

### 多项式类

In [3]:
# 定义多项式类
class DensePolynomial:
    def __init__(self, coefficients):
        self.coefficients = coefficients

    def evaluate(self, x):
        # 使用 Horner 法则计算多项式在x处的值
        result = 0
        for coefficient in reversed(self.coefficients):
            result = result * x + coefficient
        return Fp(result)

    def degree(self):
        return len(self.coefficients) - 1
    
    def evaluate_over_domain(self, domain):
        result = []
        for point in domain:
            result.append(self.evaluate(point))
        return result
    
    def mul(self, Fp, poly: 'DensePolynomial') -> 'DensePolynomial':
        R.<X> = Fp[]
        f_X = R(self.coefficients)
        g_X = R(poly.coefficients)
        # print("f_X = ", f_X)
        # print("g_X = ", g_X)
        # print("f_X * g_X = ", f_X * g_X)
        return DensePolynomial(list(f_X * g_X))
    
    def __repr__(self) -> str:
        return (f"Denspolynomial(coefficients={self.coefficients})")

### Witness 类

In [4]:
# 定义 Witness 类
class Witness:
    def __init__(self, domain: Domain, polynomial: DensePolynomial, merkle_tree, folded_evals):
        self.domain = domain
        self.polynomial = polynomial
        self.merkle_tree = merkle_tree
        self.folded_evals = folded_evals

    def __repr__(self):
        return (f"Witness(domain={self.domain}, polynomial={self.polynomial}, "
                f"merkle_tree={self.merkle_tree}, folded_evals={self.folded_evals}) ")

In [5]:
# 定义WitnessExtended类
class WitnessExtended:
    def __init__(self, domain, polynomial: DensePolynomial, merkle_tree, folded_evals, num_round, folding_randomness):
        self.domain = domain
        self.polynomial = polynomial
        self.merkle_tree = merkle_tree
        self.folded_evals = folded_evals
        # 轮数
        self.num_round = num_round
        # 下一轮进行 fold 的随机数
        self.folding_randomness = folding_randomness
    def __repr__(self):
        return (f"WitnessExtended(domain={self.domain}, polynomial={self.polynomial}, "
                f"merkle_tree={self.merkle_tree}, folded_evals={self.folded_evals}, "
                f"num_round={self.num_round}, folding_randomness={self.folding_randomness})")

### Commitment 类

In [6]:
class Commitment:
    def __init__(self, root):
        self.root = root

    def __repr__(self):
        return f"Commitment(root={self.root})"

### Proof 类

In [7]:
class RoundProof:
    def __init__(self, g_root, betas, queries_to_prev, ans_polynomial: DensePolynomial, shake_polynomial: DensePolynomial):
        self.g_root = g_root
        self.betas = betas
        self.queries_to_prev = queries_to_prev
        self.ans_polynomial = ans_polynomial
        self.shake_polynomial = shake_polynomial

    def __repr__(self):
        return (f"RoundProof(g_root={self.g_root}, betas={self.betas}, "
                f"queries_to_prev={self.queries_to_prev}, ans_polynomial={self.ans_polynomial}, "
                f"shake_polynomial={self.shake_polynomial})")

class Proof:
    def __init__(self, round_proofs: List[RoundProof], final_polynomial, queries_to_final):
        self.round_proofs = round_proofs
        self.final_polynomial = final_polynomial
        self.queries_to_final = queries_to_final
    
    def __repr__(self):
        return (f"Proof(round_proofs={self.round_proofs}, final_polynomial={self.final_polynomial}, "
            f"queries_to_final={self.queries_to_final})")

### 参数类

In [8]:
from enum import Enum
# 枚举 Soundness 类型
class SoundnessType(Enum):
    Provable = 1
    Conjecture = 2

In [9]:
import math

# 定义参数类
class Parameters():
    def __init__(self, 
                 security_level: int,
                 protocol_security_level: int,
                 starting_degree: int,
                 stopping_degree: int,
                 folding_factor: int,
                 starting_rate: int,
                 soundness_type: SoundnessType,
                 fiat_shamir_config):
        self.security_level = security_level
        self.protocol_security_level = protocol_security_level
        self.starting_degree = starting_degree
        self.stopping_degree = stopping_degree
        self.folding_factor = folding_factor
        self.starting_rate = starting_rate
        self.soundness_type = soundness_type
        self.fiat_shamir_config = fiat_shamir_config

    def __repr__(self):
        return (f"Parameters(security_level={self.security_level}, "
                f"protocol_security_level={self.protocol_security_level}, "
                f"starting_degree={self.starting_degree}, "
                f"stopping_degree={self.stopping_degree}, "
                f"folding_factor={self.folding_factor}, "
                f"starting_rate={self.starting_rate}, "
                f"soundness_type={self.soundness_type},"
                f"fiat_shamir_config={self.fiat_shamir_config})")
    
    def repetitions(self, log_inv_rate: int) -> int:
        constant = 0
        if self.soundness_type == SoundnessType.Provable:
            constant = 2
        elif self.soundness_type == SoundnessType.Conjecture:
            constant = 1
        
        return math.ceil(constant * (self.security_level / log_inv_rate))
         


In [10]:
from typing import Generic, TypeVar, List

# F = TypeVar('F')
# MerkleConfig = TypeVar('MerkleConfig')
# FSConfig = TypeVar('FSConfig')

class FullParameters:
    def __init__(self, 
                 parameters: Parameters,
                 num_rounds: int,
                 degrees: List[int],
                 rates: List[int],
                 repetitions: List[int],
                 ood_samples: int):
        self.parameters = parameters
        self.num_rounds = num_rounds
        self.rates = rates
        self.repetitions = repetitions
        # self.pow_bits = pow_bits
        self.ood_samples = ood_samples
        self.degrees = degrees

    def __repr__(self):
        return (f"FullParameters(parameters={self.parameters}, "
                f"num_rounds={self.num_rounds}, "
                f"rates={self.rates}, "
                f"repetitions={self.repetitions}, "
                f"ood_samples={self.ood_samples}, "
                f"degrees={self.degrees})")
    
    @classmethod
    def from_parameters(cls, parameters: Parameters):
        assert is_power_of_two(parameters.folding_factor) == True
        assert is_power_of_two(parameters.starting_degree) == True
        assert is_power_of_two(parameters.stopping_degree) == True
        

        d = parameters.starting_degree
        degrees = [d]
        num_rounds = 0
        while d > parameters.stopping_degree:
            # 保证能除尽
            assert d % parameters.folding_factor == 0
            d /= parameters.folding_factor
            degrees.append(d)
            num_rounds += 1
        
        degrees.pop()
        num_rounds -= 1

        rates = [parameters.starting_rate]
        log_folding_factor = math.log(parameters.folding_factor, 2)
        # parameters.starting_rate = - log_2(\rho_0)
        for i in range(1, num_rounds + 1):
            rates.append(parameters.starting_rate + i * (log_folding_factor - 1))
        
        repetitions = [parameters.repetitions(rate) for rate in rates]

        # 预防泄漏知识，repetions[i] = min(repetions[i], d_i / k)
        # 这里不包括最后一轮
        ood_samples = 2
        for i in range(0, num_rounds):
            repetitions[i] = min(repetitions[i], degrees[i] / parameters.folding_factor - ood_samples - 1)
        
        assert num_rounds + 1 == len(rates)
        assert num_rounds + 1 == len(repetitions)

        return FullParameters(
            parameters=parameters,
            num_rounds=num_rounds,
            degrees=degrees,
            rates=rates,
            repetitions=repetitions,
            ood_samples=ood_samples)

### VerificationState 类

In [11]:
# 枚举 Soundness 类型
class OracleType(Enum):
    Initial = 1
    Virtual = 2
    
class VerificationState:
    def __init__(self, 
                 oracle: OracleType,
                 domain: Domain,
                 folding_randomness,
                 num_round: int,
                 comb_randomness,
                 quotient_set,
                 ans_polynomial):
        self.oracle = oracle
        self.domain = domain
        self.folding_randomness = folding_randomness
        self.num_round = num_round  # 表示第几轮
        self.comb_randomness = comb_randomness
        self.quotient_set = quotient_set
        self.ans_polynomial = ans_polynomial

    def __repr__(self):
        return (f"oracle{self.oracle}, "
                f"domain{self.domain}, "
                f"folding_randomness={self.folding_randomness}, "
                f"num_round={self.num_round}",
                f"comb_randomness={self.comb_randomness})")

## 工具函数

In [12]:
# 定义一些工具函数
def squeeze_field_elements(Fp, num_elements):
    odd_randomness = [Fp.random_element() for i in range(num_elements)]
    return odd_randomness

In [13]:
# 将多项式的值进行分组
def stack_evaluations(evals, folding_factor):
    assert len(evals) % folding_factor == 0
    size_of_new_domain = len(evals) // folding_factor 

    stacked_evaluations = []
    for i in range(size_of_new_domain):
        new_evals = []
        for j in range(folding_factor):
            new_evals.append(evals[i + j * size_of_new_domain])
        stacked_evaluations.append(new_evals)
    
    return stacked_evaluations

In [14]:
# 多项式折叠函数
def poly_fold(f: DensePolynomial, folding_randomness, folding_factor) -> DensePolynomial:
    degree = f.degree() + 1
    coefficients = f.coefficients
    coefficients_array = np.array(coefficients)
    matrix_coefficients = coefficients_array.reshape(degree // folding_factor, folding_factor)
    # print("矩阵：")
    # print(matrix_coefficients)

    transposed_matrix = matrix_coefficients.T
    # print("转置后的矩阵：")
    # print(transposed_matrix)

    size = len(transposed_matrix[0])
    # print(size)
    folded_coefficients = [0] * size

    pow = 1
    for row in transposed_matrix:
        for i in range(len(row)):
            folded_coefficients[i] += pow * row[i]
        pow *=  folding_randomness
    
    return DensePolynomial(folded_coefficients)

In [15]:
def interpolation(Fp, points):
    n = len(points)

    # 先计算 L_i(X) 的分母
    L_denominator = [1] * n

    for i in range(0, n):
        for j in range(0, n):
            if i != j:
                xi = points[i][0]
                xj = points[j][0]
                # print("xi = ", xi, "xj = ", xj)
                # print(L_denominator[i])
                L_denominator[i] = (xi - xj) * L_denominator[i]
    
    # print("L_denominator = ", L_denominator)

    # 声明多项式自变量 X 在 GF 中
    R.<X> = Fp[]

    # 计算 L_i(X)
    L_X = [1] * n
    for i in range(0, n):
        for j in range(0, n):
            if i != j:
                xj = points[j][0]
                L_X[i] = (X - xj) * L_X[i]
        L_X[i] = Fp(1 / L_denominator[i]) * L_X[i]
    
    # 计算 f(X)
    f_X = 0
    for i in range(0, n):
        yi = points[i][1]
        f_X = f_X + yi * L_X[i]

    print("f_X = ", f_X)

    return DensePolynomial(list(f_X))

In [16]:
def poly_quotient(Fp, poly: DensePolynomial, ans: DensePolynomial, points) -> DensePolynomial:
    R.<X> = Fp[]
    ans_X = R(ans.coefficients)
    poly_X = R(poly.coefficients)
    vanish_X = R(1)
    for point in points:
        vanish_X = vanish_X * (X - point[0])
    
    quotient, remainder = (poly_X - ans_X).quo_rem(vanish_X)
    assert remainder == 0
    
    # return DensePolynomial(quotient.coefficients())
    return DensePolynomial(list(quotient))

In [17]:
def verify_decommitment(leaf_id, leaf_data, decommitment, root):
    leaf_num = 2 ** len(decommitment)
    node_id = leaf_id + leaf_num
    cur = sha256(str(leaf_data).encode()).hexdigest()
    for bit, auth in zip(bin(node_id)[3:][::-1], decommitment[::-1]):
        if bit == '0':
            h = cur + auth
        else:
            h = auth + cur
        cur = sha256(h.encode()).hexdigest()
    return cur == root

def verify_multi_path(root, queries_to_prev):
    indexs = queries_to_prev[0]
    leaves = queries_to_prev[1]
    multi_path = queries_to_prev[2]
    for i in range(len(indexs)):
        if verify_decommitment(indexs[i], leaves[i], multi_path[i], root) == False:
            return False
    return True

## StirProver

In [43]:
# 定义Prover类
class StirProver:
    def __init__(self, parameters: FullParameters):
        self.parameters = parameters
    
    def commit(self, domain: Domain, witness_polynomial: DensePolynomial):
        # domain = Domain(stir_prover.parameters.parameters.starting_degree, stir_prover.parameters.parameters.starting_rate)
        evals = [witness_polynomial.evaluate(x) for x in domain.backing_domain]
        folded_evals = stack_evaluations(evals, self.parameters.parameters.folding_factor)
        merkle_tree = MerkleTree(folded_evals)
        initial_commitment = merkle_tree.root
        # return initial_commitment, Witness(domain, witness_polynomial, merkle_tree, folded_evals)
        return {
            "commitment": Commitment(initial_commitment),
            "witness": Witness(domain, witness_polynomial, merkle_tree, folded_evals)
        }

    def round(self, Fp, transcript: MerlinTranscript, witness: WitnessExtended, debug=False):
        if debug: print("--------- prove round begin ---------")
        R.<X> = Fp[]
        # if debug: R.<X> = Fp[]
        if debug: print("\n1. Send folded function ")
        if debug: print("Before folded, f(X) = ", R(witness.polynomial.coefficients))
        # folded witness.polynomial
        g_poly = poly_fold(witness.polynomial, witness.folding_randomness, self.parameters.parameters.folding_factor)
        if debug: print("After folded, g(X) = ", R(g_poly.coefficients))

        g_domain = witness.domain.scale_with_offset(2)
        if debug: print("evaluation domain L_i : ", g_domain.backing_domain)

        # get evaluations of polynomial g_poly
        g_evaluations = g_poly.evaluate_over_domain(g_domain.backing_domain)
        # if debug: print("g_domain.backing_domain L_i: ", g_domain.backing_domain)
        if debug: print("evaluations of g(X) on L_i: ", g_evaluations)


        g_folded_evaluations = stack_evaluations(g_evaluations, self.parameters.parameters.folding_factor)
        if debug: print("Group the values of g(X) on L_i: ")
        if debug: print(g_folded_evaluations)

        if debug: print("commit the Merkle root of values of g(X):")
        g_merkle = MerkleTree(g_folded_evaluations)
        g_root = g_merkle.root
        if debug: print("Merkle root :", g_root)

        transcript.append_message(b"merkle_root", g_root.encode('ascii'))

        # Out of domain sample
        # odd_randomness = transcript.challenge_bytes(b"odd_randomness", stir_prover.parameters.ood_samples)
        if debug: print("\n2. Out-of-domain samples")
        odd_randomness = []
        for _ in range(self.parameters.ood_samples):
            random = transcript.challenge_bytes(b"odd_randomness", 1)[0] % Fp.order()
            odd_randomness.append(random)
        odd_randomness = list(set(odd_randomness))
        # odd_randomness = [transcript.challenge_bytes(b"odd_randomness", 1)[0] % Fp.order() for _ in range(self.parameters.ood_samples)]
        # odd_randomness = squeeze_field_elements(Fp, stir_prover.parameters.ood_samples)
        if debug: print("out-of-domain randomness:", odd_randomness)
        
        # Out of domain reply
        if debug: print("\n3. Out-of-domain reply")
        betas = g_poly.evaluate_over_domain(odd_randomness)
        if debug: print("betas: ", betas)
        for i in range(len(betas)):
            transcript.append_message(f'betas_at_{i}'.encode('ascii'), str(betas[i]).encode('ascii'))

        # STIR message
        if debug: print("\n4. STIR message")
        comb_randomness = transcript.challenge_bytes(b"comb_randomness", 1)[0] % Fp.order()
        if debug: print("comb_randomness:", comb_randomness)

        folding_randomness = transcript.challenge_bytes(b"folding_randomness", 1)[0] % Fp.order()
        if debug: print("folding_randomness:", folding_randomness)

        scaling_factor = witness.domain.size() / self.parameters.parameters.folding_factor
        if debug: print("scaling_factor:", scaling_factor)

        num_repetitions = self.parameters.repetitions[witness.num_round]
        stir_randomness_indexes = []
        for i in range(num_repetitions):
            index = transcript.challenge_bytes(b"stir_randomness_indexes", 1)[0] % scaling_factor
            stir_randomness_indexes.append(index)

        # 去重
        if debug: print("remove duplicates from the stir_randomness_indexes")
        stir_randomness_indexes = list(set(stir_randomness_indexes))
        L_k = witness.domain.scale_generator_by(folding_factor)
        for odd_random in odd_randomness:
            for index in stir_randomness_indexes:
                if L_k.backing_domain[index] == odd_random:
                    stir_randomness_indexes.pop(index)
        if debug: print("stir_randomness_index: ", stir_randomness_indexes)

        if debug: print("\nsample shake randomness:")
        _shake_randomness = transcript.challenge_bytes(b"_shake_randomness", 1)[0] % Fp.order()
        if debug: print("shake_randomness: ", _shake_randomness)

        if debug: print("\nquery to prev f value at stir_randomness_indexes:")
        queries_to_prev_index = stir_randomness_indexes
        queries_to_prev_ans = []
        for index in stir_randomness_indexes:
            queries_to_prev_ans.append(witness.folded_evals[index])
        if debug: print("queries_to_prev_index: ", queries_to_prev_index)
        if debug: print("queries_to_prev_ans: ", queries_to_prev_ans)

        queries_to_prev_proof = []
        for index in stir_randomness_indexes:
            queries_to_prev_proof.append(witness.merkle_tree.get_authentication_path(index))
        if debug: print("queries_to_prev_proof: ", queries_to_prev_proof)
        queries_to_prev = [queries_to_prev_index, queries_to_prev_ans, queries_to_prev_proof]
        if debug: print("queries_to_prev: ", queries_to_prev)

        if debug: print("\nfrom L_{i-1}^k to get r_{i}^{shift}:")
        if debug: print("L_{i-1} :", witness.domain.backing_domain)
        # L_k = witness.domain.scale_generator_by(folding_factor)
        if debug: print("L_{i-1}^k :", L_k.backing_domain)
        stir_randomness = [L_k.backing_domain[index] for index in stir_randomness_indexes]
        if debug: print("stir_randomness: ", stir_randomness)

        # compute the set we are quotienting by
        # ODD samples + stir_randomness
        if debug: print("\n5. Define next polynomial")
        quotient_set = odd_randomness + stir_randomness
        if debug: print("quotient set：", quotient_set)

        quotient_answers = [[x, g_poly.evaluate(x)] for x in quotient_set]
        print("values of g(X) on quotient set ", quotient_answers)

        # Ans(X)
        if debug: print("\ncompute Ans(X) by interpolation:")
        ans_polynomial = interpolation(Fp, quotient_answers)
        if debug: print("Ans(X)'s coefficients: ", ans_polynomial.coefficients)

        # compute Shake(X)
        if debug: print("\ncompute Shake(X) (by Ans(X) and quotient set) :")
        R.<X> = Fp[]
        ans_X = R(ans_polynomial.coefficients)
        if debug: print("Ans(X) = ", ans_X)
        shake_X = 0 * X
        for point in quotient_answers:
            num_polynomial = R(ans_X - point[1])
            den_polynomial = R(X - point[0])
            # 进行多项式除法，获取商和余数
            quotient, remainder = num_polynomial.quo_rem(den_polynomial)
            # print("quotient = ", quotient)
            # print("remainder = ", remainder)
            assert remainder == 0 * X
            shake_X = shake_X + quotient
            # print("shake_X = ", shake_X)
        if debug: print("Shake(X) = ", shake_X)

        # print(shake_X.coefficients())
        # shake_polynomial = DensePolynomial(shake_X.coefficients()) # 此方法返回非零系数
        shake_polynomial = DensePolynomial(list(shake_X))


        if debug: print("\ncompute quotient polynomial")
        # compute quotient polynomial
        quotient_polynomial = poly_quotient(Fp, g_poly, ans_polynomial, quotient_answers)
        if debug: print("quotient_polynomial = " , R(quotient_polynomial.coefficients))


        # Deg Correction
        if debug: print("\nDegree Correction ")
        coefficients_vec = []
        for i in range(len(quotient_set) + 1):
            coefficients_vec.append(comb_randomness ** i)

        scaling_polynomial = DensePolynomial(coefficients_vec)
        if debug: print("scaling_polynomial = ", R(scaling_polynomial.coefficients))

        # next round witness polynomial
        witness_polynomial = scaling_polynomial.mul(Fp, quotient_polynomial)
        if debug: print("witness_polynomial = ", R(witness_polynomial.coefficients))

        return {
            "witness_extended": WitnessExtended (
                domain = g_domain,
                polynomial = witness_polynomial,
                merkle_tree = g_merkle,
                folded_evals = g_folded_evaluations,
                num_round = witness.num_round + 1,
                folding_randomness = folding_randomness
            ),
            "round_proof": RoundProof (
                g_root = g_root,
                betas = betas,
                queries_to_prev = queries_to_prev,
                ans_polynomial= ans_polynomial,
                shake_polynomial= shake_polynomial
            )
        }
    
    def prove(self, witness: Witness, debug=False) -> Proof:
       R.<X> = Fp[]
       # 保证多项式的次数小于参数 starting_degree
       assert witness.polynomial.degree() < self.parameters.parameters.starting_degree

       transcript = MerlinTranscript(b"STIR")
       transcript.append_message(b"merkle_root", witness.merkle_tree.root.encode('ascii'))
       folding_randomness = transcript.challenge_bytes(b"folding_randomness", 1)[0] % Fp.order()
       print("folding_randomness: ", folding_randomness)

       witness = WitnessExtended (
              domain= witness.domain,
              polynomial= witness.polynomial,
              merkle_tree= witness.merkle_tree,
              folded_evals= witness.folded_evals,
              num_round= 0,
              folding_randomness= folding_randomness,
       )

       round_proofs = []
       if debug: print("---------------------- begin proof ----------------------")
       if debug: print("num_rounds = ", self.parameters.num_rounds)
       for _ in range(0,self.parameters.num_rounds):
              result = self.round(Fp, transcript, witness, debug)
              witness = result["witness_extended"]
              if debug: print("witness:", witness)
              if debug: print("round_proof:", result["round_proof"])
              round_proofs.append(result["round_proof"])
              if debug: print(" ")
       
       if debug: print("------------- final round ------------------")
       if debug: print("witness: ", witness)
       if debug: print(R(witness.polynomial.coefficients))
       final_polynomial = poly_fold(witness.polynomial, witness.folding_randomness, self.parameters.parameters.folding_factor)
       
       
       final_randomness_indexs = []
       num_rounds = self.parameters.num_rounds
       if debug: print("num_rounds: ", num_rounds)
       repetitions = self.parameters.repetitions
       if debug: print("repetitions: ", repetitions)
       final_repetitions = repetitions[num_rounds]
       scaling_factor = witness.domain.size() / self.parameters.parameters.folding_factor
       for i in range(final_repetitions):
              index = transcript.challenge_bytes(b"final_randomness_index", 1)[0] % scaling_factor
              final_randomness_indexs.append(index)
       final_randomness_indexs = list(set(final_randomness_indexs))

       queries_to_final_index = final_randomness_indexs
       queries_to_final_ans = [witness.folded_evals[index] for index in final_randomness_indexs]
       queries_to_final_proof = [witness.merkle_tree.get_authentication_path(index) for index in final_randomness_indexs]
       queries_to_final = [queries_to_final_index, queries_to_final_ans, queries_to_final_proof]

       if debug: print("round proofs: ", round_proofs)
       if debug: print("final polynomial: ", R(final_polynomial.coefficients))
       if debug: print("queries to final: ", queries_to_final)

       return Proof(
              round_proofs= round_proofs,
              final_polynomial= final_polynomial,
              queries_to_final= queries_to_final
       )

## StirVerifier 类

In [68]:
class StirVerifier:
    def __init__(self, parameters: FullParameters):
        self.parameters = parameters
    
    def verify(self, commitment: Commitment, proof: Proof, domain: Domain, debug=False):
        result = True
        if debug: print("--------- verifier decision phase ---------")
        # 如果证明的最后的多项式的次数 + 1  > stopping_degree ，那么直接返回 false
        if debug: print("\n1. chekck final polynomial degree")
        if proof.final_polynomial.degree() + 1 > self.parameters.parameters.stopping_degree:
            result = False
            if debug: print("final polynomial degree not correct", result)
            return False
        if debug: print("check answer: ", result)

        # 验证 Merkle Tree 路径
        current_root = commitment.root
        if debug: print("\n2. verify merkle path")
        for round_proof in proof.round_proofs:
            if debug: print("current root: ", current_root)
            if debug: print("round_proof: ", round_proof)
            if debug: print("queries_to_prev", round_proof.queries_to_prev)
            if verify_multi_path(current_root, round_proof.queries_to_prev) == False: 
                result = False
                if debug: print("Merkle Tree is not correct")
                return False
            if debug: print(verify_multi_path(current_root, round_proof.queries_to_prev))
            current_root = round_proof.g_root
        
        # 验证最后路径
        if verify_multi_path(current_root, proof.queries_to_final) == False:
            result = False
            return False
        if debug: print(verify_multi_path(current_root, proof.queries_to_final))
        if debug: print("check answer: ", result)

        if debug: print("\n3. recompute randomness")
        transcript = MerlinTranscript(b"STIR")
        transcript.append_message(b"merkle_root", commitment.root.encode('ascii'))
        if debug: print("commitment root: ", commitment.root)
        folding_randomness = transcript.challenge_bytes(b"folding_randomness", 1)[0] % Fp.order()
        print("folding_randomness: ", folding_randomness)

        verification_state = VerificationState(OracleType.Initial, domain, folding_randomness,0, 0, [],DensePolynomial([]))
        
        if debug: print("\n4. verify round proof")
        for round_proof in proof.round_proofs:
            result_verifier_round = self.verifier_round(transcript, round_proof, verification_state, debug)
            if result_verifier_round["result"] == False: 
                result = False
                # return False
            if debug: print("result: ", result_verifier_round["result"])
            verification_state = result_verifier_round["VerificationState"]
        
        # 最后一轮检查
        if debug: print("\n5. verify final polynomial")
        final_repetitions = self.parameters.repetitions[self.parameters.num_rounds]
        scaling_factor = verification_state.domain.size() / self.parameters.parameters.folding_factor
        print("self.parameters: ", self.parameters)
        final_randomness_indexs = []
        for i in range(final_repetitions):
                index = transcript.challenge_bytes(b"final_randomness_index", 1)[0] % scaling_factor
                final_randomness_indexs.append(index)
        final_randomness_indexs = list(set(final_randomness_indexs))
        print("final_randomness_indexs: ", final_randomness_indexs)

        oracle_answers = proof.queries_to_final[1]
        print("oracle_answers: ", oracle_answers)

        # compute folded answers
        if debug: print("\ncompute folded answers:")
        folded_answers = []
        L_i_mins_1 = verification_state.domain.backing_domain
        domain_stacks = stack_evaluations(L_i_mins_1, self.parameters.parameters.folding_factor)
        print("domain_stacks: ", domain_stacks)

        oracle_points = [domain_stacks[index] for index in final_randomness_indexs]
        print("oracle_points: ", oracle_points)

        # compute random_points
        L_k = verification_state.domain.scale_generator_by(self.parameters.parameters.folding_factor)
        final_randomness = [L_k.backing_domain[index] for index in final_randomness_indexs]
        print("final_randomness: ", final_randomness)

        new_oracle_points = []
        new_oracle_answers = []
        for i in range(len(oracle_points)):
            new_oracle_point = []
            new_oracle_answer = []
            add_flag = True
            for j in range(len(oracle_points[i])):
                for quotient_point in verification_state.quotient_set:
                    if quotient_point == oracle_points[i][j]: add_flag = False
            if add_flag == True:
                new_oracle_points.append(oracle_points[i])
                new_oracle_answers.append(oracle_answers[i])

        print("new_oracle_points: ", new_oracle_points)
        print("new_oracle_answers: ", new_oracle_answers)


        print("quotient_set: ", verification_state.quotient_set)
        oracle_answers = compute_oracle_final(verification_state, proof, new_oracle_points, new_oracle_answers)
        print("oracle_answers: ", oracle_answers)

        if debug: print("folding_randomness: ", verification_state.folding_randomness)

        if debug: print("\ncompute Ans(X):")
        for i in range(len(new_oracle_points)):
            interpolate_points = [[new_oracle_points[i][j], oracle_answers[i][j]] for j in range(len(new_oracle_points[i]))]    
            print("interpolate_points: ", interpolate_points)
            interpolation_polynomial = interpolation(Fp, interpolate_points)
            ans = interpolation_polynomial.evaluate(verification_state.folding_randomness)
            print("Ans(X): ", ans)
            folded_answers.append([final_randomness[i],ans])
        if debug: print("get folded_answers:")
        print("folded_answers: ", folded_answers)


        for i in range(len(folded_answers)):
            if proof.final_polynomial.evaluate(folded_answers[i][0]) != folded_answers[i][1]:
                result = False
        return result
    
    def verifier_round(self, transcript: MerlinTranscript, round_proof: RoundProof, verification_state: VerificationState, debug=False):
        if debug: print("--------- begin verifier round ------------")
        flag = True
        if debug: print("\nrecompute randomness:")
        transcript.append_message(b"merkle_root", round_proof.g_root.encode('ascii'))
        if debug: print("round_proof.g_root: ", round_proof.g_root)
        odd_randomness = [transcript.challenge_bytes(b"odd_randomness", 1)[0] % Fp.order() for _ in range(self.parameters.ood_samples)]
        odd_randomness = list(set(odd_randomness))
        if debug: print("odd_randomness: ", odd_randomness)


        for i in range(len(round_proof.betas)):
            transcript.append_message(f'betas_at_{i}'.encode('ascii'), str(round_proof.betas[i]).encode('ascii'))
        comb_randomness = transcript.challenge_bytes(b"comb_randomness", 1)[0] % Fp.order()
        if debug: print("comb_randomness: ", comb_randomness)
        new_folding_randomness = transcript.challenge_bytes(b"folding_randomness", 1)[0] % Fp.order()
        if debug: print("new_folding_randomness: ", new_folding_randomness)

        scaling_factor = verification_state.domain.size() / self.parameters.parameters.folding_factor
        print("scaling_factor: ", scaling_factor)

        num_repetitions = self.parameters.repetitions[verification_state.num_round]
        print("num_repetitions: ", num_repetitions)
        stir_randomness_indexes = []
        for i in range(num_repetitions):
            index = transcript.challenge_bytes(b"stir_randomness_indexes", 1)[0] % scaling_factor
            stir_randomness_indexes.append(index)
        print("stir_randomness_indexes: ", stir_randomness_indexes)

        # 去重
        if debug: print("\nremove duplicates from the stir_randomness_indexes: ")
        stir_randomness_indexes = list(set(stir_randomness_indexes))
        L_k = verification_state.domain.scale_generator_by(self.parameters.parameters.folding_factor)
        for odd_random in odd_randomness:
            for index in stir_randomness_indexes:
                if L_k.backing_domain[index] == odd_random:
                    stir_randomness_indexes.pop(index)
        print("stir_randomness_index: ", stir_randomness_indexes)

        _shake_randomness = transcript.challenge_bytes(b"_shake_randomness", 1)[0] % Fp.order()
        print("shake_randomness: ", _shake_randomness)

        print(verification_state.domain.backing_domain)
        L_i_mins_1 = verification_state.domain.backing_domain

        domain_stacks = stack_evaluations(L_i_mins_1, self.parameters.parameters.folding_factor)
        print("domain_stacks: ", domain_stacks)

        oracle_points = [domain_stacks[index] for index in stir_randomness_indexes]
        print("oracle_points: ", oracle_points)

        # compute oracle answers
        # oracle_answers = round_proof.queries_to_prev[1]
        print("round_proof.queries_to_prev[1] = ", round_proof.queries_to_prev[1])
        # if debug: print("oracle_answers", oracle_answers)

        oracle_answers = compute_oracle_answers(verification_state, round_proof, oracle_points)
        
        if debug: print("\ncompute oracle answers:")
        if debug: print("oracle_answers", oracle_answers)

        L_k = verification_state.domain.scale_generator_by(self.parameters.parameters.folding_factor)
        stir_randomness = [L_k.backing_domain[index] for index in stir_randomness_indexes]
        print("stir_randomness: ", stir_randomness)

        if debug: print("\ncompute folded answers by oracle answers")
        folded_answers = []
        for i in range(len(oracle_points)):
            interpolate_points = [[oracle_points[i][j], oracle_answers[i][j]] for j in range(len(oracle_points[i]))]    
            if debug: print("interpolate_points: ", interpolate_points)
            interpolation_polynomial = interpolation(Fp, interpolate_points)
            ans = interpolation_polynomial.evaluate(verification_state.folding_randomness)
            if debug: print("interpolate polynomial: ", ans)
            folded_answers.append([stir_randomness[i],ans])
        if debug: print("folded_answers: ", folded_answers)

        if debug: print("\nverifier get quotient set and quotient answers:")
        quotient_answers = []
        for i in range(len(round_proof.betas)):
            quotient_answers.append([odd_randomness[i],round_proof.betas[i]])
        for folded_answer in folded_answers: quotient_answers.append(folded_answer)
        if debug: print("quotient_answers: ", quotient_answers)

        quotient_set = []
        for odd_random in odd_randomness: quotient_set.append(odd_random)
        for stir_random in stir_randomness: quotient_set.append(stir_random)
        if debug: print("quotient_set: ", quotient_set)


        if debug: print("\nverifier evaluate round_proof.ans(_shake_randomness) and round_proof.shake(_shake_randomness):")
        interpolating_polynomial = round_proof.ans_polynomial
        print("interpolating_polynomial: ", interpolating_polynomial)
        ans_eval = interpolating_polynomial.evaluate(_shake_randomness)
        print("ans_eval: ", ans_eval)
        shake_eval = round_proof.shake_polynomial.evaluate(_shake_randomness)
        print("shake_eval: ", shake_eval)

        if debug: print("\nverifier compute Ans(X) and check Ans(shake_randomness) == proof.ans(shake_randomness):")
        ans_polynomial = interpolation(Fp, quotient_answers)
        print("Ans(X)'s coefficients: ", ans_polynomial.coefficients)
        verifier_compute_ans_eval = ans_polynomial.evaluate(_shake_randomness)
        if ans_eval != verifier_compute_ans_eval: flag = False
        if debug: print("check result: ", flag)

        # compute Shake(X)
        if debug: print("\nnow veriifier compute Shake(X) by own")
        R.<X> = Fp[]
        ans_X = R(ans_polynomial.coefficients)
        if debug: print("ans(X) = ", ans_X)
        shake_X = 0 * X
        for point in quotient_answers:
            num_polynomial = R(ans_X - point[1])
            den_polynomial = R(X - point[0])
            # 进行多项式除法，获取商和余数
            quotient, remainder = num_polynomial.quo_rem(den_polynomial)
            # print("quotient = ", quotient)
            # print("remainder = ", remainder)
            assert remainder == 0 * X
            shake_X = shake_X + quotient
            # print("shake_X = ", shake_X)
        print("Shake(X) = ", shake_X)

        # print(shake_X.coefficients())
        # shake_polynomial = DensePolynomial(shake_X.coefficients()) # 此方法返回非零系数
        shake_polynomial = DensePolynomial(list(shake_X))

        if debug: print("\nverifier check Shake(_shake_randomness) == proof.shake(_shake_randomness)")
        verifier_compute_shake_eval = shake_polynomial.evaluate(_shake_randomness)
        if debug: print("verifier_compute_shake_eval: ", verifier_compute_shake_eval)
        if verifier_compute_shake_eval != shake_eval: flag = False
        if debug: print("check answer: ", flag)

        return {
            "VerificationState": VerificationState(
            oracle=OracleType.Virtual,
            domain=verification_state.domain.scale_with_offset(2),
            folding_randomness=new_folding_randomness,
            num_round=verification_state.num_round + 1,
            comb_randomness=comb_randomness,
            quotient_set=quotient_set,
            ans_polynomial=ans_polynomial
            ),
            "result": flag
        }

### Verifier 需要的工具函数

In [49]:
def compute_g_prime_evals(Fp, quotient_set, g_hat_evaluations, ans_polynomial: DensePolynomial, points):
    '''
    quotient_set = G_i
    g_hat_evaluations = \hat{g_i}
    '''
    g_prime_evaluations = []
    for i in range(len(points)):
        product = 1
        point = points[i] # x
        for g in quotient_set:
            product *= point - g
            # print("product: ", product)
        ans_eval_point = ans_polynomial.evaluate(point)
        g_prime_evaluations.append(Fp((g_hat_evaluations[i] - ans_eval_point)/product))
    return g_prime_evaluations

In [55]:
def compute_oracle_answers(verification_state: VerificationState, round_proof: RoundProof, oracle_points):
    oracle_answers = []
    match verification_state.oracle:
        case OracleType.Initial:
            # print("OracleType.Initial")
            oracle_answers = round_proof.queries_to_prev[1]
            # print("oracle_answers: ", oracle_answers)
        case OracleType.Virtual:
            # print("OracleType.Virtual")
            # queries_to_prev_ans = [56, 149, 106, 255]
            # 先测试一个
            queries_to_prev = round_proof.queries_to_prev[1]
            # print("queries_to_prev: ", queries_to_prev)
            for j in range(len(queries_to_prev)):
                oracle_answer = []
                g_hat_evaluations = round_proof.queries_to_prev[1][j]
                g_prime_evals = compute_g_prime_evals(Fp, verification_state.quotient_set, g_hat_evaluations, verification_state.ans_polynomial, oracle_points[j])
                # queries_to_prev_ans = round_proof.queries_to_prev[1][0]
                print("\ncompute g'(x) evals: ")
                print("g_prime_evals: ", g_prime_evals)
                
                correct_degree = len(verification_state.quotient_set)
                # print("correct_degree: ", correct_degree)
                for i in range(len(oracle_points[j])):
                    point_x = oracle_points[j][i]
                    r = verification_state.comb_randomness
                    # print("combine randomness: ", r)
                    f_x = g_prime_evals[i]
                    if Fp(point_x * r) == Fp(1):
                        oracle_answer.append(f_x*(correct_degree+1))
                    else:
                        common_factor_inverse = Fp(1 - r * point_x).inverse()
                        e_plus_one = correct_degree + 1
                        r_times_x = Fp(r * point_x)
                        a = Fp(f_x) * (Fp(1) - Fp(r_times_x)**Fp(e_plus_one))
                        answer = Fp(a * common_factor_inverse)
                        # print("answer: ", answer)
                        oracle_answer.append(answer)
                oracle_answers.append(oracle_answer)
        case _:
            print("Wrong case!")
    return oracle_answers


In [56]:
def compute_oracle_final(verification_state: VerificationState, proof: Proof, oracle_points, new_oracle_answers):
    oracle_answers = []
    queries_to_prev = new_oracle_answers
    # print("queries_to_final: ", queries_to_prev)
    for j in range(len(queries_to_prev)):
        oracle_answer = []
        g_hat_evaluations = new_oracle_answers[j]
        g_prime_evals = compute_g_prime_evals(Fp, verification_state.quotient_set, g_hat_evaluations, verification_state.ans_polynomial, oracle_points[j])
        # queries_to_prev_ans = round_proof.queries_to_prev[1][0]
        print("\ncompute g'(x) evals: ")
        print("g_prime_evals: ", g_prime_evals)
        
        correct_degree = len(verification_state.quotient_set)
        # print("correct_degree: ", correct_degree)
        for i in range(len(oracle_points[j])):
            point_x = oracle_points[j][i]
            r = verification_state.comb_randomness
            # print("combine randomness: ", r)
            f_x = g_prime_evals[i]
            if Fp(point_x * r) == Fp(1):
                oracle_answer.append(f_x*(correct_degree+1))
            else:
                common_factor_inverse = Fp(1 - r * point_x).inverse()
                e_plus_one = correct_degree + 1
                r_times_x = Fp(r * point_x)
                a = Fp(f_x) * (Fp(1) - Fp(r_times_x)**Fp(e_plus_one))
                answer = Fp(a * common_factor_inverse)
                # print("answer: ", answer)
                oracle_answer.append(answer)
        oracle_answers.append(oracle_answer)
    return oracle_answers

## Example 1 ：Test StirProver prove function

In [73]:
p = 257
Fp = GF(p)
g = Fp.multiplicative_generator() # 生成元
order = g.multiplicative_order()  # Fp 生成元的阶
subgroup_generator = g^(order // 64) # 一个大小为 64 的群
root_of_unity = subgroup_generator
root_of_unity_inv =  subgroup_generator.inverse()
backing_domain =  [subgroup_generator^i for i in range(64)]
offset = 1

domain = Domain(root_of_unity, root_of_unity_inv, offset, backing_domain)
print("domain: ", domain)

coefficients = squeeze_field_elements(Fp, 64)
f0 = DensePolynomial(coefficients)
f0.coefficients

f0_evals = f0.evaluate_over_domain(domain.backing_domain)
print(f0_evals)

folding_factor = 4
folded_evals = stack_evaluations(f0_evals, folding_factor)
print("folded_evals for f_0 = ", folded_evals)

f0_merkle_tree = MerkleTree(folded_evals)
print(f0_merkle_tree.root)

parameters = Parameters(security_level=10, 
                        protocol_security_level=128,
                        starting_degree=64,
                        stopping_degree=1,
                        folding_factor=4,
                        starting_rate=2,
                        soundness_type=SoundnessType.Conjecture,
                        fiat_shamir_config=MerlinTranscript(b"initial")
                        )
full_parameters = FullParameters.from_parameters(parameters)
print("full_parameters: ", full_parameters)
stir_prover = StirProver(full_parameters)   


ans = stir_prover.commit(domain, f0)
commitment = ans["commitment"]
print(commitment.root)

domain:  Domain(root_of_unity=81, root_of_unity_inv=165, backing_domain=[1, 81, 136, 222, 249, 123, 197, 23, 64, 44, 223, 73, 2, 162, 15, 187, 241, 246, 137, 46, 128, 88, 189, 146, 4, 67, 30, 117, 225, 235, 17, 92, 256, 176, 121, 35, 8, 134, 60, 234, 193, 213, 34, 184, 255, 95, 242, 70, 16, 11, 120, 211, 129, 169, 68, 111, 253, 190, 227, 140, 32, 22, 240, 165])
[256, 62, 51, 219, 101, 243, 28, 248, 16, 132, 40, 189, 240, 211, 111, 236, 208, 122, 148, 229, 32, 71, 161, 48, 162, 206, 80, 225, 153, 207, 17, 89, 199, 115, 173, 118, 199, 83, 25, 179, 57, 215, 208, 161, 91, 170, 213, 91, 61, 178, 177, 225, 169, 150, 241, 168, 33, 163, 196, 166, 175, 72, 96, 80]
folded_evals for f_0 =  [[256, 208, 199, 61], [62, 122, 115, 178], [51, 148, 173, 177], [219, 229, 118, 225], [101, 32, 199, 169], [243, 71, 83, 150], [28, 161, 25, 241], [248, 48, 179, 168], [16, 162, 57, 33], [132, 206, 215, 163], [40, 80, 208, 196], [189, 225, 161, 166], [240, 153, 91, 175], [211, 207, 170, 72], [111, 17, 213, 96],

In [74]:
# stir_prover.round(Fp, transcript, witness_extend, True)
witness = Witness(domain, f0, f0_merkle_tree, folded_evals)
proof = stir_prover.prove(witness, True)

folding_randomness:  207
---------------------- begin proof ----------------------
num_rounds =  2
--------- prove round begin ---------

1. Send folded function 
Before folded, f(X) =  207*X^63 + 146*X^62 + 247*X^61 + 142*X^60 + 118*X^59 + 46*X^58 + 188*X^57 + 43*X^56 + 191*X^55 + 76*X^54 + 24*X^53 + 103*X^52 + 245*X^51 + 140*X^50 + 222*X^49 + 155*X^48 + 123*X^47 + 243*X^46 + 39*X^45 + 80*X^44 + 96*X^43 + 70*X^42 + 169*X^41 + 242*X^40 + 54*X^39 + 75*X^38 + 152*X^37 + 234*X^36 + 30*X^35 + 39*X^34 + 66*X^33 + 218*X^32 + 231*X^31 + 153*X^30 + 124*X^29 + 132*X^28 + 212*X^27 + 222*X^26 + 239*X^25 + 54*X^24 + 91*X^23 + 170*X^22 + 56*X^21 + 32*X^20 + 40*X^19 + 252*X^18 + 22*X^17 + 148*X^16 + 123*X^15 + 149*X^14 + 228*X^13 + 218*X^12 + 17*X^11 + 6*X^10 + 254*X^9 + 36*X^8 + 68*X^7 + 219*X^6 + 25*X^5 + 144*X^4 + 86*X^3 + 225*X^2 + 25*X + 256
After folded, g(X) =  205*X^15 + 17*X^14 + 51*X^13 + 220*X^12 + 163*X^11 + 100*X^10 + 82*X^9 + 243*X^8 + 162*X^7 + 104*X^6 + 58*X^5 + 104*X^4 + X^3 + 156*X

In [75]:
stir_verifier = StirVerifier(full_parameters)
result = stir_verifier.verify(commitment, proof, domain, debug=True)
print("result: ", result)

--------- verifier decision phase ---------

1. chekck final polynomial degree
check answer:  True

2. verify merkle path
current root:  3e46d093c69bb2e7609ef77fcbba8bb05c65f5002695be1f9a17052aa4306758
round_proof:  RoundProof(g_root=77bb2506b4ce312581093ccbfecb5463556d1eb69ae2d01c42de7dfe16750287, betas=[103, 140], queries_to_prev=[[9, 10, 3], [[132, 206, 215, 163], [40, 80, 208, 196], [219, 229, 118, 225]], [['3776ac6941c04d68a20a23296c46c6bac854a4d5026b2cffcdf5dfbcd4beabe4', '7a46e5632b2d19494d728638f969960a5fe19f9300f74c520afab219902e1571', '29a13f9ff58162a10ce98abe969ad165254e896131f60473fb74f5bf05fcd72a', '1ba0ca310af7ce43fed420982a869da3ed26bc55bad3613b451dd30c6ed88a53'], ['3776ac6941c04d68a20a23296c46c6bac854a4d5026b2cffcdf5dfbcd4beabe4', '7a46e5632b2d19494d728638f969960a5fe19f9300f74c520afab219902e1571', 'e2baac975d0a07ccf475264716e5e8eb73ade4b42f8920eb066496dda70f4067', '3e3eea51a5ad76f1da797791552a0154634735b1af2d21848bf923834cf108ac'], ['f631f187faf0652dc530c82e4023be676237