In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import random
from mcts import MCT, Node
from net import PolicyValueNet
from src_net import init_state
from collections import deque

In [2]:
CHESSBOARD_SIZE = 6

In [3]:
def preprocess_data(data):
    for i in range(len(data)):
        actions, probs = data[i][1]
        action_prob = np.zeros((CHESSBOARD_SIZE, CHESSBOARD_SIZE), dtype=float)
        for j in range(len(actions)):
            action_prob[actions[j]] = probs[j]
        data[i][1] = action_prob

    augment_data = []
    for state, action_prob, v in data:
        for i in [1, 2, 3, 4]:
            equal_state = np.rot90(state, i, axes=(-2,-1))
            equal_action_prob = np.rot90(action_prob, i, axes=(-2, -1))

            augment_data.append([equal_state, equal_action_prob, v])
            augment_data.append([np.flip(equal_state, axis=(-1)), np.flip(equal_action_prob, axis=(-1)), v])        
    return augment_data

In [4]:
def train(net, loss, updater, data_buffer, batch_size, num_epochs):
    train_data = random.sample(data_buffer, batch_size)
    X = torch.tensor(np.array([data[0] for data in train_data])).float()
    y = (torch.tensor(np.array([data[1] for data in train_data])).float(), 
            torch.tensor(np.array([data[2] for data in train_data])).float())

    old_action_prob = net(X)[0].detach().numpy()
    for epoch in range(num_epochs):
        y_pred = net(X)
        policy_loss, value_loss = loss(y_pred, y)
        l = policy_loss+value_loss
        updater.zero_grad()
        l.mean().backward()
        updater.step()
    new_action_prob = net(X)[0].detach().numpy()
    kl = np.mean(np.sum(old_action_prob*(np.log(old_action_prob+1e-10)-np.log(new_action_prob+1e-10))
                            ,axis=(-2, -1)))


    return policy_loss.mean().item(), value_loss.mean().item(), kl

In [5]:
class PolicyValueLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.policy_loss = nn.CrossEntropyLoss(reduction='none')
        self.value_loss = nn.MSELoss(reduction='none')
    
    def forward(self, y_pred, y):
        batch_size = y_pred[0].shape[0]
        policy_loss = self.policy_loss(y_pred[0].reshape(batch_size, -1), y[0].reshape(batch_size, -1))
        value_loss = self.value_loss(y_pred[1], y[1])
        return policy_loss, value_loss

In [6]:
buffer_size = 20000
batch_size = 1024
num_epochs = 5
num_games = 1

lr = 2e-3
lr_mul = 1.0
kl_targ = 0.02
params = 'reversi'+str(CHESSBOARD_SIZE)+'.params'

data_buffer = deque(maxlen=buffer_size)

net = PolicyValueNet(CHESSBOARD_SIZE)
net.load_state_dict(torch.load(params))
policy_value_loss = PolicyValueLoss()
# updater = torch.optim.SGD(net.parameters(), lr=0.2, momentum=0.9, weight_decay=1e-4)


cnt = 0

In [10]:
lr_mul = 0.5

In [11]:
while(True):
    cnt += 1
    mct = MCT(root=Node(state_now=init_state(CHESSBOARD_SIZE)), policy_value_fn=net)

    data = []
    for _ in range(num_games):
        data.extend(mct.self_play(num=1000))
    data_buffer.extend(preprocess_data(data))
    
    print(cnt, len(data))
    if(len(data_buffer)>=batch_size):
        updater = torch.optim.Adam(net.parameters(), lr=lr*lr_mul, weight_decay=1e-4)
        policy_loss, value_loss, kl = train(net, policy_value_loss, updater, 
                                                data_buffer, batch_size, num_epochs)

        # if(kl>2*kl_targ and lr_mul>0.1):
        #     lr_mul /= 1.5
        # if(kl<kl_targ/2 and lr_mul<10):
        #     lr_mul *= 1.5
            
        print('loss:', policy_loss+value_loss, 'entropy_loss:', policy_loss, 'lr_mul:', lr_mul)
        

42 32
loss: 3.3689829111099243 entropy_loss: 3.1398017406463623 lr_mul: 0.5
43 32
loss: 3.4024939239025116 entropy_loss: 3.1712169647216797 lr_mul: 0.5
44 32
loss: 3.378569111227989 entropy_loss: 3.1535322666168213 lr_mul: 0.5
45 32
loss: 3.3367519676685333 entropy_loss: 3.118351459503174 lr_mul: 0.5
46 32
loss: 3.324190080165863 entropy_loss: 3.117675304412842 lr_mul: 0.5
47 32
loss: 3.342716693878174 entropy_loss: 3.1409783363342285 lr_mul: 0.5
48 32
loss: 3.3460350334644318 entropy_loss: 3.1309895515441895 lr_mul: 0.5
49 32
loss: 3.3589348047971725 entropy_loss: 3.144484519958496 lr_mul: 0.5
50 32
loss: 3.405283272266388 entropy_loss: 3.159745454788208 lr_mul: 0.5
51 32
loss: 3.3643130362033844 entropy_loss: 3.1395933628082275 lr_mul: 0.5
52 32
loss: 3.3407169729471207 entropy_loss: 3.128758430480957 lr_mul: 0.5
53 32
loss: 3.3318770825862885 entropy_loss: 3.125093936920166 lr_mul: 0.5
54 32
loss: 3.322102725505829 entropy_loss: 3.12302565574646 lr_mul: 0.5


In [None]:
torch.save(net.state_dict(), 'reversi'+str(CHESSBOARD_SIZE)+'.params')