### Решать задачу Крестики-Нолики будем с использованием одного из табличных алгоритмов SARSA. 
#### Создадим вспомогательный класс

In [88]:
from pettingzoo.classic import tictactoe_v3

import numpy as np
import json
import pickle

In [87]:
class QFunctionClass:
    def __init__(self):
        self.qdict = {}
        self.action_n = 9

    def __getitem__(self, k):
        k = str(k)
        if k in self.qdict:
            return self.qdict[k]
        else:
            self.qdict[k] = np.zeros(self.action_n)
            return self.qdict[k]

    def __setitem__(self, k, v):
        k = str(k)
        self.qdict[k] = v

In [2]:
class RandomAgent:
    def __init__(self):
        pass

    def get_action(self, obsevation):
        mask = observation["action_mask"]
        action = env.action_space('player_2').sample(mask)
        return action

#### Создаем агента работающего по алгоритму SARSA

In [89]:
class SARSAAgent:
    def __init__(self, action_dim=9, side="X", alpha=0.5, episode_n=1000):
        self.qfunction = QFunctionClass()
        self.action_dim = action_dim
        self.side = side
        self.prev_state = ''
        self.prev_action = -1
        self.prev_reward = -1
        self.state = ''
        self.action = -1
        self.reward = -1
        self.episode_n = episode_n
        self.episode = 0.0
        self.epsilon = 1.0
        self.alpha = alpha
        self.gamma = 0.99

    def get_epsilon_greedy_action(self, q_values, epsilon, action_n, mask):
        policy = np.ones(action_n) * epsilon / action_n
        masked_q_values = [-1.0 if mask[i] == 0 else q_values[i] for i in range(len(q_values))]
        max_action = np.argmax(masked_q_values)
        policy[max_action] += 1 - epsilon
        masked_policy = policy * mask
        p = 1.0 - np.sum(masked_policy)
        masked_policy = masked_policy + p / (1 - p) * masked_policy
        return np.random.choice(np.arange(action_n), p=masked_policy)

    def fit(self, reward, done):
        self.prev_reward = self.reward
        self.reward = reward
        if done:
            self.qfunction[self.state][self.action] += self.alpha * (self.reward - self.gamma * self.qfunction[self.state][self.action])
        else:
            if self.prev_state != '':
                self.qfunction[self.prev_state][self.prev_action] +=\
                self.alpha * (self.reward + self.gamma * self.qfunction[self.state][self.action]\
                              - self.qfunction[self.prev_state][self.prev_action])
        self.episode += 1.0
        self.epsilon = max(1.0 - self.episode / self.episode_n, 1e-6)

    def get_action(self, obs):
        mask = observation["action_mask"]               
        obs = observation["observation"]
        flatten_obs = np.array(obs).reshape(-1)
        flatten_obs = [str(i) for i in flatten_obs]
        # закидываем в предыдущее значение
        self.prev_state = self.state
        self.state = ''.join(flatten_obs)
        self.prev_action = self.action
        self.action = self.get_epsilon_greedy_action(self.qfunction[self.state], self.epsilon, self.action_dim, mask)
        return self.action

    def clear_state(self):
        self.prev_state = ''
        self.prev_action = -1
        self.prev_reward = -1
        self.state = ''
        self.action = -1
        self.reward = -1

    def save_model_pickle(self, path='/home/artem/atari_games/sarsa.json'):
        with open(path, 'wb') as f:
            pickle.dump(self.qfunction.qdict, f)

    def load_model_pickle(self, path='/home/artem/atari_games/sarsa.json'):
        with open(path, 'rb') as f:
            self.qfunction.qdict = pickle.load(f)


#### Обучаем стратегию

In [90]:
def print_board(player, obs):
    if player == "player_1":
        obs = obs["observation"]
        x = obs[:,:,0]
        o = obs[:,:,1]
    else:
        obs = obs["observation"]
        x = obs[:,:,1]
        o = obs[:,:,0]

    print("-------")
    for i in range(3):
        print("|", end="")
        for j in range(3):
            if x[i][j] == 1:
                print("x|", end="")
            elif o[i][j] == 1:
                print("o|", end="")
            else:
                print(" |", end="")
        print("\n--------")
            

In [91]:
def learn_policy(env, Agents, episode_n):
    for i in range(episode_n):
        if i % 5_000 == 0:
            print(f'##################\niteration: {i}\n')
            
        Agents['player_1'].clear_state()
        Agents['player_2'].clear_state()
        env.reset()
        for agent_id in env.agent_iter():
            agent = Agents[agent_id]
            
            observation, reward, termination, truncation, info = env.last()
    
            if i % 5_000 == 0:
                print_board(agent_id, observation)
        
            if termination or truncation:
                action = None
                
            else:
                action = agent.get_action(observation)
                
            agent.fit(reward, termination or truncation)
            
            if i % 5_000 == 0:
                print(f"agent_id: {agent_id}, reward: {reward}, qfunction[prev_state]: {agent.qfunction[agent.prev_state]}\n qfunction[state]: {agent.qfunction[agent.state]}, epsilon={ np.round(agent.epsilon, 3)}")
                
            env.step(action)
    return Agents

In [92]:
env = tictactoe_v3.env()

episode_n = 60_000

agent1 = SARSAAgent(alpha=0.95, episode_n=episode_n)
agent2 = SARSAAgent(alpha=0.95, episode_n=episode_n)

Agents = {'player_1': agent1, 'player_2': agent2}

for i in range(episode_n):
    if i % 5_000 == 0:
        print(f'##################\niteration: {i}\n')
    if i == episode_n // 2:
        # смена игроков. Теперь начинает агент 2
        Agents['player_1'] = agent2
        Agents['player_2'] = agent1
        agent1.epsilon = 1.0
        agent2.epsilon = 1.0
        agent1.episode = 0
        agent2.episode = 0
        agent1.episode_n = episode_n // 2
        agent2.episode_n = episode_n // 2
        
    Agents['player_1'].clear_state()
    Agents['player_2'].clear_state()
    env.reset()
    for agent_id in env.agent_iter():
        agent = Agents[agent_id]
        
        observation, reward, termination, truncation, info = env.last()

        if i % 5_000 == 0:
            print_board(agent_id, observation)
    
        if termination or truncation:
            action = None
            
        else:
            action = agent.get_action(observation)
            
        agent.fit(reward, termination or truncation)
        
        if i % 5_000 == 0:
            print(f"agent_id: {agent_id}, reward: {reward}, qfunction[prev_state]: {agent.qfunction[agent.prev_state]}\n qfunction[state]: {agent.qfunction[agent.state]}, epsilon={ np.round(agent.epsilon, 3)}")
            
        env.step(action)

env.close()

##################
iteration: 0

-------
| | | |
--------
| | | |
--------
| | | |
--------
agent_id: player_1, reward: 0, qfunction[prev_state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 qfunction[state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.], epsilon=1.0
-------
| | | |
--------
| | | |
--------
|x| | |
--------
agent_id: player_2, reward: 0, qfunction[prev_state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 qfunction[state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.], epsilon=1.0
-------
| | | |
--------
|o| | |
--------
|x| | |
--------
agent_id: player_1, reward: 0, qfunction[prev_state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 qfunction[state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.], epsilon=1.0
-------
| | |x|
--------
|o| | |
--------
|x| | |
--------
agent_id: player_2, reward: 0, qfunction[prev_state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 qfunction[state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.], epsilon=1.0
-------
| | |x|
--------
|o| | |
--------
|x| |o|
--------
agent_id: player_1, reward: 0, qfunction[prev_state]: [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 qfunction[st

KeyboardInterrupt: 

#### Сохраняем модель 

In [83]:
agent1.save_model_pickle('good_strategy1')
agent2.save_model_pickle('good_strategy2')
# agent1.save_model_pickle()

#### Проверяем в игре против рандомной стратегии

In [86]:
from pettingzoo.classic import tictactoe_v3

env = tictactoe_v3.env()
env.reset(seed=45)

agent1 = SARSAAgent()
agent1.load_model_pickle("good_strategy2")
agent1.epsilon = 0.00001

#agent2 = RandomAgent()

agent2 = SARSAAgent()
agent2.load_model_pickle("good_strategy1")
agent2.epsilon = 0.00001

Agents = {'player_1': agent1, 'player_2': agent2}
games_n = 1000
wins = 0
losses = 0
for i in range(games_n):
    j = 0
    env.reset()
    for agent_id in env.agent_iter():
        agent = Agents[agent_id]
        observation, reward, termination, truncation, info = env.last()
        # print(reward)

        if reward == 1 and j % 2 == 0:
            wins += 1
        elif reward == -1 and j % 2 == 0:
            losses += 1
        j += 1
        if termination or truncation:
            action = None
        else:
            action = agent.get_action(observation)
    
        env.step(action)
        
print(f'player1 wins {wins} lose {losses} games from {games_n} games')       
env.close()

player1 wins 0 lose 0 games from 1000 games


## Задача решена