# Project "Training a cat to catch a Mouse using Q-learning and Qiskit"

Participants:

## 1. Intoduction 
### 1.1 Idea
The idea of this project is to train a novice(single agent) to catch a mouse in a grid environment 3*3. 

We placed the cat in the lower right corner and the mouse in the upper left corner. Our task is to find the best way for the cat to catch a motionless (and unsuspecting danger) mouse. We want to test whether it is possible to solve such a problem using quantum computing.

<table>
<tr>
    <th>MOUSE</th>
    <td>EMPTY</td>
    <td>EMPTY</td> 
</tr>

<tr>
    <td>EMPTY</td> 
    <td>EMPTY</td> 
    <td>EMPTY</td>
</tr>

<tr>
    <td>EMPTY</td> 
    <td>EMPTY</td> 
    <th>CAT</th>   
</tr>

</table>


### 1.2 Q-Learning
Q-Learning is a reinforcement learning algorithm to learn the value of an action in a particular state.
When Q-learning is performed we create what’s called a **q-table** or matrix that follows the shape of *[state, action]* and we initialize our values to zero. We then update and store our q-values after an episode. This q-table becomes a reference table for our agent to select the best action based on the q-value.


An agent interacts with the environment in 1 of 2 ways. The first is to use the q-table as a reference and view all possible actions for a given state. The agent then selects the action based on the max value of those actions. This is known as **exploiting** since we use the information we have available to us to make a decision.


The second way to take action is to act randomly. This is called **exploring**. Instead of selecting actions based on the max future reward we select an action at random. Acting randomly is important because it allows the agent to explore and discover new states that otherwise may not be selected during the exploitation process. You can balance exploration/exploitation using epsilon (ε) and setting the value of how often you want to explore vs exploit. Here’s some rough code that will depend on how the state and action space are setup.

Q-table is updated in every episode of training via **Bellman Equation**:

\begin{align}
\underbrace{\text{New}Q(s,a)}_{\scriptstyle\text{New Q-Value}}=Q(s,a)+\mkern-34mu\underset{\text{New Q-Value}}{\underset{\Bigl|}{\alpha}}\mkern-30mu \underbrace{R(s,a)}_{\scriptstyle\text{Reward}}+ \mkern-30mu\underset{\text{Discount rate}}{\underset{\Biggl |}{\gamma}}\mkern-75mu * \overbrace{\max Q'(s',a')}^{\scriptstyle\substack{\text{Maximum predicted reward, given} \\ \text{new state and all possible actions}}} * \mkern-45mu-Q(s,a)
\end{align}


### 1.3 Variational Quantum Eigensolver (VQE)

VQE is an application of the variational method of quantum mechanics.

\begin{align}
    \lambda_{min} \le \langle H \rangle_{\psi} = \langle \psi | H | \psi \rangle = \sum_{i = 1}^{N} \lambda_i | \langle \psi_i | \psi\rangle |^2
\end{align}

The above equation is known as the **variational method**.

When the Hamiltonian of a system is described by the Hermitian matrix $H$ the ground state energy of that system, $E_{gs}$, is the smallest eigenvalue associated with $H$. By arbitrarily selecting a wave function $|\psi \rangle$ (called an *ansatz*) as an initial guess approximating $|\psi_{min}\rangle$, calculating its expectation value, $\langle H \rangle_{\psi}$, and iteratively updating the wave function, arbitrarily tight bounds on the ground state energy of a Hamiltonian may be obtained. 

A systematic approach to varying the ansatz is required to implement the variational method on a quantum computer. VQE does so through the use of a parameterized circuit with a fixed form. Such a circuit is often called a *variational form*, and its action may be represented by the linear transformation $U(\theta)$. A variational form is applied to a starting state $|\psi\rangle$ (such as the vacuum state $|0\rangle$, or the Hartree Fock state) and generates an output state $U(\theta)|\psi\rangle\equiv |\psi(\theta)\rangle$. Iterative optimization over $|\psi(\theta)\rangle$ aims to yield an expectation value $\langle \psi(\theta)|H|\psi(\theta)\rangle \approx E_{gs} \equiv \lambda_{min}$. Ideally, $|\psi(\theta)\rangle$ will be close to $|\psi_{min}\rangle$ (where 'closeness' is characterized by either state fidelity, or Manhattan distance) although in practice, useful bounds on $E_{gs}$ can be obtained even if this is not the case.


## 2. Solution scheme via Q-learning & Quantum Calculations

The scheme of our solution is as follows: Each point of our medium of size 3*3 corresponds to a quantum circuit with configurable parameters. Each such scheme consists of several U3-gates, the type of which is shown below.

$$
\begin{align}
    U3(\theta, \phi, \lambda) = \begin{pmatrix}\cos(\frac{\theta}{2}) & -e^{i\lambda}\sin(\frac{\theta}{2}) \\ e^{i\phi}\sin(\frac{\theta}{2}) & e^{i\lambda + i\phi}\cos(\frac{\theta}{2}) \end{pmatrix}
\end{align}
$$

One of the options for implementing such a quantum circuit looks like this:

![Image of the Circuit](https://files.fm/thumb_show.php?i=9mynvq8t9)


In [10]:
from copy import deepcopy
import numpy as np
import random
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister
from qiskit import Aer, transpile, assemble
from qiskit.providers import backend
from qiskit.aqua.components.optimizers import COBYLA
import matplotlib.pyplot as plt
import itertools

# Types and classes

In [11]:
# TYPES:
CAT = "c"
# DOG = "d"
MOUSE = "m"
EMPTY = "emp"

# ACTIONS:
UP = "00"
DOWN = "01"
LEFT = "10"
RIGHT = "11"
ACTIONS = [UP, DOWN, LEFT, RIGHT]

# random seed
random.seed(10)
np.random.seed(10)

In [12]:
# state of cat
class State:
    def __init__(self, catP):
        #self.row = catP[0]
        #self.column = catP[1]
        self.catP = catP

    def __eq__(self, other):
        return isinstance(other, State) and self.catP == other.catP

    def __hash__(self):
        return hash(str(self.catP))

    def __str__(self):
        return f"State(cat_pos={self.catP})"

In [13]:
# GridWorld
# e.g.
#  MOUSE | EMPTY | EMPTY
#  EMPTY | EMPTY | EMPTY
#  EMPTY | EMPTY | CAT
class GridWorld:
    def __init__(self, s, catP, mouseP):
        self.numRows = s[0]
        self.numColumns = s[1]
        self.catP = catP
        self.mouseP = mouseP
        # self.dogP = dogP
        assert(not self.compaireList(self.catP, self.mouseP))
    
    def getItem(self, p):
        if p[0]>=self.numRows or p[0]<0:
            return None
        if p[1]>=self.numColumns or p[1]<0:
            return None
        if self.compaireList(p, catP):
            return CAT
        elif self.compaireList(p, mouseP):
            return MOUSE
        # elif self.compaireList(p, DOG):
        #     return DOG
        else:
            return EMPTY

    def compaireList(self, l1,l2):
        for i, j in zip(l1, l2):
            if i!=j:
                return False
        return True

    def getNumRows(self):
        return self.numRows

    def getNumColumns(self):
        return self.numColumns

    def getMouse(self):
        return self.mouse
    
    def getCatP(self):
        return self.catP

    def setCatP(self, p):
        self.catP = p
        
    def setMouseP(self, p):
        self.mouseP = p
    
    def initCatState(self, rd = False):
        # init cat position
        if not rd:
            catP = [self.getNumRows() - 1, self.getNumColumns() - 1]
        else:
            catP = [random.randint(0, self.getNumRows()), random.randint(0, self.getNumColumns())]
            while self.getItem(catP) != EMPTY and self.getItem(catP) != CAT:
                catP = [random.randint(0, self.getNumRows()), random.randint(0, self.getNumColumns())]
        self.setCatP(catP)
        return State(catP)
    
    def show(self):
        output = ""
        for i in range(self.numRows):
            for j in range(self.numColumns):
                if self.compaireList([i,j], self.catP):
                    output += CAT + " "
                if self.compaireList([i,j], self.mouseP):
                    output += MOUSE + " "
                if not self.compaireList([i,j], self.catP) and not self.compaireList([i,j], self.mouseP):
                    output += EMPTY + " "
            output += "\n"
        print(output)

In [14]:
# QNet
class QNet:
    
    def __init__(self, qTable, gridWorld:GridWorld, alpha=0.1, gamma=1.0, eps=0.2, actions=[UP, DOWN, LEFT, RIGHT], numParams=6):
        self.gw = gridWorld
        self.qt = qTable
        self.eps = eps
        self.backend = Aer.get_backend("qasm_simulator")
        self.NUM_SHOTS = 1000 # number of measurements 
        self.optimizer = COBYLA(maxiter=500, tol=0.0001) # off the shelf
        self.gamma = gamma
        self.alpha = alpha
        self.ACTIONS = actions

        # self.rets = {(0,0):([0,..,0],0.0,0), ...}
        self.rets = dict() # resulting parameters after optimization for all points in the grid

        self.state = None
        
        for i in range(self.gw.getNumRows()):
            for j in range(self.gw.getNumColumns()):
                self.rets[i, j] = (np.random.rand(numParams), 0.0, 0) 
    
    def qcMaker(self, params):
        qr = QuantumRegister(2, name="q")
        cr = ClassicalRegister(2, name="c")
        qc = QuantumCircuit(qr, cr)
        qc.u3(params[0], params[1], params[2], qr[0])
        qc.u3(params[3], params[4], params[5], qr[1])
        # qc.cx(qr[0], qr[1])
        qc.measure(qr, cr)
        return qc

    def newPosition(self, state, action):
        p = deepcopy(state.catP)
        if action == UP:
            p[0] = max(0, p[0] - 1)
        elif action == DOWN:
            p[0] = min(self.gw.getNumRows() - 1, p[0]+1)
        elif action == LEFT:
            p[1] = max(0, p[1] - 1)
        elif action == RIGHT:
            p[1] = min(self.gw.getNumColumns() - 1, p[1] + 1)
        else:
            raise ValueError(f"Unkown action {action}")
        return p
        
    def getReward(self, p):
        grid = self.gw.getItem(p)
        if grid == EMPTY:
            reward = -1
        # elif grid == DOG:
        #     reward = -1000
        elif grid == MOUSE:
            reward = 1000
        elif grid == CAT:
            reward = -1 # (maybe less than reward of empty)
        else:
            raise ValueError(f"Unknown grid item {grid}")
        return reward
    
    def selectAction(self, state, training):
        if random.uniform(0, 1) < self.eps:
            return random.choice(self.ACTIONS)
        else:
            if training:
                self.state = deepcopy(state)
                self.updateCircuit(state)
            return self.ACTIONS[np.argmax(self.qt[self.state.catP[0], self.state.catP[1]])]
        
    def lossFunction(self, params):
        action = ""
        qc = self.qcMaker(params=params)
        t_qc = transpile(qc, self.backend)
        job = assemble(t_qc, shots=self.NUM_SHOTS)
        rlt = self.backend.run(job).result()
        counts = rlt.get_counts(qc)
        # speedup training, cross the ravine
        if random.uniform(0, 1) < self.eps:
            action = random.choice(self.ACTIONS)
        else:
            action = max(counts, key = counts.get)
        
        nextPosition = self.newPosition(self.state, action) # handle the 
        reward = self.getReward(nextPosition)
        # update q-table(but not very sure, update only for this action or for all actions)
        targetQvalue = reward + self.gamma *  np.max(self.qt[nextPosition[0],nextPosition[1]])
        predictedQvalue = self.calculateQvalue(action, nextPosition, reward, targetQvalue, self.state)
        
        # update q-table
        self.updateQtable(predictedQvalue, action)
        return targetQvalue - self.qt[self.state.catP[0],self.state.catP[1]][int(action,2)]
    
    def updateQtable(self, predictedQvalue, action):
        if self.qt[(self.state.catP[0],self.state.catP[1])][int(action,2)] < predictedQvalue:
            self.qt[self.state.catP[0],self.state.catP[1]][int(action,2)] = predictedQvalue

    def calculateQvalue(self, action, nextPosition, reward, targetQvalue, state:State):
        targetQvalue = reward + self.gamma *  np.max(self.qt[nextPosition[0],nextPosition[1]])
        return self.qt[state.catP[0], state.catP[1]][int(action,2)] + self.alpha * (targetQvalue - self.qt[state.catP[0],state.catP[1]][int(action,2)]) # update q-table

    def updateCircuit(self, state:State):
        self.rets[state.catP[0], state.catP[1]] = self.optimizer.optimize(num_vars=6, objective_function=self.lossFunction, initial_point=self.rets[state.catP[0], state.catP[1]][0])

    def setAlpha(self, alpha):
        self.alpha = alpha

    def drawVectors(self, hasdiagonals):
        # Draw vectors representing the cat's desired direction for each place in the grid based on the Qtable
        x = np.linspace(0, self.gw.getNumColumns()-1, self.gw.getNumColumns())
        y = np.linspace(0, self.gw.getNumColumns()-1, self.gw.getNumColumns())
        pts = itertools.product(x, y)
        plt.scatter(*zip(*pts), marker='o', s=30, color='red')
        X, Y = np.meshgrid(x, y)
        vecx = self.qt[(X, Y)][3]-self.qt[(X, Y)][2]
        vecy = self.qt[(X, Y)][0]-self.qt[(X, Y)][1]
        norm = vecx**2 + vecy**2
        QP = plt.quiver(X, Y, vecx/norm, vecy/norm)
        plt.grid()
        plt.show()

In [15]:
# agent: cat
class Cat:
    def __init__(self, qNet: QNet, training=True, eps = 0.2, actions = [UP, DOWN, LEFT, RIGHT]):
        self.eps = eps
        self.training = training
        self.qNet = qNet
        self.ACTIONS = actions
        self.state = None

    def newPosition(self, state, action):
            p = deepcopy(state.catP)
            if action == UP:
                p[0] = max(0, p[0] - 1)
            elif action == DOWN:
                p[0] = min(self.qNet.gw.getNumRows() - 1, p[0] + 1)
            elif action == LEFT:
                p[1] = max(0, p[1] - 1)
            elif action == RIGHT:
                p[1] = min(self.qNet.gw.getNumColumns() - 1, p[1] + 1)
            else:
                raise ValueError(f"Unkown action {self.ACTIONS[action]}")
            return p

    def getReward(self, p):
        grid = self.qNet.gw.getItem(p)
        if grid == MOUSE:
            reward = 1000
            end = True
        # elif grid == DOG:
        #     reward = -100
        #     end = True
        #     self.qNet.gw.setCatP(p)
        elif grid == EMPTY:
            reward = -1
            end = False
        elif grid == CAT:
            reward = -1 # (maybe less than reward of empty)
            end = False
        else:
            raise ValueError(f"Unknown grid item {grid}")
        return reward, end

    def act(self, state, action):
        p = self.newPosition(state, action)
        reward, end = self.getReward(p)
        return p, reward, end
    
    def updateQtable(self, action, p, reward, state):
        pqv = self.qNet.calculateQvalue(action, p, reward, state)
        self.qNet.updateQtable(pqv, action)

    def setTraining(self, training):
        self.Training = training

In [16]:
# The pet school
class PetSchool:
    def __init__(self, cat:Cat, numEpisodes, maxEpisodeSteps, training=True, minAlpha = 0.02, eps = 0.2):
        self.cat = cat
        self.training = training
        self.NUM_EPISODES = numEpisodes
        self.MAX_EPISODE_STEPS = maxEpisodeSteps
        self.alphas = np.linspace(1.0, minAlpha, self.NUM_EPISODES)
        self.eps = eps

    def train(self):
        counter = 0
        for e in range(self.NUM_EPISODES): #  episode: a rund for agent
            print("episode: ", e)
            state = self.cat.qNet.gw.initCatState(rd=True) # default is rd = False
            self.cat.qNet.setAlpha(self.alphas[e])
            total_reward  = 0
            step = 0
            end = False
            for _ in range(self.MAX_EPISODE_STEPS): # step: a time step for agent
                action = self.cat.qNet.selectAction(deepcopy(state), self.training)
                p, reward, end = self.cat.act(state, action)
                self.catMoveTo(p)
                # self.cat.updateQtable(action, p, reward, state) # speedup learning
                total_reward += reward
                step += 1
                counter += 1
                if end:
                    print("catch the mouse!!!")
                    print("total reward: ", total_reward, "steps: ", step)
                    break
        print("counter: ", counter)

    def catMoveTo(self, p):
        self.cat.qNet.gw.setCatP(p)

    def show(self):
        self.cat.qNet.gw.show()
        print("qTable: ", self.cat.qNet.qt)
        print("\nparams: ", self.cat.qNet.rets)

    def initqTable(self, actions, size):
        d = {}
        for i in range(size[0]):
            for j in range(size[1]):
                d[i,j] = np.zeros(len(actions))
        return d
        
    def mouseMove(p,oldPos): # goal (mouse) moves randomly with prob p every time the cat moves
        side = 2 # Number of cells per side of the grid
        if np.random.random() < p:
            n = np.random.random()
            if n < 0.25:
                newPos = (max(0, oldPos[0]-1),oldPos[1])
            elif n < 0.5:
                newPos = (min(side - 1, oldPos[0]+1),oldPos[1])
            elif n < 0.75:
                newPos = (oldPos[0],max(0, oldPos[1]-1))
            else:
                newPos = (oldPos[0],min(side - 1, oldPos[1]+1))
        else:
            newPos = oldPos
        return newPos

In [17]:
# super parameter
gridSize = [3, 3]
catP = [gridSize[0]-1, gridSize[0]-1]
mouseP = [0, 0]
EPS = 20
MAX_EPS_STEP = 30
sizeOfParams = 6
gamma = 0.98

In [18]:
def initqTable(size, actions=[UP, DOWN, LEFT, RIGHT]):
    d = {}
    for i in range(size[0]):
        for j in range(size[1]):
            d[i,j] = np.zeros(len(actions))
    return d

# initGridWorld
gridWorld = GridWorld(gridSize, catP=catP, mouseP=mouseP)
# init q Table
qt = initqTable(gridSize)
# init q Circuit
qNet = QNet(qt, gridWorld, initialParameters, gamma=gamma)
# init cat
cat = Cat(qNet=qNet)
# init pet school
petSchool = PetSchool(cat, EPS, MAX_EPS_STEP)
# start training
petSchool.train()
# show what have been learned
petSchool.show()

episode:  0
catch the mouse!!!
total reward:  1000 steps:  1
episode:  1
catch the mouse!!!
total reward:  1000 steps:  1
episode:  2
episode:  3
episode:  4
episode:  5
episode:  6
episode:  7
episode:  8
episode:  9
episode:  10
catch the mouse!!!
total reward:  999 steps:  2
episode:  11
episode:  12
catch the mouse!!!
total reward:  996 steps:  5
episode:  13
episode:  14
episode:  15
episode:  16
episode:  17
episode:  18
episode:  19
counter:  489
m emp emp 
emp c emp 
emp emp emp 

qTable:  {(0, 0): array([0., 0., 0., 0.]), (0, 1): array([975.90194637, 945.64269226, 996.83885331, 915.58309722]), (0, 2): array([918.486568, 938.2516  ,   0.      , 918.486568]), (1, 0): array([1000.        ,    0.        ,  979.        ,  703.44311136]), (1, 1): array([975.90207546, 938.25158487, 979.        , 938.2516    ]), (1, 2): array([918.48656799, 918.486567  , 958.42      , 938.2516    ]), (2, 0): array([979.    , 958.42  , 958.42  , 938.2516]), (2, 1): array([958.42    , 938.2516  ,   0.  