Quantum Merkle Trees - Arya Grayeli and John Dunbar

We took a classical merkle tree and created a quantum version of it. A Merkle tree is "a hash tree allows efficient and secure verification of the contents of a large data structure" (https://en.wikipedia.org/wiki/Merkle_tree).

We used the paper https://arxiv.org/pdf/2112.14317.pdf as a reference for implementing the Quantum Merkle Tree, but there is no open source code that has done this before.

Quantum Merkle Trees can allow for faster commitment verification, as well as open up the possibilities for zero-knowledge verification using quantum computers.

In [1]:
# Import libraries

import numpy as np
from numpy.linalg import qr

# Importing standard Qiskit libraries
from qiskit import QuantumCircuit, transpile, Aer, IBMQ, QuantumRegister, ClassicalRegister
from qiskit.tools.jupyter import *
from qiskit.visualization import *
from ibm_quantum_widgets import *
from qiskit.providers.aer import QasmSimulator
from qiskit.extensions import UnitaryGate

# Loading your IBM Quantum account(s)
provider = IBMQ.load_account()



In [2]:
# Defining quantum functions

def qr_haar(N): # N = 3*b
    """Generates the matrix for the Quantum Haar Random Oracle Model (QHROM)"""
    
    A, B = np.random.normal(size=(N, N)), np.random.normal(size=(N, N))
    Z = A + 1j * B

    Q, R = qr(Z)
    Lambda = np.diag([R[i, i] / np.abs(R[i, i]) for i in range(N)])
    
    return np.dot(Q, Lambda)

def pool_qubits(counts):
    """Pool results from measure after a job is run into a set of output qubits"""
    bits = max(counts, key=counts.get) # get config with max occurences
    out = []
    for i in range(len(bits)-1,-1,-1):
        if bits[i] == '0':
            out.append([1,0])
        else:
            out.append([0,1])
    return out
    
# initialize backend
backend = Aer.get_backend('statevector_simulator')

With the libraries imported and some basic utility functions set up for the quantum circuits, we can now define the two entities that constitute the Merkle Tree.

Prover: The person with access and control over the entire database. In a classical Merkle Tree, it would store the merged hashes, but here it is storing the merged qubits using QHROM.

Verifier: The person who wants to make sure that Prover's database has not changed, is valid, and would like to make queries to retrieve data. It does this by requesting data from the prover, unwinding the merge operations of the Prover using the inverse QHROM, and getting an output that will tell it if the data has been modified and the value its querying.

See the paper for reasoning behind the structure of this implementation

In [3]:
class Prover:
    """The entity storing all the data"""
    def __init__(self, b, l):
        self.b = b
        self.l = l
        self.states = [None for j in range(2*l)]
        self.G = qr_haar(2**(3*b))
        x = UnitaryGate(self.G)
        
    def buildMerkleTree(self, sigma):
        """Build the merkle tree"""
        assert len(sigma) == self.l

        for j in range(l):
            assert len(sigma[j]) == self.b
            self.states[j] = [[1,0] for i in range(self.b)]
            self.states[j+l] = sigma[j]

        for u in range(self.l-1,0,-1):
            self.haar_oracle(u) # 2u, 2u+1, u

        return self.get_root()
    
    def get_R(self, i): # R is dictionary of node : state in prover
        """Get the set of ancestors of node i and their state values"""
        assert i > 0 and i < 2*self.l
        
        R = dict()
        while(i > 1):
            i//=2
            R[2*i] = self.states[2*i]
            R[2*i+1] = self.states[2*i+1]
        return R
    
    def get_root(self):
        return self.states[1]
    
    def haar_oracle(self, u):
        """Applies QHROM to states of 2u, 2u+1 and u"""
        qr = QuantumRegister(3*self.b)
        circuit = QuantumCircuit(qr)
        circuit.append(UnitaryGate(self.G), [i for i in range(3*self.b)])
        
        cr = ClassicalRegister(3*self.b)
        circuit.add_register(cr)
        
        for i in range(self.b):
            circuit.initialize(self.states[2*u][i], i)
            circuit.initialize(self.states[2*u+1][i], i+self.b)
            circuit.initialize(self.states[u][i], i+2*self.b)
            
        circuit.measure(qr, cr)
        
        out = pool_qubits(backend.run(circuit).result().get_counts())
        
        for i in range(self.b):
            self.states[2*u][i] = out[i]
            self.states[2*u+1][i] = out[i+self.b]
            self.states[u][i] = out[i+2*self.b]
        
    
class Verifier:
    """The entity communicating with the prover to query data and make sure its consistent"""
    def __init__(self, p):
        self.p = p # Prover
        self.b = p.b
        self.l = p.l
        self.root = None
        self.states = [None for j in range(2*l)]
        
    def build(self, inputs):
        self.root = self.p.buildMerkleTree(inputs)
    
    def getBlockI(self, i):
        """Get the value of input @ index i, also checks if tree has been tampered with"""
        if self.root != p.get_root():
            print("Merkle tree has changed, rebuild")
            return None
        
        assert i >= 0 and i < self.l

        i+=l
        R = self.p.get_R(i)
        R[1] = self.root
        for j in range(2*l):
            if j in R:
                self.states[j] = R[j]
            else:
                self.states[j] = [[1,0] for i in range(self.b)]

        for u in range(1,l):
            if u not in R:
                continue
            self.inverse_haar_oracle(u) # 2u, 2u+1, u
            
        u = i//2
        while u > 0:
            z = self.states[u]
            for x in z:
                if x != [1,0]:
                    return None
            u//=2

        return self.states[i]
    
    def inverse_haar_oracle(self, u):
        """Applies Inverse QHROM to states of 2u, 2u+1 and u to unmerge tree"""
        qr = QuantumRegister(3*self.b)
        circuit = QuantumCircuit(qr)
        circuit.append(UnitaryGate(self.p.G), [i for i in range(3*self.b)])
        circuit = circuit.inverse()
        
        cr = ClassicalRegister(3*self.b)
        circuit.add_register(cr)
        
        for i in range(self.b):
            circuit.initialize(self.states[2*u][i], i)
            circuit.initialize(self.states[2*u+1][i], i+self.b)
            circuit.initialize(self.states[u][i], i+2*self.b)
            
        circuit.measure(qr, cr)
        
        out = pool_qubits(backend.run(circuit).result().get_counts())
        
        for i in range(self.b):
            self.states[2*u][i] = out[i]
            self.states[2*u+1][i] = out[i+self.b]
            self.states[u][i] = out[i+2*self.b]

Now that the Prover and Verifier are set up, we can test to see if this actually works!

We set b and l to small values to speed the process up for testing and give it an arbitrary input with unique qubit combinations. The inputs don't need to start out at qubits though, they could easily be text that you convert to qubits, or other forms of data that you want to store.

These quick tests check 3 things:

    1. When we query for an object, will it return the right one
    2. When the database slightly changes, can the Verifier check that the outputs from the Prover are incorrect
    3. If we completely change our inputs and rebuild the tree, can the Verifier check that the root has been modified
    
As you can see, all the tests pass. This suggests that the Quantum Merkle Tree really works!

Feel free to play around with inputs, just make sure they are of valid format. Also be aware that because it is a tree, increasing some values (b and l) will make the runtime increase dramatically.

In [4]:
b = 2

inputs = [[[1,0],[1,0]], [[1,0],[0,1]], [[0,1],[1,0]], [[0,1],[0,1]]]

# ideally l is a power of 2 to get a perfectly balanced tree, so add dummy qubits
diff = 1<<(len(inputs)).bit_length() # get smallest power of 2 >= len(inputs)
inputs+=[[[1,0] for i in range(b)] for j in range(len(inputs)-diff)]

l = len(inputs)
p = Prover(b,l)
v = Verifier(p)
print("Models Setup")

v.build(inputs)
print("Merkle TreeBuilt")

for i,val in enumerate(inputs):
    assert v.getBlockI(i) == val
    if i + 1 == l: continue
    assert v.getBlockI(i+1) != val
print("Checked querying properly")
    
temp = p.states[1][0][0]
p.states[1][0][0] = p.states[1][0][1]
p.states[1][0][1] = temp
assert v.getBlockI(0) == None
print("Checked ability to catch minor changes in the database")

inputs[0][0] = [0,1]
p.buildMerkleTree(inputs)
assert v.getBlockI(0) == None
print("Checked ability to catch entire database being reloaded")

Models Setup
Merkle TreeBuilt
Checked querying properly
Checked ability to catch minor changes in the database
Merkle tree has changed, rebuild
Checked ability to catch entire database being reloaded
