# SAT Solver for Quantum Error Correcting Code Distance

Heather Leitch & Alastair Kay

Royal Holloway, University of London

In [16]:
#initial imports
import numpy as np
from IPython.display import display, Math, Latex

# requires ldpc https://github.com/quantumgizmos/ldpc/tree/main
from ldpc.mod2 import *
from ldpc.codes import *

#requires pysat, https://pysathq.github.io/
#pysat not so easy to get running under windows. Might instead use a cloud service such as cocalc.com
from pysat.card import *
from pysat.solvers import Glucose42

#most proof verifiers are command line utilities. Here's one we can instead use from python. https://fairlyaccountable.org/verified_rup/drup.html
import drup

Given a CSS code defined by binary matrices $H_X\in\{0,1\}^{m_X\times n}$ and $H_Z\in\{0,1\}^{m_Z\times n}$, we want to determine the (Z) distance (to find $X$ distance, just switch $H_X$ and $H_Z$). We do this by defining the decision problem:

> Is there a string $x\in\{0,1\}^n$ of weight $w_x\leq w$ for which $H_X\cdot x\equiv 0\text{ mod }2$ (i.e. there's a $Z_x$ operator that commutes with all $X$-type stabilizers) and which cannot be written as a product of $Z$-type stabilizers, i.e. $\nexists u\in\{0,1\}^{m_Z}: H_Z^T\cdot u=x$.

We need to set this up in conjunctive normal form, possibly using some auxiliary variables as well.

> CNF($w_x\leq w$) AND CNF($H_X\cdot x\equiv 0\text{ mod }2$) AND CNF($\nexists u\in\{0,1\}^{m_Z}: H_Z^T\cdot u=x$)

## Second Term

Let's define $pwt(x)$ as the parity of the weight of the bit string $x$.


Remove any linearly dependent rows from $H_X$. Then let $h_i$ be a row of $H_X$.

> $pwt(h_1. x)=0$ AND $pwt(h_2\cdot x)=0$ AND $\ldots$ AND $pwt(h_{m_X}\cdot x)=0$

Here $h\cdot x$ means the binary string $x_1h_1,x_2h_2,\ldots x_nh_n$

## Third Term

Let's take $H_Z$ in standard form
$$
H_Z=[\mathbb{1} | A]
$$
If we divide $x$ into $x_L|x_R$, $x_l$ of length equal to rank of $H_Z$ then $H_Z^Tu=x$ requires $u=x_L$, and hence our condition is reduced to
$$
A^Tx_L\neq x_R.
$$
Let $a_i$ be a row of $A^T$ (column of $A$). We set
$$
y_i=pwt(a_i\cdot x_L|{x_R}_i)
$$
so our condition is just
> $y_1$ OR $y_2$ OR $\ldots$ OR $y_{n-m_Z}$

## First Term

This can be implemented directly in the sat solvers
$$
wt(x)\leq w
$$

In [17]:
def xor(y,x1,x2):
    """
    in: y,x1,x2. three non-zero integers specifying variables
    out: SAT constraints that fix y=x1 oplus x2"""
    assert isinstance(y,int) and isinstance(x1,int) and isinstance(x2,int)
    assert y * x1 * x2 != 0
    return [[-y,-x1,-x2],[-y,x1,x2],[y,x1,-x2],[y,-x1,x2]]

def pwt(h,idx,num_vars):
    """
    Calculate the parity of the weight of the variables in h
    in:
    h: list of non-zero integers. The variables whose weight we want to calculate the parity of
    idx: the variable to store that parity (i.e. idx=pwt(h))
    num_vars: the number of variables we already have (since we'll be adding ancillas off the end of this)
    
    out:
    sat constraints for the calculation
    number of new ancillas
    """
    assert isinstance(h,list)
    assert len(h)>0
    assert all([isinstance(el,int) and h!=0 for el in h])
    assert isinstance(idx,int) and idx>0
    assert isinstance(num_vars,int) and num_vars>0
    
    if len(h)==2:
        new_cons=xor(idx,h[0],h[1])
        num_new_ancilla=0
    elif len(h)==3:
        #only add one new variable
        num_new_ancilla=1
        new_cons=xor(idx,h[2],num_vars+1)
        nc,na=pwt(h[:2],num_vars+1,num_vars+1)
        new_cons+=nc
        num_new_ancilla+=na
    else:
        #add two new variables
        v1,v2=num_vars+1,num_vars+2
        num_new_ancilla=2
        new_cons=xor(idx,v1,v2)
        B = h[:len(h)//2]
        C = h[len(h)//2:]
        nc,na=pwt(B,v1,num_vars+2)
        new_cons+=nc
        num_new_ancilla+=na
        nc,na=pwt(C,v2,num_vars+2+na)
        new_cons+=nc
        num_new_ancilla+=na
    return new_cons,num_new_ancilla
    


def sat_distance(threshold,HX,HZ,proof=False,switch=False):
    """this is the main entry point for testing the Z distance.
    we ask if there's an error with weight <=threshold
    In:
    threshold: positive integer
    HX,HZ: binary matrices (numpy array)
    proof: Boolean
    switch: Boolean. False: Z distance. True: X distance (only alters label when we print the output). You also need to switch order of HX/HZ.
    
    Out:
    returns the output of the sat solver (i.e. a satisfying instance) if it exists. Returns False if fails.
    """

    assert isinstance(HX,np.ndarray) and isinstance(HZ,np.ndarray)
    assert not np.any(HX-HX*HX) # evaluate x(1-x) on each element to ensure binary
    assert not np.any(HZ-HZ*HZ)
    constraints=[] # form a list of constraints
    ancillas=0
    n=np.shape(HX)[1]
    assert n==np.shape(HZ)[1]
    assert not np.any(np.mod(HX @ HZ.T,2))
    
    letter=switch*"X"+(1-switch)*"Z"

    #get the parity check matrices into standard form
    Hx,rkx,row_transform,col_transform=reduced_row_echelon(HX)
    Hx=np.mod(row_transform @ HX,2) # removes any linearly dependent rows without messing up columns
    Hz,rk,row_transform,col_transform=reduced_row_echelon(HZ)
    A=Hz[:rk,rk:].T # extract the non-identity part of Hz and transpose it.
    assert (Hz==np.mod(row_transform @ HZ @ col_transform,2)).all()
    Hx=np.mod(Hx @ col_transform,2) # keep the columns consistent
    #Hz now has leading block of identity

    #do the "not a stabilizer" conditions (third term)
    #introduce n-rk new variables y_i
    ancillas=2*n-rk
    constraints=[list(range(n+1,ancillas+1))] # y_1 or y_2 or y_3....
    for idx in range(n-rk):
        #set y_idx to pwt(A^T_i. x_L|x_R_i)
        row=[i+1 for i,j in enumerate(A[idx]) if j]+[rk+idx+1] # this is A_i.x_L and appropriate bit of x_R
        new_cons,num_new=pwt(row,idx+n+1,ancillas) # set equal to y_i
        constraints+=new_cons
        ancillas+=num_new
        
    #the "commutes with all X stabilizers" conditions (second term)
    #add some new variables to receive the pwt(h_i)
    start_ancillas=ancillas
    for idx in range(rkx):
        constraints+=[[-(start_ancillas+idx+1)]] # y_i=0
        new_cons,num_new=pwt([i+1 for i,j in enumerate(Hx[idx]) if j],start_ancillas+idx+1,ancillas) # y_i=pwt(h_i)
        constraints+=new_cons
        ancillas+=num_new

    #w_x<=threshold (first term)
    cnf = CardEnc.atmost(lits=list(range(1,n+1)),bound=threshold,top_id=ancillas, encoding=EncType.seqcounter)
    constraints+=cnf.clauses

    #send the constraints to the sat solver
    g = Glucose42(bootstrap_with=constraints,with_proof=proof) # if we used cryptominisat, has native "add_xor_clause", but proof not supported
                                                               # if we used gluecard 4,    has native "add_atmost()"  , but proof not supported
                                                               # native functionality should be faster
    #solve!
    solution=g.solve()
    if solution:
        model=g.get_model()
        result=[i>0 for i in model[:n]]
        assert sum(result)<=threshold          # check that the solution has the correctly limited weight
        assert np.sum(np.mod(Hx@ result,2))==0 # check that the solution really is a null vector of HX
        temp=[i+1 for v in model[:n] if v>0 for i,j in enumerate((col_transform.T)[v-1]) if j] # undo the permutation to standard form so that the output is actually meaningful!
        err=[letter+"_{"+str(i)+"}" for i in temp]
        
        x_left=result[:rk]
        x_right=result[rk:]
        assert np.sum(np.mod(A@x_left+x_right,2))>0 # check the "not a stabilizer" condition
        display(Latex("We found an error of weight "+str(sum(result))+" of the form $"+"".join(err)+"$"))
        g.delete()
        return model # remember that the columns may have been permuted, so there's no point in looking at the answer.
    else:
        if proof:
            assert check_proof(g.get_proof(),constraints)
            print("Proof of no solution verified.")
        g.delete()
        return False

    
def check_proof(proof,cnf):
    """given a purported proof and the original CNF formula, verify the proof
    Use a checker of your choice
    """
    print("Proof of no solution:")
    print(proof)
    
    proof2 = [[int(l) for l in c.split(' ')[:-1]] for c in proof if c[0]!="d"]
    drup.check_proof(cnf, proof2, verbose=True)
    
    return True
    
def find_weight(Hx,Hz,upper_bound,both=False,switch=False):
    """find a logical Z operator of the code with a weight less than upper_bound. Continue decreasing until you cannot find a shorter one
    In:
    HX,HZ: numpy matrices for parity checks
    upper_bound: positive integer for which we know a logical operator exists
    both: boolean (False: just find Z distance. True: find X and Z distances)
    switch: boolean. only relevant if both is False. By default we're calculating Z distance, but if you want X distance, set this to True. Don't change order of arguments HX/HZ
    Out:
    code distance"""
    if switch:
        HX,HZ=Hz,Hx
    else:
        HX,HZ=Hx,Hz
    assert isinstance(HX,np.ndarray) and isinstance(HZ,np.ndarray)
    assert not np.any(HX-HX*HX) # evaluate x(1-x) on each element to ensure binary
    assert not np.any(HZ-HZ*HZ)
    
    upper=upper_bound
    bRunning=True
    (x,num)=np.shape(HX)
    while bRunning:
        out=sat_distance(upper-1,HX,HZ,switch=switch)
        if out:
            #process out to find the weight of the error
            n_upper=sum([val>0 for val in out[:num]])
            assert 0<n_upper<upper
            print("Checking if there's anything shorter...")
            upper=n_upper
        else:
            bRunning=False
    if both: # also calculate the X distance (but only care if it's shorter than the Z distance)
        bRunning=True
        (x,num)=np.shape(HZ)
        while bRunning:
            out=sat_distance(upper-1,HZ,HX,switch=not switch)
            if out:
                #process out to find the weight of the error
                n_upper=sum([1 for val in out[:num] if val>0])
                assert 0<n_upper<upper
                print("Checking if there's anything shorter...")
                upper=n_upper
            else:
                bRunning=False
    return upper

## Error Correcting Code Definitions

In [5]:
#interesting error correcting codes
#functions which return H_X,H_Z

#Steane Code
def Steane_parity_checks():
    return QHammingCode(3)

def QHammingCode(n):
    """[[2^n-1,1,3]] codes"""
    HX=hamming_code(n)
    HZ=nullspace(np.vstack([HX,[1]*(2**n-1)]))
    return HX,HZ

# Bivariate bicycle codes: Nature volume 627, pages 778–782 (2024).
def IBM_parity_checks(l,m,powersA,powersB):
    """creates an IBM code using parameters l,m, of form H_X=[A|B], H_Z=[B^T|A^T].
    powersA,powersB: list of length 3, pairs of integers [a,b], first 0 to l-1, second 0 to m-1, defining a term $x^ay^b$"""
    assert isinstance(powersA,list)
    assert len(powersA)>=2
    assert all([0<=a<l and 0<=b<m for a,b in powersA])
    assert isinstance(powersB,list)
    assert len(powersB)>=2
    assert all([0<=a<l and 0<=b<m for a,b in powersB])

    A=sum([np.kron(np.roll(np.eye(l), a, axis=1),np.roll(np.eye(m), b, axis=1)) for a,b in powersA])
    B=sum([np.kron(np.roll(np.eye(l), a, axis=1),np.roll(np.eye(m), b, axis=1)) for a,b in powersB])
    
    HX=np.block([[A,B]])
    HZ=np.block([[B.T,A.T]])
    assert HX.shape==(l*m,l*m*2)
    assert all(np.sum(HX, axis=1)==[len(powersA)*2]*l*m)
    assert all(np.sum(HX, axis=0)==[len(powersA)]*l*m*2)
    return HX,HZ

def Toric_parity_checks(N):
    """create a Toric code of 2N^2 qubits"""
    return IBM_parity_checks(N,N,[[0,0],[1,0]],[[0,0],[0,1]])

## Tests to see if it works

In [8]:
#some functions that let you evaluate the conditional clasues to see if they're doing what you think
def output_truthtable(constraints):
    vars=max([abs(i) for row in constraints for i in row])
    return [bit_string_satisfies(constraints,mod10_to_mod2(idx,length=vars)) for idx in range(2**vars)]

def test_with_unknown_ancillas(constraints,input,output):
    """for a set of constraints (which have been formulated using some ancillas that I don't know how they work), see if the constraints work properly.
    input: the bit values x for a given row of the truth table
    output: the output of the truth table
    returns True/False if answer correct or not.
    if output==0, answer correct if all possible ancilla inputs return 0
    if output==1, answer correct if there exists an ancilla input that returns 1"""
    largest=max([abs(i) for constraint in constraints for i in constraint])
    first_ancilla=len(input)
    num_ancillas=largest+1-first_ancilla
    if output:
        return any([bit_string_satisfies(constraints,input+mod10_to_mod2(idx,length=num_ancillas)) for idx in range(2**num_ancillas)])
    else:
        return all([not bit_string_satisfies(constraints,input+mod10_to_mod2(idx,length=num_ancillas)) for idx in range(2**num_ancillas)])

def bit_string_satisfies(constraints,str):
    """This is purely for testing purposes"""
    return all([bit_string_or(cons,str) for cons in constraints]) # and everything

def bit_string_or(constraint,str):
    return any([test_condition(cons,str) for cons in constraint]) # or everything

def test_condition(constraint,str): # an individual term
    if constraint>0:
        return str[constraint-1]
    elif not constraint:
        return 1
    else:
        return not str[-constraint-1]



#tests to verify that I'm constructing the xor correctly
assert output_truthtable(xor(3,1,2))==[1,0,0,1,0,1,1,0] # y_1=x_1 xor x_2
assert output_truthtable(pwt([1,2,3],4,4)[0])==[1,0,0,0,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,1,0] # y_1=x_1 xor x_2 xor x_3 AND y_2=x_1 xor x_2
assert pwt([1,2,3,4,5,6,7],8,8)[1]==5 # I'm not calculating the truth table for this, but can at least predict the number of ancillas required
#test CardEnc:
conds=CardEnc.atmost(lits=[1,2,3,4],bound=2, encoding=EncType.seqcounter).clauses
assert all([test_with_unknown_ancillas(conds,mod10_to_mod2(idx,length=4),sum(mod10_to_mod2(idx,length=4))<=2) for idx in range(2**4)])

#basic test cases for distance calculations

assert sat_distance(3,*Steane_parity_checks())
assert sat_distance(6,*IBM_parity_checks(6,6,[[0,0],[1,0],[2,3]],[[0,0],[0,1],[3,2]]))
assert not sat_distance(2,*Steane_parity_checks())
assert not sat_distance(5,*IBM_parity_checks(6,6,[[0,0],[1,0],[2,3]],[[0,0],[0,1],[3,2]]))

assert find_weight(*Steane_parity_checks(),5)==3
assert find_weight(*IBM_parity_checks(6,6,[[0,0],[1,0],[2,3]],[[0,0],[0,1],[3,2]]),6)==6

assert sat_distance(3,*Toric_parity_checks(3))
assert find_weight(*Toric_parity_checks(3),7)==3
assert sat_distance(5,*Toric_parity_checks(5))
assert not sat_distance(4,*Toric_parity_checks(5))
assert find_weight(*Toric_parity_checks(5),7)==5

assert find_weight(*QHammingCode(4),9)==3
assert find_weight(*QHammingCode(4),9,switch=True)==7
print("All tests passed!!!")

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

Checking if there's anything shorter...


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

Checking if there's anything shorter...


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

Checking if there's anything shorter...


<IPython.core.display.Latex object>

Checking if there's anything shorter...


<IPython.core.display.Latex object>

Checking if there's anything shorter...


<IPython.core.display.Latex object>

Checking if there's anything shorter...
All tests passed!!!


If we enable proofs, then when a solution is not found, we also get given a proof which claims to show that there is no solution. In principle, we can check this proof with an independent proof checker, such as https://pypi.org/project/drup/ (as implemented) or https://www.cs.utexas.edu/~marijn/drat-trim/ (**not tested**). We *believe* we have correctly converted between proof formats, but hard to verify.

In [18]:
sat_distance(2,*Steane_parity_checks(),proof=True)

Proof of no solution:
['d 19 -20 0', 'd -19 20 0', 'd -2 -4 -19 0', 'd 2 4 -19 0', 'd -6 -7 -20 0', 'd 6 7 -20 0', '31 0', '29 15 9 8 0', '3 -10 15 9 16 0', '-7 -29 0', '15 9 8 0', '1 -15 0', '-15 0', '7 1 -9 0', '-7 0', '1 0', '0']
Proof of no solution verified.


False