In [1]:
import numpy as np
import FIR_Env
import matplotlib.pyplot as plt
import torch as t
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
import tqdm
import os
import time

pygame 2.1.0 (SDL 2.0.16, Python 3.9.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
env = FIR_Env.FIR()
train_EPOCH = 10
total_EPOCH = 1000
Mcts_Sim = 200
BATCH_SIZE = 128
num_collect = 10
LR = 5e-4
C = 2
use_cuda = t.cuda.is_available()

In [3]:
class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=False):
        super(Residual, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        if downsample:
            self.downsample_ = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_features=out_channels)
            )
        else:
            self.downsample_ = None

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.downsample_:
            residual = self.downsample_(residual)

        out += residual
        out = self.relu(out)

        return out


def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    if first_block:
        assert in_channels == out_channels

    layers = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            layers.append(Residual(in_channels, out_channels, stride=2, downsample=True))
        else:
            layers.append(Residual(out_channels, out_channels))

    return nn.Sequential(*layers)


class ResNet(nn.Module):
    def __init__(self, in_channels=4):
        super(ResNet, self).__init__()

        self.board_size = env.observation_space.shape[0]

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        self.layer1 = resnet_block(64, 64, 2, first_block=True)
        self.layer2 = resnet_block(64, 128, 2)
        self.layer3 = resnet_block(128, 256, 2)
        self.layer4 = resnet_block(256, 512, 2)
        # x.shape = (batch_size, 512, 3, 3)

        # action policy layers
        self.action_fc1 = nn.Linear(2048, 256)
        self.action_fc2 = nn.Linear(256, env.action_space.n)

        # state value layers
        self.value_fc1 = nn.Linear(2048, 256)
        self.value_fc2 = nn.Linear(256, 1)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = x.view(x.size(0), -1)

        # action policy layers
        action_x = self.relu(self.action_fc1(x))
        action_x = F.softmax(self.action_fc2(action_x), dim=1)

        # state value layers
        value_x = self.relu(self.value_fc1(x))
        value_x = t.tanh(self.value_fc2(value_x))
        
        return action_x, value_x

In [4]:
class MCTS(object):
    def __init__(self, num_MCTS_Sim, net, c_puct=C):
        self.Qsa = {}  # stores Q values for s,a (as defined in the paper)
        self.Nsa = {}  # stores #times edge s,a was visited
        self.Ns = {}  # stores #times board s was visited
        self.Ps = {}  # stores initial policy (returned by neural net)
        self.Es = {}  # stores game.getGameEnded ended for board s
        self.Vs = {}  # stores game.getValidMoves for board s
        self.env = FIR_Env.FIR()
        self.net = net
        if use_cuda:
            self.net.cuda()
        self.num_MCTS_Sim = num_MCTS_Sim
        self.c_puct = c_puct
    
    def swap_state(self, state):
        state = np.array(state)
        temp = np.zeros_like(state)
        temp[0] = state[1]
        temp[1] = state[0]
        temp[2] = state[3]
        temp[3] = state[2]
        return temp
    
    def getActionProb(self, state, temp=1, player=1):
        if player == -1:
            state = self.swap_state(state)
        tt = tqdm.tqdm(range(self.num_MCTS_Sim), desc='MCTS')
        for _ in tt:
            self.search(state, player)
        s = self.env.state2str(state)
        counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.env.action_space.n)]
        if temp == 0:
            bestA = np.argmax(counts)
            probs = [0] * len(counts)
            probs[bestA] = 1
            return probs
        counts = [x ** (1. / temp) for x in counts]
        probs = [x / float(sum(counts)) for x in counts]
        return probs    
        
    def search(self, state, player=1):
        s = self.env.state2str(state)
        if s not in self.Es:
            self.Es[s] = self.env.getGameEnded(state, player)
        if self.Es[s] != 0:
            return -self.Es[s]
        if s not in self.Ps:
            if use_cuda:
                state = t.cuda.FloatTensor(state).unsqueeze(0)
            else:
                state = t.FloatTensor(state).unsqueeze(0)
            self.Ps[s], v = self.net(state)
            valids = self.env.get_valid_actions(s)
            self.Ps[s] = self.Ps[s].detach().cpu().numpy()[0] * valids
            sum_Ps_s = np.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s
            else:
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])
            self.Vs[s] = valids
            self.Ns[s] = 0
            return -v
        valids = self.Vs[s]
        cur_best = -float('inf')
        best_act = -1
        for a in range(self.env.action_space.n):
            if valids[a]:
                if (s, a) in self.Qsa:
                    u = self.Qsa[(s, a)] + self.c_puct * self.Ps[s][a] * np.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)])
                else:
                    u = self.c_puct * self.Ps[s][a] * np.sqrt(self.Ns[s] + 1e-8)
                if u > cur_best:
                    cur_best = u
                    best_act = a
        a = best_act
        next_state, next_player = self.env.getNextState(state, a, player)
        v = self.search(next_state, next_player)
        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1
        else:
            self.Qsa[(s, a)] = v
            self.Nsa[(s, a)] = 1
        self.Ns[s] += 1
        return -v

In [5]:
class Net():
    def __init__(self):
        self.nnet = ResNet()
        if use_cuda:
            self.nnet.cuda()
        self.board_x = env.observation_space.shape[0]
        self.board_y = env.observation_space.shape[1]
        self.action_size = env.action_space.n
        
    
    def train(self, histories):
        optimizer = optim.Adam(self.nnet.parameters(), lr=LR)
        for i in range(train_EPOCH):
            print('epoch: %d' % i)
            self.nnet.train()
            pi_losses = []
            v_losses = []
            batch_cnt = int(len(histories) / BATCH_SIZE)
            tq = tqdm.tqdm(range(batch_cnt), desc='Training Net')
            
            for _ in tq:
                sample_ids = np.random.randint(len(histories), size=BATCH_SIZE)
                states, probs, winner = list(zip(*[histories[i] for i in sample_ids]))
                if use_cuda:
                    states = t.cuda.FloatTensor(states)
                    probs = t.cuda.FloatTensor(probs)
                    winner = t.cuda.FloatTensor(winner)
                else:
                    states = t.FloatTensor(states)
                    probs = t.FloatTensor(probs)
                    winner = t.FloatTensor(winner)
                    

                optimizer.zero_grad()

                p, v = self.nnet(states)

                pi_loss = -t.sum(probs * t.log(p)) / probs.size()[0]
                v_loss = t.sum((v.view(-1) - winner) ** 2) / winner.size()[0]
                
                loss = pi_loss + v_loss
                
                pi_losses.append(pi_loss.item())
                v_losses.append(v_loss.item())
                loss.backward()
                optimizer.step()
                tq.set_postfix(pi_loss=np.mean(pi_losses), v_loss=np.mean(v_losses))
            print('epoch: %d pi_loss: %f, v_loss: %f' % (i+1, np.mean(pi_losses), np.mean(v_losses)))
    
    def save(self, path):
        t.save(self.nnet.state_dict(), path)
        print('save model to %s' % path)
        
        
    def load(self, path):
        if os.path.exists(path):
            self.nnet.load_state_dict(t.load(path))
            print('load model from %s' % path)
        else:
            print('model not exists')

In [6]:
class Coach(object):
    def __init__(self, net):
        self.net = net
        self.mcts = MCTS(Mcts_Sim, self.net.nnet)
        self.train_history = deque(maxlen=5000)
        self.env = FIR_Env.FIR()
    
    def swap_state(self, state):
        state = np.array(state)
        temp = np.zeros_like(state)
        temp[0] = state[1]
        temp[1] = state[0]
        temp[2] = state[3]
        temp[3] = state[2]
        return temp
        
    def execute_episode(self):
        state = self.env.reset()
        self.game_history = []
        cur_player = 1
        while True:
            probs = self.mcts.getActionProb(state, temp=1, player=cur_player)
            action = np.random.choice(len(probs), p=probs)
            next_state, cur_player = self.env.getNextState(state, action, cur_player)
            reward = self.env.getGameEnded(next_state, cur_player)
            temp = self.env.getSymmetries(state, probs)
            for s, p in temp:
                self.game_history.append([s, p])
                
            state = next_state
            if reward != 0:
                for i in range(len(self.game_history)):
                    if i % 2 == 0:
                        self.game_history[i].append(reward)
                    else:
                        self.game_history[i].append(-reward)
                        self.game_history[i][0] = self.swap_state(self.game_history[i][0])
                    
                break
        return self.game_history
    
    
    def learn(self):
        for i in range(total_EPOCH):
            for j in range(num_collect):
                print('epoch: %d, collect: %d' % (i+1, j+1))
                self.mcts = MCTS(Mcts_Sim, self.net.nnet)
                self.train_history.extend(self.execute_episode())
                # 清除mcts
                self.mcts = None
                print('history size: %d' % len(self.train_history))
            
            random.shuffle(self.train_history)
            
            self.net.train(self.train_history)
            # 保存模型，三个轮流
            if i % 3 == 0:
                self.net.save('saved_model1.pth')
            elif i % 3 == 1:
                self.net.save('saved_model2.pth')
            else:
                self.net.save('saved_model3.pth')
            
            
            

In [None]:
net = Net()
# 加载模型
load_model = True
if load_model:
    newest_time = 0
    newest_model = ""
    for i in range(1, 4):
        if os.path.exists('saved_model%d.pth' % i):
            ti = os.path.getmtime('saved_model%d.pth' % i)
            if ti > newest_time:
                newest_time = ti
                newest_model = 'saved_model%d.pth' % i
    if newest_model != "":
        net.load(newest_model)
coach = Coach(net)
coach.learn()

epoch: 1, collect: 1


MCTS: 100%|██████████| 200/200 [00:02<00:00, 66.86it/s] 
MCTS: 100%|██████████| 200/200 [00:01<00:00, 106.69it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 96.55it/s] 
MCTS: 100%|██████████| 200/200 [00:01<00:00, 104.86it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 95.14it/s]
MCTS: 100%|██████████| 200/200 [00:01<00:00, 101.32it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 91.72it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 88.54it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 92.33it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 93.05it/s] 
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.09it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 68.98it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 62.54it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.32it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.85it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.42it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.03it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:0

history size: 560
epoch: 1, collect: 2


MCTS: 100%|██████████| 200/200 [00:04<00:00, 47.70it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.39it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.13it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 61.77it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.02it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 63.32it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.91it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 59.27it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 64.74it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.10it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.67it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 66.51it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 59.34it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.27it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.47it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 67.38it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.17it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 62.

history size: 1208
epoch: 1, collect: 3


MCTS: 100%|██████████| 200/200 [00:03<00:00, 52.80it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.92it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 61.68it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.72it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 61.32it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 62.24it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.93it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.90it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.24it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.02it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.98it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.19it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 53.83it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 53.33it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.00it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.61it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.99it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.

history size: 1464
epoch: 1, collect: 4


MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.30it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.04it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.12it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.12it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.72it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.07it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.68it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.59it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.97it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.19it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.20it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.68it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.99it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.37it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.06it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 59.49it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.23it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.

history size: 1944
epoch: 1, collect: 5


MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.01it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.62it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.88it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.03it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.21it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.55it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.16it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.05it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.40it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.58it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 52.76it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.53it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.88it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.51it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.09it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.72it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.89it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.

history size: 2536
epoch: 1, collect: 6


MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.83it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 66.12it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.77it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 68.94it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.19it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.99it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 71.04it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.53it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.31it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.81it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 67.34it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 67.06it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 64.70it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.30it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 64.84it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 62.55it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.66it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.

history size: 3000
epoch: 1, collect: 7


MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.51it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 62.20it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 67.61it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 61.29it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.22it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.77it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.68it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.30it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 64.14it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 59.84it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 54.44it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.25it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 56.43it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.75it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 55.89it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.64it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.23it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 62.

history size: 3560
epoch: 1, collect: 8


MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.58it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 61.73it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.32it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.45it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.62it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.19it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 60.32it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 63.82it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 63.05it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 64.89it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.37it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.20it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.05it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 59.18it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.11it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 57.92it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 53.17it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.

history size: 3864
epoch: 1, collect: 9


MCTS: 100%|██████████| 200/200 [00:03<00:00, 66.33it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.35it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 67.29it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 63.51it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 62.96it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 59.03it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 74.97it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.68it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.98it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 65.41it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 58.47it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 62.27it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 71.31it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.86it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.30it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 64.77it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.83it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 71.

history size: 4408
epoch: 1, collect: 10


MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.17it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 72.22it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 71.31it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.86it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 71.25it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.57it/s]
MCTS: 100%|██████████| 200/200 [00:03<00:00, 66.66it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.38it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 72.89it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.44it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.37it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.92it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.14it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.68it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.17it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 74.27it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.23it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.

history size: 4872
epoch: 0


  states = t.cuda.FloatTensor(states)
Training Net: 100%|██████████| 38/38 [00:14<00:00,  2.64it/s, pi_loss=4.4, v_loss=1.97] 


epoch: 1 pi_loss: 4.395950, v_loss: 1.972605
epoch: 1


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.82it/s, pi_loss=4.37, v_loss=1.99]


epoch: 2 pi_loss: 4.366600, v_loss: 1.993421
epoch: 2


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.91it/s, pi_loss=4.32, v_loss=2.01]


epoch: 3 pi_loss: 4.321607, v_loss: 2.008224
epoch: 3


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.92it/s, pi_loss=4.28, v_loss=1.99]


epoch: 4 pi_loss: 4.275049, v_loss: 1.988487
epoch: 4


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.90it/s, pi_loss=4.23, v_loss=1.97]


epoch: 5 pi_loss: 4.231813, v_loss: 1.966283
epoch: 5


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.91it/s, pi_loss=4.17, v_loss=2]   


epoch: 6 pi_loss: 4.165147, v_loss: 1.995066
epoch: 6


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.86it/s, pi_loss=4.11, v_loss=2.03]


epoch: 7 pi_loss: 4.108300, v_loss: 2.032895
epoch: 7


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.85it/s, pi_loss=4.04, v_loss=1.98]


epoch: 8 pi_loss: 4.038033, v_loss: 1.979441
epoch: 8


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.91it/s, pi_loss=3.95, v_loss=2.02]


epoch: 9 pi_loss: 3.952951, v_loss: 2.018914
epoch: 9


Training Net: 100%|██████████| 38/38 [00:13<00:00,  2.88it/s, pi_loss=3.86, v_loss=2.04]


epoch: 10 pi_loss: 3.858413, v_loss: 2.040296
save model to saved_model1.pth
epoch: 2, collect: 1


MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.95it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.50it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.73it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.64it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 84.69it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 86.90it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 88.58it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.56it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.05it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.35it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.95it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.78it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.48it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.09it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.80it/s]
MCTS: 100%|██████████| 200/200 [00:00<00:00, 338.65it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.28it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80

history size: 5000
epoch: 2, collect: 2


MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.92it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.46it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.26it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 92.15it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 90.23it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 84.34it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.60it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.24it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 71.05it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.50it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.90it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.81it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.40it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.82it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.48it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.51it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.12it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.

history size: 5000
epoch: 2, collect: 3


MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.41it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 86.94it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.71it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.90it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 89.78it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 87.78it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 89.74it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 89.69it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 86.22it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 84.81it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.61it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.31it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.85it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.39it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.69it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.78it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.39it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.

history size: 5000
epoch: 2, collect: 4


MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.17it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.02it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.88it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.36it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 91.07it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.99it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.22it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.14it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 84.84it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.05it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 88.93it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.08it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 74.02it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.04it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.06it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.94it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.32it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.

history size: 5000
epoch: 2, collect: 5


MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.29it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.00it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 84.61it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.64it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.35it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.50it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.64it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 84.14it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 88.90it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 90.11it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 87.59it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.35it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.08it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.85it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.58it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.28it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.70it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.

history size: 5000
epoch: 2, collect: 6


MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.32it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.06it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.32it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.67it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.59it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.66it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 88.51it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.68it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 85.02it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.60it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.89it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.60it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.38it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.74it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.93it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.98it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.24it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.

history size: 5000
epoch: 2, collect: 7


MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.56it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.18it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 83.59it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.57it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 74.10it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 74.40it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 74.18it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.33it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.42it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.08it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 72.82it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 71.52it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.35it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.76it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 72.40it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 67.41it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.78it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.

history size: 5000
epoch: 2, collect: 8


MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.40it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.33it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 82.39it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.15it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.37it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.55it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.62it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 77.85it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.28it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.04it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.24it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 75.20it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.17it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 67.92it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 68.68it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.09it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 74.64it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.

history size: 5000
epoch: 2, collect: 9


MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.49it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.96it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.94it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.91it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.44it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 73.92it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.03it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.09it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.25it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 79.97it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 81.96it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 78.02it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 80.19it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.90it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 72.98it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 76.71it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 69.90it/s]
MCTS: 100%|██████████| 200/200 [00:02<00:00, 70.

In [None]:
import pygame

net = Net()
newest_time = 0
newest_model = ""
for i in range(1, 4):
    if os.path.exists('saved_model%d.pth' % i):
        ti = os.path.getmtime('saved_model%d.pth' % i)
        if ti > newest_time:
            newest_time = ti
            newest_model = 'saved_model%d.pth' % i
if newest_model != "":
    net.load(newest_model)
# 测试模型，与真人对战
mcts = MCTS(800, net.nnet)
env = FIR_Env.FIR()
state = env.reset()
cur_player = 1
while True:
    env.render(state)
    if cur_player == 1: # 人类玩家
        action = env.mouse_action(state)
    else:
        probs = mcts.getActionProb(state, temp=0, player=cur_player)
        action = np.argmax(probs)
    next_state, cur_player = env.getNextState(state, action, cur_player)
    env.render(next_state)
    reward = env.getGameEnded(next_state, cur_player)
    state = next_state
    if reward != 0:
        pygame.time.delay(1000)
        pygame.quit()
        print('winner: %d' % reward)
        break