In [None]:
from qiskit import QuantumCircuit,transpile     #to create the circuit
from qiskit_aer import AerSimulator   #to simulate the circuit
import qiskit.quantum_info as qi     #to make the custom unitary
import random as r                   #generates random number for Alice and Bob
import numpy as np                   #for numerical calculation
import sympy as sp      
from sympy.physics.quantum import TensorProduct
from IPython.display import display, Math   #to make certain calculations to display neater

from qiskit_ibm_provider import IBMProvider

optimal_state_2 = [0.6532814824381883-0.2705980500730985j, 0j, 0j, 0.6532814824381883+0.2705980500730985j]
optimal_state_3 = [0.474341649025257 - 0.273861278752583j,
 0,
 0,
 0,
 0,
 0,
 0.474341649025257 + 0.273861278752583j,
 0,
 0,
 0.632455532033676,
 0,
 0,
 0,
 0,
 0,
 0]

optimal_state_4=[0.3801932173932526-0.25403698614397946j,
               0j,
               0j,
               0j,
               0j,
               0j,
               0j,
               0.3801932173932526+0.25403698614397946j,
               0j,
               0j,
               0.5290046385320097+0.10522556500522232j,
               0j,
               0j,
               0.5290046385320097-0.10522556500522232j,
               0j,
               0j]


def getz2(n):
    return sp.exp(2*sp.pi*sp.I/n)

def transform_bits_to_C4(input_dict):
    transformed_dict = {}
    for key, value in input_dict.items():
        bit_pairs = [key[i:i+2] for i in range(0, len(key), 2)]
        C4_representation = ''.join(str(int(pair, 2)) for pair in bit_pairs)
        transformed_dict[C4_representation] = value
    
    return transformed_dict

#conjugate transpose/adjoint of a matrix
def dagger(matrix):
    return matrix.T.conjugate()

#normalizes vectors
def normalize(vector):
    magnitude_squared = 0
    for component in vector:
        magnitude_squared += component * sp.conjugate(component)
    
    magnitude = sp.sqrt(magnitude_squared)
    normalized_vector = vector / magnitude
    return normalized_vector

#Gets the eigenvectors of an observable
def get_eigenvectors(matrix):
    for eigenvalue, multiplicity, vects in matrix.eigenvects():
        display(Math(f"\\text{{Eigenvalue: }} {sp.latex(eigenvalue)}"))
        for vect in vects:
            display(Math(f"\\text{{Eigenvector: }} {sp.latex(normalize(vect))}"))
    

#returns the phase angle of a complex number from 0 to 2pi.
def order(pair):
    theta = sp.arg(pair[0])  # Compute argument (phase angle) of the eigenvalue
    if theta < 0:
        theta += 2*sp.pi  # Ensure theta is positive
    return theta

def get_sorted_normalized_eigenvectors(matrix):
    # Compute eigenvalues and eigenvectors
    eigenvects = matrix.eigenvects()
    
    # Flatten eigenvectors and associate them with their eigenvalues
    eigen_pairs = [(eigenvalue, vect) for eigenvalue, multiplicity, vects in eigenvects for vect in vects]
    
    # Sort eigen_pairs based on the phase angle (theta) of the eigenvalues
    sorted_list = sorted(eigen_pairs, key=order)    
    sorted_eigenvectors = [tup[1] for tup in sorted_list]
    normalized_eigenvectors = []
    
    for vector in sorted_eigenvectors:
        normalized_eigenvectors.append(normalize(vector))
    
    return normalized_eigenvectors

def get_unitary(ordered_list_of_normalized_eigenvectors):
    dim = len(ordered_list_of_normalized_eigenvectors)
    zero_vector = sp.zeros(dim, 1)
    unitary = sp.zeros(dim, dim)
    for i in range (0,dim):
        zero_vector[i, 0] = 1      #corresponding basis vector
        unitary += zero_vector*dagger(ordered_list_of_normalized_eigenvectors[i])
        zero_vector = sp.zeros(dim, 1)
    
    return unitary 

#gets unitary from any given observable
def get_unitary_from_observable(observable):
    return get_unitary(get_sorted_normalized_eigenvectors(observable))


def extend_to_C4(state):
    extension = sp.zeros(7,1)
    return state.col_join(extension)

def extend_unitary(unitary):
    extended_row = unitary.row_join(sp.Matrix([0, 0,0]))
    return extended_row.col_join(sp.Matrix([[0, 0, 0,1]]))


def get_omega(n):
    return sp.exp((2*sp.pi*sp.I)/n)

def check_win(x,y,a,b,n):
    w_n = get_omega(n)
    if x == 0 and y == 0:
        if a == b:
            return 1
    elif x == 0 and y == 1:
        if a * b == 1:
            return 1
    elif x == 1 and y == 0:
        if a == b:
            return 1
    elif x == 1 and y == 1:
        if a * b == w_n:
            return 1
    return

def get_observables(n):
    if n == 2:
        A0 = B0 = sp.Matrix([[0,1], [1,0]])

        A1 = B1 = sp.Matrix([[0, -sp.I], [sp.I, 0]])
    else:
        A0 = [0 for _ in range(n)]
        for i in range(n):
            A0[i] = [0 for _ in range(n)]
            for j in range(n):
                if i == j:
                    A0[i][j] = 1
        
        A0 = sp.Matrix([A0[-1]] + A0[:-1])
        B0 = A0

        A1 = []
        for i in range(n):
            A1.append([0 for k in range(n)])
            for j in range(n):
                if j == i:
                    if j < n - 1:
                        A1[i][j] = getz2(n)
                    else:
                        A1[i][j] = -getz2(n)
        
        A1 = sp.Matrix([A1[-1]] + A1[:-1])
        A1 = sp.Matrix(A1)

        B1 = []
        for i in range(n):
            B1.append([0 for k in range(n)])
            for j in range(n):
                if j == i:
                    if j == 1:
                        B1[i][j] = -getz2(n)
                    else:
                        B1[i][j] = getz2(n)
        
        B1 = sp.Matrix([B1[1]]  + B1[2:] + [B1[0]])

    return (A0,A1,B0,B1)

def convert_modn_to_Zn(player_output, n):
    for i in range(n):
        if player_output == i:
            return sp.exp(sp.I*2*sp.pi/n)**i
    
    return 0


def chsh_mod_n_noiseless(rounds, A0, A1, B0, B1, n):
    opt_states = {
        2: optimal_state_2,
        3: optimal_state_3,
        4: optimal_state_4
    }


    if n == 2 or n == 4:
        A0_U = qi.Operator(get_unitary_from_observable(A0).tolist())
        A1_U = qi.Operator(get_unitary_from_observable(A1).tolist())
        B0_U = qi.Operator(get_unitary_from_observable(B0).tolist())
        B1_U = qi.Operator(get_unitary_from_observable(B1).tolist())
    else:
        A0_U = qi.Operator(extend_unitary(get_unitary_from_observable(A0)).tolist())
        A1_U = qi.Operator(extend_unitary(get_unitary_from_observable(A1)).tolist())
        B0_U = qi.Operator(extend_unitary(get_unitary_from_observable(B0)).tolist())
        B1_U = qi.Operator(extend_unitary(get_unitary_from_observable(B1)).tolist())

    rounds_won=0

    for x in range(1, rounds + 1):
        msize = A0_U.num_qubits*2
        measure_lst = [i for i in range(msize)]
        qc = QuantumCircuit(msize, msize)
        
        optimal_state = opt_states[n]
        qc.initialize(optimal_state, measure_lst)

        qc.barrier()
    
#Referee picks 0 or 1 from uniform dist
        alice_input = r.choice([0,1])            
        bob_input = r.choice([0,1])
        print("Round",x)
        print("Alice Input", alice_input)
        print("Bob Input", bob_input)

        halfway = int(msize / 2)
        alice_half = [i for i in range(halfway)]
        bob_half = [i for i in range(halfway, msize)]

        if alice_input == 0:
            qc.unitary(A0_U,alice_half, label = 'A0')
        elif alice_input == 1:
            qc.unitary(A1_U, alice_half, label = 'A1')
        
        if bob_input == 0:
            qc.unitary(B0_U, bob_half, label= 'B0')
        elif bob_input == 1:
            qc.unitary(B1_U, bob_half, label= 'B1')
 
        qc.measure(measure_lst, measure_lst)

        sim = AerSimulator()
        job = sim.run(qc,shots=1)
        results = job.result()
        data = results.get_counts()
        print("Data from circuit:" , data)
        if n == 2:
            bit_string = list(data)[0]
            a,b = bit_string[0], bit_string[-1]
            print(a,b)
        else:
            C4_data = (transform_bits_to_C4(data))
            print("Transformed Data", C4_data)
            bit_string = list(C4_data)[0]
            a,b = bit_string[0], bit_string[-1]
            print(a,b)
        alice_output, bob_output = convert_modn_to_Zn(int(a), n), convert_modn_to_Zn(int(b), n)
        if check_win(alice_input,bob_input,alice_output,bob_output,n) == 1:
            rounds_won += 1
            print("Win")


    return (rounds_won/rounds)*100

def chsh_mod_n_noisy(rounds, A0, A1, B0, B1, n, noise):
    provider = IBMProvider()
    backend = provider.backends(name=noise)[0]
    opt_states = {
        2: optimal_state_2,
        3: optimal_state_3,
        4: optimal_state_4
    }

    optimal_state = opt_states[n]

    if n == 2 or n == 4:
        A0_U = qi.Operator(get_unitary_from_observable(A0).tolist())
        A1_U = qi.Operator(get_unitary_from_observable(A1).tolist())
        B0_U = qi.Operator(get_unitary_from_observable(B0).tolist())
        B1_U = qi.Operator(get_unitary_from_observable(B1).tolist())
    else:
        A0_U = qi.Operator(extend_unitary(get_unitary_from_observable(A0)).tolist())
        A1_U = qi.Operator(extend_unitary(get_unitary_from_observable(A1)).tolist())
        B0_U = qi.Operator(extend_unitary(get_unitary_from_observable(B0)).tolist())
        B1_U = qi.Operator(extend_unitary(get_unitary_from_observable(B1)).tolist())

    rounds_won=0

    for x in range(1, rounds + 1):
        msize = A0_U.num_qubits*2
        measure_lst = [i for i in range(msize)]
        qc = QuantumCircuit(msize, msize)
        
        qc.initialize(optimal_state, measure_lst)

        qc.barrier()
    
#Referee picks 0 or 1 from uniform dist
        alice_input = r.choice([0,1])            
        bob_input = r.choice([0,1])
        print("Round",x)
        print("Alice Input", alice_input)
        print("Bob Input", bob_input)

        halfway = int(msize / 2)
        alice_half = [i for i in range(halfway)]
        bob_half = [i for i in range(halfway, msize)]

        if alice_input == 0:
            qc.unitary(A0_U,alice_half, label = 'A0')
        elif alice_input == 1:
            qc.unitary(A1_U, alice_half, label = 'A1')
        
        if bob_input == 0:
            qc.unitary(B0_U, bob_half, label= 'B0')
        elif bob_input == 1:
            qc.unitary(B1_U, bob_half, label= 'B1')
 
        qc.measure(measure_lst, measure_lst)

        transpiled_circuit = transpile(qc, backend=backend)
        sim = AerSimulator().from_backend(backend,shots=1)
        results = sim.run(transpiled_circuit).result()
        
        data = results.get_counts()

        print("Data from circuit:" , data)
        if n == 2:
            bit_string = list(data)[0]
            a,b = bit_string[0], bit_string[-1]
            print(a,b)
        else:
            C4_data = (transform_bits_to_C4(data))
            
            #If Alice or Bob end up having a 3 they lose the round 
            if n == 3:
                for key, value in C4_data.items():
                    # If the key contains a '3', continue to the next iteration
                    if '3' in key:
                        should_continue_outer = True
                        break
            
                if should_continue_outer:
                    continue
            print("Transformed Data", C4_data)
            bit_string = list(C4_data)[0]
            a,b = bit_string[0], bit_string[-1]
            print(a,b)
        alice_output, bob_output = convert_modn_to_Zn(int(a), n), convert_modn_to_Zn(int(b), n)

        if check_win(alice_input,bob_input,alice_output,bob_output,n) == 1:
            rounds_won = rounds_won + 1
            print("Win")


    return (rounds_won/rounds)*100

def chsh_mod_n(n: int, rounds: int, noise):
    A0,A1,B0,B1 = get_observables(n)
    if noise == None:
        return chsh_mod_n_noiseless(rounds, A0, A1, B0, B1, n)
    else:
        return chsh_mod_n_noisy(rounds, A0, A1, B0, B1, n, noise)




def main():
    #if adding noise, just put the name of the quantum computer as a string in the 3rd argument.
    print(chsh_mod_n(2,1000, None))


main()