In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import random
from mcts import MCT, Node
from net import PolicyValueNet
from src_net import init_state
from collections import deque

In [2]:
CHESSBOARD_SIZE = 6

In [3]:
def preprocess_data(data):
    for i in range(len(data)):
        actions, probs = data[i][1]
        action_prob = np.zeros((CHESSBOARD_SIZE, CHESSBOARD_SIZE), dtype=float)
        for j in range(len(actions)):
            action_prob[actions[j]] = probs[j]
        data[i][1] = action_prob

    augment_data = []
    for state, action_prob, v in data:
        for i in [1, 2, 3, 4]:
            equal_state = np.rot90(state, i, axes=(-2,-1))
            equal_action_prob = np.rot90(action_prob, i, axes=(-2, -1))

            augment_data.append([equal_state, equal_action_prob, v])
            augment_data.append([np.flip(equal_state, axis=(-1)), np.flip(equal_action_prob, axis=(-1)), v])        
    return augment_data

In [4]:
def train(net, loss, updater, data_buffer, batch_size, num_epochs):
    train_data = random.sample(data_buffer, batch_size)
    X = torch.tensor([data[0] for data in train_data]).float()
    y = (torch.tensor([data[1] for data in train_data]).float(), 
            torch.tensor([data[2] for data in train_data]).float())

    for epoch in range(num_epochs):
        y_pred = net(X)
        policy_loss, value_loss = loss(y_pred, y)
        l = policy_loss+value_loss
        updater.zero_grad()
        l.mean().backward()
        updater.step()

    return policy_loss.mean().item(), value_loss.mean().item() 

In [5]:
class PolicyValueLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.policy_loss = nn.CrossEntropyLoss(reduction='none')
        self.value_loss = nn.MSELoss(reduction='none')
    
    def forward(self, y_pred, y):
        batch_size = y_pred[0].shape[0]
        policy_loss = self.policy_loss(y_pred[0].reshape(batch_size, -1), y[0].reshape(batch_size, -1))
        value_loss = self.value_loss(y_pred[1], y[1])
        return policy_loss, value_loss

In [6]:
buffer_size = 20000
batch_size = 1024
num_epochs = 5
num_games = 1

data_buffer = deque(maxlen=buffer_size)

net = PolicyValueNet(CHESSBOARD_SIZE)
policy_value_loss = PolicyValueLoss()

cnt = 0

In [7]:
while(True):
    cnt += 1
    mct = MCT(root=Node(state_now=init_state(CHESSBOARD_SIZE)), policy_value_fn=net)

    data = []
    for _ in range(num_games):
        data.extend(mct.self_play(num=300))
    data_buffer.extend(preprocess_data(data))
    
    print(cnt, len(data))
    if(len(data_buffer)>=batch_size):
        updater = torch.optim.Adam(net.parameters(), lr=0.0075, weight_decay=1e-4)
        # updater = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9, weight_decay=1e-4)
        policy_loss, value_loss = train(net, policy_value_loss, updater, data_buffer, batch_size, num_epochs)
        print('loss:', policy_loss+value_loss, 'entropy_loss:', policy_loss)
        

  input = module(input)


1 32
2 31


In [None]:
torch.save(net.state_dict(), 'reversi'+str(CHESSBOARD_SIZE)+'.params')