In [15]:
import hashlib

In [16]:
def init_schnorr_group():
    def generate_q_r_p():
        q = random_prime(2^128)
        r = randrange(1, 2^128)
        p = q * r + 1
        return q, r, p

    q, r, p = generate_q_r_p()
    while not is_prime(p) or not is_prime(q):
        q, r, p = generate_q_r_p()

    h = randrange(2, p)
    while pow(h, r, p) == 1:
        h = rangdrange(2, p)
        
    g = pow(h, r, p)

    return q, r, p, h, g

In [17]:
q, r, p, h, g = init_schnorr_group()
q, r, p, h, g

(100046075239591942434489303188176921637,
 242424725939097071700735228314386621910,
 24253642371240362049549125591238689522767925366065780820002063680879217266671,
 23194359962280232696940270728863367210233245942733215658717754642643137352960,
 2196910842516590547964977442169540993197785195830411487960553675044397505600)

In [167]:
c = 10
n = 3

In [168]:
class User:
    def __init__(self, uid, c):
        self.uid = uid
        self.c = c
        
    def set_price(self, price):
        self.price = price
        
        """
        price_bits[0]: LSB, price_bits[c]: MSB
        price_bits[i] structure: {
            "bit": int,
            "alpha": int,
            "beta": int,
            "x": int,
            "r": int,
            "d": int,
        }
        """
        self.price_bits = tuple({"bit": int(b)} for b in reversed(format(price, f"0{self.c}b")))

    def commit(self, bit_idx, alpha, beta):
        self.price_bits[bit_idx]['alpha'] = alpha
        self.price_bits[bit_idx]['beta'] = beta
        
    def round1(self, bit_idx, x, r):
        self.price_bits[bit_idx]['x'] = x
        self.price_bits[bit_idx]['r'] = r
        
    def round2_add_d(self, bit_idx, d):
        self.price_bits[bit_idx]['d'] = d
        
    def get_price_bit(self, bit_idx):
        return self.price_bits[bit_idx]['bit']
    
    def get_price_d(self, bit_idx):
        return self.price_bits[bit_idx]['d']
    
    def get_x(self, bit_idx):
        return self.price_bits[bit_idx]['x']
    
    def get_r(self, bit_idx):
        return self.price_bits[bit_idx]['r']

In [169]:
class BulletinBoard:
    def __init__(self, c):
        self.c = c
        self.n = 0
        """
        data: {
            uid: (
                bit_idx: {
                    "epsilon": (c: int, A: int, B: int),
                    "X": int,
                    "R": int,
                    "b": int,
                }
            )
        }
        """
        self.data = {}
        
    def join(self, uid):
        self.data[uid] = tuple({} for _ in range(self.c))
        self.n += 1
        
    def commit(self, uid, bit_idx, epsilon):
        self.data[uid][bit_idx]["epsilon"] = epsilon
        
    def round1(self, uid, bit_idx, X, R):
        self.data[uid][bit_idx]["X"] = X
        self.data[uid][bit_idx]["R"] = R
        
    def round2(self, uid, bit_idx, b):
        self.data[uid][bit_idx]["b"] = b
        
    def get_g_x_all(self, bit_idx):
        g_x = {}
        for uid, bits in self.data.items():
            g_x[uid] = bits[bit_idx]["X"]
        return g_x
    
    def get_b_all(self, bit_idx):
        b = {}
        for uid, bits in self.data.items():
            b[uid] = bits[bit_idx]["b"]
        return b

In [170]:
pub_board = BulletinBoard(c)
users = {}

for uid in range(n):
    user = User(uid, c)
    user.set_price(randrange(0, 2**c-1))
    pub_board.join(uid)
    users[uid] = user
    print(user.uid, user.price, user.price_bits)

0 627 ({'bit': 1}, {'bit': 1}, {'bit': 0}, {'bit': 0}, {'bit': 1}, {'bit': 1}, {'bit': 1}, {'bit': 0}, {'bit': 0}, {'bit': 1})
1 417 ({'bit': 1}, {'bit': 0}, {'bit': 0}, {'bit': 0}, {'bit': 0}, {'bit': 1}, {'bit': 0}, {'bit': 1}, {'bit': 1}, {'bit': 0})
2 561 ({'bit': 1}, {'bit': 0}, {'bit': 0}, {'bit': 0}, {'bit': 1}, {'bit': 1}, {'bit': 0}, {'bit': 0}, {'bit': 0}, {'bit': 1})


In [171]:
def commit(user):
    print(f"commit uid {user.uid} price: {user.price} to the bulletin board")

    # iterate price bit array from LSB
    for bit_idx, price_bit in enumerate(user.price_bits):
        # generate alpha, beta
        alpha = g^randrange(1, p)
        beta = g^randrange(1, p)
        user.commit(bit_idx, alpha, beta)
        
        # generate epsilon
        epsilon = gen_epsilon(alpha, beta, price_bit['bit'])
        
        # push to bulletin board
        pub_board.commit(uid, bit_idx, epsilon)

def gen_epsilon(alpha, beta, price_bit):
    return (g^(alpha*beta + price_bit), g^alpha, g^beta)

for uid, u in users.items():
    commit(u)

# TODO: NIZK proof of well formedness

commit uid 0 price: 627 to the bulletin board
commit uid 1 price: 417 to the bulletin board
commit uid 2 price: 561 to the bulletin board


In [173]:
junction_idx = None

def round1(user, bit_idx):
    # generate x, r
    x = g^randrange(1, p)
    r = g^randrange(1, p)
    user.round1(bit_idx, x, r)
    
    # publish public key
    pub_board.round1(uid, bit_idx, g^x, g^r)
    
    # TODO: publish NIZK proofs of knowledge of x = log_g^X and r = log_g^Y

"""
g_x: {
    uid: int
}
"""
def round2_calc_g_y(g_x_all: dict, uid: int):
    g_x_lt_uid = Zmod(p)(1)
    g_x_gt_uid = Zmod(p)(1)

    for i, g_xi in g_x_all.items():
        if i < uid:
            g_x_lt_uid *= g_xi
        elif i > uid:
            g_x_gt_uid *= g_xi

    return g_x_lt_uid * g_x_gt_uid^(-1)

def round2(user: User, bit_idx):
    price_bit = user.get_price_bit(bit_idx)
    x, r = user.get_x(bit_idx), user.get_r(bit_idx)
    
    # store d
    if junction_idx is None:
        price_d = price_bit
        user.round2_add_d(bit_idx, price_d)
    else:
        price_d = user.get_price_d(junction_idx) & price_bit
        user.round2_add_d(bit_idx, price_d)
    
    # generate cryptogram
    if price_d == 0: # 0-cryptogram
        g_x_all = pub_board.get_g_x_all(bit_idx)
        g_y = round2_calc_g_y(g_x_all, user.uid)
        b = g_y ^ x
    elif price_d == 1: # 1-cryptogram
        b = g ^ (x * r)
        
    pub_board.round2(uid, bit_idx, b)


for bit_idx in reversed(range(c)):
    for uid, u in users.items():
        round1(u, bit_idx)

    for uid, u in users.items():
        round2(u, bit_idx)
  
    # verify
    b_all = pub_board.get_b_all(bit_idx)
    
    tmp = Zmod(p)(1)
    for uid, b in b_all.items():
        tmp *= b
    T = 1 if tmp != 1 else 0
        
    if T == 1: # => update junction idx
        junction_idx = bit_idx
        
    print(bit_idx, T)

9 1
8 0
7 0
6 1
5 1
4 1
3 0
2 0
1 1
0 1
