# BLS Demo

In this notebook we try to illustrate the BLS distributed key generation and verifiable randomness protocol.

In [1]:
import bn256
import random
import os
import binascii
import hashlib
import collections
import ecdsa
import unittest
import numpy as np
from pypoly import Polynomial

### Initialization
 
We introduce $N$ participants and use as our threshold $t$, where 
$N = 2n + 1 = n + (n + 1)$ and $t = n + 1$. $n$ is the degree of the polynomials 
  
Here we are using the multiplicative representations (i.e. $a \cdot G$), since this is in accordance with the corresponding algebra of elliptic curves.

In [2]:
n = 10        # Degree of polynomial
N = 2 * n + 1 # Number of users/participants
t = n + 1     # Threshold

### Necessary functions: batch 1

In [3]:
### Following three functions taken straight from the BN256 pairings code

def bls_keygen():
    k,g = bn256.g2_random()
    return (k,g)

def bls_sign(privkey, msg):
    pt = bn256.g1_hash_to_point(msg)
    assert pt.is_on_curve()

    return bn256.g1_compress(pt.scalar_mul(privkey))

def bls_verify(pubkey, msg, csig):
    sig = bn256.g1_uncompress(csig)

    assert type(pubkey) == bn256.curve_twist
    assert type(sig) == bn256.curve_point

    msg_pt = bn256.g1_hash_to_point(msg)

    assert msg_pt.is_on_curve()

    v1 = bn256.optimal_ate(pubkey, msg_pt)
    v2 = bn256.optimal_ate(bn256.twist_G, sig)

    return v1 == v2

####

def mod_q(num): # q = order: order of G1, G2
    qq = bn256.order
    return pow(num, int(1), qq)

####

def generate_random_number(n=32) -> bytes: # Taken straight from someone else's notebook
    """Collect n bytes of random data.

    Temporary implementation since os.urandom might be not safe enough!
    The quality of the randomness depends on the OS implementation.
    """
    return os.urandom(n)

#####

def rnd_poly_coeffs(d):
    # Retruns random coefficients for a polynomial of degree d
    a = [ int.from_bytes(generate_random_number(), 'little')  for i in range(d+1)]
    return a

def poly(coeffs, x): # Calculate individual terms of the form coeff_j * x^j as well as the sum p(x) = coeff_0 + ... coeff_n * x^n
    terms = [ coeffs[j] * x**j for j in range(len(coeffs)) ]
    s = 0
    for term in terms:
        s += term
    return terms, s 

### Necessary functions: batch 2

In [4]:
def lagrange_coeff(i, lst, t): # Lagrange coefficients, used for Lagrange interpolation
    assert (t == len(lst))
    parts = [j+1 for j in lst] # j is the array index, j+i the participant ID
    idx_i = parts.index(i+1)
    parts.pop(idx_i)
    prod = 1
    for j in parts:
        denom = pow( j - (i+1), bn256.order - 2, bn256.order)
        prod *= mod_q(j * denom)
    return mod_q(prod)
    
# Introducing the generators
G1 = bn256.curve_G # Generator of EC1
G2 = bn256.twist_G # Generator of EC2

def G2mult(coeff): # Multiply generator of curve 2 (G2) by an integer
    return bn256.g2_scalar_base_mult(coeff)

def G2mult2(coeff_list): # Create list made of generator of curve 2 (G2) multiplied by integers in coeff_list
    return [ bn256.g2_scalar_base_mult(coeff) for coeff in coeff_list ]

def AnyG1Pointmult(AG1,x): # Multiply general points on curve 1 A*G1 with another integer scalar x 
    if x < 0:
        #return AnyG1Pointmult(AG1.inverse(), int(abs(x)))
        return AnyG1Pointmult(AG1, bn256.order - int(x))
    else:
        return AG1.scalar_mul(x)

def AnyG2Pointmult(AG2,x): # Multiply general points on curve 2 A*G2 with another integer scalar x 
    if x < 0:
        #return AnyG2Pointmult(AG2.inverse(), int(abs(x)))
        return AnyG2Pointmult(AG2, bn256.order - int(x))
    else:
        return AG2.scalar_mul(x)

def recursive_g1_add(lst):  # Addition of points on curve 1. Done in this way because '+' operator doesn't work for EC points (in this package)
    while len(lst) > 2:
        sum_points = bn256.g1_add(lst[0], lst[1])
        lst = lst[2:]
        lst.append(sum_points)
    if len(lst) == 2:
        return bn256.g1_add(lst[0], lst[1])
    else:
        return lst[0]

def recursive_g2_add(lst): # Addition of points on curve 2. Done in this way because '+' operator doesn't work for EC points (in this package)
    while len(lst) > 2:
        sum_points = bn256.g2_add(lst[0], lst[1])
        lst = lst[2:]
        lst.append(sum_points)
    if len(lst) == 2:
        return bn256.g2_add(lst[0], lst[1])
    else:
        return lst[0]

### Generate keys and RND polynomials (coefficients)
Calculate polynomial coefficients for each user's polynomial. They are of the form $\text{RND}(n) \leftarrow a_{ik}$, where $1 \leq i \leq N$ and $0 \leq k \leq t$.

In [5]:
P_poly = []
for i in range(N): # Loop over users
    P_poly.append(rnd_poly_coeffs(n)) # RND polynomial coefficients. dim = N x t
    
#print ("P_poly rnd   = ", P_poly)

### Prepare points and shares for exchanging

Compute points $A_{ik} = a_{ik} \cdot G_2$ and $s_{ij} = f_i(j) = \sum_{k=0}^t a_{ik} \cdot j^k \cdot G_2$ to be shared.

In [6]:
AsG     = [] # List of lists. Each list contains the coefficients of the users (random) polynomial. Dimensions: N x t
sij     = [] # sij's shared by each user. These are the s_ij = f_i(j). Dimensions: N x (N-1) 
sij_tmp = 0  # dummy variable
ss      = []

### Points for exchanging

for i in range(N): # Loop over users
    AsG.append(G2mult2(P_poly[i]))  # Terms of the form aik * G2. Dims: N x t
    js = [j for j in range(1,N+1)] # j runs over participant IDs
    sij = []
    for j in js:
        _, sij_tmp = poly(P_poly[i], j)
        sij.append(sij_tmp)    # List containing s_ij = f_i(j), 0 <= i <= N-1, 1 <= j <= N
    ss.append(sij)
    
#print ("ss  = ", ss)
#print ("AsG = ", AsG)

### Validation
At this stage the points $A_{ik} = a_{ik} \cdot G_2$ have been distributed. Furthermore, participant $P_i$ has sent $s_{ij}$ to participant $P_j$.  
Each participant can now calculate and check that $\sum_{k=0}^t A_{ik} \cdot j^k == s_{ij}$. This is what happens in the following two steps.

In [7]:
AsGkterms = []
AsGkterms_tmp = []

for i in range(N): # Loop over users
    js = [j for j in range(1,N+1) ] # j runs over participant IDs
    AsGkterms_tmp = []
    for j in js:        
        AsGk_t_tmp = [ AnyG2Pointmult(AsG[i][k], j**k) for k in range(t) ] # Terms of the form j^k * aik * G2
        AsGkterms_tmp.append(recursive_g2_add(AsGk_t_tmp)) # sum the terms above for each j
    AsGkterms.append(AsGkterms_tmp)

#print (" A x G x j^k = ", AsGkterms)

Check that $s_{ij} \cdot G_2 == \sum_k a_{ik} \cdot j^k \cdot G_2$:

In [None]:
for i in range(N):    
    rhs = AsGkterms[i] # RHS in Eqn. (1) in David's notes
    lhs = G2mult2(ss[i]) # LHS in Eqn. (1) in David's notes
    #unittest.TestCase.assertListEqual(lhs, rhs)
    print ("Sanity check #", i)
    #print ("RHS  = ", rhs)    
    #print ("LHS  = ", lhs)
    ok = (str(rhs) == str(lhs))
    assert ok

Sanity check # 0
Sanity check # 1
Sanity check # 2
Sanity check # 3
Sanity check # 4
Sanity check # 5
Sanity check # 6


### Calculating PubK and public verification Key

The global public key can now be determined as: $\text{PubK} = \sum_i A_{i,0}$. The index $i \in I_{\text{QUAL}} \subseteq \{1, \ldots, N \}$.  
Next, each participant, $i \in I_{\text{QUAL}}$ generates a secret key $x_j \leftarrow \sum_{i} s_{ij}$ of the shares that have been sent to her/him from $j \in I_{\text{QUAL}} \backslash \{i\}$.  
Lastly, each participant generates public validation keys: $VK_{j} = G_2 \cdot x_j$.

In [None]:
PKs = []
for i in range(N):
# We could/should play around with number of participants. Assuming not everyone being honest etc. 
# For now: QUAL = I = {1,.., N}
    PKs.append(AsG[i][0])  

PK = recursive_g2_add(PKs)
#print ("PK   = ", PK)

# Individual secret keys
xjs = []
for j in range(N): # Loop over participants
    xj = 0
    for i in range(N): # Sum over participants in the QUAL group in David's notes
        xj += ss[i][j]
    xjs.append(xj) # list of secret keys

#print ("xj   = ", xjs)

VKs  = []
sumss = []
testsum = []

for j in range(N):
    sumsij = 0
    for i in range(N): # Sum over participants in the QUAL group in David's notes
        sumsij += ss[i][j]
    testsum.append(sumsij)
    VKs.append(G2mult(sumsij))
    
#print ("VK   = ", VKs)
#print ("tst  = ", testsum)
assert (testsum == xjs)


### The participants create individual signatures

Each of the participants computes $\sigma_{r,i} = \mathrm{Sign}(r ~||~ \sigma_{r-1}, sk_{G,i})$. To begin with we create a message $\text{msg}_{\text{init}} = \text{"This is Fetch AI."}$, and subsequently hash it: $\text{SHA}_{512}(\text{msg}_{\text{init}}) \leftarrow \text{msg}_{\text{init}}$.  
Next, signatures are created: $\text{Sig}_{r, i} = \text{BLS_sign}(sk_i, \text{msg}_{\text{init}})$.  
  
Please note, that here we calculate $\sigma_{r,i} = \mathrm{Sign}(\sigma_{r-1}, sk_{G,i})$, i.e. we are not concatenating with the round ID.

In [None]:
# Use SHA512 for hashing
sha = hashlib.sha512()

msg_init = "This is Fetch AI.".encode("utf-8")
print("msg_init 1 : ", msg_init)

#sha.update(msg_init)

hash_msg_init = hashlib.sha512(b'This is Fetch AI.').hexdigest()

print("msg_init 2 : ", hash_msg_init)

sigs = [] # List containing signatures of participants

# Init. Create array of signatures of the form: (i, sig_{r,i})
for i in range(N):
    sigs.append( (i, bls_sign(xjs[i], hash_msg_init.encode("utf-8"))) )
    
#print ("Sigs = ", sigs)  

At this stage, two more steps are necessary.  
1) Use public $VK$'s to verify signatures  
2) Compute group signature, combining $t$ of the signature shares

### 1) Verify each signature share

In [None]:
# Try a test case: 
print("Sig[0][1]    = ", sigs[0][1])
print("VK_0         = ", VKs[0])

# Each signature share is verified using (public) verification keys:
for i in range(N):
    ok = bls_verify(VKs[i], hash_msg_init.encode("utf-8"), sigs[i][1])
    assert ok

### 2) Compute group signature, using $t$ signature shares

Taking advantage of the threshold property, the combined signature is calculated: $\sigma_{t} = \sum_{i \in \{1, \ldots, N \} } \lambda_i \cdot \sigma{i}$, where $\lambda_i = \prod_{j \in \{1, \ldots, N \}\backslash {i}} \frac{j}{j - i}$ are the Lagrange coefficients.  
Finally, the validity is verified by running $\text{Verify}_{BLS} (\text{PK}, \text{SHA}_{512}(\text{msg}_{\text{init}}) , \sigma_{t} ) $, where $\text{PK}$ is the global public key.  
NB! Please note that the signature share are initially compressed curve points, and in order to use scalar multiplication we uncompress the shares.

In [None]:
# Combine the signature shares. Use the lagrange_coeff(i,t) function above

sig_comb_lst = []
# TEST CASE 1
print("TEST CASE 1")
print("===========")
tst_lst = random.sample(range(N),t) # t users randomly selected
for i in tst_lst:   
    print("i = ", i, ", ID = ", i+1, ", sig = ", bn256.g1_uncompress( sigs[i][1] ))
#    print("1) lambda = ", mod_q(lagrange_coeff(i, tst_lst,t)))
#    print("2) lambda = ", lagrange_coeff(i, tst_lst,t))
    print("sig*lambda = ", AnyG1Pointmult( bn256.g1_uncompress( sigs[i][1] ), mod_q(lagrange_coeff(i, tst_lst,t)) ) )
    sig_comb_lst.append( AnyG1Pointmult( bn256.g1_uncompress( sigs[i][1] ), mod_q(lagrange_coeff(i, tst_lst,t)) ) )

sig_comb = recursive_g1_add(sig_comb_lst)
print("sigma_t = ", sig_comb)
ok = bls_verify(PK, hash_msg_init.encode("utf-8"), bn256.g1_compress(sig_comb) )
assert ok


sig_comb_lst = []
# TEST CASE 2
print("TEST CASE 2")
print("===========")
tst_lst =  random.sample(range(N),t) # t users randomly selected
for i in tst_lst:    
    print("i = ", i, ", ID = ", i+1, ", sig = ", bn256.g1_uncompress( sigs[i][1] ))
#    print("1) lambda = ", mod_q(lagrange_coeff(i, tst_lst,t)))
#    print("2) lambda = ", lagrange_coeff(i, tst_lst,t))
    print("sig*lambda = ", AnyG1Pointmult( bn256.g1_uncompress( sigs[i][1] ), mod_q(lagrange_coeff(i,tst_lst,t)) ) )
    sig_comb_lst.append( AnyG1Pointmult( bn256.g1_uncompress( sigs[i][1] ), mod_q(lagrange_coeff(i, tst_lst,t)) ) )

sig_comb = recursive_g1_add(sig_comb_lst)
print("sigma_t = ", sig_comb)
ok = bls_verify(PK, hash_msg_init.encode("utf-8"), bn256.g1_compress(sig_comb) )
assert ok


sig_comb_lst = []
# TEST CASE 3
print("TEST CASE 3")
print("===========")
tst_lst =  random.sample(range(N),t) # t users randomly selected
for i in tst_lst:    
    print("i = ", i, ", ID = ", i+1, ", sig = ", bn256.g1_uncompress( sigs[i][1] ))
#    print("1) lambda = ", mod_q(lagrange_coeff(i, tst_lst,t)))
#    print("2) lambda = ", lagrange_coeff(i, tst_lst,t))
    print("sig*lambda = ", AnyG1Pointmult( bn256.g1_uncompress( sigs[i][1] ), mod_q(lagrange_coeff(i, tst_lst,t)) ) )
    sig_comb_lst.append( AnyG1Pointmult( bn256.g1_uncompress( sigs[i][1] ), mod_q(lagrange_coeff(i, tst_lst,t)) ) )

sig_comb = recursive_g1_add(sig_comb_lst)
print("sigma_t = ", sig_comb)
ok = bls_verify(PK, hash_msg_init.encode("utf-8"), bn256.g1_compress(sig_comb) )
assert ok