In [1]:
import numpy as np
import math
from hashlib import sha256
from Crypto.Cipher import AES
from Crypto import Random
from Crypto.Util import Counter
import secrets

In [2]:
# This mapping is used to convert a command to an operation
opMap = {
    "add": lambda x, y: x + y,
    "subtract": lambda x, y: x - y,
    "divide": lambda x, y: np.divide(x, y),
    "mul": lambda x, y: x * y,
    "matmul": lambda x, y: np.dot(x, y)
}

# This class represents a node in the DAG
class Node:
    def __init__(self, symbol, data, left=None, right=None, elementCount = 1):
        self.symbol = symbol
        self.data = data
        self.dataList = list(map(lambda x: int(x), data.reshape(-1, )))
        self.size = len(self.dataList) * elementCount
        self.left = left
        self.right = right
        self.elementCount = elementCount # This variable is used to specify the number of operations one must perform to calculate an element in this node

# Convert our inputs into nodes for the DAG
def processInputs(input_data, dag):
    for k, v in input_data.items():
        dag[k] = Node(k, v)

# Solve the DAG. This involves creating nodes and connecting them, and then solving them.
def solveDAG(command, dag, outputs):
    if command.startswith("output"): 
        res, sym = command.split(" ")
        outputs.add(sym)
    else:
        res, op, x, y = command.split(" ")
        dag[res] = Node(res, opMap[op](dag[x].data, dag[y].data), dag[x], dag[y], dag[x].data.shape[1] * 2 - 1 if op == "matmul" else 1)

# operationIdx is the number of operations we've done into the round so far
# roundNum is the number of the round we're currently on
# layerSize is the number of operations we can perform in any given round at most
# rounds is the dictionary that contains the data for each round
# outputs is the set of symbols that are outputs
# root is the root node of the DAG that we are currently solving
def calculateRoundParams(root, roundNum, operationsIdx, rounds, layerSize, outputs):
    if root == None: return [roundNum, operationsIdx]
    if root.left == None and root.right == None: return [roundNum, operationsIdx]

    # we first process the left and right nodes
    roundNum, operationsIdx = calculateRoundParams(root.left, roundNum, operationsIdx, rounds, layerSize, outputs)
    roundNum, operationsIdx = calculateRoundParams(root.right, roundNum, operationsIdx, rounds, layerSize, outputs)

    rootIdx = 0
    if not roundNum in rounds: rounds[roundNum] = {}

    # This loop is used to create the data for each round. We keep looping whilst we are processing this element
    while rootIdx < len(root.dataList):
        newRootIdx = min(math.ceil((layerSize - operationsIdx) / root.elementCount) + rootIdx, len(root.dataList)) # either this is more than len(root.dataList) or it is less than len(root.dataList)
        rounds[roundNum][root.symbol] = root.dataList[rootIdx:newRootIdx].copy()
        operationsIdx += (newRootIdx - rootIdx) * root.elementCount # keeping count of the number of operations we've done
        rootIdx = newRootIdx

        # We remove symbols from the round if they are not the "round roots"
        if root.left.symbol in rounds[roundNum] and not root.left.symbol in outputs:
            del rounds[roundNum][root.left.symbol]
        if root.right.symbol in rounds[roundNum] and not root.right.symbol in outputs:
            del rounds[roundNum][root.right.symbol] 

        if operationsIdx >= layerSize:
            operationsIdx = 0
            roundNum += 1
            rounds[roundNum] = {}
    return [roundNum, operationsIdx]

In [3]:
# This is an example run demonstrating how the protocol works
dag = {}
outputs = set()
input_data = {
    'input_var_1': np.array([[1, 2], [3, 4]]),
    'input_var_2': np.array([[5, 6], [7, 8]]),
    'identity': np.array([[1,0], [0,1]])
}
processInputs(input_data, dag)
test_operations = [
    "operation_1 add input_var_1 input_var_2",
    "operation_2 subtract input_var_1 operation_1",
    "operation_3 mul operation_2 input_var_1",
    "operation_4 matmul operation_3 identity",
    #"output operation_1"
]
for i in test_operations:
    solveDAG(i, dag, outputs)
rounds = {}
calculateRoundParams(dag['operation_4'], 1, 0, rounds, 4, outputs)
print(rounds)

{1: {'operation_1': [6, 8, 10, 12]}, 2: {'operation_2': [-5, -6, -7, -8]}, 3: {'operation_3': [-5, -12, -21, -32]}, 4: {'operation_4': [-5, -12]}, 5: {'operation_4': [-21, -32]}, 6: {}}


In [4]:
# Function to convert integers to 32-byte representation
def int_to_32bytes(i):
    # Handle negative integers using two's complement
    if i < 0:
        i = (1 << 256) + i  # Convert to two's complement 256-bit integer
    return i.to_bytes(32, byteorder='big')    
# Function to convert 32-byte representation back to integer
def bytes_to_int(b):
    i = int.from_bytes(b, byteorder='big')
    # Check if the integer is negative by checking the most significant bit (MSB)
    if i >= (1 << 255):  # If the MSB is set
        i -= (1 << 256)  # Convert back to a negative number using two's complement
    return i

# Function to encrypt data using AES-256 in CTR mode
def encrypt_round_data(data, key, ctr):
    # Derive a 256-bit AES key from the given key
    private_key = sha256(key.encode('utf-8')).digest()
    # Initialize AES cipher in CTR mode
    cipher = AES.new(private_key, AES.MODE_CTR, counter=ctr)
    # Encrypt the data
    encrypted_data = cipher.encrypt(data)
    return encrypted_data

# Example of how to process rounds and encrypt them
def encrypt_rounds(rounds, key, output):
    encrypted_rounds = {}
    for round_num, round_data in rounds.items():
        ctr = Counter.new(128)
        encrypted_rounds[round_num] = {}
        for symbol in sorted(round_data.keys()):
            encrypted_rounds[round_num][symbol] = encrypt_round_data(b''.join(int_to_32bytes(i) for i in round_data[symbol]), key, ctr)
    return encrypted_rounds

# Function to decrypt data using AES-256 in CTR mode
def decrypt_round_data(encrypted_data, key, ctr):
    # Derive a 256-bit AES key from the given key
    private_key = sha256(key.encode('utf-8')).digest()
    # Initialize AES cipher in CTR mode with the same counter used during encryption
    cipher = AES.new(private_key, AES.MODE_CTR, counter=ctr)
    # Decrypt the data
    decrypted_data = cipher.decrypt(encrypted_data)
    return decrypted_data

# Example of how to process decrypted rounds
def decrypt_rounds(encrypted_rounds, key):
    decrypted_rounds = {}
    for round_num, encrypted_data in encrypted_rounds.items():
        ctr = Counter.new(128)
        decrypted_rounds[round_num] = {}
        for symbol in sorted(encrypted_rounds[round_num].keys()):
            decrypted_data = decrypt_round_data(encrypted_data[symbol], key, ctr)
            decrypted_rounds[round_num][symbol] = [bytes_to_int(decrypted_data[i:i+32]) for i in range(0, len(decrypted_data), 32)]
    return decrypted_rounds

In [5]:
# Generate a random 256-bit number as our key
random_256bit_number = secrets.randbits(256)

# Convert to a 32-byte (256-bit) hexadecimal string if needed
random_256bit_hex = hex(random_256bit_number)[2:]  # Remove the '0x' prefix
# Ensure it's 64 hex characters (32 bytes) long by padding with zeros if necessary
random_256bit_hex = random_256bit_hex.zfill(64)
print(f"Random 256-bit number that is our key: {random_256bit_hex}")
key = random_256bit_hex
encrypted_rounds = encrypt_rounds(rounds, key, outputs)
print(encrypted_rounds)

Random 256-bit number that is our key: a234e5a7bfe43098b375181a0bd2eee989848807a4b838c27e8a904340d22923
{1: {'operation_1': b'\xf3\xb9Nl[\xe4\x98q\xe1\t\xf7\xec\xe3\xae\xf6\xb5\x96\x0bNg\xe1\xd3\xcc\xe1\x05g\xfc\xf3\x10\x97\xb10\xcd\x8e\xd9S\xedJ\xcd\x89\xb1oE\xd3\x85\xbc\xf7\x11\xb7\xfc\xc9,\xe2\x9eG\xc0\xb9\x8f[\xe4\xf3\x96\x9f\xd9q\x8a\x0b\xe0\x08\x87\x98]\x0e\xaep\xc1\xcf6\xfc\xad{\xce\xd0\x8e\xf1\xb5\x91`%\xbcl+\xaf\xd8hw\xee,\x93p\xc5\xa7\xc9\xca9\xafB\xb1)\xecO\r\x82IJ\xc7x({\xfdI\xbe@\xac\xdbq\xd6\x94'}, 2: {'operation_2': b"\x0cF\xb1\x93\xa4\x1bg\x8e\x1e\xf6\x08\x13\x1cQ\tJi\xf4\xb1\x98\x1e,3\x1e\xfa\x98\x03\x0c\xefhN\xcd2q&\xac\x12\xb52vN\x90\xba,zC\x08\xeeH\x036\xd3\x1da\xb8?Fp\xa4\x1b\x0ci`+\x8eu\xf4\x1f\xf7xg\xa2\xf1Q\x8f>0\xc9\x03R\x841/q\x0eJn\x9f\xdaC\x93\xd4P'\x97\x84\x11\xd3l\x8f:X65\xc6P\xbdN\xd6\x13\xb0\xf2}\xb6\xb58\x87\xd7\x84\x02\xb6A\xbfS$\x8e)`"}, 3: {'operation_3': b"\x0cF\xb1\x93\xa4\x1bg\x8e\x1e\xf6\x08\x13\x1cQ\tJi\xf4\xb1\x98\x1e,3\x1e\xfa\x98\x03\x0c\xefh

In [6]:
# Example usage
# Assuming `encrypted_rounds` is the result of the encryption process
decrypted_rounds = decrypt_rounds(encrypted_rounds, key)

# Print the decrypted round data
for round_num, decrypted_data in decrypted_rounds.items():
    print(f"Round {round_num}: {decrypted_data}")

# Print the original round data
for round_num, round_data in rounds.items():
    print(f"Round {round_num}: {round_data}")

Round 1: {'operation_1': [6, 8, 10, 12]}
Round 2: {'operation_2': [-5, -6, -7, -8]}
Round 3: {'operation_3': [-5, -12, -21, -32]}
Round 4: {'operation_4': [-5, -12]}
Round 5: {'operation_4': [-21, -32]}
Round 6: {}
Round 1: {'operation_1': [6, 8, 10, 12]}
Round 2: {'operation_2': [-5, -6, -7, -8]}
Round 3: {'operation_3': [-5, -12, -21, -32]}
Round 4: {'operation_4': [-5, -12]}
Round 5: {'operation_4': [-21, -32]}
Round 6: {}
