In [1]:
import numpy as np
from gym import spaces
from itertools import permutations
from qiskit import (QuantumCircuit, execute, BasicAer)
from qiskit.circuit.library import (HGate, XGate, CXGate)

from math import ceil
from qiskit import *
from qiskit.circuit.library import GroverOperator
from qiskit.quantum_info import Statevector

In [2]:
class QTicTacToeEnv:
    
    def __init__(self, grid_size):
        
        """
        Inits the QTTT environment
        grid_size: linear size of the board
        """
        # select simulators
        self.simulator = BasicAer.get_backend('qasm_simulator')
        self.statevec_sim = BasicAer.get_backend('statevector_simulator')
        # one qubit for each element of the board
        self.qnum = grid_size ** 2
        # init board circuit
        self.circuit = QuantumCircuit(self.qnum)
        # init moves dictionary
        self.moves = self._init_moves_dict()
        # init action space as a gym space obj, so that agents can interpret it
        self.action_space = spaces.Discrete(len(self.moves))
        # init dictionary of possible final board configs
        self.endings_lookuptable = self._init_outcomes_dict()
        self.status_id = ""
    
    def _init_moves_dict(self):
        
        """
        Generates a dictionary with all possible moves.
        Possible moves are: place H or X on a chosen qubit; apply a CNOT at chosen quibits pair
        :return: a dict with int keys and tuples of (qubits, qiskit gates) as values
        """
        mvs_dict = {}
        mv_indx = 0
        for q in range(self.qnum):
            mvs_dict[mv_indx] = ([q], HGate())
            mv_indx += 1
            mvs_dict[mv_indx] = ([q], XGate())
            mv_indx += 1
        for (c, t) in permutations(list(range(self.qnum)), 2):
            mvs_dict[mv_indx] = ([c, t], CXGate())
            mv_indx += 1
        return mvs_dict
    
    def _win_check(self, board):
        
        """
        Checks for game result
        board: string representing the final state of the board
        return: the winner (player 1 or 2) or draw flag (0)
        """
        # transofrm board string to rows, columns and diagonals
        d = int(np.sqrt(self.qnum))
        rows = [board[i*d:(i+1)*d] for i in range(d)]
        cols = ["".join([rows[i][j] for i in range(d)]) for j in range(d)]
        diags = ["".join([rows[i][i] for i in range(d)]), "".join([rows[i][d-i-1] for i in range(d)])]
        winner = 0
        # winning conditions for players 1 and 2
        cond_1 = bin(0)[2:].zfill(d)
        cond_2 = bin(2**d - 1)[2:].zfill(d)
        # check each line and exit if both player win
        for line in [*rows, *cols, *diags]:
            if line == cond_1:
                if winner == 0 or winner == 1:
                    winner = 1
                elif winner == 2:
                    return 0  # because both players won
            elif line == cond_2:
                if winner == 0 or winner == 2:
                    winner = 2
                elif winner == 1:
                    return 0  # because both players won
        return winner
    
    def _init_outcomes_dict(self):
        """
        Inits a dictionary with all possible solutions 
        return: a dict whose keys are the winning player or a draw flag (0) and whose associated values
        are all the final board configs leading to such outcome
        """
        out_dict = {1: [], 2: [], 0: []}
        # init all possible observed board states
        all_states = [bin(x)[2:].zfill(self.qnum) for x in range(2**self.qnum)]
        for state in all_states:
            print('sadassadsadsadsadasd', state)
            winner = self._win_check(state)
            out_dict[winner].append(int(state, 2))
        
        return out_dict
    
    def move(self, action):
        """
        Take the action by appending its gate to the board circ.
        param action: int, key of the moves dict
        """
        self.status_id += "{}-".format(action)
        self.circuit.append(self.moves[action][1], self.moves[action][0])
    
    def _get_statevec(self):
        """
        quantumly observe the board by returning the "percept" as the statevector of the board circuit
        return: rounded state vec of the board
        """
        job = execute(self.circuit, self.statevec_sim)
        result = job.result()
        output_state = result.get_statevector()
        return np.around(output_state, decimals=2)
    
    def collapse_board(self):
        """
        Final move, measure the board and observe final state
        return: final state of the board
        """

        self.circuit.measure_all()
        print(self.circuit)
        
        job = execute(self.circuit, backend=self.simulator, shots=1)
        res = job.result()
        counts = res.get_counts()
        collapsed_state = int(list(counts.keys())[0][:self.qnum], 2)
        return collapsed_state
    
    def check_end(self, board_state):
        """
        Check for ending
        :param board_state: classical board state after collapse
        :return: winning player (1 or 2) or draw flag (0)
        """
        if board_state in self.endings_lookuptable[1]:
            print('ahmad', self.endings_lookuptable[1])
            print("\nPlayer 1 wins!!!\n")
            return 1
        elif board_state in self.endings_lookuptable[2]:
            print('ahmad', self.endings_lookuptable[2])
            print("\nPlayer 2 wins!!!\n")
            return 2
        else:
            print("\nIt's a draw!\n")
            return 0
        
    def step(self, action):
        """
        Perform the chosen action on the board
        param action: int representing the chosen action
        return: new_state of the board, reward (static), done=False
        """
        self.move(action)
        new_state = self._get_statevec()
        reward = -0.1
        return new_state, reward, False
    
    def reset(self):
        """
        Resets the board
        return:
        """
        self.circuit = QuantumCircuit(self.qnum, self.qnum)
        self.circuit.h(list(range(self.qnum)))
        self.status_id = ""
        return self._get_statevec()
    
    def render(self):
        # TODO: devise a render function
        return 0

In [3]:
class GroverQuantumBoardLearner:
    """
    Chosen environment must be discrete!
    """
    def __init__(self, env):
        self.env = env
        # we do not know in advance how many possible states in the env,
        # they will be added during training
        self.obs_dim = 1
        # number of possible actions extracted from env
        self.acts_dim = self.env.action_space.n
        # evaluate number of needed qubits to encode actions
        self.acts_reg_dim = ceil(np.log2(self.acts_dim))
        # evaluate maximum number of grover steps
        self.max_grover_steps = int(round(np.pi/(4*np.arcsin(1./np.sqrt(2**self.acts_reg_dim))) - 0.5))
        # state variable
        self.state = self.env.reset()
        # action variable
        self.action = 0
        # init dictionary of quality values, str(state) is used for better comparison
        self.state_vals = {str(self.state): 0.}
        # init dictionary of grover steps for each state-action pair
        self.grover_steps = {str(self.state): np.zeros(self.acts_dim, dtype=int)}
        # init dictionary of flags to stop grover amplification
        self.grover_steps_flag = {str(self.state): np.zeros(self.acts_dim, dtype=bool)}
        # learner hyperparameters
        self.hyperparams = {'k': -1, 'alpha': 0.05, 'gamma': 0.99}
        # grover oracles
        self.grover_ops = self._init_grover_ops()
        # state-action circuits
        self.acts_circs = self._init_acts_circs()
        self.SIM = BasicAer.get_backend('qasm_simulator')
        
    def set_hyperparams(self, hyperdict):
        # Set new values for learner's hyperparameters
        self.hyperparams = hyperdict
        
    def _new_state_check(self, newstate):
        # Checks if newstate was already observed
        if str(newstate) in self.state_vals.keys():
            return
        else:
            self.state_vals[str(newstate)] = 0.
            self.grover_steps[str(newstate)] = np.zeros(self.acts_dim, dtype=int)
            self.grover_steps_flag[str(newstate)] = np.zeros(self.acts_dim, dtype=bool)
            self._append_new_circ(newstate)
            
    def _init_acts_circs(self):
        """
        Creates the state-action circuits and inits them in full superposition
        return: a dict of circuits, keys are strings of state vectors
        """
        circs = {str(self.state): QuantumCircuit(self.acts_reg_dim)}
        for _, c in circs.items():
            c.h(list(range(self.acts_reg_dim)))

        return circs
    
    def _append_new_circ(self, state):
        #Inits a new state-action circuit
        self.acts_circs[str(state)] = QuantumCircuit(self.acts_reg_dim)
        self.acts_circs[str(state)].h(list(range(self.acts_reg_dim)))
    
    def _update_statevals(self, reward, new_state):
        """
        Bellman equation to update state values
        param reward: the reward received by the agent
        param new_state: the new state visited by the agent
        """
        self.state_vals[str(self.state)] += self.hyperparams['alpha']\
        * (reward + self.hyperparams['gamma']*self.state_vals[str(new_state)]
           - self.state_vals[str(self.state)])
    
    def _eval_grover_steps(self, reward, new_state):
        """
        Choose how many grover steps to take based on the reward and the value of the new state
        param reward: the reward received by the agent
        param new_state: the new state visited by the agent
        return: number of grover steps to be taken,
        if it exceeds the theoretical optimal number the latter is returned instead
        """
        steps_num = int(self.hyperparams['k']*(reward + self.state_vals[str(new_state)]))
        return min(steps_num, self.max_grover_steps)
    
    def _init_grover_ops(self):
        """
        Inits grover oracles for the actions set
        return: a list of qiskit instructions ready to be appended to circuit
        """
        states_binars = [format(i, '0{}b'.format(self.acts_reg_dim)) for i in range(self.acts_dim)]
        targ_states = [Statevector.from_label(s) for s in states_binars]
        grops = [GroverOperator(oracle=ts) for ts in targ_states]
        return [g.to_instruction() for g in grops]
    
    def _run_grover(self):
        # Deploy grover ops on acts_circs
        gsteps = self.grover_steps[str(self.state)][self.action]
        circ = self.acts_circs[str(self.state)]
        op = self.grover_ops[self.action]
        for _ in range(gsteps):
            circ.append(op, list(range(self.acts_reg_dim)))
        self.acts_circs[str(self.state)] = circ
        
    def _run_grover_bool(self):
    
        # Update state-action circuits based on evaluated steps
        flag = self.grover_steps_flag[str(self.state)]
        gsteps = self.grover_steps[str(self.state)][self.action]
        circ = self.acts_circs[str(self.state)]
        op = self.grover_ops[self.action]
        if not flag.any():
            for _ in range(gsteps):
                circ.append(op, list(range(self.acts_reg_dim)))
                
        if gsteps >= self.max_grover_steps and not flag.any():
            self.grover_steps_flag[str(self.state)][self.action] = True
        self.acts_circs[str(self.state)] = circ

    def _take_action(self):
        """
        Measures state-action circuit and chooses which action to take
        return: int, chosen action
        """
        circ = self.acts_circs[str(self.state)]
        action = self.acts_dim + 1
        while action >= self.acts_dim:
            circ = self.acts_circs[str(self.state)]
            circ_tomeasure = circ.copy()
            circ_tomeasure.measure_all()
            job = execute(circ_tomeasure, backend=self.SIM, shots=1)
            result = job.result()
            counts = result.get_counts()
            action = int((list(counts.keys()))[0], 2)
        return action

In [4]:
def train(env, pl1, pl2, hyperparams):
    traj_dict = {}
    stats = {"Pl1 wins": [], "Pl2 wins": [], "Draws": []}
    # set initial max_steps
    gamelen = hyperparams['game_length']
        
    for epoch in range(hyperparams['max_epochs']):
        if epoch % 10 == 0:
            print("Processing epoch {} ...".format(epoch))
        # reset env
        state = env.reset()
        # init list for traj
        traj = [state]
            
        if hyperparams['graphics']:
            env.render()
            
        for step in range(gamelen):
            print('\rTurn {0}/{1}'.format(step, gamelen))
            # pl1 goes first, then pl2
                
            for player in (pl1, pl2):
                player._new_state_check(state)
                player.state = state
                # Select action
                action = player._take_action()  #self._run_grover_bool()
                player.action = action
                # take action
                new_state, reward, done = env.step(action)
                player._new_state_check(new_state)
                player.state = state
                # print('REWARD: ', reward)
                # update statevals and grover steps
                player._update_statevals(reward, new_state)
                player.grover_steps[str(state)][action] = player._eval_grover_steps(reward, new_state)
                # amplify amplitudes with zio grover
                # player._run_grover()
                player._run_grover_bool()
                # render if curious
                if hyperparams['graphics']:
                    env.render()
                # save transition
                traj.append(new_state)
                state = new_state
            
        # measure and observe outcome
        final = env.collapse_board()
        print("Observed board state: ", final)
        winner = env.check_end(final)
        if winner == 1:
            stats["Pl1 wins"].append(epoch)
            pl1._new_state_check(state)
            pl1._update_statevals(100, state)
            pl2._new_state_check(state)
            pl2._update_statevals(-10, state)
            
        elif winner == 2:
            stats["Pl2 wins"].append(epoch)
            pl2._new_state_check(state)
            pl2._update_statevals(100, state)
            pl1._new_state_check(state)
            pl1._update_statevals(-10, state)
            
        else:
            stats["Draws"].append(epoch)
            pl1._new_state_check(state)
            pl1._update_statevals(-5, state)
            pl2._new_state_check(state)
            pl2._update_statevals(-5, state)
                
        traj_dict['epoch_{}'.format(epoch)] = traj
        
    # return trajectories
    return traj_dict, stats

In [5]:
board_dim = 3

env = QTicTacToeEnv(board_dim)

player_1 = GroverQuantumBoardLearner(env)
player_2 = GroverQuantumBoardLearner(env)

game_hyperparms = {'max_epochs': 1, 'game_length': 4, 'graphics': False}

player_hyperparms = {'k': 0.1, 'alpha': 0.05, 'gamma': 0.99}
player_1.set_hyperparams(player_hyperparms)
player_2.set_hyperparams(player_hyperparms)

game_trajectories, game_stats = train(env, player_1, player_2, game_hyperparms)
print(game_trajectories)

print(game_stats)
print(player_1.state_vals)
#print(player_1.grover_steps)

sadassadsadsadsadasd 000000000
sadassadsadsadsadasd 000000001
sadassadsadsadsadasd 000000010
sadassadsadsadsadasd 000000011
sadassadsadsadsadasd 000000100
sadassadsadsadsadasd 000000101
sadassadsadsadsadasd 000000110
sadassadsadsadsadasd 000000111
sadassadsadsadsadasd 000001000
sadassadsadsadsadasd 000001001
sadassadsadsadsadasd 000001010
sadassadsadsadsadasd 000001011
sadassadsadsadsadasd 000001100
sadassadsadsadsadasd 000001101
sadassadsadsadsadasd 000001110
sadassadsadsadsadasd 000001111
sadassadsadsadsadasd 000010000
sadassadsadsadsadasd 000010001
sadassadsadsadsadasd 000010010
sadassadsadsadsadasd 000010011
sadassadsadsadsadasd 000010100
sadassadsadsadsadasd 000010101
sadassadsadsadsadasd 000010110
sadassadsadsadsadasd 000010111
sadassadsadsadsadasd 000011000
sadassadsadsadsadasd 000011001
sadassadsadsadsadasd 000011010
sadassadsadsadsadasd 000011011
sadassadsadsadsadasd 000011100
sadassadsadsadsadasd 000011101
sadassadsadsadsadasd 000011110
sadassadsadsadsadasd 000011111
sadassad

Processing epoch 0 ...
Turn 0/4
Turn 1/4
Turn 2/4
Turn 3/4
        ┌───┐                               ░ ┌─┐                        
   q_0: ┤ H ├─────────────────■─────────────░─┤M├────────────────────────
        ├───┤                 │             ░ └╥┘┌─┐                     
   q_1: ┤ H ├─────────────────┼─────────────░──╫─┤M├─────────────────────
        ├───┤                 │  ┌───┐      ░  ║ └╥┘┌─┐                  
   q_2: ┤ H ├───────■─────────┼──┤ X ├──────░──╫──╫─┤M├──────────────────
        ├───┤┌───┐  │  ┌───┐┌─┴─┐└─┬─┘      ░  ║  ║ └╥┘┌─┐               
   q_3: ┤ H ├┤ X ├──┼──┤ X ├┤ X ├──┼────■───░──╫──╫──╫─┤M├───────────────
        ├───┤└─┬─┘  │  └─┬─┘├───┤  │    │   ░  ║  ║  ║ └╥┘┌─┐            
   q_4: ┤ H ├──┼────┼────┼──┤ X ├──┼────┼───░──╫──╫──╫──╫─┤M├────────────
        ├───┤  │  ┌─┴─┐  │  └─┬─┘  │    │   ░  ║  ║  ║  ║ └╥┘┌─┐         
   q_5: ┤ H ├──┼──┤ X ├──┼────┼────■────┼───░──╫──╫──╫──╫──╫─┤M├─────────
        ├───┤  │  └───┘  │    │       ┌─┴─┐ ░  ║  ║  