In [5]:
import secrets
from py_ecc.optimized_bls12_381 import G1, G2, curve_order,add, multiply, pairing,final_exponentiate, FQ12
from mpyc.runtime import mpc
from py_ecc.optimized_bls12_381.optimized_curve import neg
from __future__ import annotations
from typing import List, Tuple
import time
import random

In [6]:
class BBSPlus:
    def __init__(self, ell: int):
        self.ell = ell
        self.g1 = G1
        self.g2 = G2
        self.H = [multiply(G1, secrets.randbelow(curve_order)) for _ in range(ell + 1)]
        self.x = secrets.randbelow(curve_order)
        self.X = multiply(G2, self.x)

    def sign(self, messages: list[int]) -> tuple:
        if len(messages) != self.ell:
            raise ValueError(f"Expected {self.ell} messages, got {len(messages)}")
        e = secrets.randbelow(curve_order)
        s = secrets.randbelow(curve_order)
        B = self.g1
        B = add(B, multiply(self.H[0], s))
        for mi, Hi in zip(messages, self.H[1:]):
            mi_mod = mi % curve_order
            B = add(B, multiply(Hi, mi_mod))
        denom = (self.x + e) % curve_order
        inv = pow(denom, -1, curve_order)
        A = multiply(B, inv)
        return (A, e, s)

    def verify(self, messages: list[int], signature: tuple) -> bool:
        A, e, s = signature
        if len(messages) != self.ell:
            return False
        B = self.g1
        B = add(B, multiply(self.H[0], s))
        for mi, Hi in zip(messages, self.H[1:]):
            mi_mod = mi % curve_order
            B = add(B, multiply(Hi, mi_mod))
        left_input = add(self.X, multiply(self.g2, e % curve_order))  # G2 element
        left = pairing(left_input, A)
        right = pairing(self.g2, B)
        return left == right

In [7]:
def _ec_add(P, Q):
    """Add two G1 points, with None as the point-at-infinity."""
    if P is None:
        return Q
    if Q is None:
        return P
    return add(P, Q)

def msm_g1(bases: List[Tuple], scalars: List[int], *, window: int = 4, g1=G1):
    """Windowed MSM for G1. Returns sum_i scalars[i] * bases[i].
    This is a compact educational MSM; for production use blst/relic/mcl MSM APIs.
    """
    assert len(bases) == len(scalars)
    if not bases:
        return add(g1, multiply(g1, 0))

    w = window
    n_buckets = 1 << w
    max_bits = max((s % curve_order).bit_length() for s in scalars) or 1
    n_windows = (max_bits + w - 1) // w

    R = None  # infinity
    mask = (1 << w) - 1

    for j in range(n_windows - 1, -1, -1):
        # R <<= w (w doublings)
        if R is not None:
            R = multiply(R, 1 << w)
        # Fill buckets for this window
        buckets = [None] * n_buckets
        shift = j * w
        for base, s in zip(bases, scalars):
            d = (int(s) >> shift) & mask
            if d:
                buckets[d] = _ec_add(buckets[d], base)
        # Accumulate buckets high to low
        running = None
        win_sum = None
        for b in range(n_buckets - 1, 0, -1):
            if buckets[b] is not None:
                running = _ec_add(running, buckets[b])
                win_sum = _ec_add(win_sum, running)
        R = _ec_add(R, win_sum)
    return add(g1, multiply(g1, 0)) if R is None else R


def compute_B(g1, H: List[Tuple], messages: List[int], s: int):
    scalars = [1, s % curve_order] + [mi % curve_order for mi in messages]
    bases = [g1] + H[:1] + H[1:]
    return msm_g1(bases, scalars, window=4, g1=g1)

class BBSPlusOptimized:
    def __init__(self, ell: int):
        self.ell = ell
        self.g1, self.g2 = G1, G2
        self.H = [multiply(G1, secrets.randbelow(curve_order)) for _ in range(ell + 1)]
        self.x = secrets.randbelow(curve_order)
        self.X = multiply(G2, self.x)

    def sign(self, messages: List[int]):
        if len(messages) != self.ell:
            raise ValueError(f"Expected {self.ell} messages, got {len(messages)}")
        e = secrets.randbelow(curve_order)
        s = secrets.randbelow(curve_order)
        B = compute_B(self.g1, self.H, messages, s)
        inv = pow((self.x + e) % curve_order, -1, curve_order)
        A = multiply(B, inv)
        return (A, e, s)

    def verify(self, messages: List[int], signature: Tuple) -> bool:
        A, e, s = signature
        if len(messages) != self.ell:
            return False
        B = compute_B(self.g1, self.H, messages, s)
        Q1 = add(self.X, multiply(self.g2, e % curve_order))
        f1 = pairing(Q1, A, final_exponentiate=False)
        f2 = pairing(self.g2, neg(B), final_exponentiate=False)
        return final_exponentiate(f1 * f2) == FQ12.one()

In [8]:
def bench(ell, trials=10):
    msgs = [random.randrange(curve_order) for _ in range(ell)]
    base = BBSPlus(ell)
    opt  = BBSPlusOptimized(ell)

    # Warm-up
    sig0 = base.sign(msgs); base.verify(msgs, sig0)
    sig1 = opt.sign(msgs);  opt.verify(msgs, sig1)

    # Sign
    t0 = time.perf_counter()
    for _ in range(trials):
        base.sign(msgs)
    t1 = time.perf_counter()

    t2 = time.perf_counter()
    for _ in range(trials):
        opt.sign(msgs)
    t3 = time.perf_counter()

    # Verify
    sig0 = base.sign(msgs)
    sig1 = opt.sign(msgs)

    t4 = time.perf_counter()
    for _ in range(trials):
        base.verify(msgs, sig0)
    t5 = time.perf_counter()

    t6 = time.perf_counter()
    for _ in range(trials):
        opt.verify(msgs, sig1)
    t7 = time.perf_counter()

    print(f"ell={ell}, trials={trials}")
    print(f"  sign:  base={(t1-t0)/trials:.4f}s  opt={(t3-t2)/trials:.4f}s")
    print(f"  verify:base={(t5-t4)/trials:.4f}s  opt={(t7-t6)/trials:.4f}s")

for ell in [3, 8, 16, 32, 64]:
    bench(ell, trials=5)

ell=3, trials=5
  sign:  base=0.0326s  opt=0.0173s
  verify:base=0.8048s  opt=0.2187s
ell=8, trials=5
  sign:  base=0.0654s  opt=0.0263s
  verify:base=0.8480s  opt=0.2306s
ell=16, trials=5
  sign:  base=0.1214s  opt=0.0383s
  verify:base=0.8917s  opt=0.2475s
ell=32, trials=5
  sign:  base=0.2226s  opt=0.0590s
  verify:base=1.0005s  opt=0.2626s
ell=64, trials=5
  sign:  base=0.4296s  opt=0.0970s
  verify:base=1.2127s  opt=0.3000s
