In [1]:
from app.rl.alpha_nine.a9 import A9Model, convert_inputs
from gym_nine_mens_morris.envs.nine_mens_morris_env import NineMensMorrisEnv, Pix
import torch
from torch.nn import functional as F
import matplotlib.pyplot as plt
import numpy as np
import copy

tensor([[1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
torch.Size([1, 24])
torch.Size([1, 24])
torch.Size([1, 24])
tensor([[0.2483, 0.2038, 0.3065, 0.2415]], grad_fn=<SoftmaxBackward>)


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model = A9Model(8, 1, 1, 1).double().to(device)

In [4]:
def subsample_legal_positions(probs, legal_pos):
    """
    probs: tensor of shape (24)
    legal_pos: list of tuples. Shape: (n, 3)
    """
    flattened_idx = [np.ravel_multi_index(pos, (3, 2, 4)) for pos in legal_pos]
    return probs[flattened_idx]


def sample_action(legal_actions, pos1, pos2, move, kill, is_phase_1, argmax=False):

    probs = []

    legal_pos = list(set([action[0] for action in legal_actions]))  # [(3, 2, 4), (3, 2, 4), ... ]
    pos_probs_ = pos1 if is_phase_1 else pos2
    # subsample pos_probs with legal actions
    pos_probs = subsample_legal_positions(pos_probs_, legal_pos)
    pos_idx = int(pos_probs.argmax()) if argmax else int(torch.multinomial(pos_probs, 1).squeeze())  # 24
    pos = legal_pos[pos_idx]  # (3, 2, 4)
    probs.append(pos_probs[pos_idx])

    # [0, 1, 2, 3]
    legal_moves = list(set([action[1] for action in legal_actions if tuple(action[0]) == tuple(pos) and action[1] is not None]))
    if len(legal_moves) != 0:
        move = move[legal_moves]
        move_idx = int(move.argmax()) if argmax else int(torch.multinomial(move, 1).squeeze())  # 4
        move_ = legal_moves[move_idx]  # 4
        probs.append(move[move_idx])
    else:
        move_ = None

    legal_kills = list(set([tuple(action[2]) for action in legal_actions if tuple(action[0]) == tuple(pos) and action[2] is not None]))
    if len(legal_kills) != 0:
        kill = subsample_legal_positions(kill, legal_kills)
        kill_idx = int(kill.argmax()) if argmax else int(torch.multinomial(kill, 1).squeeze())
        kill_ = legal_kills[kill_idx]  # (3, 2, 4)
        probs.append(kill[kill_idx])
    else:
        kill_ = None

    return pos, move_, kill_, torch.mean(torch.stack(probs))


def play(player_1, player_2, render=True):
    env = NineMensMorrisEnv()
    env.reset()
    if render:
        env.render()

    info = {}
    is_done = False
    while not is_done:
        player = player_1 if env.player == Pix.W else player_2
        action_pos, move, kill_pos = player(env)
        state, reward, is_done, info = env.step(action_pos, move, kill_pos)
        print(info)
        if render:
            env.render()

    won = info.get('won')
    if won:
        return 1 if won == Pix.W.string else 2
    return 0


def random_player(env, legal_actions=None):
    legal_actions = legal_actions if legal_actions is not None else env.get_legal_actions()
    random_idx = int(torch.randint(low=0, high=len(legal_actions), size=(1,))[0])
    random_action = legal_actions[random_idx]
    print(random_action)
    return random_action


class AIPlayer:
    def __init__(self, model):
        self.model = model

    def __call__(self, env, legal_actions=None):
        legal_actions = legal_actions if legal_actions is not None else env.get_legal_actions()
        xs = [convert_inputs(env.board, env.player)]
        xs = torch.stack(xs).long().to(device)
        was_train = self.model.training

        self.model.eval()
        with torch.no_grad():
            yh_pos_1, yh_pos_2, yh_move, yh_kill = self.model(xs)  # yh shape: (batch, 9)

        if was_train:
            self.model.train()

        pos, move, kill, _ = sample_action(legal_actions, yh_pos_1[0], yh_pos_2[0], yh_move[0], yh_kill[0], env.is_phase_1(), argmax=True)

        return pos, move, kill


def get_credits(t, gamma):
    credits = []
    prev_credit = 1
    for i in range(t):
        credits.append(prev_credit)
        prev_credit *= gamma
    return torch.tensor(list(reversed(credits))).double().to(device)


def get_returns(stats, gamma):
    total_t = len(stats)
    returns = []
    prev_return = 0
    for t in range(total_t):
        prev_return = stats[total_t - t -1][1] + (gamma * prev_return)
        returns.append(prev_return)
    return torch.tensor(list(reversed(returns))).double().to(device)


def get_loss(stats):
    loss = 0
    for i_env in range(len(stats)):
        returns = get_returns(stats[i_env], gamma=0.75)

        probs = torch.stack([stat[0] for stat in stats[i_env]])
        probs = torch.log(probs)

        credits = get_credits(len(stats[i_env]), gamma=0.75)

        loss += torch.sum(probs * credits * returns) / len(stats[i_env])
    return -1 * loss / len(stats)


def train():

    batch_size = 4
    epochs = 100

    model.train()
    optim = torch.optim.Adam(model.parameters(), lr=1e-4)
    envs = [NineMensMorrisEnv() for _ in range(batch_size)]
    losses = []
    opponent = AIPlayer(copy.deepcopy(model))
    prev_models = []

    for epoch in range(epochs):
        print(f'episode: {epoch}')
        # Before game starts

        me_idx_bool = torch.arange(batch_size) > (batch_size - 1) / 2
        stats = [[] for _ in range(batch_size)]
        is_dones = [False for _ in range(batch_size)]
        _ = [env.reset() for env in envs]
        ai_pieces = [Pix.W if me_idx_bool[i] else Pix.B for i in range(batch_size)]

        # Monte Carlo loop
        for t in range(15):
            legal_actions = [env.get_legal_actions() for env in envs]

            xs = [convert_inputs(envs[i].board, envs[i].player) for i in range(batch_size)]
            xs = torch.stack(xs).long().to(device)

            yh_pos_1, yh_pos_2, yh_move, yh_kill = model(xs)  # yh shape: (batch, 9)

            for i in range(batch_size):
                if not is_dones[i]:

                    # Is current player AI or other?
                    if envs[i].player == ai_pieces[i]:

                        pos, move, kill, prob = sample_action(legal_actions[i], yh_pos_1[i], yh_pos_2[i], yh_move[i], yh_kill[i], envs[i].is_phase_1())

                        state, reward, is_done, info = envs[i].step(pos, move, kill)

                        stats[i].append([prob, reward])
                    else:
                        action = opponent(envs[i], legal_actions[i])
                        _, _, is_done, _ = envs[i].step(*action)

                    is_dones[i] = is_done

            if all(is_dones):
                break

        loss = get_loss(stats)

        optim.zero_grad()
        loss.backward()
        optim.step()
        losses.append(loss.item())

        print(loss.item())

        if (epoch + 1) % 5 == 0:
            plays = [play(AIPlayer(model), opponent) for _ in range(1)]
            plays = plays + [play(opponent, AIPlayer(model)) for _ in range(1)]
            winners = torch.tensor(plays).float()

            draws = len(torch.nonzero(winners == 0))
            wins = len(torch.nonzero(winners == 1))
            loses = len(torch.nonzero(winners == 2))

            print(f'{epoch}: {np.mean(losses)}, plays: {draws, wins, loses}')

            # if (epoch + 1) % 1000 == 0:
            # plt.plot(losses)
            # plt.show()

            losses = []
            prev_models = prev_models[-10:]
            prev_models.append(copy.deepcopy(model))
            opponent.model = prev_models[np.random.choice(len(prev_models), 1)[0]]


        if (epoch + 1) % 500 == 0:
            plays = [play(AIPlayer(model), random_player) for _ in range(100)]
            winners = torch.tensor(plays).float()

            draws = len(torch.nonzero(winners == 0))
            wins = len(torch.nonzero(winners == 1))
            loses = len(torch.nonzero(winners == 2))
            print(f"Against random guy: ", end='')
            print(draws, wins, loses, end=', ')

            plays = [play(random_player, AIPlayer(model)) for _ in range(100)]
            winners = torch.tensor(plays).float()

            draws = len(torch.nonzero(winners == 0))
            wins = len(torch.nonzero(winners == 1))
            loses = len(torch.nonzero(winners == 2))

            print(draws, wins, loses)

            play(random_player, AIPlayer(model), render=True)
            play(AIPlayer(model), random_player, render=True)

In [None]:
train()


In [None]:
plays = [play(AIPlayer(model), random_player) for _ in range(100)]
winners = torch.tensor(plays).float()

draws = len(torch.nonzero(winners == 0))
wins = len(torch.nonzero(winners == 1))
loses = len(torch.nonzero(winners == 2))

draws, wins, loses

In [None]:
plays = [play(random_player, AIPlayer(model)) for _ in range(100)]
winners = torch.tensor(plays).float()

draws = len(torch.nonzero(winners == 0))
wins = len(torch.nonzero(winners == 1))
loses = len(torch.nonzero(winners == 2))

draws, wins, loses

In [None]:
play(random_player, AIPlayer(model), render=True)

In [5]:
play(random_player, random_player, render=True)

Current Player: W
[9 9 0 0]

•-----•-----•
| •---•---• |
| | •-•-• | |
•-•-•   •-•-•
| | •-•-• | |
| •---•---• |
•-----•-----•

((0, 0, 1), None, None)
{'code': 0}
Current Player: B
[8 9 0 0]

•-----•-----W
| •---•---• |
| | •-•-• | |
•-•-•   •-•-•
| | •-•-• | |
| •---•---• |
•-----•-----•

((2, 1, 2), None, None)
{'code': 0}
Current Player: W
[8 8 0 0]

•-----•-----W
| •---•---• |
| | •-•-• | |
•-•-•   •-•-B
| | •-•-• | |
| •---•---• |
•-----•-----•

((1, 0, 0), None, None)
{'code': 0}
Current Player: B
[7 8 0 0]

•-----•-----W
| W---•---• |
| | •-•-• | |
•-•-•   •-•-B
| | •-•-• | |
| •---•---• |
•-----•-----•

((0, 1, 1), None, None)
{'code': 0}
Current Player: W
[7 7 0 0]

•-----B-----W
| W---•---• |
| | •-•-• | |
•-•-•   •-•-B
| | •-•-• | |
| •---•---• |
•-----•-----•

((0, 1, 3), None, None)
{'code': 0}
Current Player: B
[6 7 0 0]

•-----B-----W
| W---•---• |
| | •-•-• | |
•-•-•   •-•-B
| | •-•-• | |
| •---•---• |
•-----W-----•

((0, 0, 2), None, None)
{'code': 0}
Current Player: 

0