In [1]:
import os

import torch
import torch.optim as optim
import torch.nn.functional as F

import Othello

from tqdm.auto import trange

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# Policy Gradient

$$
\max J(\pi)=\int_\tau p(\tau \mid \pi)R(\tau) =\mathbb E_{\tau\sim\pi}[R(\tau)]
$$

$$
\begin{align}\nabla_\theta J(\pi_\theta) &= \int_\tau \nabla_\theta p(\tau \mid \pi_\theta)R(\tau) \\%&= \int_\tau  p(\tau \mid \pi_\theta)\frac {\nabla_\theta p(\tau \mid \pi_\theta)}{ p(\tau \mid \pi_\theta)}R(\tau) \\ 
&= \mathbb E_{\tau\sim\pi_\theta}\left[\nabla_\theta \ln p(\tau \mid \pi_\theta)R(\tau)\right] \\ 
&= \mathbb E_{\tau\sim\pi_\theta}\left[\nabla_\theta \left(\ln \rho_0(S_0) + \sum_{t=0}^T\ln p(S_{t+1} \mid S_t,A_t) + \sum_{t=0}^T\ln \pi_\theta(A_t \mid S_t)\right)R(\tau)\right] \\ 
\nabla_\theta J(\pi_\theta) &= \mathbb E_{\tau\sim\pi_\theta}\left[ \sum_{t=0}^T \nabla_\theta \ln \pi_\theta(A_t \mid S_t)R(\tau)\right] \\\end{align}
$$
					

In [2]:
if torch.cuda.is_available():
    device = 'cuda'
else: device = 'cpu'

# the setting of game
n = 8
aim = "win" # "loss"

# model setting
bone = "ResNet"

# train parameters
n_epoch = 100000
batch_size = 32
lr = 1e-3

# Gradient policy Coefficients
beta = 0.5 # the weight of KL

In [3]:
envs = Othello.Envs(n, batch_size, device)

actor = Othello.Actor(n, bone).to(device)

opt = optim.Adam(actor.parameters(), lr=lr) 

In [7]:
TIMESTAMP = f"{datetime.now():%Y-%m-%dT%H-%M-%S}"

writer = SummaryWriter(log_dir=f"tf_dir/n={n}_beta={beta}_bone={bone}_{aim}/{TIMESTAMP}/")

In [8]:
def Loss(agents, envs, n, beta=0.1, aim="win"):
    
    if isinstance(agents, list):
        assert len(agents) == 2, "Othello requires two player" 
    else:
        agents = [agents, agents] # self-playing game
    
    for agent in agents:
        if isinstance(agent, torch.nn.Module):
            device = next(agent.parameters()).device
        
    batch_size = envs.batch_size
    rounds = n * (n + 1) 
    
    obs, reward, terminated, _, _ = envs.reset()

    actions = torch.zeros(batch_size, 2, rounds//2, device=device).type(torch.int64) # the chosen actions
    rewards = torch.zeros(batch_size, 2, rounds//2, device=device)
    terminateds = torch.zeros(batch_size, 2, rounds//2, device=device)
    probs = torch.zeros(batch_size, 2, rounds//2, device=device) # the probs of the chosen actions
    KLs = torch.zeros(batch_size, 2, rounds//2, device=device) # KL control the distribution of probs on actions

    for i in range(rounds//2): 
        for p in range(2): 

            valid = obs[-1]

            agent = agents[p]

            prob = agent(obs).float() * valid # change the prob of invalid moves to 0

            # randomly choose the next move on the prob distributation
            actions[:, p, i] = torch.multinomial(prob, num_samples=1).view(-1)
            probs[:, p, i] = (F.one_hot(actions[:, p, i], num_classes=n*n+1) * prob).sum(-1)

            # save reward & terminated
            rewards[:, p, i] = reward
            terminateds[:, p, i] = terminated

            # cross entropy
            PriDis = ((1 - (valid == 0).sum(1) * 1e-6) / valid.sum(1))[:,None] * valid + 1e-6 * (valid == 0) # invalid move for 1e-6
            KLs[:, p, i] = (prob * (torch.log(prob + 1e-6) - torch.log(PriDis))).sum(1) # add 1e-6 to prevent -inf

            # next step
            act = F.one_hot(actions[:, p, i], num_classes=n*n+1)
            obs, reward, terminated, _, _ = envs.step(act, restart=False)

    # caculate hat_R
    R = torch.zeros(batch_size, 2, rounds//2, device=device)
    tmp = torch.zeros(batch_size, 2, device=device)

    for i in range(rounds//2-1, -1, -1): 
        for p in range(2): 

            R[:, p, i] = rewards[:, p, i] + (1 - terminateds[:, p, i]) * tmp[:, p]
        
        tmp = R[:, :, i]

    # optimize
    if aim == "win":
        PGs = -R.detach() * torch.log(probs)
    elif aim == "loss":
        PGs = R.detach() * torch.log(probs)
    else: raise ValueError("aim must be 'win' or 'loss'")
    
    Loss = PGs.sum() + beta * KLs.sum()

    return Loss, {"PGloss": PGs[:, 0, :].sum().item() / batch_size, "KLloss": KLs[:, 0, :].sum().item() / batch_size} # return Loss, and detailed loss of the black

In [9]:
@torch.no_grad()
def Test(agents, envs, n, times=10):
    
    if isinstance(agents, list):
        assert len(agents) == 2, "Othello requires two player" 
    else:
        agents = [agents, agents] # self-playing game
    
    for agent in agents:
        if isinstance(agent, torch.nn.Module):
            device = next(agent.parameters()).device
    
    batch_size = envs.batch_size
    rounds = n * (n + 1) 
    
    blackWin = 0
    whiteWin = 0
    
    for k in range(times):
    
        obs, reward, _, _, _ = envs.reset()

        actions = torch.zeros(batch_size, 2, rounds//2, device=device).type(torch.int64) # the chosen actions
        rewards = torch.zeros(batch_size, 2, rounds//2, device=device)

        for i in range(rounds//2): 
            for p in range(2): 

                valid = obs[-1]

                agent = agents[p]

                prob = agent(obs).float() * valid # change the prob of invalid moves to 0

                # randomly choose the next move with the biggest prob
                actions[:, p, i] = prob.max(-1)[1].view(-1)

                # save reward
                rewards[:, p, i] = reward

                # next step
                act = F.one_hot(actions[:, p, i], num_classes=n*n+1)
                obs, reward, terminated, _, _ = envs.step(act, restart=False)

        black_win = (rewards[:, 0].sum(-1) > 0).sum() / batch_size
        white_win = (rewards[:, 1].sum(-1) > 0).sum() / batch_size
        
        blackWin += black_win / times
        whiteWin += white_win / times
    
    tie = 1 - blackWin - whiteWin

    return {"Prob of black win": blackWin.item(), "Prob of white win": whiteWin.item(), "Prob of tie": tie.item()}

In [10]:
res_train = {'PGloss': 0, 'KLloss': 0}

for k in trange(n_epoch): 
    
    actor = actor.train()
    
    loss, detail = Loss(actor, envs, n, beta=beta, aim=aim)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    res_train['PGloss'] += detail['PGloss'] / 100
    res_train['KLloss'] += detail['KLloss'] / 100
        
    if k % 100 == 0 and k > 0:
        
        actor = actor.eval()
        
        res = Test([actor, envs.randomAction], envs, n, times=10)
        print('Black Agent, White Random:') 
        for key, value in res.items():
            print(f"{key}: {value:7.2%}", end='; ')
        print('\n')
        
        writer.add_scalar("Agent Black Win", res["Prob of black win"], k)
        
        res = Test([envs.randomAction, actor], envs, n, times=10)
        print('Black Random, White Agent:') 
        for key, value in res.items():
            print(f"{key}: {value:7.2%}", end='; ')
        print('\n')
        
        writer.add_scalar("Agent White Win", res["Prob of white win"], k)
        
        for key, value in res_train.items():
            writer.add_scalar(key, value, k)
            res_train[key] = 0
            
    if k % (n_epoch // 40) == 0:
            
        if not os.path.exists(f"model/n={n}_beta={beta}_bone={bone}/{TIMESTAMP}"):
            os.makedirs(f"model/n={n}_beta={beta}_bone={bone}/{TIMESTAMP}")
        
        torch.save(actor.state_dict(), f"model/n={n}_beta={beta}_bone={bone}/{TIMESTAMP}/epoch{k}.pth")

  0%|          | 0/100000 [00:00<?, ?it/s]

Black Agent, White Random:
Prob of black win:  85.00%; Prob of white win:  13.13%; Prob of tie:   1.88%; 

Black Random, White Agent:
Prob of black win:  10.63%; Prob of white win:  86.88%; Prob of tie:   2.50%; 

Black Agent, White Random:
Prob of black win:  85.63%; Prob of white win:  12.81%; Prob of tie:   1.56%; 

Black Random, White Agent:
Prob of black win:  11.56%; Prob of white win:  85.94%; Prob of tie:   2.50%; 

Black Agent, White Random:
Prob of black win:  84.38%; Prob of white win:  12.19%; Prob of tie:   3.44%; 

Black Random, White Agent:
Prob of black win:   9.06%; Prob of white win:  88.75%; Prob of tie:   2.19%; 

Black Agent, White Random:
Prob of black win:  88.12%; Prob of white win:   9.38%; Prob of tie:   2.50%; 

Black Random, White Agent:
Prob of black win:   6.25%; Prob of white win:  92.19%; Prob of tie:   1.56%; 

Black Agent, White Random:
Prob of black win:  84.69%; Prob of white win:  12.81%; Prob of tie:   2.50%; 

Black Random, White Agent:
Prob of bl