In [41]:
import numpy as np

INT_BITS = 8 # number of triple value for an integer to encode (slots/trunk)
KEY_SIZE = 10 # number of integers in the key (trunk/key)
g = 7 # generator
p = (2**(INT_BITS * 2)) + 1 # prime number modulo
N = INT_BITS * KEY_SIZE # number of bits in the key (slots/key)
NUM_CLIENTS = 7 # number of clients
BIT_NUM_CLIENTS = int(np.ceil(np.log2(NUM_CLIENTS))) # number of bits needed to represent the number of clients
MAX_NUM_CLIENTS = 2 ** BIT_NUM_CLIENTS # maximum number of clients

local_w = np.random.randint(-1, 2, (NUM_CLIENTS, N)).astype(int)
local_w_shared = np.zeros((NUM_CLIENTS, N)).astype(int)
share_key = np.zeros((MAX_NUM_CLIENTS, KEY_SIZE)).astype(int)
secret_key = np.ones((MAX_NUM_CLIENTS, KEY_SIZE)).astype(int)
public_key = np.zeros((MAX_NUM_CLIENTS, KEY_SIZE)).astype(int)
public_key_cache = np.zeros((MAX_NUM_CLIENTS, KEY_SIZE)).astype(int)

In [28]:
def element_wise_pow(a, b: np.ndarray, p: int) -> np.ndarray:
    if type(a) == np.ndarray:
        return np.array([pow(int(a[x]), int(b[x]), p) for x in range(len(a))])
    return np.array([pow(int(a), int(b[x]), p) for x in range(KEY_SIZE)])

In [39]:
def multi_DH(n, round=0, prefix=0, current_total_round=0, total_rounds=0, verbose=False):
    '''
    This function computes the multi-party Diffie-Hellman key exchange protocol.
    n: number of parties (power of 2)
    round: the current round in current subset n
    prefix: the prefix of the current round
    total_rounds: the total number of rounds 
    '''
    # Exit condition (n == 1), compute the shared key
    if n == 1:
        if current_total_round == total_rounds:
            share_key[prefix] = element_wise_pow(public_key[prefix ^ 1], secret_key[prefix], p)
            if verbose:
                print(f"Round {total_rounds}: {prefix} <- {prefix ^ 1} Shared")
        return
    
    # round 0 in subset n, compute the public keys from generator g or from previous origin's public keys cache
    if round == 0:
        if total_rounds == 0:
            if current_total_round == total_rounds:
                public_key[prefix] = element_wise_pow(g, secret_key[prefix], p)
                public_key[prefix + n // 2] = element_wise_pow(g, secret_key[prefix + n // 2], p)
                if verbose:
                    print(f"Round {total_rounds}: {prefix}, {prefix + n // 2} <- g")
                return
        else:
            if current_total_round == total_rounds:
                public_key[prefix] = element_wise_pow(public_key_cache[prefix ^ n + (n - 1)], secret_key[prefix], p)
                public_key[prefix + n // 2] = element_wise_pow(public_key_cache[prefix ^ n + (n - 1)], secret_key[prefix + n // 2], p)
                if verbose:
                    print(f"Round {total_rounds}: {prefix}, {prefix + n // 2} <- {prefix ^ n + (n - 1)}")
                return
        multi_DH(n, round + 1, prefix, current_total_round, total_rounds + 1, verbose)
        return
        
    # round < n // 2, compute the public keys from previous client in the subset's public keys
    elif round < n // 2:
        if current_total_round == total_rounds:
            public_key[prefix + round] = element_wise_pow(public_key[prefix + round - 1], secret_key[prefix + round], p)
            public_key[prefix + round + n // 2] = element_wise_pow(public_key[prefix + round + n // 2 - 1], secret_key[prefix + round + n // 2], p)
            if verbose:
                    print(f"Round {total_rounds}: {prefix + round}, {prefix + round + n // 2} <- {prefix + round - 1}, {prefix + round + n // 2 - 1}")
            return
        multi_DH(n, round + 1, prefix, current_total_round, total_rounds + 1, verbose)
        return
    
    # round == n // 2, cache the public keys for exchanging between 2 subsets of n/2, recurse
    elif round == n // 2:
        if current_total_round == total_rounds:
            public_key_cache[prefix + n // 2 - 1] = public_key[prefix + n // 2 - 1]
            public_key_cache[prefix + n - 1] = public_key[prefix + n - 1] 

    multi_DH(n // 2, round - n // 2, prefix, current_total_round, total_rounds, verbose)
    multi_DH(n // 2, round - n // 2, prefix + n // 2, current_total_round, total_rounds, verbose)

In [30]:
def binary_to_decimal(vector):
    decimal = 0
    for bit in vector:
        decimal = (decimal << 1) | bit
    return decimal

def decimal_to_binary(decimal, num_bits=INT_BITS * 2):
    return [int(x) for x in bin(decimal)[2:].zfill(num_bits)]

In [31]:
def encode(x: np.ndarray):
    x = x.copy()  # Make a copy of the array
    l = len(x) // INT_BITS
    x = np.reshape(x, (l, INT_BITS))
    x[:, ::2][x[:, ::2] == 0] = -2
    x += 2
    x = np.stack([(x // 2 % 2).flatten(), (x % 2).flatten()], axis=0)
    x = np.reshape(x.T, (l, -1))
    x = np.apply_along_axis(binary_to_decimal, axis=1, arr=x)
    return x

In [32]:
def decode(x: np.ndarray):
    x = x.copy()  # Make a copy of the array
    x = np.array([decimal_to_binary(i) for i in x])
    x = np.reshape(x, (-1))
    x = x[::2] * 2 + x[1::2]
    x -= 2 
    x[x == -2] = 0
    return x

In [42]:
print(local_w)

# Initialize the secret keys
for i in range(NUM_CLIENTS):
    secret_key[i] = encode(local_w[i])

print(secret_key)

# Run the multi-party Diffie-Hellman key exchange protocol
for i in range(MAX_NUM_CLIENTS):
    multi_DH(MAX_NUM_CLIENTS, current_total_round=i, verbose=True)
    print(public_key)

print(share_key)

# Reconstruction
for i in range(NUM_CLIENTS):
    local_w_shared[i] = decode(share_key[i])
    
print(local_w_shared)

[[-1  0  1  0  1 -1 -1 -1 -1 -1  0  0  0  1  1  1  0  0  1  0 -1 -1 -1 -1
   0 -1  1  1 -1  1  0  0  0  1  0  1 -1 -1 -1  0  0  0 -1 -1 -1  1 -1 -1
   1  1  0  1  1  1 -1  1 -1 -1 -1 -1  1  0  0  0  0  0 -1  1  0  1  0  1
   1  0  1  0 -1  0 -1  1]
 [ 0  0  0  0  0  0 -1  1 -1 -1 -1  1 -1 -1 -1  1  0  0  0  1  1  0  1 -1
  -1  0  0  0  1  1  0  0  1 -1 -1 -1 -1  0  0 -1 -1  1 -1 -1 -1  1 -1 -1
   1  1 -1  0  1 -1 -1  0 -1 -1  0  1 -1  1  0  0  0  0 -1  1 -1  0  1 -1
   1  1 -1  1  1  0 -1 -1]
 [-1 -1  1 -1 -1  1  0  0 -1  0 -1  0  0  0  1  0 -1  0  0  1  0 -1 -1  1
   1  0  0 -1  1 -1  0 -1 -1 -1  0 -1  1  1  1 -1  1  0  1  0 -1 -1  1 -1
   0  1 -1  0  1  1  1  0 -1  0  0  0 -1  1  0 -1  1  1 -1  1  0 -1  0 -1
  -1 -1  1  1  1  0  0 -1]
 [ 1 -1 -1  0  0  0  0  0  0  0 -1 -1 -1  0 -1  0 -1 -1  1  1 -1  1  1 -1
   0  0  0  0  0 -1  0 -1  1 -1  0 -1  0  0  1  0  0  0  1  0  1 -1  0  1
   1  1  0 -1 -1  1 -1 -1  1 -1 -1 -1  0  0  0 -1  0  1 -1  0 -1 -1  0 -1
   1 -1 -1  1  1  0  0  1]
 [-1