In [1]:
# %matplotlib inline
# from plot_braket_circuit import *

In [2]:
from braket.devices import LocalSimulator
from braket.circuits import Circuit

In [3]:
def generate_oracle(secret_s):
    
    # validating input secret s:
    first_1_bit_location = -1
    other_1_bit_location_list = list()
    
    for index,bit_value in enumerate(secret_s):
        if (bit_value != '0' and bit_value != '1'):
            raise Exception ('Incorrect char \'' + bit_value + '\' in secret string S:' + secret_s)
        else:
            if (bit_value == '1'):
                if (first_1_bit_location == -1):
                    first_1_bit_location = index
                else:
                    other_1_bit_location_list.append(index)
                
    if (first_1_bit_location == -1):
        raise Exception (' All 0 in secret string S')
            
    n = len(str(secret_s))
    
    oracle_circuit = Circuit()

        
    oracle_circuit.cnot(first_1_bit_location, first_1_bit_location+n)
#     qt.cnot(first_1_bit_location, first_1_bit_location+n)
    
    for other_1_bit_location in other_1_bit_location_list:
        oracle_circuit.cnot(first_1_bit_location, other_1_bit_location)
        
    for i in range(n):
#         if (i != first_1_bit_location):
        oracle_circuit.cnot(i, n+i)
        
    for other_1_bit_location in other_1_bit_location_list:
        oracle_circuit.cnot(first_1_bit_location, other_1_bit_location)
                
    return oracle_circuit

In [4]:
device = LocalSimulator()

In [5]:
def generate_input_circuit(source_list):
    
    input_circuit_list = list()
    
    for input_index, digit_string in enumerate(source_list):
        cur_circuit = Circuit()
        for reg_index, digit_value in enumerate(digit_string):
            if (digit_value == '0'):
                cur_circuit.i(reg_index)
            elif (digit_value == '1'):
                cur_circuit.x(reg_index)
            else:
                raise Exception('incorrect input value: \'' + digit_value + '\' in: ' + digit_string )
        
        input_circuit_list.append(cur_circuit)
        
    return input_circuit_list

In [6]:
def generate_full_bit_string(bit_number):
    zero_string = '0' * bit_number
    result_list = list()
    for i in range(pow(2, bit_number)):
        cur_string = (zero_string + bin(i)[2:])[-bit_number:]
        result_list.append(cur_string)
    return result_list

In [7]:
def generate_simon_input_string(bit_number):
    result_list = generate_full_bit_string(bit_number)
    zero_string = '0'*bit_number
    
    for i in range(len(result_list)):
        result_list[i] = result_list[i] + zero_string
        
    return result_list

In [8]:
def generate_simon_input_circuit(bit_number):
    input_string_list = generate_simon_input_string(bit_number)
    
    circuit_list = generate_input_circuit(input_string_list)
    
    return circuit_list
        

In [9]:
def grouping(measure_result_list):
    
    if (len(measure_result_list) == 0):
        raise Exception ('Empty measurement result list.')
        
    if (len(measure_result_list[0])%2 == 1):
        raise Exception ('Measurement result with odd number of digit.')   
    
    bit_number = int(len(measure_result_list[0])/2)
    
    result_group = dict()
    
    for each_result in measure_result_list:
        input_value = each_result[0:bit_number]
        output_value = each_result[bit_number:]
        
        if output_value in result_group.keys():
            result_group[output_value].append(input_value)
        else:
            result_group[output_value] = list()
            result_group[output_value].append(input_value)
            
    return result_group
        

In [10]:
def verify_result(result_group, secret_s, print_detail=True):
    bit_number = len(secret_s)
    
    for each_key in result_group.keys():
        
        
        if (len(each_key) != bit_number):
            if (print_detail):
                print('Incorrect key:' + each_key)
            return False
        
        input_list = result_group[each_key]
        
        if (len(input_list)!= 2):
            if (print_detail):
                print ('More than two input for one output')
            return False
        
        for i in range(bit_number):
            input_bit_1 = bool(int(input_list[0][i]))
            input_bit_2 = bool(int(input_list[1][i]))
            secret_bit = bool(int(secret_s[i]))
            
            if (input_bit_1^input_bit_2 != secret_bit):
                if (print_detail):
                    print (str(i) + 'input:' + input_list[0] + ' xor ' + secret_s + ' != ' + input_list[1])
                    
                return False
            
    return True

In [12]:
for secret_s_bit_number in range(2,8):
    
    secret_s_list = generate_full_bit_string(secret_s_bit_number)[1:]
    
    print ('checking bit number:' + str(secret_s_bit_number))
    
    for secret_s in secret_s_list:
        bit_number = len(secret_s)
        
        oracle_circuit = generate_oracle(secret_s)
        
        input_circuit_list = generate_simon_input_circuit(bit_number)
        
        measure_result_list = list()
        
        for cur_circuit in input_circuit_list:
            task = device.run(cur_circuit + oracle_circuit, shots=500)
            result = task.result()

            if (len(result.measurement_counts.keys())!=1):
                raise Exception('Multiple classical result generated!')

            result_string = list(result.measurement_counts.keys())[0]

            measure_result_list.append(result_string)
            
        
        result_group = grouping(measure_result_list)
        
        
        
        result = verify_result(result_group, secret_s)
        
        if (result):
            print ('For secret string:' + secret_s + '. Answers are correct!')
            
        print ('------------------')
    
    

checking bit number:2
For secret string:01. Answers are correct!
------------------
For secret string:10. Answers are correct!
------------------
For secret string:11. Answers are correct!
------------------
checking bit number:3
For secret string:001. Answers are correct!
------------------
For secret string:010. Answers are correct!
------------------
For secret string:011. Answers are correct!
------------------
For secret string:100. Answers are correct!
------------------
For secret string:101. Answers are correct!
------------------
For secret string:110. Answers are correct!
------------------
For secret string:111. Answers are correct!
------------------
checking bit number:4
For secret string:0001. Answers are correct!
------------------
For secret string:0010. Answers are correct!
------------------
For secret string:0011. Answers are correct!
------------------
For secret string:0100. Answers are correct!
------------------
For secret string:0101. Answers are correct!
------

For secret string:0000110. Answers are correct!
------------------
For secret string:0000111. Answers are correct!
------------------
For secret string:0001000. Answers are correct!
------------------
For secret string:0001001. Answers are correct!
------------------
For secret string:0001010. Answers are correct!
------------------
For secret string:0001011. Answers are correct!
------------------
For secret string:0001100. Answers are correct!
------------------
For secret string:0001101. Answers are correct!
------------------
For secret string:0001110. Answers are correct!
------------------
For secret string:0001111. Answers are correct!
------------------
For secret string:0010000. Answers are correct!
------------------
For secret string:0010001. Answers are correct!
------------------
For secret string:0010010. Answers are correct!
------------------
For secret string:0010011. Answers are correct!
------------------
For secret string:0010100. Answers are correct!
--------------