# Simon's Algorithm
## Problem statement:

We have a 2:1 function (i.e., a function that maps 2 different inputs to one output).  
This function takes $\{0,1\}^n$ bits as input and produce $\{0,1\}^n$ bits as output such that the output $f(x) = f(x ⊕ s)$ (⊕ means bitwise X-OR) where $s$ is also of $\{0,1\}^n$ bits.  
We need to find the arbitrary bitstring $s$.

In [None]:
import qckt
import qckt.backend as bknd
import numpy as np

## Oracle function

In [None]:
def simons_oracle(secret_code: str):
    """
    Quantum circuit implementing one possible oracle for Simon's problem.

    Args:
        secret_code (str): secret string we wish to find

    Steps:
    1. first pick one qubit in s that is 1, call it flag qubit
    2. start with copying x to output using cnots
    3. then, code the circuit such that, if the flag qubit is 1 in x, xor the output register with s; use cnot gates to do that

    Discussion:
    lets take two input states |x> and |x + s>  ('+' denotes a modulo 2 addition here, i.e., xor), as defined, flag bit is going to be 1 in one and 0 in the other
    case 1: flag is 0 in |x> and 1 in |x + s>
        by construction, if flag is 0, f(x) = |x>, i.e., same as the input
        f(x+s) will also be |x> since flag is 1 for |x+s> hence the output will be xor'ed with s, i.e.
        outout will be |x+s+s> = |x>
        so f(x) = f(x+s)
    case 2: flag is 1 in |x> and 0 in |x + s>
        by construction, if flag is 1, f(x) = |x+s>
        f(x+s) will also be |x+s>, since for flag = 0, output is the same as input
        so, f(x) = f(x+s)
    """

    n = len(secret_code)

    # Lets decide to use the first 1 in s as the flag bit
    flag_bit=secret_code.find('1')
    
    # and start constructing the oracle circuit
    circ = qckt.QCkt(2*n)   # n qubits for input and n qubits for output

    # First copy x to output, so |x>|0> -> |x>|x>
    for i in range(n):
        circ.CX(i, i+n)
    
    # If flag_bit=-1, i.e., s is the all-zeros string, x+s is the same as x, hence the condition f(x) = f(x+s) is trivially satisfied, hence, we do nothing more.
    # else, we xor the output with s, controlled by the flag qubit
    if flag_bit != -1:
        # apply the XOR with s controlled by the flag qubit
        for index,bit_value in enumerate(secret_code):
            if bit_value not in ['0','1']:
                raise Exception ('Incorrect char \'' + bit_value + '\' in secret string s:' + secret_code)
            # XOR'ing with s controlled with the flag qubit means for each bit as 1 in s, cnot the corresponding qubit in output using flag as the control qubit.
            if(bit_value == '1'):
                circ.CX(flag_bit,index+n)
    return circ

## Simon's algorithm circuit

In [None]:
# secret code s
s = '01101'
print(f'Secret code is {s}')

# form the input and output registers
regsz = len(s)
inpreg = [i for i in range(regsz)]
outreg = [i+regsz for i in range(regsz)]

# construct the circuit
ckt = qckt.QCkt(regsz*2,regsz*2)
# full superposition of input
ckt.H(inpreg)

# insert the oracle function
ckt.Border()
ora = simons_oracle(s)
ckt = ckt.append(ora)
ckt.Border()
# ckt.Probe(header='f(x) applied to a full superposition of x')

# ... and then the rest of the circuit
# ckt.M(outreg)  # this measurement is useful for understanding the functionality, but redundant due to the principle of defered measurement
ckt.H(inpreg)
# ckt.Probe(header='state before final measurement')
ckt.M(inpreg)
ckt.draw()

## Run the circuit to collect different set of outputs

In [None]:

num_tries = 0
y_vals = []
while len(y_vals) < regsz-1:
    job = qckt.Job(ckt, shots=1)
    bk = bknd.Qdeb()
    bk.runjob(job)
    creg_counts = job.get_counts()

    creg_max = np.argmax(np.array(creg_counts))
    y = creg_max & (2**regsz - 1)
    print(f'{creg_max:0{2*regsz}b}  {y:0{regsz}b}')
    if y != 0 and y not in y_vals:
        y_vals.append(y)
    num_tries += 1


print(f'After {num_tries} tries')
print('y values:', [f'{val:0{regsz}b}' for val in y_vals])

## Use the collected outputs in classical algorithm to solve the set of simultaneous equations
Left as an exercise to the reader :-)

In [None]:
print('Now solve these linear simultaneous equations to get value of s')
print('note: it is possible in some run the data does not lead to independent set of equations')
for y in y_vals:
    eqn = ""
    for i in range(regsz):
        if y & (1<<i):
            if eqn == "":
                eqn = f's{i}'
            else:
                eqn = f'{eqn} + s{i}'
    print(eqn,'= 0 (mod 2)')