# Training Cat to catch moving mouse
This project "**Training a Cat to Learn Catch a Mouse**" for **Qiskit Fall Fest MUNICH 2021**.

In [None]:
from copy import deepcopy
import numpy as np
import random
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister
from qiskit import Aer, transpile, assemble
from qiskit.providers import backend
from qiskit.aqua.components.optimizers import COBYLA
import matplotlib.pyplot as plt
import itertools

In [29]:
# TYPES:
CAT = "c"
MOUSE = "m"
EMPTY = "emp"

# ACTIONS:
UP = "000"
DOWN = "001"
LEFT = "010"
RIGHT = "011"

UPPERLEFT = "100"
UPPERRIGHT = "101"
LOWERLEFT = "110"
LOWERRIGHT = "111"

ACTIONS = [UP, DOWN, LEFT, RIGHT, UPPERLEFT, UPPERRIGHT, LOWERLEFT, LOWERRIGHT]

# random seed
seed = 10
random.seed(seed)
np.random.seed(seed)

In [30]:
# state of cat
class State:
    def __init__(self, catP, mouseP):
        self.catP = catP
        self.mouseP = mouseP

    def __eq__(self, other):
        return isinstance(other, State) and self.mouseP==other.mouseP and self.catP == other.catP

    def __hash__(self):
        return hash(str(self.catP))

    def __str__(self):
        return f"State(cat_pos={self.catP})"

In [31]:
# GridWorld
# e.g.
#  MOUSE | EMPTY | EMPTY
#  EMPTY | EMPTY | EMPTY
#  EMPTY | EMPTY | CAT
class GridWorld:
    def __init__(self, s, catP, mouseP):
        self.numRows = s[0]
        self.numColumns = s[1]
        self.catP = catP
        self.mouseP = mouseP
        # self.dogP = dogP
        assert(not self.compaireList(self.catP, self.mouseP))
    
    def getItem(self, p):
        if p[0]>=self.numRows or p[0]<0:
            return None
        if p[1]>=self.numColumns or p[1]<0:
            return None
        if self.compaireList(p, catP):
            return CAT
        elif self.compaireList(p, mouseP):
            return MOUSE
        else:
            return EMPTY

    def compaireList(self, l1,l2):
        for i, j in zip(l1, l2):
            if i!=j:
                return False
        return True

    def getNumRows(self):
        return self.numRows

    def getNumColumns(self):
        return self.numColumns

    def getMouseP(self):
        return self.mouseP
    
    def getCatP(self):
        return self.catP

    def setCatP(self, p):
        self.catP = p
        
    def setMouseP(self, p):
        self.mouseP = p
    
    def initCatState(self, rd = False):
        # init cat position
        if not rd:
            catP = [self.getNumRows() - 1, self.getNumColumns() - 1]
        else:
            catP = [random.randint(0, self.getNumRows()), random.randint(0, self.getNumColumns())]
            while self.getItem(catP) != EMPTY and self.getItem(catP) != CAT:
                catP = [random.randint(0, self.getNumRows()), random.randint(0, self.getNumColumns())]
        self.setCatP(catP)
        return State(catP, self.getMouseP())
    
    def show(self):
        output = ""
        for i in range(self.numRows):
            for j in range(self.numColumns):
                if self.compaireList([i,j], self.catP):
                    output += CAT + " "
                if self.compaireList([i,j], self.mouseP):
                    output += MOUSE + " "
                if not self.compaireList([i,j], self.catP) and not self.compaireList([i,j], self.mouseP):
                    output += EMPTY + " "
            output += "\n"
        print(output)

In [32]:
# QNet
class QNet:
    
    def __init__(self, qTable, gridWorld:GridWorld, alpha=0.1, gamma=1.0, eps=0.2, actions=[UP, DOWN, LEFT, RIGHT], numParams=9):
        self.gw = gridWorld
        self.qt = qTable
        self.eps = eps
        self.backend = Aer.get_backend("qasm_simulator")
        self.NUM_SHOTS = 1000 # number of measurements 
        self.optimizer = COBYLA(maxiter=500, tol=0.0001) # off the shelf
        self.gamma = gamma
        self.alpha = alpha
        self.ACTIONS = actions

        # self.rets = {(0,0):([0,..,0],0.0,0), ...}
        self.rets = dict() # resulting parameters after optimization for all points in the grid

        self.state = None
        
        for i in range(self.gw.getNumRows()):
            for j in range(self.gw.getNumColumns()):
                for m in range(self.gw.getNumRows()):
                    for n in range(self.gw.getNumColumns()):
                        self.rets[i, j, m, n] = (np.random.rand(numParams), 0.0, 0) 
    
    def qcMaker(self, params):
        qr = QuantumRegister(3, name="q")
        cr = ClassicalRegister(3, name="c")
        qc = QuantumCircuit(qr, cr)
        start = 0
        step = 3
        assert(not len(params)%3)
        # add U3 for all qubits once
        for i in range(int(len(params)/3)): 
            qc.u3(params[start], params[start + 1], params[start + 2], qr[i])
            start += step
        # qc.cx(qr[0], qr[1])
        qc.measure(qr, cr)
        return qc

    def newPosition(self, state, action):
        p = deepcopy(state.catP)
        if action == UP:
            p[0] = max(0, p[0] - 1)
        elif action == DOWN:
            p[0] = min(self.gw.getNumRows() - 1, p[0]+1)
        elif action == LEFT:
            p[1] = max(0, p[1] - 1)
        elif action == RIGHT:
            p[1] = min(self.gw.getNumColumns() - 1, p[1] + 1)
        elif action == UPPERLEFT:
            p[0] = max(0, p[0] - 1)
            p[1] = max(0, p[1] - 1)
        elif action == UPPERRIGHT:
            p[0] = max(0, p[0] - 1)
            p[1] = min(self.gw.getNumColumns() - 1, p[1] + 1)
        elif action == LOWERLEFT:
            p[0] = min(self.gw.getNumRows() - 1, p[0] + 1)
            p[1] = max(0, p[1] - 1)
        elif action == LOWERRIGHT:
            p[0] = min(self.gw.getNumRows() - 1, p[0] + 1)
            p[1] = min(self.gw.getNumColumns() - 1, p[1] + 1)
        else:
            raise ValueError(f"Unkown action {action}")
        return p
        
    def getReward(self, p):
        grid = self.gw.getItem(p)
        if grid == EMPTY:
            reward = -1
        elif grid == MOUSE:
            reward = 1000
        elif grid == CAT:
            reward = -1
        else:
            raise ValueError(f"Unknown grid item {grid}")
        return reward
    
    def selectAction(self, state, training):
        if random.uniform(0, 1) < self.eps:
            return random.choice(self.ACTIONS)
        else:
            if training:
                self.state = state
                self.updateCircuit(state)
            return self.ACTIONS[np.argmax(self.qt[state.catP[0], state.catP[1], state.mouseP[0], state.mouseP[1]])]
        
    def lossFunction(self, params):
        action = ""
        qc = self.qcMaker(params=params)
        t_qc = transpile(qc, self.backend)
        job = assemble(t_qc, shots=self.NUM_SHOTS)
        rlt = self.backend.run(job).result()
        counts = rlt.get_counts(qc)
        # speedup training, cross the ravine
        if random.uniform(0, 1) < self.eps:
            action = random.choice(self.ACTIONS)
        else:
            action = max(counts, key = counts.get)
        nextPosition = self.newPosition(self.state, action) # handle the 
        reward = self.getReward(nextPosition)
        # update q-table(but not very sure, update only for this action or for all actions)
        targetQvalue = reward + self.gamma *  np.max(self.qt[nextPosition[0],nextPosition[1], self.state.mouseP[0], self.state.mouseP[1]])
        predictedQvalue = self.calculateQvalue(action, nextPosition, reward, targetQvalue, self.state)
        
        # update q-table
        self.updateQtable(predictedQvalue, action)
        return targetQvalue - self.qt[self.state.catP[0], self.state.catP[1], self.state.mouseP[0], self.state.mouseP[1]][int(action,2)]
    
    def updateQtable(self, predictedQvalue, action):
        if self.qt[self.state.catP[0], self.state.catP[1], self.state.mouseP[0], self.state.mouseP[1]][int(action,2)] < predictedQvalue:
            self.qt[self.state.catP[0], self.state.catP[1], self.state.mouseP[0], self.state.mouseP[1]][int(action,2)] = predictedQvalue

    def calculateQvalue(self, action, nextPosition, reward, targetQvalue, state:State):
        return self.qt[state.catP[0], state.catP[1], state.mouseP[0], state.mouseP[1]][int(action,2)] + self.alpha * (targetQvalue - self.qt[state.catP[0],state.catP[1],state.mouseP[0], state.mouseP[1]][int(action,2)]) # update q-table

    def updateCircuit(self, state:State):
        self.rets[state.catP[0], state.catP[1], state.mouseP[0], state.mouseP[1]] = self.optimizer.optimize(num_vars=9, objective_function=self.lossFunction, initial_point=self.rets[state.catP[0], state.catP[1], state.mouseP[0], state.mouseP[1]][0])

    def setAlpha(self, alpha):
        self.alpha = alpha

In [33]:
# agent: cat
class Cat:
    def __init__(self, qNet: QNet, training=True, eps = 0.2, actions = [UP, DOWN, LEFT, RIGHT]):
        self.eps = eps
        self.training = training
        self.qNet = qNet
        self.ACTIONS = actions
        self.state = None

    def newPosition(self, state, action):
            p = deepcopy(state.catP)
            if action == UP:
                p[0] = max(0, p[0] - 1)
            elif action == DOWN:
                p[0] = min(self.qNet.gw.getNumRows() - 1, p[0] + 1)
            elif action == LEFT:
                p[1] = max(0, p[1] - 1)
            elif action == RIGHT:
                p[1] = min(self.qNet.gw.getNumColumns() - 1, p[1] + 1)
            # in case of tough boundary then slipe
            elif action == UPPERLEFT:
                p[0] = max(0, p[0] - 1)
                p[1] = max(0, p[1] - 1)
            elif action == UPPERRIGHT:
                p[0] = max(0, p[0] - 1)
                p[1] = min(self.qNet.gw.getNumColumns() - 1, p[1] + 1)
            elif action == LOWERLEFT:
                p[0] = min(self.qNet.gw.getNumRows() - 1, p[0] + 1)
                p[1] = max(0, p[1] - 1)
            elif action == LOWERRIGHT:
                p[0] = min(self.qNet.gw.getNumRows() - 1, p[0] + 1)
                p[1] = min(self.qNet.gw.getNumColumns() - 1, p[1] + 1)
            else:
                raise ValueError(f"Unkown action {self.ACTIONS[action]}")
            return p

    def getReward(self, p):
        grid = self.qNet.gw.getItem(p)
        if grid == MOUSE:
            reward = 1000
            end = True
        elif grid == EMPTY:
            reward = -1
            end = False
        elif grid == CAT:
            reward = -1
            end = False
        else:
            raise ValueError(f"Unknown grid item {grid}")
        return reward, end

    def act(self, state, action):
        p = self.newPosition(state, action)
        reward, end = self.getReward(p)
        return p, reward, end
    
    def updateQtable(self, action, p, reward, state):
        pqv = self.qNet.calculateQvalue(action, p, reward, state)
        self.qNet.updateQtable(pqv, action)

    def setTraining(self, training):
        self.Training = training

In [34]:
# The pet school
class PetSchool:
    def __init__(self, cat:Cat, numEpisodes, maxEpisodeSteps, training=True, minAlpha = 0.02, eps = 0.2):
        self.cat = cat
        self.training = training
        self.NUM_EPISODES = numEpisodes
        self.MAX_EPISODE_STEPS = maxEpisodeSteps
        self.alphas = np.linspace(1.0, minAlpha, self.NUM_EPISODES)
        self.eps = eps

    def train(self):
        counter = 0
        rd = True
        for e in range(self.NUM_EPISODES): #  episode: a rund for agent
            print("episode: ", e)
            if e > int(self.NUM_EPISODES/2):
                rd = False
            state = self.cat.qNet.gw.initCatState(rd=rd)
            self.cat.qNet.setAlpha(self.alphas[e])
            total_reward  = 0
            step = 0
            end = False
            for _ in range(self.MAX_EPISODE_STEPS): # step: a time step for agent
                action = self.cat.qNet.selectAction(deepcopy(state), self.training)
                p, reward, end = self.cat.act(state, action)
                self.catMoveTo(p)
                if not end: # stupid mouse may move to cat
                    end = self.mouseMove(self.cat.qNet.gw.getMouseP())
                    total_reward += 1000 # not update qtable
                total_reward += reward
                step += 1
                counter += 1
                if end:
                    print("catch the mouse!!!")
                    print("total reward: ", total_reward, "steps: ", step)
                    break
        print("counter: ", counter)
    
    def catMoveTo(self, p):
        self.cat.qNet.gw.setCatP(p)

    def show(self):
        self.cat.qNet.gw.show()
        print("qTable: ", self.cat.qNet.qt)
        print("\nparams: ", self.cat.qNet.rets)

    def initqTable(self, actions, size):
        d = {}
        for i in range(size[0]):
            for j in range(size[1]):
                d[i,j] = np.zeros(len(actions))
        return d
    
    # @Daniel-Molpe
    def mouseMove(self, oldPos, p=0.5): # goal (mouse) moves randomly with prob p every time the cat moves
        side = min(self.cat.qNet.gw.getNumColumns(), self.cat.qNet.gw.getNumRows()) # Number of cells per side of the grid
        end = False
        if np.random.random() < p:
            n = np.random.random()
            if n < 0.25:
                newPos = (max(0, oldPos[0]-1), oldPos[1])
            elif n < 0.5:
                newPos = (min(side - 1, oldPos[0]+1),oldPos[1])
            elif n < 0.75:
                newPos = (oldPos[0], max(0, oldPos[1]-1))
            else:
                newPos = (oldPos[0], min(side - 1, oldPos[1]+1))
        else:
            newPos = oldPos
        self.cat.qNet.gw.setMouseP(newPos)
        if self.cat.qNet.gw.getCatP == newPos: # mouse ends the training
            end = True
        return end

In [35]:
# super parameter
gridSize = [3, 3]
catP = [gridSize[0]-1, gridSize[0]-1]
mouseP = [0, 0]
EPS = 20
MAX_EPS_STEP = 30
sizeOfParams = 9
gamma = 0.98

In [36]:
def initqTable(size, actions=[UP, DOWN, LEFT, RIGHT]):
    d = {}
    for i in range(size[0]):
        for j in range(size[1]):
            for m in range(size[0]):
                for n in range(size[1]):
                    d[i, j, m, n] = np.zeros(len(actions))
    return d

# initGridWorld
gridWorld = GridWorld(gridSize, catP=catP, mouseP=mouseP)
# init q Table
qt = initqTable(gridSize, actions=ACTIONS)
# init q Circuit
qNet = QNet(qt, gridWorld, gamma=gamma, actions=ACTIONS)
# init cat
cat = Cat(qNet=qNet)
# init pet school
petSchool = PetSchool(cat, EPS, MAX_EPS_STEP)
# start training
petSchool.train()
# show what have been learned
petSchool.show()

episode:  0
catch the mouse!!!
total reward:  1000 steps:  1
episode:  1
episode:  2
episode:  3
episode:  4
episode:  5
episode:  6
episode:  7
episode:  8
episode:  9
catch the mouse!!!
total reward:  1000 steps:  1
episode:  10
episode:  11
episode:  12
episode:  13
episode:  14
episode:  15
episode:  16
episode:  17
episode:  18
episode:  19
counter:  542
emp m emp 
emp emp c 
emp emp emp 

qTable:  {(0, 0, 0, 0): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 0, 0, 1): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 0, 0, 2): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 0, 1, 0): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 0, 1, 1): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 0, 1, 2): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 0, 2, 0): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 0, 2, 1): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 0, 2, 2): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 1, 0, 0): array([0., 0., 0., 0., 0., 0., 0., 0.]), (0, 1, 0, 1): array([0., 0., 0., 0., 0., 0