In [1]:
import os

import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import Othello

from tqdm.auto import trange

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

from collections import deque

# Proximal Policy Optimization

$$
\begin{align}
\nabla_\theta J(\pi_\theta) 
&= \int_\tau \nabla_\theta p(\tau \mid \pi_\theta)R(\tau) \\
%
&= \int_\tau \frac{p(\tau \mid \pi_{\theta_{old}})}{p(\tau \mid \pi_{\theta_{old}})} \nabla_\theta p(\tau \mid \pi_\theta)R(\tau) \\
%
&= \int_\tau p(\tau \mid \pi_{\theta_{old}})\frac{\nabla_\theta p(\tau \mid \pi_\theta)}{p(\tau \mid \pi_{\theta_{old}})} R(\tau) \\
&= E_{\tau\sim\pi_{\theta_{old}}}[\frac{\nabla_\theta p(\tau \mid \pi_\theta)}{p(\tau \mid \pi_{\theta_{old}})}R(\tau)] \\
&= E_{\tau\sim\pi_{\theta_{old}}}[\frac{\nabla_\theta\rho_0(S_0) \prod_{t=0}^Tp(S_{t+1} \mid S_t,A_t)\pi_\theta(A_t \mid S_t)}{\rho_0(S_0) \prod_{t=0}^Tp(S_{t+1} \mid S_t,A_t)\pi_{\theta_{old}}(A_t \mid S_t)}R(\tau)] \\
&= E_{\tau\sim\pi_{\theta_{old}}}[\frac{ \nabla_\theta\prod_{t=0}^T\pi_\theta(A_t \mid S_t)}{ \prod_{t=0}^T\pi_{\theta_{old}}(A_t \mid S_t)}R(\tau)] \\
\nabla_\theta J(\pi_\theta)&\approx E_{\tau\sim\pi_{\theta_{old}}}[\sum_{t=0}^T \frac{ \nabla_\theta\pi_\theta(A_t \mid S_t)}{\pi_{\theta_{old}}(A_t \mid S_t)}R(\tau)]
\end{align}
$$

Note the last $\approx$, which requires $\pi_{\theta}(A_t \mid S_t) \approx \pi_{\theta_{old}}(A_t \mid S_t)$ to hold, so the PPO needs to ensure that the new strategy is within a certain gap of the old strategy. 

PPO tried both KL dispersion and direct truncation methods, and the latter proved to be simpler and more efficient.

$$
\nabla_\theta J^{CLIP}(\pi_\theta) = \begin{cases}
E_{\tau\sim\pi_{\theta_{old}}}[\sum_{t=0}^T \nabla_\theta \min(r_t(\theta), 1+\epsilon) R(\tau)], & \text{if }R(\tau)\ge0\\
E_{\tau\sim\pi_{\theta_{old}}}[\sum_{t=0}^T \nabla_\theta \max(r_t(\theta), 1-\epsilon)R(\tau)], & \text{if }R(\tau)<0\\
\end{cases}
$$


In [2]:
if torch.cuda.is_available():
    device = 'cuda'
else: device = 'cpu'

# the setting of game
n = 8
aim = "win" # "loss"

# model setting 
bone = "ResNet" # ResNet for lightweight model, ViT for large model
config = {"dims": [1, 64, 128, 256, 128, 64, 1], "kernel_size": 5}

# train parameters
n_epoch = 40000
batch_size_envs = 128
batch_size_ppo = 2048
lr_actor = 1e-4
# lr_critic = 1e-4

# PPO parameters
reuse_time = 2
epsilon = 0.1 # the thresholds of clip
beta = 0.5 # the weight of KL

path = f"PPO_n={n}_epsilon={epsilon}_beta={beta}_bone={bone}_{aim}"

In [3]:
envs = Othello.Envs(n, batch_size_envs, device)

actor = Othello.Actor(n, bone, **config).to(device)
# critic = Othello.Critic(n, bone).to(device)

opt = optim.Adam(actor.parameters(), lr=lr_actor) 
# opt_critic = optim.Adam(critic.parameters(), lr=lr_critic)

In [4]:
baseline = Othello.Actor(n, 'ResNet').to(device)
baseline.load_state_dict(
        torch.load(f"model/n={n}_beta=0.5_bone=ResNet_{aim}/epoch_bestVsRandom.pth")
    )

<All keys matched successfully>

In [5]:
TIMESTAMP = f"{datetime.now():%Y-%m-%dT%H-%M-%S}"

writer = SummaryWriter(log_dir=f"tf_dir/{path}/{TIMESTAMP}/")

In [6]:
class ReplayBuffer:
    def __init__(self, envs, agents, batch_num):
        self.envs = envs
        self.n = envs.n
        self.batch_size = envs.batch_size

        self.agents = agents
        if isinstance(agents, list):
            assert len(agents) == 2, "Othello requires two players"
        else:
            self.agents = [agents, agents]  # self-playing game

        for agent in self.agents:
            if isinstance(agent, torch.nn.Module):
                self.device = next(agent.parameters()).device

        self.rounds = self.n * (self.n + 1)
        self.batch_num = batch_num

        self.data = deque(maxlen=batch_num * self.batch_size * self.rounds)  # S A P R

        for _ in range(batch_num):
            self.__update()
            
    @torch.no_grad()
    def __update(self):
        
        for agent in self.agents:
            agent = agent.eval()

        batch_size = self.batch_size

        rounds = self.rounds

        obs, reward, terminated, _, _ = self.envs.reset()

        boards = torch.zeros(
            (batch_size, 2, rounds // 2, self.n, self.n), device=self.device
        )
        valids = torch.zeros(
            (batch_size, 2, rounds // 2, self.n * self.n + 1), device=self.device
        )  # boards + valids == S
        actions = torch.zeros(batch_size, 2, rounds // 2, device=self.device).type(
            torch.int64 
        )  # the chosen actions
        rewards = torch.zeros(batch_size, 2, rounds // 2, device=self.device)  # rewards
        terminateds = torch.zeros(
            batch_size, 2, rounds // 2, device=self.device
        )  # whether the game is terminated
        probs = torch.zeros(
            batch_size, 2, rounds // 2, device=self.device
        )  # the probs of the chosen actions

        for i in range(rounds // 2):
            for p in range(2):

                boards[:, p, i] = obs[0]
                valids[:, p, i] = obs[1]

                valid = obs[1]

                agent = self.agents[p]

                prob = agent(obs).float() * valid 
                # change the prob of invalid moves to 0

                # randomly choose the next move on the prob distributation
                actions[:, p, i] = torch.multinomial(prob, num_samples=1).view(-1)
                probs[:, p, i] = (
                    F.one_hot(actions[:, p, i], num_classes=n * n + 1) * prob
                ).sum(-1)

                # save reward & terminated
                rewards[:, p, i] = reward
                terminateds[:, p, i] = terminated

                # next step
                act = F.one_hot(actions[:, p, i], num_classes=n * n + 1)
                obs, reward, terminated, _, _ = envs.step(act, restart=False)

        # caculate hat_R
        Reward = torch.zeros(batch_size, 2, rounds // 2, device=self.device)
        tmp = torch.zeros(batch_size, 2, device=self.device)

        for i in range(rounds // 2 - 1, -1, -1):
            for p in range(2):

                Reward[:, p, i] = (
                    rewards[:, p, i] + (1 - terminateds[:, p, i]) * tmp[:, p]
                )  # Truncate the reward when the game is terminated

            tmp = Reward[:, :, i]

        boards = boards.flatten(0, 2).detach().cpu()
        valids = valids.flatten(0, 2).detach().cpu()
        actions = actions.flatten(0, 2).detach().cpu()
        probs = probs.flatten(0, 2).detach().cpu()
        Reward = Reward.flatten(0, 2).detach().cpu()
        
        for i in range(len(boards)):
            self.data.append((boards[i], valids[i], actions[i], probs[i], Reward[i]))
                    
    def update(self, n=1):

        for _ in range(n):
            self.__update()

    def getDataset(self):

        class ArchiveDataset(Dataset):
            def __init__(self, data):
                self.data = data

            def __len__(self):
                return len(self.data)

            def __getitem__(self, idx):
                
                b, v, A, P, R = self.data[idx]
                
                return ((b, v), A, P, R)

        return ArchiveDataset(self.data)

In [7]:
def PPOloss(agent, data, epsilon, beta):
    
    device = next(agent.parameters()).device
    
    S, A, P, R = data
    S, A, P, R = (S[0].to(device), S[1].to(device)), A.to(device), P.to(device), R.to(device)
    
    probs = agent(S)
    prob = torch.gather(probs, dim=1, index=A.reshape(-1, 1)).reshape(-1)
    # prob = probs.gather(1, A)
    
    ratio = prob / P
    
    L = torch.min(ratio * R, torch.clip(ratio, 1-epsilon, 1+epsilon) * R)
    
    # cross entropy
    valid = S[-1]
    PriDis = ((1 - (valid == 0).sum(1) * 1e-6) / valid.sum(1))[:,None] * valid + 1e-6 * (valid == 0) # invalid move for 1e-6
    KLs = (probs * (torch.log(probs + 1e-6) - torch.log(PriDis))).sum(1) # add 1e-6 to prevent -inf
    
    # Add negative sign due to gradient descent 
    if aim == "win":
        L = -L
    elif aim == "loss":
        L = L
        
    return L.mean() + beta * KLs.mean(), {"KLloss": KLs.mean().item()}
    

In [8]:
@torch.no_grad()
def Test(agents, envs, n, times=10, strategy='biggest'):
    
    if isinstance(agents, list):
        assert len(agents) == 2, "Othello requires two players" 
    else:
        agents = [agents, agents] # self-playing game
    
    for agent in agents:
        if isinstance(agent, torch.nn.Module):
            device = next(agent.parameters()).device
    
    batch_size = envs.batch_size
    rounds = n * (n + 1) 
    
    blackWin = 0
    whiteWin = 0
    
    for k in range(times):
    
        obs, reward, _, _, _ = envs.reset()

        actions = torch.zeros(batch_size, 2, rounds//2, device=device).type(torch.int64) # the chosen actions
        rewards = torch.zeros(batch_size, 2, rounds//2, device=device)

        for i in range(rounds//2): 
            for p in range(2): 

                valid = obs[-1]

                agent = agents[p]

                prob = agent(obs).float() * valid # change the prob of invalid moves to 0

                if strategy == "biggest":
                    # choose the next move with the biggest prob
                    actions[:, p, i] = prob.max(-1)[1].view(-1)
                elif strategy == "random":
                    # randomly choose the next move on the prob distributation
                    actions[:, p, i] = torch.multinomial(prob, num_samples=1).view(-1)

                # save reward
                rewards[:, p, i] = reward

                # next step
                act = F.one_hot(actions[:, p, i], num_classes=n*n+1)
                obs, reward, terminated, _, _ = envs.step(act, restart=False)

        black_win = (rewards[:, 0].sum(-1) > 0).sum() / batch_size
        white_win = (rewards[:, 1].sum(-1) > 0).sum() / batch_size
        
        blackWin += black_win / times
        whiteWin += white_win / times
    
    tie = 1 - blackWin - whiteWin

    return {"Prob of black win": blackWin.item(), "Prob of white win": whiteWin.item(), "Prob of tie": tie.item()}

In [9]:
cnt = 0
res_train = {'KLloss': 0}

archiver = ReplayBuffer(envs, actor, batch_num=1)

for i in trange(n_epoch):
    
    archiver.update()
    
    dataset = archiver.getDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size_ppo, shuffle=True)
    
    for j in range(reuse_time):
        
        for data in dataloader:
            
            actor = actor.train()
            
            loss, detail = PPOloss(actor, data, epsilon=epsilon, beta=beta)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            cnt += 1
            res_train['KLloss'] += detail['KLloss']
            
    if i % (n_epoch // 1000) == (n_epoch // 1000) - 1 or i == 0:
        
        actor = actor.eval()
        
        res = Test([actor, envs.randomAction], envs, n, times=10)
        print('Black Agent, White Random:') 
        for key, value in res.items():
            print(f"{key}: {value:7.2%}", end='; ')
        print('\n')
        
        writer.add_scalar("Agent Black Win", res["Prob of black win"], i)
        
        res = Test([envs.randomAction, actor], envs, n, times=10)
        print('Black Random, White Agent:') 
        for key, value in res.items():
            print(f"{key}: {value:7.2%}", end='; ')
        print('\n')
        
        writer.add_scalar("Agent White Win", res["Prob of white win"], i)
        
        res = Test([actor, baseline], envs, n, times=10, strategy='random')
        print('Black Agent, White Baseline:') 
        for key, value in res.items():
            print(f"{key}: {value:7.2%}", end='; ')
        print('\n')
        
        writer.add_scalar("Agent Black Win (Baseline)", res["Prob of black win"], i)
        
        res = Test([baseline, actor], envs, n, times=10, strategy='random')
        print('Black Baseline, White Agent:') 
        for key, value in res.items():
            print(f"{key}: {value:7.2%}", end='; ')
        print('\n')
        
        writer.add_scalar("Agent White Win (Baseline)", res["Prob of white win"], i)
        
        for key, value in res_train.items():
            writer.add_scalar(key, value / cnt, i)
            res_train[key] = 0
        cnt = 0
        
    
    if i % (n_epoch // 40) == (n_epoch // 40) - 1:
            
        if not os.path.exists(f"model/{path}/{TIMESTAMP}"):
            os.makedirs(f"model/{path}/{TIMESTAMP}")
        
        torch.save(actor.state_dict(), f"model/{path}/{TIMESTAMP}/epoch{i}.pth")

  0%|          | 0/40000 [00:00<?, ?it/s]

Black Agent, White Random:
Prob of black win:  46.56%; Prob of white win:  50.00%; Prob of tie:   3.44%; 

Black Random, White Agent:
Prob of black win:  50.47%; Prob of white win:  45.55%; Prob of tie:   3.98%; 

Black Agent, White Baseline:
Prob of black win:  25.55%; Prob of white win:  71.09%; Prob of tie:   3.36%; 

Black Baseline, White Agent:
Prob of black win:  68.67%; Prob of white win:  28.28%; Prob of tie:   3.05%; 

Black Agent, White Random:
Prob of black win:  78.44%; Prob of white win:  17.66%; Prob of tie:   3.91%; 

Black Random, White Agent:
Prob of black win:  12.73%; Prob of white win:  84.61%; Prob of tie:   2.66%; 

Black Agent, White Baseline:
Prob of black win:  31.17%; Prob of white win:  65.08%; Prob of tie:   3.75%; 

Black Baseline, White Agent:
Prob of black win:  61.64%; Prob of white win:  35.08%; Prob of tie:   3.28%; 

Black Agent, White Random:
Prob of black win:  85.08%; Prob of white win:  12.34%; Prob of tie:   2.58%; 

Black Random, White Agent:
Pr