In [27]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import random
from mcts import MCT, Node
from net import PolicyValueNet
from src_net import init_state
from collections import deque
import time

In [28]:
CHESSBOARD_SIZE = 6

In [29]:
def preprocess_data(data):
    for i in range(len(data)):
        actions, probs = data[i][1]
        action_prob = np.zeros((CHESSBOARD_SIZE, CHESSBOARD_SIZE), dtype=float)
        for j in range(len(actions)):
            action_prob[actions[j]] = probs[j]
        data[i][1] = action_prob

    augment_data = []
    for state, action_prob, v in data:
        for i in [1, 2, 3, 4]:
            equal_state = np.rot90(state, i, axes=(-2,-1))
            equal_action_prob = np.rot90(action_prob, i, axes=(-2, -1))

            augment_data.append([equal_state, equal_action_prob, v])
            augment_data.append([np.flip(equal_state, axis=(-1)), np.flip(equal_action_prob, axis=(-1)), v])        
    return augment_data

In [30]:
from torch.utils.data import Dataset, DataLoader
from d2l import torch as d2l

class MyDataset(Dataset):
    def __init__(self, data):
        self.X = torch.tensor(np.array([d[0] for d in data])).float()
        self.y = [(torch.tensor(np.array(d[1])).float(), torch.tensor(np.array(d[2])).float()) for d in data]

    def __getitem__(self, index):
        return (self.X[index], self.y[index])

    def __len__(self):
        return len(self.y)
    


def train_epoch(net, train_iter, loss, updater):
    net.train()

    metric = d2l.Accumulator(3)
    for X, y in train_iter:
        y_pred = net(X)
        policy_loss, value_loss = loss(y_pred, y)
        l = policy_loss+value_loss
        updater.zero_grad()
        l.mean().backward()
        updater.step()
        metric.add(float(policy_loss.sum()), float(value_loss.sum()), y_pred[0].shape[0])

    return metric[0]/metric[2], metric[1]/metric[2]


def train(net, loss, updater, data, batch_size, num_epochs, sample_size):
    train_dataset = MyDataset(random.sample(data, sample_size))
    train_iter = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    net.train()

    for epoch in range(num_epochs):
        policy_loss, value_loss = train_epoch(net, train_iter, loss, updater)
      
    return policy_loss, value_loss

In [31]:
class PolicyValueLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.policy_loss = nn.CrossEntropyLoss(reduction='none')
        self.value_loss = nn.MSELoss(reduction='none')
    
    def forward(self, y_pred, y):
        batch_size = y_pred[0].shape[0]
        policy_loss = self.policy_loss(y_pred[0].reshape(batch_size, -1), y[0].reshape(batch_size, -1))
        value_loss = self.value_loss(y_pred[1], y[1])
        return policy_loss, value_loss

In [32]:
buffer_size = 32*8*300
data_buffer = deque(maxlen=buffer_size)

batch_size = 1024
num_epochs = 5
num_games = 1
time_step = 5*60
def sample_size():
    return batch_size

lr = 2e-3
params = 'reversi'+str(CHESSBOARD_SIZE)+'.params'

net = PolicyValueNet(CHESSBOARD_SIZE)
# net.load_state_dict(torch.load(params))
policy_value_loss = PolicyValueLoss()

cnt = 0

In [33]:
while(True):
    cnt += 1
    mct = MCT(policy_value_fn=net)

    start = time.time()
    game_cnt = 0
    while(game_cnt<num_games):
        mct.set_root(Node(state_now=init_state(CHESSBOARD_SIZE)))
        data_buffer.extend(preprocess_data(mct.self_play(num=400)))
        game_cnt += 1
        

    updater = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)
    # updater = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    if(len(data_buffer)<sample_size()):
        continue
    policy_loss, value_loss = train(net, policy_value_loss, updater, data_buffer, batch_size, num_epochs, 
                                        sample_size=sample_size())

    print('step:', cnt, 'cost_time:', time.time()-start, 
                            'loss:', policy_loss+value_loss, 'entropy_loss:', policy_loss)
        

step: 4 cost_time: 123.63841080665588 loss: 4.372487366199493 entropy_loss: 3.57893967628479
step: 5 cost_time: 129.00777959823608 loss: 4.534661054611206 entropy_loss: 3.5765976905822754
step: 6 cost_time: 122.49259352684021 loss: 4.387017548084259 entropy_loss: 3.572873115539551
step: 7 cost_time: 127.82873725891113 loss: 4.188644826412201 entropy_loss: 3.571941614151001
step: 8 cost_time: 122.00814485549927 loss: 4.096157252788544 entropy_loss: 3.566882848739624
step: 9 cost_time: 125.86330652236938 loss: 4.04937943816185 entropy_loss: 3.561460494995117
step: 10 cost_time: 125.47502589225769 loss: 4.0442891120910645 entropy_loss: 3.558140993118286
step: 11 cost_time: 110.25218081474304 loss: 3.99556365609169 entropy_loss: 3.550175189971924
step: 12 cost_time: 116.82256388664246 loss: 4.065080165863037 entropy_loss: 3.549067497253418
step: 13 cost_time: 119.94454288482666 loss: 4.071060597896576 entropy_loss: 3.5499091148376465
step: 14 cost_time: 116.74259614944458 loss: 4.108671545

In [None]:
torch.save(net.state_dict(), 'reversi'+str(CHESSBOARD_SIZE)+'.params')