In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys; sys.path.append('..')
import random, math, os
import pyzx as zx
from fractions import Fraction
import numpy as np
%config InlineBackend.figure_format = 'svg'
zx.quantomatic.quantomatic_location = r'C:\Users\John\Desktop\Quantomatic.jar'
zx.tikz.tikzit_location = r'C:\Users\John\Documents\tikzit\tikzit.exe'

In [3]:
def generate_clifford_circuit(qubits, depth, p_cnot=0.3, p_t=0):
    p_s = 0.5*(1.0-p_cnot-p_t)
    p_had = 0.5*(1.0-p_cnot-p_t)
    c = zx.Circuit(qubits)
    for _ in range(depth):
        r = random.random()
        if r > 1-p_had:
            c.add_gate("HAD",random.randrange(qubits))
        elif r > 1-p_had-p_s:
            c.add_gate("S",random.randrange(qubits))
        elif r > 1-p_had-p_s-p_t:
            c.add_gate("T",random.randrange(qubits))
        else:
            tgt = random.randrange(qubits)
            while True:
                ctrl = random.randrange(qubits)
                if ctrl!=tgt: break
            c.add_gate("CNOT",tgt,ctrl)
    return c

In [6]:
seed = 1342
random.seed(seed)
reps = 1
qubits = 20
depth = 800

method1 = 0
method2 = 0
method3 = 0

for i in range(1,reps+1):
    if i%10 == 0: print(i, end='.')
    #c = generate_clifford_circuit(qubits, depth, p_cnot=0.3, p_t=0.2)
    c = zx.Circuit.load(r'..\circuits\Fast\gf2^4_mult_before')
    g = c.to_graph()
    zx.full_reduce(g)
    g.normalise()
    g2 = g.copy()
    #c2 = zx.extract.streaming_extract(g2,quiet=True).to_basic_gates()
    #c2 = zx.optimize.basic_optimization(c2.to_basic_gates()).to_basic_gates()
    #g2 = g.copy()
    #c3 = zx.extract.modified_extract(g2).to_basic_gates()
    #c3 = zx.optimize.basic_optimization(c3.to_basic_gates()).to_basic_gates()
    g2 = g.copy()
    c2 = extract_better(g2,optimize_czs=True,optimize_cnots=1,quiet=False).to_basic_gates()
    #c2 = zx.optimize.basic_optimization(c2.to_basic_gates()).to_basic_gates()
    g2 = g.copy()
    c3 = extract_better(g2,optimize_czs=True,optimize_cnots=2,quiet=False).to_basic_gates()
    #c3 = zx.optimize.basic_optimization(c3.to_basic_gates()).to_basic_gates()
    g2 = g.copy()
    c4 = extract_better(g2,optimize_czs=True,optimize_cnots=3,quiet=False).to_basic_gates()
    #c4 = zx.optimize.basic_optimization(c4.to_basic_gates()).to_basic_gates()
    method1 += c2.twoqubitcount()
    method2 += c3.twoqubitcount()
    method3 += c4.twoqubitcount()

print(c.to_basic_gates().stats())
print(method1/reps, method2/reps, method3/reps)

Simple vertex
Vertices extracted: 9
Simple vertex
Vertices extracted: 1
Simple vertex
Vertices extracted: 2
Simple vertex
Vertices extracted: 1
Simple vertex
Vertices extracted: 1
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 5
Simple vertex
Vertices extracted: 2
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 2
Simple vertex
Vertices extracted: 2
Simple vertex
Vertices extracted: 2
Simple vertex
Vertices extracted: 2
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 1
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 4
Simple vertex
Vertices extracted: 1
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 5
Simple vertex
Vertices extracted: 4
Simple vertex
Vertices extracted: 2
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extracted: 3
Simple vertex
Vertices extra

In [8]:
zx.d3.draw(c2)

In [4]:
def column_optimal_swap(m):
    r, c = m.rows(), m.cols()
    connections = {i: set() for i in range(r)}
    connectionsr= {j: set() for j in range(c)}

    for i in range(r):
            for j in range(c):
                if m.data[i][j]: 
                    connections[i].add(j)
                    connectionsr[j].add(i)

    target = _find_targets(connections, connectionsr)
    if not target: target = dict()
    #target = {v:k for k,v in target.items()}
    left = list(set(range(c)).difference(target.values()))
    right = list(set(range(c)).difference(target.keys()))
    for i in range(len(left)):
        target[right[i]] = left[i]
    return target

def _find_targets(conn, connr, target={}):
    target = target.copy()
    r = len(conn)
    c = len(connr)
    
    claimedcols = set(target.keys())
    claimedrows = set(target.values())
    
    while True:
        min_index = -1
        min_options = set(range(1000))
        for i in range(r):
            if i in claimedrows: continue
            s = conn[i] - claimedcols # The free columns
            if len(s) == 1:
                j = s.pop()
                target[j] = i
                claimedcols.add(j)
                claimedrows.add(i)
                break
            if len(s) == 0: return None # contradiction
            found_col = False
            for j in s:
                t = connr[j] - claimedrows
                if len(t) == 1: # j can only be connected to i
                    target[j] = i
                    claimedcols.add(j)
                    claimedrows.add(i)
                    found_col = True
                    break
            if found_col: break
            if len(s) < len(min_options):
                min_index = i
                min_options = s
        else: # Didn't find any forced choices
            if not (conn.keys() - claimedrows): # we are done
                return target
            if min_index == -1: raise ValueError("This shouldn't happen ever")
            # Start depth-first search
            tgt = target.copy()
            #print("backtracking on", min_index)
            for i2 in min_options:
                #print("trying option", i2)
                tgt[i2] = min_index
                r = _find_targets(conn, connr, tgt)
                if r: return r
            #print("Unsuccessful")
            return target

In [5]:
from pyzx.extract import max_overlap, bi_adj, connectivity_from_biadj, apply_rule, pivot, permutation_as_swaps
from pyzx.circuit import Circuit, CNOT
from pyzx.simplify import id_simp
from pyzx.linalg import Mat2, greedy_reduction

def extract_better(g, optimize_czs=True, optimize_cnots=2, quiet=True):
    """Given a graph put into semi-normal form by :func:`simplify.full_reduce`, 
    it extracts its equivalent set of gates into an instance of :class:`circuit.Circuit`.
    """
    #g.normalise()
    qs = g.qubits() # We are assuming that these are objects that update...
    rs = g.rows()   # ...to reflect changes to the graph, so that when...
    ty = g.types()  # ... g.set_row/g.set_qubit is called, these things update directly to reflect that
    phases = g.phases()
    c = Circuit(g.qubit_count())

    gadgets = {}
    for v in g.vertices():
        if g.vertex_degree(v) == 1 and v not in g.inputs and v not in g.outputs:
            n = list(g.neighbours(v))[0]
            gadgets[n] = v
    
    qubit_map = dict()
    frontier = []
    for o in g.outputs:
        v = list(g.neighbours(o))[0]
        if v in g.inputs: continue
        frontier.append(v)
        qubit_map[v] = qs[o]
        
    czs_saved = 0
    
    while True:
        # preprocessing
        for v in frontier: # First removing single qubit gates
            q = qubit_map[v]
            b = [w for w in g.neighbours(v) if w in g.outputs][0]
            e = g.edge(v,b)
            if g.edge_type(e) == 2: # Hadamard edge
                c.add_gate("HAD",q)
                g.set_edge_type(e,1)
            if phases[v]: 
                c.add_gate("ZPhase", q, phases[v])
                g.set_phase(v,0)
        # And now on to CZ gates
        cz_mat = Mat2([[0 for i in range(g.qubit_count())] for j in range(g.qubit_count())])
        for v in frontier:
            for w in list(g.neighbours(v)):
                if w in frontier:
                    cz_mat.data[qubit_map[v]][qubit_map[w]] = 1
                    cz_mat.data[qubit_map[w]][qubit_map[v]] = 1
                    g.remove_edge(g.edge(v,w))
        
        if optimize_czs:
            overlap_data = max_overlap(cz_mat)
            while len(overlap_data[1]) > 2: #there are enough common qubits to be worth optimising
                i,j = overlap_data[0][0], overlap_data[0][1]
                czs_saved += len(overlap_data[1])-2
                c.add_gate("CNOT",i,j)
                for qb in overlap_data[1]:
                    c.add_gate("CZ",j,qb)
                    cz_mat.data[i][qb]=0
                    cz_mat.data[j][qb]=0
                    cz_mat.data[qb][i]=0
                    cz_mat.data[qb][j]=0
                c.add_gate("CNOT",i,j)
                overlap_data = max_overlap(cz_mat)

        for i in range(g.qubit_count()):
            for j in range(i+1,g.qubit_count()):
                if cz_mat.data[i][j]==1:
                    c.add_gate("CZ",i,j)
        
        # Now we can proceed with the actual extraction
        # First make sure that frontier is connected in correct way to inputs
        neighbours = set()
        for v in frontier.copy():
            d = [w for w in g.neighbours(v) if w not in g.outputs]
            if any(w in g.inputs for w in d): #frontier vertex v is connected to an input
                if len(d) == 1: # Only connected to input, remove from frontier
                    frontier.remove(v)
                    continue
                # We disconnect v from the input b via a new spider
                b = [w for w in d if w in g.inputs][0]
                q = qs[b]
                r = rs[b]
                w = g.add_vertex(1,q,r+1)
                e = g.edge(v,b)
                et = g.edge_type(e)
                g.remove_edge(e)
                g.add_edge((v,w),2)
                g.add_edge((w,b),3-et)
                d.remove(b)
                d.append(w)
            neighbours.update(d)
        if not frontier: break # We are done
        
        # First we check if there is a phase gadget in the way
        removed_gadget = False
        for w in neighbours:
            if w not in gadgets: continue
            for v in g.neighbours(w):
                if v in frontier:
                    apply_rule(g,pivot,[(w,v,[],[o for o in g.neighbours(v) if o in g.outputs])])
                    frontier.remove(v)
                    del gadgets[w]
                    frontier.append(w)
                    qubit_map[w] = qubit_map[v]
                    removed_gadget = True
                    break
        if removed_gadget: # There was indeed a gadget in the way. Go back to the top
            continue
            
        neighbours = list(neighbours)
        m = bi_adj(g,neighbours,frontier)
        #print(m)
        #print(m)
        if all(sum(row)!=1 for row in m.data): # No easy vertex
            if optimize_cnots>1:
                 greedy = greedy_reduction(m)
            else: greedy = None
            if greedy:
                greedy = [CNOT(target,control) for control,target in greedy]
                if (len(greedy)==1 or optimize_cnots<3) and not quiet: print("Found greedy reduction with", len(greedy), "CNOT")
                cnots = greedy
            if not greedy or (optimize_cnots == 3 and len(greedy)>1):
                perm = column_optimal_swap(m)
                #print(perm)
                perm = {v:k for k,v in perm.items()}
                neighbours2 = [neighbours[perm[i]] for i in range(len(neighbours))]
                m2 = bi_adj(g, neighbours2, frontier)
                if optimize_cnots > 0:
                    cnots = m2.to_cnots(optimize=True)
                else:
                    cnots = m2.to_cnots(optimize=False)
                m3 = m2.copy()
                for cnot in cnots:
                    m3.row_add(cnot.target,cnot.control)
                reductions = sum(1 for row in m3.data if sum(row)==1)
                if greedy and (len(cnots)/reductions > len(greedy)-0.1):
                    if not quiet: print("Found greedy reduction with", len(greedy), "CNOTs")
                    cnots = greedy
                else:
                    neighbours = neighbours2
                    m = m2
                    if not quiet: print("Gaussian elimination with", len(cnots), "CNOTs")
            for cnot in cnots:
                m.row_add(cnot.target,cnot.control)
                c.add_gate("CNOT",qubit_map[frontier[cnot.control]],qubit_map[frontier[cnot.target]])
            connectivity_from_biadj(g,m,neighbours,frontier)
        else:
            if not quiet: print("Simple vertex")
        good_verts = dict()
        for i, row in enumerate(m.data):
            if sum(row) == 1:
                v = frontier[i]
                w = neighbours[[j for j in range(len(row)) if row[j]][0]]
                good_verts[v] = w
        if not good_verts: raise Exception("No extractable vertex found. Something went wrong")
        for v,w in good_verts.items():
            c.add_gate("HAD",qubit_map[v])
            qubit_map[w] = qubit_map[v]
            b = [o for o in g.neighbours(v) if o in g.outputs][0]
            g.remove_vertex(v)
            g.add_edge((w,b))
            frontier.remove(v)
            frontier.append(w)
        if not quiet: print("Vertices extracted:", len(good_verts))
        #continue
            
    if optimize_czs:
        if not quiet: print("CZ gates saved:", czs_saved)
    # Outside of loop. Finish up the permutation
    id_simp(g,quiet=True) # Now the graph should only contain inputs and outputs
    swap_map = {}
    leftover_swaps = False
    for v in g.outputs: # Finally, check for the last layer of Hadamards, and see if swap gates need to be applied.
        q = qs[v]
        i = list(g.neighbours(v))[0]
        if i not in g.inputs: 
            raise TypeError("Algorithm failed: Not fully reducable")
            return c
        if g.edge_type(g.edge(v,i)) == 2:
            c.add_gate("HAD", q)
            g.set_edge_type(g.edge(v,i),1)
        if qs[i] != q: leftover_swaps = True
        swap_map[q] = qs[i]
    if leftover_swaps: 
        for t1, t2 in permutation_as_swaps(swap_map):
            c.add_gate("SWAP", t1, t2)
    # Since we were extracting from right to left, we reverse the order of the gates
    c.gates = list(reversed(c.gates))
    return c