# Optimization of Nonlocality against CHSH Inequality

### Goal: Find a set of states and measurements that optimally violate the CHSH inequality.

In [1]:
import pennylane as qml
from pennylane import numpy as np

### Setup the Quantum Circuits

In [2]:
# The CHSH scenario only requires two qubits ("a" and "b") to simulate
dev = qml.device("default.qubit", wires=["a", "b"])

# optimizer
opt = qml.GradientDescentOptimizer()

@qml.qnode(dev)
def real_chsh_circuit(a_state, b_state, a_measurement, b_measurement):  
    """
    Ansatz circuit where states and measurents are free to rotate about the y-axis.
    In this model, all quantum states real.

    Args:
        a_state (float) : local rotation angle on Alice's state
        b_state (float) : local rotation angle on Bob's state
        a_measurement (float) : rotation angle of Alice's measurement
        b_measurement (float) : rotation angle of Bob's measurement
    """
    qml.Hadamard(wires="a")
    qml.CNOT(wires=["a", "b"])
    
    qml.RY(a_state, wires="a")
    qml.RY(b_state, wires="b")
    
    qml.RY(a_measurement, wires="a")
    qml.RY(b_measurement, wires="b")
    return qml.expval(qml.PauliZ("a") @ qml.PauliZ("b"))

@qml.qnode(dev)
def arbitrary_chsh_circuit(state_settings, measurement_a_settings, measurement_b_settings):
    """
    Ansatz circuit where states are prepared with a two-qubit `ArbitraryUnitary` gate and
    measurements are prepared with a one-qubit `ArbitraryUnitary` gate.

    Args:
        state_settings (list[float]) : 15 angle parameters for the arbitrary two-qubit state.
        measurement_a_settings (list[float]) : 3 angle parameters for Alice's local measurement.
        measurement_b_settings (list[float]) : 3 angle parameters for Bob's local measurement.
    """
    qml.templates.subroutines.ArbitraryUnitary(state_settings, ["a","b"])
    
    qml.templates.subroutines.ArbitraryUnitary(measurement_a_settings, ["a"])
    qml.templates.subroutines.ArbitraryUnitary(measurement_b_settings, ["b"])
    return qml.expval(qml.PauliZ("a") @ qml.PauliZ("b"))

### Setup methods to randomize initial conditions

In [3]:
def real_chsh_rand_settings():
    """
    Returns a (3,2) tensor containing random values for the input to `real_chsh_circuit()`.
    """
    return 2*np.pi*np.random.random((3, 2)) - np.pi

def arbitrary_chsh_rand_settings():
    """
    Returns a 27-element list containing random values for the input to `arbitrary_chsh_circuit()`.
    """
    return 2*np.pi*np.random.random(27) - np.pi

### Setup the Cost Function

In [4]:
def real_chsh_cost(real_settings):
    """
    Applies the CHSH inequality as a cost function with respect to the `real_chsh_circuit()`.
    """
    run1 = real_chsh_circuit(real_settings[0,0], real_settings[0,1], real_settings[1, 0], real_settings[2, 0])
    run2 = real_chsh_circuit(real_settings[0,0], real_settings[0,1], real_settings[1, 0], real_settings[2, 1])
    run3 = real_chsh_circuit(real_settings[0,0], real_settings[0,1], real_settings[1, 1], real_settings[2, 0])
    run4 = real_chsh_circuit(real_settings[0,0], real_settings[0,1], real_settings[1, 1], real_settings[2, 1])
    return -(run1 + run2 + run3 - run4)
    
def arbitrary_chsh_cost(arb_settings):
    """
    Applies the CHSH inequality as a cost function with respect to the `arbitrary_chsh_circuit()`.
    """
    run1 = arbitrary_chsh_circuit(arb_settings[0:15], arb_settings[15:18], arb_settings[21:24])
    run2 = arbitrary_chsh_circuit(arb_settings[0:15], arb_settings[15:18], arb_settings[24:27])
    run3 = arbitrary_chsh_circuit(arb_settings[0:15], arb_settings[18:21], arb_settings[21:24])
    run4 = arbitrary_chsh_circuit(arb_settings[0:15], arb_settings[18:21], arb_settings[24:27])
    return -(run1 + run2 + run3 - run4)

### Optimizing the CHSH circuit over the real parameter space

In [5]:
# initial settings
real_settings = real_chsh_rand_settings()

# performing gradient descent
for i in range(500):
    real_settings = opt.step(real_chsh_cost, real_settings)
    
    if i%50 == 0:
        print("iteration : ",i, ", cost : ", real_chsh_cost(real_settings))
        print("settings :\n", real_settings, "\n")

# printing final results    
print("final cost : ", real_chsh_cost(real_settings),",")
print("final settings : ", real_settings)

iteration :  0 , cost :  1.219826597924631
settings :
 [[-1.28273668  2.98479755]
 [-3.13141847 -0.15085304]
 [ 2.46867867  0.98293043]] 

iteration :  50 , cost :  -2.387072015937195
settings :
 [[-0.65006262  2.35212349]
 [-2.64373421 -0.00586324]
 [ 2.39189478  0.42704026]] 

iteration :  100 , cost :  -2.698061121788894
settings :
 [[-0.58859778  2.29065865]
 [-2.35934461 -0.22878799]
 [ 2.29051207  0.46695813]] 

iteration :  150 , cost :  -2.79479555250338
settings :
 [[-0.575122    2.27718287]
 [-2.21259303 -0.3620638 ]
 [ 2.22342431  0.52057012]] 

iteration :  200 , cost :  -2.820187104297066
settings :
 [[-0.5716382   2.27369907]
 [-2.13992466 -0.43124837]
 [ 2.18855186  0.55195877]] 

iteration :  250 , cost :  -2.8264352877793613
settings :
 [[-0.57078089  2.27284176]
 [-2.10441588 -0.46589983]
 [ 2.17142496  0.56822835]] 

iteration :  300 , cost :  -2.8279472342867376
settings :
 [[-0.57057341  2.27263428]
 [-2.08706807 -0.48304017]
 [ 2.16307583  0.57637   ]] 

iteration

### Optimizing the CHSH circuit over general parameter space

In [6]:
# optimizing arbitrary unitary circuit
arbitrary_settings = arbitrary_chsh_rand_settings()

# performing gradient descent
for i in range(500):
    arbitrary_settings = opt.step(arbitrary_chsh_cost, arbitrary_settings)
    
    if i%50 == 0:
        print("iteration : ",i, ", cost : ", arbitrary_chsh_cost(arbitrary_settings))
        print("settings :\n", arbitrary_settings, "\n")

# printing final results    
print("final cost : ", arbitrary_chsh_cost(arbitrary_settings),",")
print("final settings : ", arbitrary_settings)

iteration :  0 , cost :  -0.41830496390513955
settings :
 [ 1.65708772  2.17850395  2.52975817 -2.88317924 -0.05337802  1.75985253
 -1.08340749  2.07161189  2.15473777 -0.28739921 -2.56199967  2.7176077
  2.38016049  2.46136541 -1.84620093 -0.73159535  2.23414258 -2.56023066
  0.56775649 -1.33244215 -1.60740881  0.07449446  2.00865129 -1.29035447
  0.26473917  1.97206009  1.72452452] 

iteration :  50 , cost :  -2.1505564797060868
settings :
 [ 1.88117465  2.21609692  2.46480395 -3.10855822 -0.37120278  2.03094602
 -1.04386472  1.95575106  2.42520974 -0.53032759 -2.82180553  2.5083555
  2.12832195  2.40532368 -1.93281918 -0.77284081  2.33382532 -2.56023066
  0.57350769 -1.27983489 -1.60740881  0.13652049  1.79821062 -1.29035447
  0.20526995  2.16846988  1.72452452] 

iteration :  100 , cost :  -2.472228809004104
settings :
 [ 1.93997326  2.23648851  2.45307532 -3.1703855  -0.39407608  2.05528016
 -1.02581995  1.99259807  2.51094597 -0.57585452 -2.86544942  2.53289284
  2.10646638  2.45