# どうぶつサッカーの環境

In [None]:
from enum import Enum, auto
import random
import numpy as np
import copy

BOARD_HEIGHT = 5
BOARD_WIDTH = 3


class PlayPos(Enum):
    FRONTPLAYER = auto()
    BACKPLAYER = auto()


def playpos_opponent(playpos):
    if playpos == PlayPos.FRONTPLAYER:
        return PlayPos.BACKPLAYER
    elif playpos == PlayPos.BACKPLAYER:
        return PlayPos.FRONTPLAYER


# move_command = [(x, y, "さ")]
# kick_command = [[x, y], [a, b]]
class Act():
    def __init__(self, move, kick):
        self.move_command = move
        self.kick_command = kick

    def is_same(self, aite):
        if self.move_command == aite.move_command and self.kick_command == aite.kick_command:
            return True
        else:
            return False


def list_to_tuple(_list):
    t = ()
    for e in _list:
        if isinstance(e, list):
            t += (list_to_tuple(e), )
        else:
            t += (e, )
    return t


def printn(inp):
    print(inp, end="")


class PieceID(Enum):
    SARU_ID = auto()
    RISU_ID = auto()
    USAGI_ID = auto()
    OYASARU_ID = auto()
    BALL_ID = auto()


class Piece():
    def __init__(self, spiece, powe=None):
        self.spiece = spiece
        self.power = powe
        if self.spiece == PieceID.SARU_ID:
            self.kick_to = np.array([(1, 1), (1, -1), (-1, 0), (2, 2), (2, -2),
                                     (-2, 0)])
            if self.power == PlayPos.FRONTPLAYER:
                self.identity = "さ"
            else:
                self.identity = "サ"
        if self.spiece == PieceID.OYASARU_ID:
            self.kick_to = np.array([(1, 1), (1, -1), (-1, 0), (2, 2), (2, -2),
                                     (-2, 0)])
            if self.power == PlayPos.FRONTPLAYER:
                self.identity = "お"
            else:
                self.identity = "オ"
        elif self.spiece == PieceID.RISU_ID:
            self.kick_to = np.array([(1, 0), (2, 0), (-1, 0), (-2, 0), (0, 1),
                                     (0, 2), (0, -1), (0, -2)])
            if self.power == PlayPos.FRONTPLAYER:
                self.identity = "り"
            else:
                self.identity = "リ"
        elif self.spiece == PieceID.USAGI_ID:
            self.kick_to = np.array([(1, 0), (2, 0), (0, 1), (0, 2), (0, -1),
                                     (0, -2), (1, 1), (2, 2), (1, -1),
                                     (2, -2)])
            if self.power == PlayPos.FRONTPLAYER:
                self.identity = "う"
            else:
                self.identity = "ウ"
        elif self.spiece == PieceID.BALL_ID:
            self.identity = "ボ"


class Board:
    def __init__(self, tn=PlayPos.FRONTPLAYER, cell=None):
        self.turn = tn

        if cell is None:
            self.cells = []
            for i in range(BOARD_HEIGHT):
                self.cells.append([None for i in range(BOARD_WIDTH)])
                reset_flag = True
        else:
            self.cells = cell
            reset_flag = False

        self.S_ball = Piece(PieceID.BALL_ID)
        self.S_oyasaru_f = Piece(PieceID.OYASARU_ID, PlayPos.FRONTPLAYER)
        self.S_saru_f = Piece(PieceID.SARU_ID, PlayPos.FRONTPLAYER)
        self.S_risu_f = Piece(PieceID.RISU_ID, PlayPos.FRONTPLAYER)
        self.S_usa_f = Piece(PieceID.USAGI_ID, PlayPos.FRONTPLAYER)
        self.front_piece = {
            "さ": self.S_saru_f,
            "り": self.S_risu_f,
            "う": self.S_usa_f,
            "お": self.S_oyasaru_f
        }

        self.S_oyasaru_s = Piece(PieceID.OYASARU_ID, PlayPos.BACKPLAYER)
        self.S_saru_s = Piece(PieceID.SARU_ID, PlayPos.BACKPLAYER)
        self.S_risu_s = Piece(PieceID.RISU_ID, PlayPos.BACKPLAYER)
        self.S_usa_s = Piece(PieceID.USAGI_ID, PlayPos.BACKPLAYER)
        self.back_piece = {
            "サ": self.S_saru_s,
            "リ": self.S_risu_s,
            "ウ": self.S_usa_s,
            "オ": self.S_oyasaru_s
        }
        self.reset(reset_flag)

    def reset(self, reset_flag=True):
        if reset_flag is True:
            for i in range(BOARD_HEIGHT):
                for j in range(BOARD_WIDTH):
                    self.cells[i][j] = None
            self.cells[1][1] = self.S_saru_f
            self.cells[0][0] = self.S_usa_f
            self.cells[0][2] = self.S_risu_f
            self.cells[2][1] = self.S_ball
            self.cells[3][1] = self.S_saru_s
            self.cells[4][0] = self.S_risu_s
            self.cells[4][2] = self.S_usa_s
            self.turn = PlayPos.FRONTPLAYER if random.random(
            ) >= 0.5 else PlayPos.BACKPLAYER
        else:
            for i in range(BOARD_HEIGHT):
                for j in range(BOARD_WIDTH):
                    obj = self.cells[i][j]
                    if obj is not None:
                        if obj.spiece == PieceID.SARU_ID:
                            if obj.power == PlayPos.FRONTPLAYER:
                                self.cells[i][j] = self.S_saru_f
                            else:
                                self.cells[i][j] = self.S_saru_s
                        elif obj.spiece == PieceID.USAGI_ID:
                            if obj.power == PlayPos.FRONTPLAYER:
                                self.cells[i][j] = self.S_usa_f
                            else:
                                self.cells[i][j] = self.S_usa_s
                        elif obj.spiece == PieceID.RISU_ID:
                            if obj.power == PlayPos.FRONTPLAYER:
                                self.cells[i][j] = self.S_risu_f
                            else:
                                self.cells[i][j] = self.S_risu_s
                        elif obj.spiece == PieceID.OYASARU_ID:
                            if obj.power == PlayPos.FRONTPLAYER:
                                self.cells[i][j] = self.S_oyasaru_f
                            else:
                                self.cells[i][j] = self.S_oyasaru_s
                        elif obj.spiece == PieceID.BALL_ID:
                            self.cells[i][j] = self.S_ball

    def clone(self):
        return Board(self.turn, copy.deepcopy(self.cells))

    def display(self):
        nums = 4
        for side in self.cells[::-1]:
            printn("{}|".format(nums))
            for cel in side:
                if cel is not None:
                    printn(cel.identity)
                else:
                    printn("　")
                printn("|")
            print()
            nums -= 1
        print(".  0  1  2")

    def where_you(self, piece):
        arr = np.array(self.cells)
        pos = []
        for nd in np.where(arr == piece):
            if len(nd) != 0:
                pos.append(nd[0])
        return tuple(pos)

    def parse_board_cells(self):
        new_cells = []
        if self.turn == PlayPos.BACKPLAYER:
            piece_dict = {
                "サ": "さ",
                "リ": "り",
                "ウ": "う",
                "オ": "お",
                "さ": "サ",
                "り": "リ",
                "う": "ウ",
                "お": "オ",
                "ボ": "ボ"
            }
            for i in range(BOARD_HEIGHT):
                yokocell = self.cells[BOARD_HEIGHT - 1 - i]
                newyokocell = []
                for nakami in yokocell[::-1]:
                    if nakami is None:
                        newyokocell.append(None)
                    else:
                        newyokocell.append(piece_dict[nakami.identity])
                new_cells.append(newyokocell)
        else:
            for i in range(BOARD_HEIGHT):
                newyokocell = []
                for j in range(BOARD_WIDTH):
                    if self.cells[i][j] is None:
                        newyokocell.append(None)
                    else:
                        newyokocell.append(self.cells[i][j].identity)
                new_cells.append(newyokocell)
        return new_cells

    def tensor_state_parsed(self):
        parsed_cells = self.parse_board_cells()
        tensored_board = np.empty((9, 5, 3), dtype=np.float32)
        piece_list = ["さ", "り", "う", "お", "サ", "リ", "ウ", "オ", "ボ"]
        for i in range(9):
            for y in range(BOARD_HEIGHT):
                for x in range(BOARD_WIDTH):
                    if parsed_cells[y][x] == piece_list[i]:
                        tensored_board[i][y][x] = 1
                    else:
                        tensored_board[i][y][x] = 0
        return tensored_board

    def piece_can_move(self, piece):
        my_place = self.where_you(piece)
        if len(my_place) == 0:
            return None, None
        legal_l = []
        ball_legal = None
        if piece.spiece == PieceID.OYASARU_ID:
            for dx in range(-2, 3, 1):
                for dy in range(-2, 3, 1):
                    new_x = my_place[0] + dx
                    new_y = my_place[1] + dy
                    if 0 <= new_x <= 4 and 0 <= new_y <= 2:
                        if self.cells[new_x][new_y] is None:
                            legal_l.append((new_x, new_y))
                        elif self.cells[new_x][
                                new_y].spiece == PieceID.BALL_ID:
                            ball_legal = (new_x, new_y)
        else:
            for dx in range(-1, 2, 1):
                for dy in range(-1, 2, 1):
                    new_x = my_place[0] + dx
                    new_y = my_place[1] + dy
                    if 0 <= new_x <= 4 and 0 <= new_y <= 2:
                        if self.cells[new_x][new_y] is None:
                            legal_l.append((new_x, new_y))
                        elif self.cells[new_x][
                                new_y].spiece == PieceID.BALL_ID:
                            ball_legal = (new_x, new_y)
        return legal_l, ball_legal

    def piece_can_kick(self, piece, fromhere, tempboard):
        kick_to = piece.kick_to
        fromhere = np.array(fromhere)
        kick_l = []
        if self.turn == PlayPos.BACKPLAYER:
            kick_to = -1 * kick_to
        for kick in kick_to:
            dist = fromhere + kick
            if -1 <= dist[0] <= 5 and 0 <= dist[1] <= 2:
                if dist[0] == -1 or dist[0] == 5:
                    templ = []
                    templ.append(list(dist))
                    kick_l.append(templ)
                elif tempboard[dist[0]][dist[1]] is None:
                    templ = []
                    templ.append(list(dist))
                    kick_l.append(templ)
                elif tempboard[dist[0]][dist[1]] == 1:
                    pass
                elif tempboard[dist[0]][dist[1]].power == self.turn:
                    temppiece = tempboard[dist[0]][dist[1]]
                    tempboard[dist[0]][dist[1]] = 1
                    for oup in self.piece_can_kick(temppiece, tuple(dist),
                                                   tempboard):
                        templ = []
                        templ.append(list(dist))
                        tl = templ + oup
                        kick_l.append(tl)
        return kick_l

    # return [Act, Act, Act,...]
    def piece_legal_move(self, piece):
        acts = []
        legal_l, ball_legal = self.piece_can_move(piece)
        if legal_l is not None:
            for i, spot in enumerate(legal_l):
                legal_l[i] = spot + (piece.identity, )
            for lem in legal_l:
                act = Act(lem, None)
                acts.append(act)

        if ball_legal is not None:
            tempcell = copy.deepcopy(self.cells)
            kicker_place = self.where_you(piece)
            tempcell[kicker_place[0]][kicker_place[1]] = None
            tempcell[ball_legal[0]][ball_legal[1]] = 1
            kicks = self.piece_can_kick(piece, ball_legal, tempcell)
            ball_legal = ball_legal + (piece.identity, )
            for kick in kicks:
                kickt = list_to_tuple(kick)
                act = Act(ball_legal, kickt)
                acts.append(act)
        return acts

    def legal_moves(self):
        piece_dict = self.front_piece if self.turn == PlayPos.FRONTPLAYER else self.back_piece
        acts = []
        for pie in piece_dict.values():
            acts = acts + self.piece_legal_move(pie)
        return acts

    def parse_legal_moves(self):
        parsed_legalmoves = []
        if self.turn == PlayPos.BACKPLAYER:
            legal = self.legal_moves()
            for action in legal:
                piece_dict = {"サ": "さ", "リ": "り", "ウ": "う", "オ": "お"}
                new_move_command = [
                    BOARD_HEIGHT - 1 - action.move_command[0],
                    BOARD_WIDTH - 1 - action.move_command[1],
                    piece_dict[action.move_command[2]]
                ]
                new_kick_commands = None
                if action.kick_command is not None:
                    new_kick_commands = []
                    for kick_to in action.kick_command:
                        if kick_to is not None:
                            new_kick_to = [
                                BOARD_HEIGHT - 1 - kick_to[0],
                                BOARD_WIDTH - 1 - kick_to[1]
                            ]
                        else:
                            new_kick_to = None
                        new_kick_commands.append(new_kick_to)
                new_action = Act(new_move_command, new_kick_commands)
                parsed_legalmoves.append(new_action)
        else:
            parsed_legalmoves = self.legal_moves()
        return parsed_legalmoves

    # actionのnumはparsedされ、action自体はそのまま
    def legalmoves_to_num_parsed(self):
        legalmoves_num_act_dict = {}
        legalmoves = self.legal_moves()
        if not legalmoves:
            old_place_num = 15
            move_to_num = 15
            kick_to_num = 15
            action_num = old_place_num * 288 + move_to_num * 18 + kick_to_num
            legalmoves_num_act_dict[action_num] = None
            return legalmoves_num_act_dict
        piece_dict = self.front_piece if self.turn == PlayPos.FRONTPLAYER else self.back_piece
        for action in legalmoves:
            S_move_ready = piece_dict[action.move_command[2]]
            old_place = self.where_you(S_move_ready)
            old_place_num = old_place[0] * 3 + old_place[1]
            move_to = (action.move_command[0], action.move_command[1])
            move_to_num = move_to[0] * 3 + move_to[1]
            if action.kick_command is None:
                kick_to_num = 15
            else:
                kick_to = action.kick_command[-1]
                if kick_to[0] == -1:
                    kick_to_num = 16 if self.turn == PlayPos.FRONTPLAYER else 17
                elif kick_to[0] == 5:
                    kick_to_num = 17 if self.turn == PlayPos.FRONTPLAYER else 16
                else:
                    kick_to_num = kick_to[0] * 3 + kick_to[1]
            if self.turn == PlayPos.BACKPLAYER:
                old_place_num = 14 - old_place_num if 0 <= old_place_num <= 14 else old_place_num
                move_to_num = 14 - move_to_num if 0 <= move_to_num <= 14 else move_to_num
                kick_to_num = 14 - kick_to_num if 0 <= kick_to_num <= 14 else kick_to_num
            action_num = old_place_num * 288 + move_to_num * 18 + kick_to_num
            if action_num not in legalmoves_num_act_dict:
                legalmoves_num_act_dict[action_num] = action
        return legalmoves_num_act_dict

    def action_parser(self, action):
        if action is None:
            if self.turn == PlayPos.FRONTPLAYER:
                self.turn = PlayPos.BACKPLAYER
            else:
                self.turn = PlayPos.FRONTPLAYER
            return True, None
        legalmoves = self.legal_moves()
        existflag = False
        for l in legalmoves:
            if l.is_same(action):
                existflag = True
        if not existflag:
            return False, None
        piece_dict = self.front_piece if self.turn == PlayPos.FRONTPLAYER else self.back_piece
        S_move_ready = piece_dict[action.move_command[2]]
        move_to = (action.move_command[0], action.move_command[1])

        # 移動はaction_parser内で
        if action.kick_command is None:
            old_place = self.where_you(S_move_ready)
            self.cells[old_place[0]][old_place[1]] = None
            self.cells[move_to[0]][move_to[1]] = S_move_ready
        else:
            last_stop = action.kick_command[-1]
            if last_stop[0] == -1:
                return True, PlayPos.BACKPLAYER
            elif last_stop[0] == 5:
                return True, PlayPos.FRONTPLAYER
            old_place = self.where_you(S_move_ready)
            self.cells[old_place[0]][old_place[1]] = None
            self.cells[move_to[0]][move_to[1]] = S_move_ready
            self.cells[last_stop[0]][last_stop[1]] = self.S_ball

        Sarufpos = self.where_you(self.S_saru_f)
        if len(Sarufpos) != 0 and Sarufpos[0] == 4:
            self.cells[Sarufpos[0]][Sarufpos[1]] = self.S_oyasaru_f
        Saruspos = self.where_you(self.S_saru_s)
        if len(Saruspos) != 0 and Saruspos[0] == 0:
            self.cells[Saruspos[0]][Saruspos[1]] = self.S_oyasaru_s

        if self.turn == PlayPos.FRONTPLAYER:
            self.turn = PlayPos.BACKPLAYER
        else:
            self.turn = PlayPos.FRONTPLAYER

        return True, None


class DQNenv:
    def __init__(self):
        self.Board = Board()
        self.Board.reset()
        self.legalmoves_num_act_dict = self.Board.legalmoves_to_num_parsed()

    def reset(self):
        self.Board.reset()
        self.legalmoves_num_act_dict = self.Board.legalmoves_to_num_parsed()

    # ボードの状態をnumpy配列で表す。
    # board.cell → np.narray
    # 下側を自陣、上側を敵陣と固定
    # 0ch:自さ, 1ch:自り, 2ch:自う, 3ch:自お,
    # 4ch:敵サ, 5ch:敵リ, 6ch:敵ウ, 7ch:敵オ, 8ch:ボ
    def tensor_state(self, board_cells):
        tensored_board = np.empty((9, 5, 3), dtype=np.float32)
        piece_list = ["さ", "り", "う", "お", "サ", "リ", "ウ", "オ", "ボ"]
        for i in range(9):
            for y in range(BOARD_HEIGHT):
                for x in range(BOARD_WIDTH):
                    if board_cells[y][x] == piece_list[i]:
                        tensored_board[i][y][x] = 1
                    else:
                        tensored_board[i][y][x] = 0
        return tensored_board

    def tensor_state_parsed(self):
        cells = []
        if self.Board.turn == PlayPos.BACKPLAYER:
            piece_dict = {
                "サ": "さ",
                "リ": "り",
                "ウ": "う",
                "オ": "お",
                "さ": "サ",
                "り": "リ",
                "う": "ウ",
                "お": "オ",
                "ボ": "ボ"
            }
            for i in range(BOARD_HEIGHT):
                yokocell = self.Board.cells[BOARD_HEIGHT - 1 - i]
                newyokocell = []
                for nakami in yokocell[::-1]:
                    if nakami is None:
                        newyokocell.append(None)
                    else:
                        newyokocell.append(piece_dict[nakami.identity])
                cells.append(newyokocell)
        else:
            for i in range(BOARD_HEIGHT):
                newyokocell = []
                for j in range(BOARD_WIDTH):
                    if self.Board.cells[i][j] is None:
                        newyokocell.append(None)
                    else:
                        newyokocell.append(self.Board.cells[i][j].identity)
                cells.append(newyokocell)
        return self.tensor_state(cells)

    def legalmoves(self):
        return self.legalmoves_num_act_dict

    def step(self, tensored_action):
        now_turn = self.Board.turn
        if len(self.legalmoves_num_act_dict) == 0:
            action = None
        else:
            action = self.legalmoves_num_act_dict[tensored_action]
        success, winner = self.Board.action_parser(action)
        if success is False:
            print("Irritating input!")
            return None, None, None, None
        if winner is not None:
            if now_turn == winner:
                reward = 1
            else:
                reward = -1
            next_state = None
            next_action = None
            done = True
            skip_turn = False
        else:
            self.legalmoves_num_act_dict = self.Board.legalmoves_to_num_parsed(
            )
            reward = 0
            next_state = self.tensor_state_parsed()
            next_action = list(self.legalmoves_num_act_dict.keys())
            done = False
            skip_turn = False if len(
                self.legalmoves_num_act_dict) != 0 else True

        return next_state, next_action, reward, done, skip_turn


class BattleEnv:
    def __init__(self, frontman, backman):
        self.Board = Board()
        self.Board.reset()
        self.front = frontman
        self.back = backman

    def progress(self):
        while True:
            #self.Board.display()
            if self.Board.turn == PlayPos.FRONTPLAYER:
                now_player = self.front
            elif self.Board.turn == PlayPos.BACKPLAYER:
                now_player = self.back
            while True:
                if len(self.Board.legal_moves()) == 0:
                    action = None
                else:
                    action = now_player.action(self.Board)
                success, winner = self.Board.action_parser(action)
                if success is True:
                    break
            if winner is not None:
                if winner == PlayPos.FRONTPLAYER:
                    print("Front Player wins!")
                else:
                    print("Back Player wins!")
                return winner

    def reset(self):
        self.Board.reset()

### 必要なI/O
環境からの盤面inputは9x5x3であり、必ず下側のプレイヤーが現在の手番プレイヤーとして価値を計算する。

outputは4608の手

# DQNの実装

ライブラリのインポートと環境の生成

In [None]:
import os
import datetime
import math
from collections import namedtuple
from itertools import count
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = DQNenv()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

学習用のリプレイメモリ

In [None]:
######################################################################
# Replay Memory

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'next_actions', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

ニューラルネットワークの定義。10層の畳み込みニューラルネットワークに全結合層を接続。出力の活性化関数はtanhとして行動価値を-1～1の範囲で出力。

In:9ch x 5 x 3

Out:16x16x18=4608状態(n // 288:移動元, (n%288) // 18:移動先, (n%288) % 18:蹴り先. 0~14:盤面, 15:パス・何もしない, 16:自分のゴール,17:相手のゴール)

In [None]:
######################################################################
# DQN

k = 192
# fcl_units = 256
fcl_units = 4608 * 4
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(9, k, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(k)
        self.conv2 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(k)
        self.conv3 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(k)
        self.conv4 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(k)
        self.conv5 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(k)
        self.conv6 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(k)
        self.conv7 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(k)
        self.conv8 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(k)
        self.conv9 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(k)
        self.conv10 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn10 = nn.BatchNorm2d(k)
        self.fcl1 = nn.Linear(k * 15, fcl_units)
        self.fcl2 = nn.Linear(fcl_units, 4608)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))
        x = F.relu(self.bn9(self.conv9(x)))
        x = F.relu(self.bn10(self.conv10(x)))
        x = F.relu(self.fcl1(x.view(-1, k * 15)))
        x = self.fcl2(x)
        return x.tanh()

環境から盤面を受け取りtensorに変換する関数を以下に定義する。どこまで環境側に任せ、どこまでこの関数で変換するかは諸説ある。 ... **環境**

In [None]:
def get_state(env):
    features = np.empty((1, 9, 5, 3), dtype=np.float32)
    features[0] = env.tensor_state_parsed()
    state = torch.from_numpy(features).to(device)
    return state

訓練に使用するハイパーパラメータを設定し、ニューラルネットワーク、オプティマイザ、リプレイメモリを初期化。

また、𝜀グリーディー方策で手を選ぶ関数を定義。手は𝜀の確率で、合法手からランダムに選択。それ以外では、方策ネットワーク(policy_net)で行動価値が最大となる手を選択。εは漸減させる。

またここで環境に対する入力(=アクション)を決定する。0~4608。

合法手リストは[?, ?,..., ?]とする。 ... **環境**

In [None]:
######################################################################
# Training

BATCH_SIZE = 256
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 2000
OPTIMIZE_PER_EPISODES = 16
TARGET_UPDATE = 4

policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict (policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters(), lr=1e-5)

memory = ReplayMemory(131072)

def epsilon_greedy(state, legal_moves):
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * episodes_done / EPS_DECAY)

    if sample > eps_threshold:
        with torch.no_grad():
            q = policy_net(state)
            _, select = q[0, legal_moves].max(0)
    else:
        select = random.randrange(len(legal_moves))
    return select

def select_action(state, env):
    legalmoves_dict = env.legalmoves()
    legal_moves = list(legalmoves_dict.keys())

    select = epsilon_greedy(state, legal_moves)
# [?, ?, ?], tensor([[[?, ?, ?]]])
    return legal_moves[select], torch.tensor([[legal_moves[select]]], device=device, dtype=torch.long)

学習部分

In [None]:
######################################################################
# Training loop

losses = []
temp = 0
def optimize_model():
    global temp
    temp += 1
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)

# batch.next_state ... [?, None, None, ?,...,None]
# non_final_mask = [True, False, False, True,...,False]

    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # 合法手のみ
    non_final_next_actions_list = []
# batch.next_action
# ([3, 5, 7, 9, 10, 32, 33], [4, 17, 32, 37, 41, 42, 49],...,[...])
    for next_actions in batch.next_actions:
        if next_actions is not None:
# non_final_next_actions_list
# [[3, 5, 7, 9, 10, 32, 33,..., 0], [4, 17, 32, 37, 41, 42, 49,..., 0],...,[..., 0]]
            non_final_next_actions_list.append(next_actions + [next_actions[0]] * (100 - len(next_actions)))
    non_final_next_actions = torch.tensor(non_final_next_actions_list, device=device, dtype=torch.long)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    # 合法手のみの最大値
    target_q = target_net(non_final_next_states)
    # 相手番の価値のため反転する
    next_state_values[non_final_mask] = -target_q.gather(1, non_final_next_actions).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = next_state_values * GAMMA + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    losses.append(loss.item())

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

訓練ループ。𝜀グリーディー方策で手を選択し終局するまで対局を進めます。終局したらまた初期局面から対局を行います。

1手進めるごとに経験データをリプレイメモリに格納します。

エピソードが終了したタイミングで、OPTIMIZE_PER_EPISODES回に1回の間隔でパラメータの更新を行います。

指定のエピソード数に達したら、モデルを保存して終了します。

In [None]:
######################################################################
# main training loop

num_episodes = 10000
episodes_done = 0
pbar = tqdm(total=num_episodes)
for i_episode in range(num_episodes):
    # Initialize the environment and state
    env.reset()
    state = get_state(env)
    
    for t in count():
        # Select and perform an action
        move, action = select_action(state, env)
        next_state, next_actions, reward, done, skip_turn = env.step(move)

        reward = torch.tensor([reward], device=device)

        # Observe new state
        if done:
            next_state = None
            next_actions = None
        else:
          next_state = get_state(env)

        # Store the transition in memory
        memory.push(state, action, next_state, next_actions, reward)

        if done:
            break

        # Move to the next state
        state = next_state

    episodes_done += 1
    pbar.update()

    if i_episode % OPTIMIZE_PER_EPISODES == OPTIMIZE_PER_EPISODES - 1:
        # Perform several episodes of the optimization (on the target network)
        optimize_model()

        # pbar.set_description(f'loss = {losses[-1]:.3e}')

        # Update the target network, copying all weights and biases in DQN
        if i_episode // OPTIMIZE_PER_EPISODES % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

modelfile = '/content/drive/My Drive/Colab Notebooks/model.pt'
print('save {}'.format(modelfile))
torch.save({'state_dict': target_net.state_dict(), 'optimizer': optimizer.state_dict()}, modelfile)

print('Complete')

In [None]:
plt.plot(losses)

# 対局

### Deep Learning

In [None]:
class GreedyPlayer:
    def __init__(self, device):
        self.device = device
        self.model = DQN().to(device)
        checkpoint = torch.load("/content/drive/My Drive/Colab Notebooks/model100000.pt")
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()
        self.features = np.empty((1, 9, 5, 3), np.float32)

    def action(self, Board):
        with torch.no_grad():
            self.features[0] = Board.tensor_state_parsed()
            state = torch.from_numpy(self.features).to(self.device)
            q = self.model(state)
            # 合法手に絞る
            legalmoves_num_act_dict = Board.legalmoves_to_num_parsed()
            legal_moves = list(legalmoves_num_act_dict.keys())
            next_actions = torch.tensor([legal_moves],
                                        device=self.device,
                                        dtype=torch.long)
            legal_q = q.gather(1, next_actions)
            return legalmoves_num_act_dict[legal_moves[legal_q.argmax(
                dim=1).item()]]

### ランダムプレイヤー

In [None]:
class RandomPlayer:
    def __init__(self):
        pass

    def action(self, Board):
        not_lose = []
        legalmoves = Board.legal_moves()
        legalmoves_p = Board.parse_legal_moves()
        for i, action in enumerate(legalmoves_p):
            if action.kick_command is None:
                not_lose.append(i)
            else:
                lose_flag = False
                last_stop = action.kick_command[-1]
                if last_stop[0] == -1:
                    lose_flag = True
                elif last_stop[0] == 5:
                    return legalmoves[i]
                if not lose_flag:
                    not_lose.append(i)
        ret_index = not_lose[random.randrange(len(not_lose))]
        return legalmoves[ret_index]

### ヒューマン

In [None]:
class Human:
    def __init__(self):
        pass

    def action(self, Board):
        while True:
            legalmoves = Board.legal_moves()
            print("合法手")
            for i, action in enumerate(legalmoves):
                printn(str(i) + '):')
                printn(action.move_command)
                printn(" ")
                print(action.kick_command)
            tmp = input()
            try:
                inp = int(tmp)
                ret = legalmoves[inp]
                break
            except Exception:
                pass
        return ret

## モンテカルロ

In [None]:
class MonteCarlo:
    def __init__(self):
        pass

    def policy(self, Board):
        not_lose = []
        legalmoves = Board.legal_moves()
        legalmoves_p = Board.parse_legal_moves()
        for i, action in enumerate(legalmoves_p):
            if action.kick_command is None:
                not_lose.append(i)
            else:
                lose_flag = False
                last_stop = action.kick_command[-1]
                if last_stop[0] == -1:
                    lose_flag = True
                elif last_stop[0] == 5:
                    return legalmoves[i]
                if not lose_flag:
                    not_lose.append(i)
        if len(not_lose) != 0:
            ret_index = not_lose[random.randrange(len(not_lose))]
        else:
            # print("Error: not_lose is 0")
            # print(legalmoves)
            # print(legalmoves_p)
            ret_index = 0
        return legalmoves[ret_index]

    def trial(self, Board, act):
        kaisu = 0
        tempboard = Board.clone()
        myturn = tempboard.turn
        success, winner = tempboard.action_parser(act)
        if not success:
            print("act is not recieved")
        if winner is not None:
            if winner == myturn:
                return 10
            else:
                return -10
        while True:
            kaisu += 1
            legal_moves_l = tempboard.legal_moves()
            # if len(legal_moves_l) == 0:
            #     print("trial legalmove is 0:kaisu is {}".format(kaisu))
            #     tempboard.display()
            #     print(tempboard.turn)
            while True:
                if len(legal_moves_l) == 0:
                    action = None
                else:
                    action = self.policy(tempboard)
                success, winner = tempboard.action_parser(action)
                if success is True:
                    break
                else:
                    print("fuck")
            if winner is not None:
                if winner == myturn:
                    return 1
                else:
                    return -1

    def action(self, Board):
        scores = {}
        n = 50
        legalmoves = Board.legal_moves()
        for i, act in enumerate(legalmoves):
            scores[i] = 0
            for j in range(n):
                scores[i] += self.trial(Board, act)
            scores[i] /= n

        max_score = max(scores.values())
        for i, v in scores.items():
            if v == max_score:
                return legalmoves[i]


ランダムと対局

In [None]:
rand = RandomPlayer()
deepl = GreedyPlayer(device)

battle = BattleEnv(deepl, rand)

while True:
    count = 0
    for j in range(100):
        winner = battle.progress()
        if winner == PlayPos.FRONTPLAYER:
            count += 1
        battle.reset()
    won = count
    print("winrate:{}".format(won), flush=True)
    if won > 80:
      break

### 人間と対戦

In [None]:
human = Human()
deepl = GreedyPlayer(device)

battle = BattleEnv(deepl, human)
winner = battle.progress()

### モンテカルロと対戦

In [None]:
monte = MonteCarlo()
deepl = GreedyPlayer(device)

battle = BattleEnv(deepl, monte)

while True:
    count = 0
    for j in range(1):
        winner = battle.progress()
        if winner == PlayPos.FRONTPLAYER:
            count += 1
        battle.reset()
    won = count
    print("winrate:{}".format(won), flush=True)
    if won > 0:
      break