In [None]:
# imports
import numpy as np


In [None]:
def f(x, a, k):
    """ 
    Given a boolean string x, a list of integers a, and an integer k, this function returns the value of the function f(x) = (1 + (-1)^(a*x mod k)) / 2
    """
    assert isinstance(x, str)
    x = np.array([int(i) for i in x])
    return (1 + (-1)**(np.dot(a, x) % k)) // 2

def get_uf(f, n):
    """
    returns the 2^(n+1) x 2^(n+1) matrix representation of the unitary operator Uf
    f: a function that takes in a binary string x (of length n) and returns a truthy or valsey value
    """
    size = 2**(n + 1) # n input qubits + z
    matrix = np.zeros((size, size), dtype=int) 
    
    false_block = np.array([[1, 0], [0, 1]]) # I
    true_block = np.array([[0, 1], [1, 0]]) # NOT
    
    # fill diagonal blocks
    for i in range(size // 2):
        x = f"{i:0{n}b}"
        value = f(x)
        block = true_block if value else false_block

        # insert block into the ith 2x2 diagonal block entry
        matrix[2*i:2*i+2, 2*i:2*i+2] = block
    return matrix

def get_hadamard(n):
    """
    returns the 2^n x 2^n matrix representation of the Hadamard gate
    """
    hadamard = 1 / np.sqrt(2) * np.array([[1, 1], [1, -1]])
    matrix = hadamard.copy()
    
    # use the fact that Hn is just n hadamard gates tensored together
    for i in range(n-1):
        matrix = np.kron(matrix, hadamard)
    return matrix

def get_input_state(n):
    # top n qubits are in the |0> state
    state = np.zeros(2**n)
    state[0] = 1.0
    # last qubit is in the |1> state
    state = np.kron(state, np.array([0, 1]))

    # apply hadamard to all qubits
    return np.matmul(get_hadamard(n + 1), state)

def inverse_kron(C, B):
    """
    Suppose C = np.kron(A, B) for some matrices A and B. This function returns A given C and B.
    I credit ChatGPT with the implementation of this function, as I could not find it implemented in numpy already.
    """
    # make C and B 2D (to handle the case of vector 1D inputs)
    C = np.atleast_2d(C)
    B = np.atleast_2d(B)

    m_C, n_C = C.shape
    m_B, n_B = B.shape

    # verify dimensions
    if m_C % m_B != 0 or n_C % n_B != 0:
        raise ValueError("Dimensions of C are not compatible with B for inverse Kronecker operation.")

    # dimensions of A
    m_A, n_A = m_C // m_B, n_C // n_B
    A = np.zeros((m_A, n_A), dtype=C.dtype)

    # recover A by comparing blocks in C with B
    for i in range(m_A):
        for j in range(n_A):
            # Extract the block from C
            block = C[i*m_B:(i+1)*m_B, j*n_B:(j+1)*n_B]

            # Compute the scaling factor using tolerances
            if np.allclose(B, 0):
                if np.allclose(block, 0):
                    A[i, j] = 0
                else:
                    raise ValueError("B is zero but corresponding block in C is not zero.")
            else:
                A[i, j] = np.sum(block * np.conj(B)) / np.sum(B * np.conj(B))

    # if C and B were originally 1D, flatten A to match the expected shape
    if C.shape[0] == 1 or C.shape[1] == 1:
        A = A.flatten()

    return A


In [157]:
n = 4

def a(x):
    return f(x, [1, 2, 3, 4], 2)

def b(x):
    return f(x, [2, 4, 6, 8], 2)

def c(x):
    return f(x, [1, 1, 1, 1], 3)


In [None]:
def run_balance_test(f, n):
    input_state = get_input_state(n)
    uf = get_uf(f, n)
    uf_output_state = np.matmul(uf, input_state)

    # we know the bottom qubit should be in the H |1> state
    bottom_qubit = np.matmul(get_hadamard(1), np.array([0, 1]))
    
    # extract the top n qubits
    top_qubits = inverse_kron(uf_output_state, bottom_qubit)

    # apply hadamard to the top n qubits before measuring
    output_state = np.matmul(get_hadamard(n), top_qubits)
    return output_state

In [165]:
def _print_latex_probabilities_table(amplitudes):
    # print start of table
    print("""\\begin{table}
    \\begin{tabular}{|c|c|}
        \\hline
        $x_1$ $x_2$ $x_3$ $x_4$ & $P(x)$ \\\\
        \\hline""")
    for i in range(16):
        binary = format(i, '04b')
        probability = np.abs(amplitudes[i])**2
        print(f"        {binary} & {probability} \\\\")

    # print end of table
    print("""        \\hline
    \\end{tabular}
\\end{table}""")

def print_probabilities_table(amplitudes, latex=False):
    if latex:
        return _print_latex_probabilities_table(amplitudes)
    
    print("x1 x2 x3 x4 | probability")
    for i in range(16):
        binary = format(i, '04b')
        probability = np.abs(amplitudes[i])**2
        print(binary, "      |", probability)

In [166]:
print_probabilities_table(run_balance_test(a, n), latex=True)

\begin{table}
    \begin{tabular}{|c|c|}
        \hline
        $x_1$ $x_2$ $x_3$ $x_4$ & $P(x)$ \\
        \hline
        0000 & 0.0 \\
        0001 & 0.0 \\
        0010 & 3.4184043884735275e-64 \\
        0011 & 0.0 \\
        0100 & 0.0 \\
        0101 & 0.0 \\
        0110 & 3.4184043884735275e-64 \\
        0111 & 0.0 \\
        1000 & 0.0 \\
        1001 & 0.0 \\
        1010 & 0.9999999999999989 \\
        1011 & 0.0 \\
        1100 & 0.0 \\
        1101 & 0.0 \\
        1110 & 3.4184043884735275e-64 \\
        1111 & 0.0 \\
        \hline
    \end{tabular}
\end{table}


In [167]:
print_probabilities_table(run_balance_test(b, n), latex=True)

\begin{table}
    \begin{tabular}{|c|c|}
        \hline
        $x_1$ $x_2$ $x_3$ $x_4$ & $P(x)$ \\
        \hline
        0000 & 0.9999999999999989 \\
        0001 & 0.0 \\
        0010 & 0.0 \\
        0011 & 0.0 \\
        0100 & 3.4184043884735275e-64 \\
        0101 & 0.0 \\
        0110 & 0.0 \\
        0111 & 0.0 \\
        1000 & 3.4184043884735275e-64 \\
        1001 & 0.0 \\
        1010 & 0.0 \\
        1011 & 0.0 \\
        1100 & 3.4184043884735275e-64 \\
        1101 & 0.0 \\
        1110 & 0.0 \\
        1111 & 0.0 \\
        \hline
    \end{tabular}
\end{table}


In [168]:
print_probabilities_table(run_balance_test(c, n), latex=True)

\begin{table}
    \begin{tabular}{|c|c|}
        \hline
        $x_1$ $x_2$ $x_3$ $x_4$ & $P(x)$ \\
        \hline
        0000 & 0.14062499999999983 \\
        0001 & 0.015624999999999976 \\
        0010 & 0.01562499999999998 \\
        0011 & 0.015624999999999986 \\
        0100 & 0.015624999999999983 \\
        0101 & 0.015624999999999983 \\
        0110 & 0.015624999999999983 \\
        0111 & 0.14062499999999983 \\
        1000 & 0.015624999999999983 \\
        1001 & 0.015624999999999983 \\
        1010 & 0.015624999999999983 \\
        1011 & 0.14062499999999983 \\
        1100 & 0.015624999999999972 \\
        1101 & 0.14062499999999983 \\
        1110 & 0.14062499999999983 \\
        1111 & 0.14062499999999983 \\
        \hline
    \end{tabular}
\end{table}
