The following work is inspired by researched done in https://arxiv.org/pdf/0810.3828.pdf by Daoyi Dong, Chunlin Chen, Hanxiong Li, Tzyh-Jong Tarn

Additional inspiration taken from Reinforcement Learning by R. S. Sutton and Andrew G Barto

TODO: Fix issue with the grover iteration step. Solution is not converging to optimal policy. 




In [62]:
# First let's import the necessary libraries
from qiskit import *
from qiskit.tools.monitor import job_monitor
import numpy as np

In [55]:
# Now let's model our state and action space
class GridWorld:
        def __init__(self):
            self.width = 3
            self.height = 3
            self.start_state = (0,0)
            self.goal_state = (0,2)
            self.obstacles = [
                (0, 1), (1, 1)
            ]
            self.actions = ["up", "down", "left", "right"]
            self.actions_binary = {"00":"up", "01":"down", "10":"left", "11":"right"}
            self.num_actions = len(self.actions)
            self.state_space = [(i, j) for i in range(self.width) for j in range(self.height)]
            self.num_states = len(self.state_space)
            self.reward = -1
            self.terminal_reward = 10
            self.current_state = self.start_state
        
        def reset(self):
            self.current_state = self.start_state
            return self.current_state
        def step(self, action):
            # let's use the matrix notation here: i for rows, j for columns
            i, j = self.current_state
            reward = -1
            if action == "up":
                next_state = (max(i-1, 0), j)
            elif action == "down":
                next_state = (min(i+1, self.height - 1), j)
            elif action == "left":
                next_state = (i, max(j-1, 0))
            elif action == "right":
                next_state = (i, min(j+1, self.width-1))
            else:
                raise Exception("Uknown action entered: " + action)
                
            if next_state in self.obstacles:
                next_state = self.current_state
            elif next_state == self.goal_state:
                reward = self.terminal_reward
            else:
                reward = self.reward
                
            self.current_state = next_state
            
            return next_state, reward                                                                        

In [63]:
# Let's now model our Quantum TD algorithm that uses Grover's algorithm to update our P
class Quantum_TD:
    def __init__(self, environment, episodes=10, alpha = 0.5, gamma=0.9, epsilon=0.1, td_error=0.01):
        # Parameter L refers to the number of grover iterations
        self.env = environment
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.episodes = episodes
        self.L = 2
        self.state_values = np.zeros(self.env.num_states)
        self.P = np.ones((self.env.num_states, self.env.num_actions))/self.env.num_actions
        self.steps_per_episode = np.zeros(episodes)
        self.td_error = td_error
    
    def state_td_update(self, state, reward, next_state):
        V = self.state_values[self.env.state_space.index(state)]
        V_prime = self.state_values[self.env.state_space.index(next_state)]
        
        return V + self.alpha * (reward + self.gamma * V_prime - V)
    def groverIteration00(self, circuit, state_index):
        qr = circuit.qubits
        q1 = qr[2*state_index]
        q2 = qr[2*state_index + 1]
        circuit.s(q1)
        circuit.s(q2)
        circuit.h(q2)
        circuit.cx(q1, q2)
        circuit.h(q1)
        circuit.s(q1)
        circuit.s(q2)
        circuit.h(q1)
        circuit.h(q2)
        circuit.x(q1)
        circuit.x(q2)
        circuit.h(q2)
        circuit.cx(q1, q2)
        circuit.h(q2)
        circuit.x(q1)
        circuit.x(q2)
        circuit.h(q1)
        circuit.h(q2)
        return circuit
    def groverIteration01(self, circuit, state_index):
        qr = circuit.qubits
        q1 = qr[2*state_index]
        q2 = qr[2*state_index + 1]
        circuit.s(q1)
        circuit.h(q2)
        circuit.cx(q1, q2)
        circuit.h(q1)
        circuit.s(q1)
        circuit.s(q2)
        circuit.h(q1)
        circuit.h(q2)
        circuit.x(q1)
        circuit.x(q2)
        circuit.h(q2)
        circuit.cx(q1, q2)
        circuit.h(q2)
        circuit.x(q1)
        circuit.x(q2)
        circuit.h(q1)
        circuit.h(q2)
        return circuit
    def groverIteration10(self, circuit, state_index):
        qr = circuit.qubits
        q1 = qr[2*state_index]
        q2 = qr[2*state_index + 1]
        circuit.s(q2)
        circuit.h(q2)
        circuit.cx(q1, q2)
        circuit.h(q1)
        circuit.s(q1)
        circuit.s(q2)
        circuit.h(q1)
        circuit.h(q2)
        circuit.x(q1)
        circuit.x(q2)
        circuit.h(q2)
        circuit.cx(q1, q2)
        circuit.h(q2)
        circuit.x(q1)
        circuit.x(q2)
        circuit.h(q1)
        circuit.h(q2)
        return circuit
    def groverIteration11(self, circuit, state_index):
        qr = circuit.qubits
        q1 = qr[2*state_index]
        q2 = qr[2*state_index + 1]
        circuit.h(q2)
        circuit.cx(q1, q2)
        circuit.h(q1)
        circuit.s(q1)
        circuit.s(q2)
        circuit.h(q1)
        circuit.h(q2)
        circuit.x(q1)
        circuit.x(q2)
        circuit.h(q2)
        circuit.cx(q1, q2)
        circuit.h(q2)
        circuit.x(q1)
        circuit.x(q2)
        circuit.h(q1)
        circuit.h(q2)

        return circuit

    def p_update(self, circuit, action, state_index, iterations):
        if action=="up":
            for i in range(iterations):
                circuit = self.groverIteration00(circuit, state_index)
        elif action=="down":
            for i in range(iterations):
                circuit = self.groverIteration01(circuit, state_index)
        elif action=="left":
            for i in range(iterations):
                circuit = self.groverIteration10(circuit, state_index)
        elif action =="right":
            for i in range(iterations):
                circuit = self.groverIteration11(circuit, state_index)
        return circuit
            
    def run_job(self, circuit):
        backend = Aer.get_backend("qasm_simulator")
        job = execute(circuit, backend, shots=1)
        count = job.result().get_counts()
        return count
    def get_action(self, state, circuit):
        state_index = self.env.state_space.index(state)
        q_reg = circuit.qubits
        c_reg = circuit.clbits
        circuit.measure([state_index*2 ,state_index*2 + 1], [0,1])
        count = self.run_job(circuit)
        max_action_binary = max(count, key=count.get)
        action = self.env.actions_binary[max_action_binary]
        return action
    
    def perform_q_rl(self):
        state_values = self.state_values
        episodes = self.episodes
        env = self.env
        actions = self.env.actions
        action_q_registers = QuantumRegister(self.env.num_states * 2)
        # for observed action: either 00, 01, 10, 11
        action_c_registers = ClassicalRegister(2)
        circuit = QuantumCircuit(action_q_registers, action_c_registers)
        for indx,_ in enumerate(self.env.state_space):
            # All action registers are in this state: 1/2(|0>+|1>)(|0>+|1>)
            circuit.h(2*indx)
            circuit.h(2*indx+1)
        for i in range(episodes):
            done = False
            curr_state = env.reset()
            while not done:
                state_index = env.state_space.index(curr_state)
                action = self.get_action(curr_state, circuit)
                next_state, reward = env.step(action)
                state_values[state_index] = self.state_td_update(curr_state, reward, next_state)
                # update probability amplitudes
                iterations = min(int(0.3*(reward + state_values[env.state_space.index(next_state)])),1)
                circuit = self.p_update(circuit, action, state_index, iterations)
                curr_state = next_state
                self.steps_per_episode[i]+=1
                if curr_state == env.goal_state:
                    done = True

        return state_values
                
        
        

In [64]:
environment = GridWorld()
agent = Quantum_TD(environment)
final_values = agent.perform_q_rl()
print(final_values)
print(agent.steps_per_episode)

[-9.06776152  0.          0.         -8.75046992  0.          6.19060373
 -8.013426   -6.95073109 -2.15027035]
[ 55. 102.  50.  41.  36.  75.  63.  95. 132.  61.]
