In [1]:
from game import * #game imported env
from copy import deepcopy

In [2]:
c_puct = 0.25
N = 3
eps = 1e-5

In [79]:
class PUCTnode(object):
    #after state node with PUCT value, child, parent, recorded
    def __init__(self, parent, player, action):
        self.parent = parent
        self.player = player
        self.action = action
        #non zero value for stabilization
        self.N = 2 #visit count
        self.W = 1 #winning count
        self.P = 0
        self.children = list()
        self.cidx = [-1]*(N*N)
        
    def score(self, c = c_puct): #the PUCT score
        return (self.W/self.N) + c*self.P*np.sqrt(self.parent.N)/(1+self.N)
    
    def best_child(self):
        mchildren = None
        mscore = -1 #score>0
        for child in self.children:
            if child.score() - mscore > eps:
                mchildren = [child]
                mscore = child.score()
            elif abs(child.score() - mscore) < eps:
                mchildren.append(child)
        if mchildren == None:
            return None
        mc = np.random.choice(mchildren)
        return mc

In [98]:
class PUCT_Agent(object):
    def __init__(self):
        pass
    def gen_move(self, env, simulation_cnt = 900):
        tree = PUCT(env)
        for ep in range(simulation_cnt):
            tree.run_cycle()
            #print("cycle\n", ep)
        p = tree.getPi(tree.root)
        return np.argmax(p)
        

class PUCT(object):
    def __init__(self, env, N=N):
        self.n = N
        self.n2 = N*N
        self.renv = deepcopy(env)
        self.simu_env = deepcopy(env)
        self.root = PUCTnode(None, -self.renv.player_to_move(), N*N) #
        self.expand(self.root) #watch out hirozen effect
        self.model = Random_Agent()
    
    def reset(self):
        self.simu_env = deepcopy(self.renv)
    
    def clear(self):
        self.renv = deepcopy(env)
        self.simu_env = deepcopy(self.renv)
        self.root = PUCTnode(None, -1, N*N)
    
    def expand(self, node):
        if len(node.children) != 0:
            raise AlgoError("want to expand already expanded node")
            return 
        for action in self.simu_env.valid_actions():
            node.cidx[action] = len(node.children)
            node.children.append(PUCTnode(node, self.simu_env.player_to_move(), action))
        return node.children != None
    
    def simulate(self):
        env = self.simu_env
        while env.status == -2:
            pos = self.model.gen_move(env) #forget to add color in attributes
            env.take_action(pos)
        return env.status()
    
    def run_cycle(self):
        #reset
        self.reset()
        env = self.simu_env
        #selection 
        cur = self.root
        while cur.best_child() != None:
            cur = cur.best_child()
            env.take_action(cur.action)
        #expansion
        if(self.expand(cur)): #if not end
            pos = self.model.gen_move(env) #forget to add color in attributes
            cur = cur.children[cur.cidx[pos]]
            env.take_action(pos)
            
        #    env.take_action(cur.action)
        #simulation
        winner = self.simulate()
        #backpropagation
        while True:
            cur.N += 1
            if cur.player == winner:
                cur.W += 1
            elif cur.player == -winner:
                cur.W -= 1
            if cur.parent == None:
                break
            else:
                cur = cur.parent
    
    def getPi(self, node):
        c_visit = [0.0 if (node.cidx[i] == -1) else node.children[node.cidx[i]].N for i in range(N*N)]
        print(c_visit)
        #if(node.children):
        #    assert(int(sum(c_visit)-len(node.children)*10) == node.N-10 )
        c_p = [x/node.N for x in c_visit]
        return c_p

In [99]:
pp = PUCT_Agent()
pr = Random_Agent()
pr2 = Random_Agent()
pp2 = PUCT_Agent()
pp3 = PUCT_Agent()
ph = Human_Agent()
g = Game(pp, pp2)
#g2 = Game(pr, pr2)
g3 = Game(pp3, ph)

In [100]:
#raise NameError('HiThere')

In [101]:
play = g.play()

IndexError: list index out of range

In [None]:
print(play)


In [102]:
Game.display(play)

[['o' '.' '.']
 ['.' '.' '.']
 ['.' '.' '.']]
[['o' '.' '.']
 ['.' 'x' '.']
 ['.' '.' '.']]
[['o' '.' 'o']
 ['.' 'x' '.']
 ['.' '.' '.']]
[['o' 'x' 'o']
 ['.' 'x' '.']
 ['.' '.' '.']]
[['o' 'x' 'o']
 ['.' 'x' 'o']
 ['.' '.' '.']]
[['o' 'x' 'o']
 ['.' 'x' 'o']
 ['.' '.' 'x']]
[['o' 'x' 'o']
 ['.' 'x' 'o']
 ['o' '.' 'x']]
[['o' 'x' 'o']
 ['x' 'x' 'o']
 ['o' '.' 'x']]
[['o' 'x' 'o']
 ['x' 'x' 'o']
 ['o' 'o' 'x']]


In [None]:
grr = Game(pr,ph)
a = grr.play()

In [None]:
Game.display(a)