In [4]:
import cirq
import numpy as np

In [14]:
def beta(s, j, x):
    index_num = (2*j-1)*(2**(s-1))
    index_den = (j-1)*(2**s)
    
    num = np.sqrt(np.sum(abs(x[index_num : index_num+2**(s-1)])**2))
    den = np.sqrt(np.sum(abs(x[index_den : index_den+2**(s)])**2))
    
    if den == 0:
        beta = 0
    else:
        beta = 2*np.arcsin(num/den)
    return beta

In [42]:
def locate_x(curr_j, prev_j, length):
    curr_bin = bin(curr_j)[2:].zfill(length)
    prev_bin = bin(prev_j)[2:].zfill(length)
    return [i for i, (x,y) in enumerate(zip(curr_bin,prev_bin)) if x!=y]

In [75]:
def amplitude_embedding(x):
    n = int(np.log2(len(x)))
    qubits = cirq.GridQubit.rect(1, n)
    circuit = cirq.Circuit()
    
    circuit += cirq.ry(beta(n, 1, x))(qubits[0])
    
    for i in range(1,n):
        # We can have at most i control bits
        # Total possibilities is therefore 2^i
        controls = 2**i
        
        control_qubits = [qubits[c] for c in range(i+1)]
        circuit += cirq.ControlledGate(sub_gate=cirq.ry(beta(n-i, controls, x)), 
                                       num_controls=len(control_qubits)-1)(*control_qubits)
        
        for j in range(1, controls):
            for loc in locate_x(controls-j-1, controls-j, i):
                circuit += cirq.X(qubits[loc])
                
                circuit += cirq.ControlledGate(sub_gate=cirq.ry(beta(n-i, controls-j, x)), 
                                       num_controls=len(control_qubits)-1)(*control_qubits)
            
        for k in range(i):
            circuit += cirq.X(qubits[k])

    return circuit

In [100]:
x = [1,2,3,4]
d = np.sqrt(np.sum(np.square(x)))
x = x/d
print('amplitudes:',x)
print('probs:',np.square(x))
qc = amplitude_embedding(x)

amplitudes: [0.18257419 0.36514837 0.54772256 0.73029674]
probs: [0.03333333 0.13333333 0.3        0.53333333]


In [101]:
qubits = cirq.GridQubit.rect(1, 2)
qc += cirq.measure(*qubits, key='result')
qc

In [102]:
s=cirq.Simulator()
shots = 10000
samples=s.run(qc, repetitions=shots)
res = dict(samples.histogram(key="result"))
for key, value in res.items():
    res[key] = value/shots
res

{3: 0.5364, 1: 0.1376, 2: 0.2948, 0: 0.0312}