In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import random
from mcts import MCT, Node
from net import PolicyValueNet
from src_net import init_state
from collections import deque

In [2]:
CHESSBOARD_SIZE = 6

In [3]:
def preprocess_data(data):
    for i in range(len(data)):
        actions, probs = data[i][1]
        action_prob = np.zeros((CHESSBOARD_SIZE, CHESSBOARD_SIZE), dtype=float)
        for j in range(len(actions)):
            action_prob[actions[j]] = probs[j]
        data[i][1] = action_prob

    augment_data = []
    for state, action_prob, v in data:
        for i in [1, 2, 3, 4]:
            equal_state = np.rot90(state, i, axes=(-2,-1))
            equal_action_prob = np.rot90(action_prob, i, axes=(-2, -1))

            augment_data.append([equal_state, equal_action_prob, v])
            augment_data.append([np.flip(equal_state, axis=(-1)), np.flip(equal_action_prob, axis=(-1)), v])        
    return augment_data

In [4]:
from torch.utils.data import Dataset, DataLoader
from d2l import torch as d2l

class MyDataset(Dataset):
    def __init__(self, data):
        self.X = torch.tensor(np.array([d[0] for d in data])).float()
        self.y = [(torch.tensor(np.array(d[1])).float(), torch.tensor(np.array(d[2])).float()) for d in data]

    def __getitem__(self, index):
        return (self.X[index], self.y[index])

    def __len__(self):
        return len(self.y)
    


def train_epoch(net, train_iter, loss, updater):
    net.train()

    metric = d2l.Accumulator(3)
    for X, y in train_iter:
        y_pred = net(X)
        policy_loss, value_loss = loss(y_pred, y)
        l = policy_loss+value_loss
        updater.zero_grad()
        l.mean().backward()
        updater.step()
        metric.add(float(policy_loss.sum()), float(value_loss.sum()), y_pred[0].shape[0])

    return metric[0]/metric[2], metric[1]/metric[2]


def train(net, loss, updater, data, batch_size, num_epochs):
    train_dataset = MyDataset(data)
    train_iter = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    net.train()

    for epoch in range(num_epochs):
        policy_loss, value_loss = train_epoch(net, train_iter, loss, updater)
      
    return policy_loss, value_loss

In [5]:
class PolicyValueLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.policy_loss = nn.CrossEntropyLoss(reduction='none')
        self.value_loss = nn.MSELoss(reduction='none')
    
    def forward(self, y_pred, y):
        batch_size = y_pred[0].shape[0]
        policy_loss = self.policy_loss(y_pred[0].reshape(batch_size, -1), y[0].reshape(batch_size, -1))
        value_loss = self.value_loss(y_pred[1], y[1])
        return policy_loss, value_loss

In [6]:
buffer_size = 12000
batch_size = 512
num_epochs = 2
num_games = 40

lr = 2e-3
params = 'reversi'+str(CHESSBOARD_SIZE)+'.params'

data_buffer = deque(maxlen=buffer_size)
net = PolicyValueNet(CHESSBOARD_SIZE)
net.load_state_dict(torch.load(params))
policy_value_loss = PolicyValueLoss()

cnt = 0

In [7]:
while(True):
    cnt += 1
    mct = MCT(root=Node(state_now=init_state(CHESSBOARD_SIZE)), policy_value_fn=net)

    for _ in range(num_games):
        data_buffer.extend(preprocess_data(mct.self_play(num=400)))
        mct.root = Node(state_now=init_state(CHESSBOARD_SIZE))

    updater = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)
    # updater = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    policy_loss, value_loss = train(net, policy_value_loss, updater, data_buffer, batch_size, num_epochs)

    print('step:', cnt, 'loss:', policy_loss+value_loss, 'entropy_loss:', policy_loss)
        

  input = module(input)


In [None]:
torch.save(net.state_dict(), 'reversi'+str(CHESSBOARD_SIZE)+'.params')