In [5]:
import numpy as np
from numpy.linalg import qr

# Importing standard Qiskit libraries
from qiskit import QuantumCircuit, transpile, Aer, IBMQ, QuantumRegister, ClassicalRegister
from qiskit.tools.jupyter import *
from qiskit.visualization import *
from ibm_quantum_widgets import *
from qiskit.providers.aer import QasmSimulator
from qiskit.extensions import UnitaryGate

# Loading your IBM Quantum account(s)
provider = IBMQ.load_account()



In [9]:
def qr_haar(N): # N = 3*b
    A, B = np.random.normal(size=(N, N)), np.random.normal(size=(N, N))
    Z = A + 1j * B

    Q, R = qr(Z)
    Lambda = np.diag([R[i, i] / np.abs(R[i, i]) for i in range(N)])
    
    return np.dot(Q, Lambda)

def random_qubit():
    if np.random.random() < 0.5:
        return [1,0]
    return [0,1]

def pool_qubits(counts):
    bits = max(counts, key=counts.get)
    out = []
    for i in range(len(bits)-1,-1,-1):
        if bits[i] == '0':
            out.append([1,0])
        else:
            out.append([0,1])
    return out
    

backend = Aer.get_backend('statevector_simulator')

In [20]:
class Prover:
    def __init__(self, b, l):
        self.b = b
        self.l = l
        self.states = [None for j in range(2*l)]
        self.G = qr_haar(2**(3*b))
        x = UnitaryGate(self.G)
        
    def buildMerkleTree(self, sigma):
        assert len(sigma) == self.l

        for j in range(l):
            assert len(sigma[j]) == self.b
            self.states[j] = [[1,0] for i in range(self.b)]
            self.states[j+l] = sigma[j]

        for u in range(self.l-1,0,-1):
            self.haar_oracle(u) # 2u, 2u+1, u

        return self.get_root()
    
    def get_R(self, i): # R is dictionary of node : state in prover
        assert i > 0 and i < 2*l
        
        R = dict()
        while(i > 1):
            i//=2
            R[2*i] = self.states[2*i]
            R[2*i+1] = self.states[2*i+1]
        return R
    
    def get_root(self):
        return self.states[1]
    
    def haar_oracle(self, u):
        qr = QuantumRegister(3*self.b)
        circuit = QuantumCircuit(qr)
        circuit.append(UnitaryGate(self.G), [i for i in range(3*self.b)])
        
        cr = ClassicalRegister(3*self.b)
        circuit.add_register(cr)
        
        for i in range(self.b):
            circuit.initialize(self.states[2*u][i], i)
            circuit.initialize(self.states[2*u+1][i], i+self.b)
            circuit.initialize(self.states[u][i], i+2*self.b)
            
        circuit.measure(qr, cr)
        
        out = pool_qubits(backend.run(circuit).result().get_counts())
        
        for i in range(self.b):
            self.states[2*u][i] = out[i]
            self.states[2*u+1][i] = out[i+self.b]
            self.states[u][i] = out[i+2*self.b]
        
    
class Verifier:
    def __init__(self, p):
        self.p = p
        self.b = p.b
        self.l = p.l
        self.root = None
        self.states = [None for j in range(2*l)]
        
    def build(self, inputs):
        self.root = self.p.buildMerkleTree(inputs)
    
    def getBlockI(self, i):
        if self.root != p.get_root():
            print("Merkle tree has changed, rebuild")
            return None
        
        assert i >= 0 and i < l

        i+=l
        R = self.p.get_R(i)
        R[1] = self.root
        for j in range(2*l):
            if j in R:
                self.states[j] = R[j]
            else:
                self.states[j] = [[1,0] for i in range(self.b)]

        for u in range(1,l):
            if u not in R:
                continue
            self.inverse_haar_oracle(u) # 2u, 2u+1, u
            
        u = i//2
        while i > 0:
            z = self.states[u]
            for x in z:
                if x != [1,0]:
                    return None
            i//=2

        return self.states[i]
    
    def inverse_haar_oracle(self, u):
        qr = QuantumRegister(3*self.b)
        circuit = QuantumCircuit(qr)
        circuit.append(UnitaryGate(self.p.G), [i for i in range(3*self.b)])
        circuit = circuit.inverse()
        
        cr = ClassicalRegister(3*self.b)
        circuit.add_register(cr)
        
        for i in range(self.b):
            circuit.initialize(self.states[2*u][i], i)
            circuit.initialize(self.states[2*u+1][i], i+self.b)
            circuit.initialize(self.states[u][i], i+2*self.b)
            
        circuit.measure(qr, cr)
        
        out = pool_qubits(backend.run(circuit).result().get_counts())
        
        for i in range(self.b):
            self.states[2*u][i] = out[i]
            self.states[2*u+1][i] = out[i+self.b]
            self.states[u][i] = out[i+2*self.b]

In [23]:
b = 2
l = 4

# random_inputs = [[random_qubit() for i in range(b)] for j in range(l)]
inputs = [[[1,0],[1,0]], [[1,0],[0,1]], [[0,1],[1,0]], [[0,1],[0,1]]]

diff = 1<<(len(inputs)-1).bit_length() # get smallest power of 2 >= len(inputs)
inputs+=[[[1,0] for i in range(b)] for j in range(diff)]
print("data")
l = len(inputs)
p = Prover(b,l)
print("Prover")
v = Verifier(p)
print("verify")

print("Setup")
v.build(inputs)
print("Built")

for i,val in enumerate(inputs):
    assert v.getBlockI(i) == val
    assert v.getBlockI(i+1) != val
    
print("Checked 1")
    
p.states[l][0] = 1 - p.states[l][0]
assert v.getBlockI(0) == None

inputs[0][0] = 1 - inputs[0][0]
p.buildMerkleTree(inputs)
assert v.getBlockI(0) == None

data
Prover
verify
Setup
Built
[[1, 0], [1, 0]] [[1, 0], [1, 0]] [[1, 0], [1, 0]]


AssertionError: 