In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from mcts import MCT, Node
from net import PolicyValueNet
from src_net import init_state

In [2]:
CHESSBOARD_SIZE = 6
BUFFER_SIZE = 1000

In [3]:
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X).float()
        self.y = []
        for (actions, probs), v in y:
            action_prob = torch.zeros((CHESSBOARD_SIZE, CHESSBOARD_SIZE)).float()
            for i in range(len(actions)):
                action_prob[actions[i]] = probs[i]
            
            self.y.append((action_prob, torch.tensor(v).float()))

    def __getitem__(self, index):
        return (self.X[index], self.y[index])

    def __len__(self):
        return self.X.shape[0]

In [4]:
from torch.utils.data import DataLoader
from d2l import torch as d2l

def load_data(data, batch_size):
    X, y = data['feature'], data['label']
    train_dataset = MyDataset(X, y)
    train_iter = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    return train_iter


def train_epoch(net, train_iter, loss, updater):
    net.train()

    metric = d2l.Accumulator(2)
    for X, y in train_iter:
        y_hat = net(X)
        l = loss(y_hat, y)
        updater.zero_grad()
        l.backward()
        updater.step()
        metric.add(float(l), y_hat[0].shape[0])


    return metric[0] / metric[1]

def train(net, train_iter, loss, num_epochs, updater):
    # animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],ylim=[0.0, 4.0], legend=['train loss'])

    loss_begin = None
    for epoch in range(num_epochs):
        train_loss = train_epoch(net, train_iter, loss, updater)
        if(loss_begin is None):
            loss_begin = train_loss
        # animator.add(epoch+1, train_loss)
    
    return loss_begin, train_loss

In [5]:
def policy_value_loss(y_hat, y):
    policy_loss = F.cross_entropy(y_hat[0], y[0], size_average=False)
    value_loss = F.mse_loss(y_hat[1], y[1], size_average=False)
    return policy_loss+value_loss

In [6]:
net = PolicyValueNet(CHESSBOARD_SIZE)
mct = MCT(policy_value_fn=net)
# updater = torch.optim.SGD(net.parameters(), lr=0.01)
# updater = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-3)
# updater = torch.optim.Adam(net.parameters(), lr=0.01)

In [7]:
# data = {'feature':[], 'label':[]}
# for _ in range(10):
#     root = Node(state_now=init_state(CHESSBOARD_SIZE))  
#     mct.root = root

#     features, labels = mct.self_play()
#     data['feature'].extend(features)
#     data['label'].extend(labels)

In [8]:
# train_iter = load_data(data, 64)
# updater = torch.optim.Adam(net.parameters(), lr=0.01)
# loss_begin, loss_end = train(net, train_iter, policy_value_loss, 5, updater)

In [9]:
while(True):
    data = {'feature':[], 'label':[]}
    for _ in range(20):
        root = Node(state_now=init_state(CHESSBOARD_SIZE))  
        mct.root = root

        features, labels = mct.self_play(300)
        data['feature'].extend(features)
        data['label'].extend(labels)

    train_iter = load_data(data, 64)
    # updater = torch.optim.Adam(net.parameters(), lr=0.05)
    updater = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    loss_begin, loss_end = train(net, train_iter, policy_value_loss, 10, updater)
    with open("loss.txt","a") as f:
        f.write(str(loss_begin)+' '+str(loss_end)+'\n')


  input = module(input)
