In [1]:
from functools import partial
from collections import defaultdict
import random
import numpy as np
import copy
from tqdm import tqdm

from tic_tac_toe import (TicTacToe,
                         get_actions,
                         check_n_win,
                         board_get_hash, 
                         board_print,
                        )

>   ## Часть первая: крестики-нолики при помощи Q-обучения
> В коде, прилагающемся к последней лекции про обучение с подкреплением, реализован Environment для крестиков-ноликов, в котором можно при инициализации указывать разные размеры доски и условия победы, а также функции для рисования, в том числе с указанием оценки различных действий. С этим окружением все задания и связаны.
>    1. Реализуйте обычное (табличное) Q-обучение. Обучите стратегии крестиков и ноликов для доски 3х3.
>    2. Попробуйте обучить стратегии крестиков и ноликов для доски 4х4 и/или 5х5.

In [2]:
env = TicTacToe(n_rows=3, n_cols=3, n_win=3)
check_win = partial(check_n_win, n_win=env.n_win)

In [3]:
state = env.reset()
board, cur_turn = state
board_print(board)

╭───┬───┬───╮
│   │   │   │ 
├───┼───┼───┤
│   │   │   │ 
├───┼───┼───┤
│   │   │   │ 
╰───┴───┴───╯


In [4]:
def play_random(env):
    state = env.reset()
    done = False
    while not done:
        actions = get_actions(state)
        random_action = actions[np.random.randint(len(actions))]
        state, reward, done = env.step(random_action)
        if reward == 1:
            print("Крестики выиграли!")
        if reward == -1:
            print("Нолики выиграли!")
    board_print(env.board)
play_random(env)

Крестики выиграли!
╭───┬───┬───╮
│ o │   │ x │ 
├───┼───┼───┤
│ o │   │ x │ 
├───┼───┼───┤
│   │   │ x │ 
╰───┴───┴───╯


In [5]:
def apply_action(state, action):
    board, cur_turn = state
    board = board.copy() 
    x, y = action
    board[x, y] = cur_turn
    return board, -cur_turn

def state_action_key(state, action):
    return board_get_hash(apply_action(state, action)[0])

Агента будем обучать играть против самого себя. Ход противника рассчитывает сам агент, выбирая худший для себя вариант. Для ускорения обучения используем проверку выигрыша для возможного хода.

In [6]:
class QLearningAgent:
    def __init__(self, alpha=0.1, gamma=0.99):
        self.alpha = alpha
        self.gamma = gamma
        self.Q = defaultdict(int)

    def get_q(self, state, action):
        key = state_action_key(state, action)
        if key not in self.Q:
            state_o = apply_action(state, action)
            if check_win(state_o[0], state[1]):
                self.Q[key] = 1
        return self.Q[key]
        
    def set_q(self, state, action, q_new):
        key = state_action_key(state, action)
        err = q_new - self.Q[key]
        self.Q[key] += self.alpha * err
        return err
        
    def update(self, state, action, next_state, reward, done):
        if done:
            err = self.set_q(state, action, reward)
        else:
            next_q = max(
                self.get_q(next_state, action)
                for action in get_actions(next_state)
            )
            err = self.set_q(state, action, reward + self.gamma * next_q)
        return err
                  
    def opponent_act(self, state_o):
        actions_o = get_actions(state_o)
        if len(actions_o) == 1:
            return actions_o[0]
        
        states_x = [apply_action(state_o, action_o)
                    for action_o in actions_o
                   ]
        
        for i, state_x in enumerate(states_x):
            board, x_turn = state_x
            if check_win(board, -x_turn):
                return actions_o[i]
        
        max_q_x = [max([self.get_q(state_x, action_x) 
                        for action_x in get_actions(state_x)
                       ])
                   for state_x in states_x
                  ]
        min_idx = np.flatnonzero(max_q_x == np.min(max_q_x))
        return actions_o[np.random.choice(min_idx)]
        #return actions_o[np.argmin(max_q_x)]
         
    def act(self, state):
        actions = get_actions(state)
        q_actions = [self.get_q(state, action) for action in actions]
        max_idx = np.flatnonzero(q_actions == np.max(q_actions))
        return actions[np.random.choice(max_idx)]
        #return actions[np.argmax(q_actions)]

In [7]:
def play(env, agent, opponent, first_action=None):
    state_x = env.reset()
    for step in range(1, 1000):
        if step == 1 and first_action is not None:
            action_x = first_action
        else:
            action_x = agent.act(state_x)
        state_o, reward_x, done = env.step(action_x)
        if done:
            return reward_x, step
            
        action_o = opponent.opponent_act(state_o)
        state_x, reward_o, done = env.step(action_o)  
        if done:
            return reward_o, step

In [8]:
class RandomAgent:
    @staticmethod
    def act(state):
        return random.choice(get_actions(state))
    
    @staticmethod
    def opponent_act(state):
        return random.choice(get_actions(state))

In [9]:
def mean_play(env, agent1, agent2, n=1):
    if n==1:
        return play(env, agent1, agent2)
    else:
        rewards, steps = zip(*[play(env, agent1, agent2) for _ in range(n)])
        return np.mean(rewards), np.mean(steps)

In [10]:
def evaluate(env, agent1, agent2, n=1):
    reward1, step1 = mean_play(env, agent1, agent2, n)
    reward2, step2 = mean_play(env, agent2, agent1, n)
    return reward1, -reward2, (step1 + step2) / 2

In [11]:
agent = QLearningAgent()

In [12]:
def train(env, agent, n_episodes, n_evaluate, eps=0.1):
    prev_agent = copy.deepcopy(agent)
    reward_sum = 0
    err_sum = 0
    for i_episode in tqdm(range(1, n_episodes + 1)):
        state_x = env.reset()
        err_episode = 0
        for step in range(1, 100):

            if step == 1:
                action_x = random.choice(get_actions(state_x))
            else:
                action_x = agent.act(state_x)
                   
            state_o, reward, done = env.step(action_x)
            if done:
                err = agent.update(state_x, action_x, state_o, reward, done)
                err_episode += abs(err)
                break
                
            if eps > random.random():    
                action_o = random.choice(get_actions(state_o))
            else:
                action_o = agent.opponent_act(state_o)
            next_state_x, reward, done = env.step(action_o) 

            err = agent.update(state_x, action_x, next_state_x, reward, done)
            err_episode += abs(err)
            state_x = next_state_x
            if done:
                break
        reward_sum += reward
        err_sum += err_episode / step
                
        if (i_episode) % n_evaluate == 0:
            random.seed(0)
            x_reward_random, o_reward_random, steps_random = evaluate(env, agent, RandomAgent, n=100)
            x_reward_prev, o_reward_prev, steps_prev = evaluate(env, agent, prev_agent, n=1)
            print(f'episode = {i_episode}, q_mean_err = {err_sum / n_evaluate :.2f}, len(Q)={len(agent.Q)}')
            print(f'evaluate with random = {x_reward_random:.2f}, {o_reward_random:.2f}, {steps_random:.2f}')
            print(f'evaluate with prev = {x_reward_prev:.2f}, {o_reward_prev:.2f}, {steps_prev:.2f}')
            prev_agent = copy.deepcopy(agent)
            reward_sum = 0
            err_sum = 0
            play(env, agent, agent)
            board_print(env.board)      
n_evaluate = 1000
train(env, agent, 10 * n_evaluate, n_evaluate)            

 11%|█         | 1058/10000 [00:03<00:36, 247.93it/s]

episode = 1000, q_mean_err = 0.20, len(Q)=2604
evaluate with random = 0.99, 0.77, 3.38
evaluate with prev = 1.00, 1.00, 3.50
╭───┬───┬───╮
│ o │ x │ x │ 
├───┼───┼───┤
│ x │ o │ o │ 
├───┼───┼───┤
│ x │ o │ x │ 
╰───┴───┴───╯


 20%|██        | 2042/10000 [00:06<00:33, 234.48it/s]

episode = 2000, q_mean_err = 0.10, len(Q)=2641
evaluate with random = 1.00, 0.66, 3.47
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ x │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ o │ x │ o │ 
╰───┴───┴───╯


 31%|███       | 3060/10000 [00:09<00:29, 237.00it/s]

episode = 3000, q_mean_err = 0.10, len(Q)=2651
evaluate with random = 1.00, 0.72, 3.41
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ x │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ o │ x │ o │ 
╰───┴───┴───╯


 40%|████      | 4049/10000 [00:12<00:25, 230.42it/s]

episode = 4000, q_mean_err = 0.09, len(Q)=2662
evaluate with random = 0.99, 0.66, 3.42
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ x │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ o │ x │ o │ 
╰───┴───┴───╯


 50%|█████     | 5035/10000 [00:15<00:25, 195.44it/s]

episode = 5000, q_mean_err = 0.10, len(Q)=2680
evaluate with random = 0.99, 0.62, 3.47
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ o │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ x │ x │ o │ 
╰───┴───┴───╯


 60%|██████    | 6038/10000 [00:18<00:17, 231.66it/s]

episode = 6000, q_mean_err = 0.08, len(Q)=2685
evaluate with random = 0.99, 0.71, 3.39
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ o │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ x │ x │ o │ 
╰───┴───┴───╯


 70%|███████   | 7039/10000 [00:21<00:12, 235.53it/s]

episode = 7000, q_mean_err = 0.08, len(Q)=2686
evaluate with random = 1.00, 0.64, 3.42
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ o │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ x │ x │ o │ 
╰───┴───┴───╯


 80%|████████  | 8045/10000 [00:24<00:08, 239.39it/s]

episode = 8000, q_mean_err = 0.09, len(Q)=2686
evaluate with random = 1.00, 0.78, 3.33
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ x │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ o │ x │ o │ 
╰───┴───┴───╯


 91%|█████████ | 9051/10000 [00:27<00:03, 238.48it/s]

episode = 9000, q_mean_err = 0.09, len(Q)=2686
evaluate with random = 1.00, 0.77, 3.36
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ x │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ o │ x │ o │ 
╰───┴───┴───╯


100%|██████████| 10000/10000 [00:30<00:00, 332.23it/s]

episode = 10000, q_mean_err = 0.08, len(Q)=2687
evaluate with random = 0.99, 0.79, 3.37
evaluate with prev = 0.00, 0.00, 5.00
╭───┬───┬───╮
│ x │ o │ o │ 
├───┼───┼───┤
│ o │ x │ x │ 
├───┼───┼───┤
│ x │ x │ o │ 
╰───┴───┴───╯





In [13]:
evaluate(env, RandomAgent, RandomAgent, n=100)

(0.39, -0.29, 4.21)

Для доски 3х3 всего может быть 3^9 = 19689 состояний. Мы обучились на менее 3000 состояний. При этом агент намного обыгрывает случайную стратегию, сам с собой играет вничью.

Попробуем обучить агента для доски 4х4

In [14]:
env = TicTacToe(n_rows=4, n_cols=4, n_win=3)
check_win = partial(check_n_win, n_win=env.n_win)

agent4 = QLearningAgent()
n_evaluate = 10000
train(env, agent4, 10 * n_evaluate, n_evaluate)  

 10%|▉         | 9989/100000 [01:20<10:47, 139.08it/s]

episode = 10000, q_mean_err = 0.17, len(Q)=172081
evaluate with random = 1.00, 0.76, 3.35
evaluate with prev = 1.00, 1.00, 3.50


 10%|█         | 10018/100000 [01:21<38:01, 39.44it/s]

╭───┬───┬───┬───╮
│   │   │   │   │ 
├───┼───┼───┼───┤
│   │   │   │ o │ 
├───┼───┼───┼───┤
│ x │ x │ x │   │ 
├───┼───┼───┼───┤
│   │ o │   │   │ 
╰───┴───┴───┴───╯


 20%|█▉        | 19999/100000 [02:38<10:33, 126.25it/s]

episode = 20000, q_mean_err = 0.05, len(Q)=191469
evaluate with random = 1.00, 0.80, 3.42
evaluate with prev = 1.00, -1.00, 3.50


 20%|██        | 20012/100000 [02:40<53:07, 25.09it/s] 

╭───┬───┬───┬───╮
│   │ o │   │ o │ 
├───┼───┼───┼───┤
│   │ x │   │   │ 
├───┼───┼───┼───┤
│ x │ x │ o │   │ 
├───┼───┼───┼───┤
│   │ x │   │   │ 
╰───┴───┴───┴───╯


 30%|██▉       | 29998/100000 [03:58<09:29, 122.99it/s]

episode = 30000, q_mean_err = 0.02, len(Q)=197060
evaluate with random = 1.00, 0.84, 3.38
evaluate with prev = 1.00, -1.00, 4.00


 30%|███       | 30011/100000 [03:59<44:37, 26.14it/s] 

╭───┬───┬───┬───╮
│   │ x │   │   │ 
├───┼───┼───┼───┤
│   │   │ x │   │ 
├───┼───┼───┼───┤
│   │ x │ o │ x │ 
├───┼───┼───┼───┤
│ o │   │   │ o │ 
╰───┴───┴───┴───╯


 40%|███▉      | 39994/100000 [05:18<07:54, 126.45it/s]

episode = 40000, q_mean_err = 0.00, len(Q)=201448
evaluate with random = 1.00, 0.78, 3.42
evaluate with prev = 1.00, -1.00, 4.00


 40%|████      | 40019/100000 [05:20<29:46, 33.58it/s] 

╭───┬───┬───┬───╮
│   │ x │ x │ o │ 
├───┼───┼───┼───┤
│   │ x │   │ o │ 
├───┼───┼───┼───┤
│ o │ x │   │   │ 
├───┼───┼───┼───┤
│   │   │   │   │ 
╰───┴───┴───┴───╯


 50%|████▉     | 49986/100000 [06:40<06:39, 125.29it/s]

episode = 50000, q_mean_err = 0.00, len(Q)=204991
evaluate with random = 1.00, 0.82, 3.38
evaluate with prev = 1.00, -1.00, 4.00


 50%|█████     | 50014/100000 [06:42<22:39, 36.78it/s] 

╭───┬───┬───┬───╮
│   │ x │   │ o │ 
├───┼───┼───┼───┤
│ x │ x │ o │   │ 
├───┼───┼───┼───┤
│   │ x │   │   │ 
├───┼───┼───┼───┤
│ o │   │   │   │ 
╰───┴───┴───┴───╯


 60%|█████▉    | 59995/100000 [08:02<05:11, 128.41it/s]

episode = 60000, q_mean_err = 0.00, len(Q)=208627
evaluate with random = 1.00, 0.76, 3.46
evaluate with prev = 1.00, -1.00, 4.00


 60%|██████    | 60021/100000 [08:04<19:09, 34.78it/s] 

╭───┬───┬───┬───╮
│ o │   │ x │   │ 
├───┼───┼───┼───┤
│ x │ x │   │ o │ 
├───┼───┼───┼───┤
│ x │   │   │   │ 
├───┼───┼───┼───┤
│ o │   │   │   │ 
╰───┴───┴───┴───╯


 70%|██████▉   | 69999/100000 [09:23<03:38, 137.14it/s]

episode = 70000, q_mean_err = 0.00, len(Q)=211130
evaluate with random = 1.00, 0.82, 3.40
evaluate with prev = 1.00, -1.00, 4.00


 70%|███████   | 70013/100000 [09:24<16:24, 30.45it/s] 

╭───┬───┬───┬───╮
│   │ o │   │   │ 
├───┼───┼───┼───┤
│   │ o │ x │   │ 
├───┼───┼───┼───┤
│   │ x │   │   │ 
├───┼───┼───┼───┤
│ x │ x │ o │   │ 
╰───┴───┴───┴───╯


 80%|███████▉  | 79998/100000 [10:43<02:25, 137.62it/s]

episode = 80000, q_mean_err = 0.00, len(Q)=214745
evaluate with random = 1.00, 0.78, 3.44
evaluate with prev = 1.00, -1.00, 4.00


 80%|████████  | 80024/100000 [10:45<09:02, 36.79it/s] 

╭───┬───┬───┬───╮
│   │   │   │   │ 
├───┼───┼───┼───┤
│   │   │ x │   │ 
├───┼───┼───┼───┤
│   │ x │   │ o │ 
├───┼───┼───┼───┤
│ x │ x │ o │ o │ 
╰───┴───┴───┴───╯


 90%|████████▉ | 89989/100000 [12:04<01:13, 137.07it/s]

episode = 90000, q_mean_err = 0.00, len(Q)=217947
evaluate with random = 1.00, 0.70, 3.41
evaluate with prev = 1.00, -1.00, 4.00


 90%|█████████ | 90016/100000 [12:05<04:22, 37.99it/s] 

╭───┬───┬───┬───╮
│ o │   │   │   │ 
├───┼───┼───┼───┤
│   │ x │   │ o │ 
├───┼───┼───┼───┤
│   │ x │   │   │ 
├───┼───┼───┼───┤
│   │ x │ o │ x │ 
╰───┴───┴───┴───╯


100%|█████████▉| 99997/100000 [13:24<00:00, 136.87it/s]

episode = 100000, q_mean_err = 0.00, len(Q)=220828
evaluate with random = 1.00, 0.86, 3.46
evaluate with prev = 1.00, -1.00, 4.00


100%|██████████| 100000/100000 [13:25<00:00, 124.13it/s]

╭───┬───┬───┬───╮
│   │   │ o │   │ 
├───┼───┼───┼───┤
│   │   │ x │   │ 
├───┼───┼───┼───┤
│   │ x │   │   │ 
├───┼───┼───┼───┤
│ x │ x │ o │ o │ 
╰───┴───┴───┴───╯





In [15]:
evaluate(env, RandomAgent, RandomAgent, n=100)

(0.24, -0.12, 5.145)

Для доски 4х4 всего может быть 3^16 = 43_046_721 состояний. Мы обучились на менее 300тыс состояний. В игре выигрывают крестики.