In [2]:
from collections import deque

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from alphazero import PolicyValueNet, ChessBoard
from alphazero.self_play_dataset import SelfPlayDataSet, SelfPlayData
from alphazero.train import PolicyValueLoss


In [3]:
class GameDataset(Dataset):
    """ 自我博弈数据集类，每个样本为元组 `(feature_planes, pi, z)` """

    def __init__(self, data_list):
        super().__init__()
        self.__data_deque = deque(data_list)

    def __len__(self):
        return len(self.__data_deque)

    def __getitem__(self, index):
        return self.__data_deque[index]

    def clear(self):
        """ 清空数据集 """
        self.__data_deque.clear()


In [ ]:
device = torch.device('cuda:0')
chess_board = ChessBoard(board_len=7)

# 创建数网络
policy_value_net = PolicyValueNet(board_len=7, n_feature_planes=13, policy_output_dim=100, is_use_gpu=True)

# 创建优化器和损失函数
optimizer = Adam(policy_value_net.parameters(), lr=1e-2, weight_decay=1e-4)
criterion = PolicyValueLoss()

# self.lr_scheduler = MultiStepLR(self.optimizer, [1500, 2500], gamma=0.1)
lr_scheduler = ExponentialLR(optimizer, gamma=0.998)  # 0.998 ** 1000 = 0.135



In [6]:
data_list  = torch.load("./data/data_deque.pth")

In [11]:
dataset = GameDataset(data_list)
print(len(dataset))


609070


In [12]:
data_loader = DataLoader(dataset, batch_size=100, shuffle=True, drop_last=False)

In [13]:
policy_value_net.train()

loss_history = []

for epoch in range(20):
    p_bar = tqdm(enumerate(data_loader, 0), ncols=80)
    for i, data in p_bar:
        p_bar.set_description(f"Epoch {epoch + 1}: Batch {i + 1}")

        feature_planes, pi, z = data
        feature_planes, pi, z = feature_planes.to(device), pi.to(device), z.to(device)

        # 前馈
        p_hat, value = policy_value_net(feature_planes)
        # 梯度清零
        optimizer.zero_grad()
        # 计算损失
        loss = criterion(p_hat.float(), pi.float(), value.flatten().float(), z.float())
        # 误差反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 学习率退火
        lr_scheduler.step()

    print(f"Epoch {epoch + 1} Loss: {loss.item():.4f}")
    loss_history.append(loss.item())


Epoch 1


Batch 6091: : 6091it [02:01, 50.12it/s]


Epoch 1 Loss: 2.7633
Epoch 2


Batch 6091: : 6091it [01:47, 56.77it/s]


Epoch 2 Loss: 2.9978
Epoch 3


Batch 6091: : 6091it [01:49, 55.64it/s]


Epoch 3 Loss: 2.8063
Epoch 4


Batch 6091: : 6091it [01:49, 55.55it/s]


Epoch 4 Loss: 3.1907
Epoch 5


Batch 6091: : 6091it [01:59, 51.14it/s]


Epoch 5 Loss: 2.8691
Epoch 6


Batch 6091: : 6091it [01:55, 52.67it/s]


Epoch 6 Loss: 2.6527
Epoch 7


Batch 6091: : 6091it [01:46, 57.41it/s]


Epoch 7 Loss: 2.7245
Epoch 8


Batch 6091: : 6091it [01:49, 55.82it/s]


Epoch 8 Loss: 3.1121
Epoch 9


Batch 6091: : 6091it [01:44, 58.47it/s]


Epoch 9 Loss: 3.0050
Epoch 10


Batch 6091: : 6091it [01:44, 58.13it/s]


Epoch 10 Loss: 2.9269
Epoch 11


Batch 6091: : 6091it [01:39, 60.91it/s]


Epoch 11 Loss: 3.1434
Epoch 12


Batch 6091: : 6091it [01:42, 59.39it/s]


Epoch 12 Loss: 2.6415
Epoch 13


Batch 6091: : 6091it [01:43, 58.87it/s]


Epoch 13 Loss: 3.1103
Epoch 14


Batch 6091: : 6091it [01:44, 58.03it/s]


Epoch 14 Loss: 2.7369
Epoch 15


Batch 6091: : 6091it [01:44, 58.41it/s]


Epoch 15 Loss: 3.0235
Epoch 16


Batch 6091: : 6091it [01:42, 59.19it/s]


Epoch 16 Loss: 2.9932
Epoch 17


Batch 6091: : 6091it [01:42, 59.29it/s]


Epoch 17 Loss: 2.9419
Epoch 18


Batch 6091: : 6091it [01:42, 59.25it/s]


Epoch 18 Loss: 3.4173
Epoch 19


Batch 6091: : 6091it [01:43, 58.70it/s]


Epoch 19 Loss: 2.8582
Epoch 20


Batch 6091: : 6091it [01:44, 58.21it/s]


Epoch 20 Loss: 3.1229


In [14]:
torch.save(policy_value_net, "./data/policy_value_net_20.pth")