<a href="https://colab.research.google.com/github/BankNatchapol/Comparison-of-Quantum-Gradient/blob/main/concept_implementation/parameter_shift_rule_visualization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pennylane

In [73]:
import pennylane as qml
from pennylane import numpy as np
from matplotlib import pyplot as plt
from matplotlib.widgets import Slider


In [167]:
dev = qml.device("default.qubit", wires=1)
@qml.qnode(dev)
def function(var):
  qml.RX(var, wires=0)
  return qml.expval(qml.PauliZ(0))

In [168]:
#@title Parameter Shift Rule Visualization
import plotly.graph_objects as go
import numpy as np

# Create figure
fig = go.Figure()
x = np.linspace(-np.pi, np.pi, 100)
x_tangent = np.linspace(-1.3, 1.3, 10)
y = [function(v) for v in x]
# Add traces, one for each slider step
fig.add_trace(
      go.Scatter(
          visible=False,
          line=dict(color="red", width=2),
          x=x,
          y=y))
for origin in np.arange(-1.2, 1.2, 0.03):
  points_x = [origin+np.pi/2, origin-np.pi/2]
  tangent_x = [x+origin for x in x_tangent]
  points_y = [function(v) for v in points_x]
  m = (points_y[0]-points_y[1])/(points_x[0]-points_x[1])
  m_string = str(format(m, '.2f'))
  m_string = " "*(5-len(m_string)) + m_string
  m_pi_2 = m*np.pi/2 
  m_pi_2_string = str(format(m_pi_2, '.2f'))
  m_pi_2_string = " "*(5-len(m_pi_2_string)) + m_pi_2_string
  tangent = [m_pi_2*x+function(origin) for x in x_tangent]
  string = str(format(origin, '.2f'))
  string = " "*(5-len(string)) + string
  fig.add_trace(
      go.Scatter(
          visible=False,
          marker=dict(
            color='green',
            size=10,
          ),
          name="origin = " + string,
          x=[origin],
          y=[function(origin)]),
          )
  
  fig.add_trace(
      go.Scatter(
          visible=False,
          line=dict(color="blue", width=2),
          name="m = " + m_string,
          x=points_x,
          y=points_y),
          )
  
  fig.add_trace(
      go.Scatter(
          visible=False,
          mode='lines',
          line=dict(color="green", width=2, dash='dash'),
          name="m*pi/2 = " + m_pi_2_string,
          x=tangent_x,
          y=tangent),
          )

# Make 10th trace visible

fig.data[10].visible = True
fig.data[11].visible = True
fig.data[12].visible = True
fig.data[0].visible = True
# Create and add slider
steps = []
for i in range(1, len(fig.data), 3):
    step = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)},
              {"title": "Slider switched to step: " + str(i)}],  # layout attribute
    )
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    step["args"][0]["visible"][i+1] = True  # Toggle i'th trace to "visible"
    step["args"][0]["visible"][i+2] = True  # Toggle i'th trace to "visible"
    step["args"][0]["visible"][0] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=10,
    currentvalue={"prefix": "Frequency: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    autosize=False,
    width=1000,
    height=500,
    sliders=sliders
)
fig.update_yaxes(range=[-1.5, 1.5])
fig.show()
