### Import and some functions

In [64]:
IPYNB_PATH = './'
%cd IPYNB_PATH

[Errno 2] No such file or directory: 'IPYNB_PATH'
/home/smallfish/repo/MasterTheGameOfConnect4


In [65]:
from mcts_with_simulation import MCTS, NN3DConnect4, State
import os
import time
import torch
import json
from typing import List
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

def get_filenames(dir_path) -> List[str]:
    filenames = []
    for roots, dirs, files in os.walk(dir_path):
        for file in files:
            filenames.append(file)
    return filenames


### Hyper-parameters and Constants

In [66]:
CONFIG_PATH = 'config.yaml'
CONFIG = {
    'path': {
        'models_dir': 'models/',
        'not_trained_trajectories': 'not_trained_trajectories/',
        'trained_trajectories': 'trained_trajectories/',
        'logs_dir': 'logs/'},
    'model': {
        'channels': 16,
        'blocks': 4},
    'self_play': {
        'temperature_drop': 30,
        'mcts_time_limit': 99999,
        'mcts_max_simulation_cnt': 3},
    'train': {
        'batch_size': 64,
        'frequency': 128,
        'reuse': 1,
        'buf_size': 6400,
        'learning_rate': 0.08}}
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def get_latest_model_path(return_number: bool = False):
    # Rule of model filename: number(5 digits) + .pt
    filenames = get_filenames(CONFIG['path']['models_dir'])
    if len(filenames) == 0:
        if return_number:
            return None, None
        else:
            return None
    file_numbers = [int(filename.split('.')[-2]) for filename in filenames]
    latest_number = max(file_numbers)
    latest_path = os.path.join(CONFIG['path']['models_dir'],  '{:05d}'.format(latest_number) + '.pt')
    if return_number:
        return latest_path, latest_number
    else:
        return latest_path


### Learner

In [67]:
class ReplayBuffer(Dataset):
    def __init__(self, buf_size, frequency):
        self.buf_size = buf_size
        self.frequency = frequency
        self.data = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def add_trajectories(self, trajectories: List[dict]):
        for trajectory in trajectories:
            boards = trajectory['boards']
            properties = trajectory['properties']
            hands = trajectory['hands']
            distributions_2d = trajectory['distributions_2d']
            winner = trajectory['winner']

            # Get each state's complete information
            for idx in range(len(boards)):
                input_tensor_3d, input_tensor_scalar = MCTS.observation_tensors(board=boards[idx], properties=properties[idx], hands=hands[idx])
                policy_target_tensor_2d = torch.tensor(distributions_2d[idx])
                # Get reward
                if winner == 0:
                    reward = 0.
                elif winner == 1:
                    if hands[idx] % 2 == 0:
                        reward = 1
                    else:
                        reward = 0
                elif winner == 2:
                    if hands[idx] % 2 ==0:
                        reward = 0
                    else:
                        reward = 1
                else:
                    raise ValueError('Invalid winner: {}'.format(winner))
                reward_tensor_1d = torch.tensor([reward])
                #
                self.data.append((input_tensor_3d, input_tensor_scalar, policy_target_tensor_2d, reward_tensor_1d))
                # Full
                if len(self.data) > self.buf_size:
                    self.data.pop(0)

In [68]:
class Learner:
    def __init__(self, replay_buffer):
        # Find the latest pt file
        latest_model_path, latest_number = get_latest_model_path(return_number=True)
        self.iteration = (latest_number + 1) if latest_number else 0
        self.model = NN3DConnect4(features_3d_in_channels=4,
                             features_scalar_in_channels=4,
                             channels=CONFIG['model']['channels'],
                             blocks=CONFIG['model']['blocks'])
        if not latest_model_path:
            self.save_model()
        else:
            self.model.load_state_dict(torch.load(latest_model_path))
        self.model.to(DEVICE)
        self.optimizer = torch.optim.SGD(params=self.model.parameters(),
                                    lr=CONFIG['train']['learning_rate'],
                                    momentum=0.9,
                                    weight_decay=0.0001,
                                    nesterov=True)
        self.replay_buffer: ReplayBuffer = replay_buffer
        self.summary_writer = SummaryWriter(log_dir=CONFIG['path']['logs_dir'], purge_step=self.iteration)

    def save_model(self):
        filename = '{:05d}.pt'.format(self.iteration)
        torch.save(self.model.state_dict(), f= os.path.join(CONFIG['path']['models_dir'], filename))

    def train(self):
        self.iteration += 1
        data_loader = DataLoader(dataset=self.replay_buffer,
                                 batch_size=CONFIG['train']['batch_size'],
                                 shuffle=True)
        self.model.train()
        for i, data in enumerate(data_loader):
            input_tensor_3d, input_tensor_scalar, policy_target, value_target = data
            policy_target = policy_target.to(DEVICE)
            value_target = value_target.to(DEVICE)

            self.model.zero_grad()
            policy, value = self.model.forward(input_tensor_3d.float(), input_tensor_scalar.float())

            policy_loss = (-policy_target.view(-1, State.HEIGHT * State.WIDTH) * (1e-8+policy).log()).sum(dim=1).mean()
            value_loss = torch.nn.MSELoss()(value_target, value)
            loss = policy_loss + value_loss

            print("Iteration {} loss: {}".format(self.iteration, loss))

            loss.backward()
            self.optimizer.step()

            assert CONFIG['train']['frequency'] % CONFIG['train']['batch_size'] == 0
            if (i + 1 ) >= (CONFIG['train']['frequency'] * CONFIG['train']['reuse'] // CONFIG['train']['batch_size']):
                lr = next(iter(self.optimizer.param_groups))['lr']
                self.summary_writer.add_scalar('train_param/lr', lr, self.iteration)
                self.summary_writer.add_scalar('train_param/batch_size', CONFIG['train']['batch_size'], self.iteration)
                self.summary_writer.add_scalar('train_param/frequency', CONFIG['train']['frequency'], self.iteration)
                self.summary_writer.add_scalar('train_param/reuse', CONFIG['train']['reuse'], self.iteration)
                self.summary_writer.add_scalar('train_param/buf_size', CONFIG['train']['buf_size'], self.iteration)
                self.summary_writer.add_scalar('model/channels', CONFIG['model']['channels'], self.iteration)
                self.summary_writer.add_scalar('model/blocks', CONFIG['model']['blocks'], self.iteration)
                self.summary_writer.add_scalar('self_play/temperature_drop', CONFIG['self_play']['temperature_drop'], self.iteration)
                self.summary_writer.add_scalar('self_play/mcts_time_limit', CONFIG['self_play']['mcts_time_limit'], self.iteration)
                self.summary_writer.add_scalar('self_play/mcts_max_simulation_cnt', CONFIG['self_play']['mcts_max_simulation_cnt'], self.iteration)
                self.summary_writer.add_scalar('loss/total_loss', loss, self.iteration)
                self.summary_writer.add_scalar('loss/policy_loss', policy_loss, self.iteration)
                self.summary_writer.add_scalar('loss/value_loss', value_loss, self.iteration)
                break
        self.save_model()


### Actor
- A game contains 64 states.
- Number of games everytime actor plays: frequency / 64
- In each training, feed total "frequency" data, using split into batches to feed. After one \
    this process, we call it a step.
- If buffer is full, when a new game finished, oldest one will be removed from buf. (They're still in the filesystem.)

In [69]:
def actor() -> List[dict]:
    """

    :return: Trajectories (List of trajectory dicts)
    """
    assert CONFIG['train']['frequency'] % CONFIG['train']['batch_size'] == 0
    #####################
    # Fetch newest model
    #####################
    model = NN3DConnect4(features_3d_in_channels=4,
                         features_scalar_in_channels=4,
                         channels=CONFIG['model']['channels'],
                         blocks=CONFIG['model']['blocks'])
    # Find the .pt file
    while True:
        latest_model_path = get_latest_model_path()
        if latest_model_path and os.path.exists(latest_model_path):
            model.load_state_dict(torch.load(latest_model_path))
            break
        else:
            # Fetch again after a while
            print('Find no latest pt model file...(wait for a while)')
            time.sleep(5)
            continue
    ######################
    trajectories = []
    for _ in range(CONFIG['train']['frequency'] // CONFIG['train']['batch_size']):

        #######################
        # Self-play a game
        #######################
        cur_node = MCTS.get_init_node()
        boards = []
        properties = []
        hands = []
        distributions_2d = []
        winner = 0 # 0 means draw; 1 means black win; 2 means white win
        while True:
            # Check if terminal
            if cur_node.hands >= 64:
                # Check who win
                if cur_node.properties[0] > cur_node.properties[1]:
                    winner = 1
                elif cur_node.properties[0] < cur_node.properties[1]:
                    winner = 2
                break
            # Record
            hands.append(cur_node.hands)
            boards.append(cur_node.board.tolist())
            properties.append(cur_node.properties)
            # Self-play
            mcts = MCTS(root_node=cur_node,
                        eval_func='nn',
                        model=model,
                        device=DEVICE,
                        max_time_sec=CONFIG['self_play']['mcts_time_limit'],
                        max_simulation_cnt=CONFIG['self_play']['mcts_max_simulation_cnt'],
                        not_greedy=True if cur_node.hands <= CONFIG['self_play']['temperature_drop'] else False)
            move, sim_cnt, time_used = mcts.run(return_simulation_cnt=True, return_time_used=True)
            distributions_2d.append(mcts.get_root_child_distribution_2d(normalize=True).tolist())
            print('{}({:.2f}sec)'.format( cur_node.hands, time_used), end=' ')
            cur_node = cur_node.get_node_after_playing(move=move)
        ####################
        # Save the trajectory
        ####################
        # Make dictionary to store trajectory information
        trajectory = {'boards': boards,
                      'properties': properties,
                      'hands': hands,
                      'distributions_2d': distributions_2d,
                      'winner': winner}
        # Use current time to be part of the filename (Low prob to collision)
        output_filename = datetime.now().strftime('%Y_%m_%d_%H_%M_%S') + '.json'
        output_path = os.path.join(CONFIG['path']['not_trained_trajectories'], output_filename)
        with open(output_path, 'w+') as file:
            json.dump(trajectory, file)
        trajectories.append(trajectory)
    return trajectories


### Main Controller

In [70]:
learner = Learner(replay_buffer=ReplayBuffer(buf_size=CONFIG['train']['buf_size'], frequency=CONFIG['train']['frequency']))

while True:
    #
    self_play_trajectories = actor()
    learner.replay_buffer.add_trajectories(trajectories=self_play_trajectories)
    #
    learner.train()



0(0.21sec) 1(0.22sec) 2(0.13sec) 3(0.14sec) 4(0.04sec) 5(0.06sec) 6(0.18sec) 7(0.12sec) 8(0.13sec) 9(0.07sec) 10(0.07sec) 11(0.05sec) 12(0.10sec) 13(0.04sec) 14(0.03sec) 15(0.03sec) 16(0.04sec) 17(0.03sec) 18(0.03sec) 19(0.03sec) 20(0.06sec) 21(0.78sec) 22(0.11sec) 23(0.10sec) 24(0.05sec) 25(0.03sec) 26(0.03sec) 27(0.03sec) 28(0.03sec) 29(0.03sec) 30(0.06sec) 31(0.17sec) 32(0.12sec) 33(0.12sec) 34(0.05sec) 35(0.05sec) 36(0.03sec) 37(0.13sec) 38(0.13sec) 39(0.09sec) 40(0.14sec) 41(0.10sec) 42(0.04sec) 43(0.11sec) 44(0.10sec) 45(0.07sec) 46(0.08sec) 47(0.09sec) 48(0.02sec) 49(0.03sec) 50(0.03sec) 51(0.03sec) 52(0.04sec) 53(0.02sec) 54(0.02sec) 55(0.04sec) 56(0.03sec) 57(0.02sec) 58(0.02sec) 59(0.02sec) 60(0.02sec) 61(0.02sec) 62(0.02sec) 63(0.02sec) 0(0.03sec) 1(0.02sec) 2(0.10sec) 3(0.06sec) 4(0.05sec) 5(0.05sec) 6(0.03sec) 7(0.05sec) 8(0.17sec) 9(0.03sec) 10(0.02sec) 11(0.16sec) 12(0.12sec) 13(0.06sec) 14(0.06sec) 15(0.04sec) 16(0.05sec) 17(0.15sec) 18(0.03sec) 19(0.02sec) 20(0.03sec) 

KeyboardInterrupt: 

In [None]:
6 * 6 * 64