In [6]:
!pipreqs /home/nthuuser/下載/Andy/train.ipynb

/bin/sh: 1: pipreqs: not found


In [1]:
import numpy as np
import chess
import chess.svg
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchsummary import summary
from collections import deque
import random
import os
import datetime
from IPython.display import SVG, display
import ftplib
import hashlib
import json
import math
import multiprocessing as mp
from collections import defaultdict

In [2]:
pieces_order = 'KQRBNPkqrbnp'  # 12x8x8
castling_order = 'KQkq'
ind = {pieces_order[i]: i for i in range(12)}

def alg_to_coord(alg):
    rank = 8 - int(alg[1])        # 0-7
    file = ord(alg[0]) - ord('a')  # 0-7
    return rank, file


def coord_to_alg(coord):
    letter = chr(ord('a') + coord[1])
    number = str(8 - coord[0])
    return letter + number


def to_planes(fen):
    board_state = replace_tags_board(fen)
    pieces_both = np.zeros(shape=(12, 8, 8), dtype=np.float32)
    for rank in range(8):
        for file in range(8):
            v = board_state[rank * 8 + file]
            if v.isalpha():
                pieces_both[ind[v]][rank][file] = 1
    assert pieces_both.shape == (12, 8, 8)
    return pieces_both


def replace_tags_board(board_san):
    board_san = board_san.split(" ")[0]
    board_san = board_san.replace("2", "11")
    board_san = board_san.replace("3", "111")
    board_san = board_san.replace("4", "1111")
    board_san = board_san.replace("5", "11111")
    board_san = board_san.replace("6", "111111")
    board_san = board_san.replace("7", "1111111")
    board_san = board_san.replace("8", "11111111")
    return board_san.replace("/", "")


def is_black_turn(fen):
    return fen.split(" ")[1] == 'b'

def check_current_planes(realfen, planes):
    cur = planes[0:12]
    assert cur.shape == (12, 8, 8)
    fakefen = ["1"] * 64
    for i in range(12):
        for rank in range(8):
            for file in range(8):
                if cur[i][rank][file] == 1:
                    assert fakefen[rank * 8 + file] == '1'
                    fakefen[rank * 8 + file] = pieces_order[i]

    castling = planes[12:16]
    fiftymove = planes[16][0][0]
    ep = planes[17]

    castlingstring = ""
    for i in range(4):
        if castling[i][0][0] == 1:
            castlingstring += castling_order[i]

    if len(castlingstring) == 0:
        castlingstring = '-'

    epstr = "-"
    for rank in range(8):
        for file in range(8):
            if ep[rank][file] == 1:
                epstr = coord_to_alg((rank, file))

    # realfen = maybe_flip_fen(realfen, flip=is_black_turn(realfen))
    realparts = realfen.split(' ')
    assert realparts[1] == 'w'
    assert realparts[2] == castlingstring
    assert realparts[3] == epstr
    assert int(realparts[4]) == fiftymove
    # realparts[5] is the fifty-move clock, discard that
    return "".join(fakefen) == replace_tags_board(realfen)


def canon_input_planes(fen):
    """
    :param fen:
    :return : (18, 8, 8) representation of the game state
    """
    fen = maybe_flip_fen(fen, is_black_turn(fen))
    return all_input_planes(fen)


def all_input_planes(fen):
    current_aux_planes = aux_planes(fen)

    history_both = to_planes(fen)

    ret = np.vstack((history_both, current_aux_planes))
    assert ret.shape == (18, 8, 8)
    return ret


def maybe_flip_fen(fen, flip=False):
    if not flip:
        return fen
    foo = fen.split(' ')
    rows = foo[0].split('/')

    def swapcase(a):
        if a.isalpha():
            return a.lower() if a.isupper() else a.upper()
        return a

    def swapall(aa):
        return "".join([swapcase(a) for a in aa])
    return "/".join([swapall(row) for row in reversed(rows)]) \
        + " " + ('w' if foo[1] == 'b' else 'b') \
        + " " + "".join(sorted(swapall(foo[2]))) \
        + " " + foo[3] + " " + foo[4] + " " + foo[5]


def aux_planes(fen):
    foo = fen.split(' ')

    en_passant = np.zeros((8, 8), dtype=np.float32)
    if foo[3] != '-':
        eps = alg_to_coord(foo[3])
        en_passant[eps[0]][eps[1]] = 1

    fifty_move_count = int(foo[4])
    fifty_move = np.full((8, 8), fifty_move_count, dtype=np.float32)

    castling = foo[2]
    auxiliary_planes = [np.full((8, 8), int('K' in castling), dtype=np.float32),
                        np.full((8, 8), int('Q' in castling),
                                dtype=np.float32),
                        np.full((8, 8), int('k' in castling),
                                dtype=np.float32),
                        np.full((8, 8), int('q' in castling),
                                dtype=np.float32),
                        fifty_move,
                        en_passant]

    ret = np.asarray(auxiliary_planes, dtype=np.float32)
    assert ret.shape == (6, 8, 8)
    return ret

In [3]:
from glob import glob
from collections import deque
from concurrent.futures import ProcessPoolExecutor
# dataset = deque(), deque(), deque()
# data_dir = '/content/drive/MyDrive/Colab_Notebooks/rl_chess'
model_dir = 'models'
play_data_dir = 'stockfish_data'
play_data_filename_tmpl = "best_dataset_2_%s.json"
# "%s.json"

def find_pgn_files(directory, pattern='*.pgn'):
  dir_pattern = os.path.join(directory, pattern)
  files = list(sorted(glob(dir_pattern)))
  return files

def read_game_data_from_file(path):
  try:
    with open(path, "rt") as f:
        return json.load(f)
  except Exception as e:
    print(e)

def get_game_data_filenames(play_data_dir,play_data_filename_tmpl):
  pattern = os.path.join(play_data_dir, play_data_filename_tmpl % "*")
  files = list(sorted(glob(pattern)))
  return files

def testeval(fen, absolute=False) -> float:
    # somehow it doesn't know how to keep its queen
    piece_vals = {'K': 3, 'Q': 14, 'R': 5, 'B': 3.25, 'N': 3, 'P': 1}
    ans = 0.0
    tot = 0
    for c in fen.split(' ')[0]:
        if not c.isalpha():
            continue

        if c.isupper():
            ans += piece_vals[c]
            tot += piece_vals[c]
        else:
            ans -= piece_vals[c.upper()]
            tot += piece_vals[c.upper()]
    v = ans/tot
    if not absolute and is_black_turn(fen):
        v = -v
    assert abs(v) < 1
    return np.tanh(v * 3)  # arbitrary

def create_uci_labels():
    """
    Creates the labels for the universal chess interface into an array and returns them
    :return:
    """
    labels_array = []
    letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
    numbers = ['1', '2', '3', '4', '5', '6', '7', '8']
    promoted_to = ['q', 'r', 'b', 'n']

    for l1 in range(8):
        for n1 in range(8):
            destinations = [(t, n1) for t in range(8)] + \
                           [(l1, t) for t in range(8)] + \
                           [(l1 + t, n1 + t) for t in range(-7, 8)] + \
                           [(l1 + t, n1 - t) for t in range(-7, 8)] + \
                           [(l1 + a, n1 + b) for (a, b) in
                            [(-2, -1), (-1, -2), (-2, 1), (1, -2), (2, -1), (-1, 2), (2, 1), (1, 2)]]
            for (l2, n2) in destinations:
                if (l1, n1) != (l2, n2) and l2 in range(8) and n2 in range(8):
                    move = letters[l1] + numbers[n1] + \
                        letters[l2] + numbers[n2]
                    labels_array.append(move)
    for l1 in range(8):
        l = letters[l1]
        for p in promoted_to:
            labels_array.append(l + '2' + l + '1' + p)
            labels_array.append(l + '7' + l + '8' + p)
            if l1 > 0:
                l_l = letters[l1 - 1]
                labels_array.append(l + '2' + l_l + '1' + p)
                labels_array.append(l + '7' + l_l + '8' + p)
            if l1 < 7:
                l_r = letters[l1 + 1]
                labels_array.append(l + '2' + l_r + '1' + p)
                labels_array.append(l + '7' + l_r + '8' + p)
    return labels_array

def flipped_uci_labels():
    """
    Seems to somehow transform the labels used for describing the universal chess interface format, putting
    them into a returned list.
    :return:
    """
    def repl(x):
        return "".join([(str(9 - int(a)) if a.isdigit() else a) for a in x])

    return [repl(x) for x in create_uci_labels()]


def flip_policy(pol):
    """
    :param pol policy to flip:
    :return: the policy, flipped (for switching between black and white it seems)
    """
    return np.asarray([pol[ind] for ind in unflipped_index])


def convert_to_cheating_data(data):
    """
    :param data: format is SelfPlayWorker.buffer
    :return:
    """
    state_list = []
    policy_list = []
    value_list = []
    for state_fen, policy, value in data:

        state_planes = canon_input_planes(state_fen)
        policy = all_moves2index_dict[policy]
        temp = np.zeros((1968,))
        temp[policy] = 1
        policy = temp
        del temp
        if is_black_turn(state_fen):
            policy = flip_policy(policy)

        # move_number = int(state_fen.split(' ')[5])
        # # reduces the noise of the opening... plz train faster
        # value_certainty = min(10, move_number)/10
        # _value = value*value_certainty + testeval(state_fen, False)*(1-value_certainty)

        state_list.append(state_planes)
        policy_list.append(policy)
        value_list.append(value)

    return np.array(state_list, dtype=np.float32), np.array(policy_list, dtype=np.float32), np.array(value_list, dtype=np.float32)


def load_data_from_file(filename):
    data = read_game_data_from_file(filename)
    return convert_to_cheating_data(data)


labels = create_uci_labels()
n_labels = int(len(labels))
flipped_labels = flipped_uci_labels()
unflipped_index = [labels.index(x) for x in flipped_labels]
all_moves2index_dict = {move: i for i, move in enumerate(labels)}
int_to_move = {v: k for k, v in all_moves2index_dict.items()}

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.action_size = 8*8*73
        self.conv1 = nn.Conv2d(18, 256, 5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(256)

    def forward(self, s):
        s = s.view(-1, 18, 8, 8)  # batch_size x channels x board_x x board_y
        s = F.relu(self.bn1(self.conv1(s)))
        return s

class ResBlock(nn.Module):
    def __init__(self, inplanes=256, planes=256, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = F.relu(self.bn1(out))
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = F.relu(out)
        return out

class OutBlock(nn.Module):
    def __init__(self):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv2d(256, 4, kernel_size=1) # value head
        self.bn = nn.BatchNorm2d(4)
        self.fc1 = nn.Linear(8*8*4, 256)
        self.fc2 = nn.Linear(256, 1)

        self.conv1 = nn.Conv2d(256, 2, kernel_size=1) # policy head
        self.bn1 = nn.BatchNorm2d(2)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.fc = nn.Linear(8*8*2, 1968)

    def forward(self, s):
        v = F.relu(self.bn(self.conv(s))) # value head
        v = v.view(-1, 4*8*8)
        v = F.relu(self.fc1(v))
        v = F.tanh(self.fc2(v))

        p = F.relu(self.bn1(self.conv1(s))) # policy head
        p = p.view(-1, 2*8*8)
        p = self.fc(p)
        p = self.logsoftmax(p).exp()
        return p, v

class ChessNet(nn.Module):
    def __init__(self):
        super(ChessNet, self).__init__()
        self.conv = ConvBlock()
        for block in range(7):
            setattr(self, f"res_{block}", ResBlock())
        self.outblock = OutBlock()

    def forward(self, s):
        s = self.conv(s)
        for block in range(7):
            s = getattr(self, f"res_{block}")(s)
        return self.outblock(s)



In [5]:
def dataloader(state_ary, policy_ary, value_ary,batch_size = 128):
  states = []
  policies = []
  values = []
  n = state_ary.shape[0]
  for i in range(n//batch_size):
    states.append(torch.tensor(state_ary[i*batch_size:(i+1)*batch_size],dtype=torch.float32).to(device))
    policies.append(torch.tensor(policy_ary[i*batch_size:(i+1)*batch_size],dtype=torch.float32).to(device))
    values.append(torch.tensor(value_ary[i*batch_size:(i+1)*batch_size],dtype=torch.float32).to(device))
  if n%batch_size != 0:
    states.append(torch.tensor(state_ary[n//batch_size*batch_size:],dtype=torch.float32).to(device))
    policies.append(torch.tensor(policy_ary[n//batch_size*batch_size:],dtype=torch.float32).to(device))
    values.append(torch.tensor(value_ary[n//batch_size*batch_size:],dtype=torch.float32).to(device))
  return states,policies,values

In [6]:
class AlphaLoss(torch.nn.Module):
    def __init__(self):
        super(AlphaLoss, self).__init__()

    def forward(self, y_value, value, y_policy, policy):
        value_error = (value - y_value) ** 2
        policy_error = torch.sum((-policy * (1e-6 + y_policy.float()).float().log()), 1)
        total_error = value_error.squeeze()*100 + policy_error
        return total_error.mean()

In [7]:
class ChessAgent:
    def __init__(self, name, device, path=None):
        self.name = name
        self.device = device
        self.model = ChessNet().to(device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001,weight_decay=1e-5)
        self.criterion = AlphaLoss()

        self.batch_size = 512
        # self.train_size = 10000
        self.tau_decay_rate = 0.99
        if path:
          self.load_model(model_dir,path)
          self.model.to(device)
        else:
          pass

    def load_model(self, dir, filename):
        """Load a model from disk."""
        if os.path.exists(os.path.join(dir,filename)):
            # save_data = torch.load(os.path.join(dir,filename))
            save_data = torch.load(os.path.join(dir,filename),map_location=device)
            self.model.load_state_dict(save_data['model_state'])
            print(f"Model loaded from {filename}")
            return True
        return False

    def save_model(self,dir, filename):
        """Save a model to disk."""
        save_data = {
            'model_state': self.model.state_dict(),
        }
        path = os.path.join(dir,filename)
        torch.save(save_data, path)
        print(f"Model saved to {path}")

    def choose_action(self, board,max_depth = 15,n_simulations=1000,mode = 'tree'):
        """Choose action by probability."""
        if mode == 'mcts':
            n_move = int(board.fen().split(' ')[-1])
            # move = mcts(board,self.model,n_simulations=n_simulations,temperature=(1/n_move)*5)
            move = mcts(board,self.model,n_simulations=n_simulations,temperature = 1.0)
            if move is None:
                move = random.choice(list(board.legal_moves))
            return move.uci()
        elif mode == 'tree':
            tree = MCTS(self.model, max_depth=max_depth)
            root = tree.search(board, n_simulations=n_simulations)
            move = root.best_action(temperature = 1.0)
            if move is None:
                move = random.choice(list(board.legal_moves))
            return move.uci()

    def apply_temperature(self, policy, turn):
        """
        Applies a random fluctuation to probability of choosing various actions
        :param policy: list of probabilities of taking each action
        :param turn: number of turns that have occurred in the game so far
        :return: policy, randomly perturbed based on the temperature. High temp = more perturbation. Low temp
            = less.
        """
        tau = np.power(self.tau_decay_rate, int(turn) + 1)
        if tau < 0.1:
            tau = 0
        if tau == 0:
            action = np.argmax(policy)
            ret = np.zeros(1968)
            ret[action] = 1.0
            return ret
        else:
            ret = np.power(policy, 1/tau)
            ret /= np.sum(ret)
            return ret

    def get_best_move(self, board):
        policy,value = self.model(torch.FloatTensor(np.array(canon_input_planes(board.fen()))).to(device))
        policy = policy.squeeze().cpu().detach().numpy()
        value = value.squeeze().cpu().detach().numpy()
        if is_black_turn(board.fen()):
            policy = flip_policy(policy)
        legal_moves = list(board.legal_moves)
        legal_moves_uci = [move.uci() for move in legal_moves]
        sorted_indices = np.argsort(policy)[::-1]
        for move_index in sorted_indices:
            move = int_to_move[move_index]
            if move in legal_moves_uci:
                return move
        return None
# agent = ChessAgent('test',device,'')

In [10]:
def train_from_data(agent):
  filenames = get_game_data_filenames(play_data_dir,play_data_filename_tmpl)
  filenames = filenames
  print(filenames)
  print(len(filenames))
  epochs = 50
  # agent = self.white_agent
  steps = 0
  batch_size = 2048
  print(f'Training agent {agent.name}')
  while len(filenames)>0:
    states,policies,values=([],[],[])
    for _ in range(10):
      if not filenames:
        break
      filename = filenames.pop()
      print(f'load data from {filename}')
      state_ary, policy_ary, value_ary = load_data_from_file(filename)

      temp = dataloader(state_ary, policy_ary, value_ary,batch_size)
      states,policies,values = states+temp[0],policies+temp[1],values+temp[2]
      del temp, state_ary, policy_ary, value_ary
    print(f'num_data {len(values)}')
    for epoch in range(epochs):
      running_loss=0
      for state,policy,value in zip(states,policies,values):
        # print(state.size(),policy.size(),value.size())

        predicted_policies, predicted_values = agent.model(state)
        # print(predicted_values[-1],value[-1])
        # print(predicted_policies.size(),predicted_values.size())

        loss = agent.criterion(predicted_values, value, predicted_policies, policy)

        agent.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(agent.model.parameters(), max_norm=1.0)
        agent.optimizer.step()
        running_loss += loss.item()
      if (epoch+1)%10==0:
          print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(values) :.4f}')
      steps+=1
    agent.save_model(model_dir,f'train_1216_{steps}.pt')

In [11]:
device = torch.device('cuda')
model_path = ''
agent = ChessAgent("white",device,model_path)

In [12]:
train_from_data(agent)

['stockfish_data/best_dataset_2_167.json', 'stockfish_data/best_dataset_2_168.json', 'stockfish_data/best_dataset_2_169.json', 'stockfish_data/best_dataset_2_170.json', 'stockfish_data/best_dataset_2_171.json', 'stockfish_data/best_dataset_2_172.json', 'stockfish_data/best_dataset_2_173.json', 'stockfish_data/best_dataset_2_174.json', 'stockfish_data/best_dataset_2_175.json', 'stockfish_data/best_dataset_2_176.json', 'stockfish_data/best_dataset_2_177.json', 'stockfish_data/best_dataset_2_178.json', 'stockfish_data/best_dataset_2_179.json', 'stockfish_data/best_dataset_2_180.json', 'stockfish_data/best_dataset_2_181.json', 'stockfish_data/best_dataset_2_182.json', 'stockfish_data/best_dataset_2_183.json', 'stockfish_data/best_dataset_2_184.json', 'stockfish_data/best_dataset_2_185.json', 'stockfish_data/best_dataset_2_186.json', 'stockfish_data/best_dataset_2_187.json', 'stockfish_data/best_dataset_2_188.json', 'stockfish_data/best_dataset_2_189.json', 'stockfish_data/best_dataset_2_19