In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import random
from einops.layers.torch import Rearrange
from einops import rearrange

from typing import Any, Dict, Tuple, Optional
from game_mechanics import GoEnv, choose_move_randomly, load_pkl, play_go, save_pkl
from tqdm.notebook import tqdm

from functools import partial
import pandas as pd
from datetime import datetime

In [3]:
def choose_move(observation, legal_moves, neural_network: nn.Module) -> int:
    """Called during competitive play. It acts greedily given current state of the board and value
    function dictionary. It returns a single move to play.

    Args:
        state:
        
    Returns:
        move sampled from the policy network
    """
    probs, value = neural_network(observation, legal_moves)
    probs = probs[0].cpu().detach().numpy()
    move = np.random.choice(range(82), p=probs)
    return move


def choose_move_human(observation: np.ndarray, legal_moves:np.ndarray, neural_network: nn.Module) -> int:
    print(observation)
    i, j = [int(_) for _ in input().split(" ")]

    return (i-1)*9 + j-1

def random_move(observation, legal_moves):
    return random.choice(legal_moves)

def choose_move_no_network_human(observation: np.ndarray, legal_moves: np.ndarray) -> int:
    """The arguments in play_game() require functions that only take the state as input.

    This converts choose_move() to that format.
    """
    return choose_move_human(observation, legal_moves, my_network)

In [4]:
# class alpha_go_zero(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.head1 = nn.Linear(81, 82)
#         self.head2 = nn.Linear(81, 1)        


#     def forward(self, x, legal_moves):
#         illegal_moves = [i for i in range(81) if i not in legal_moves]

#         x = rearrange(x, 'w h -> (w h)') #self.flatten(x)
#         x1 = self.head1(x)
#         x1[illegal_moves] = -torch.inf
#         x1 = F.softmax(x1, dim=-1)
#         x2 = self.head2(x)
#         x2 = F.tanh(x2)                
#         return x1, x2
    
    
class alpha_go_zero_batch(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Linear(81,100),
#             nn.BatchNorm1d(num_features=100),
            nn.ReLU(),
            nn.Linear(100,100),
#             nn.BatchNorm1d(num_features=100),
            nn.ReLU()
        )
        
        self.tower1 = nn.Sequential(
            nn.Linear(100,100),
#             nn.BatchNorm1d(num_features=100),
            nn.ReLU(),
            nn.Linear(100,82)
        )
        
        self.tower2 = nn.Sequential(
            nn.Linear(100,100),
#             nn.BatchNorm1d(num_features=100),
            nn.ReLU(),
            nn.Linear(100,1),
            nn.ReLU(),
        )
#         self.head1 = nn.Linear(81, 82)
#         self.head2 = nn.Linear(81, 1)        


    def forward(self, x, legal_moves):
        illegal = lambda legal: [move not in legal for move in range(82)]
        mask = torch.stack([torch.as_tensor(illegal(lm)) for lm in legal_moves])        

        
        x = rearrange(x, 'b w h -> b (w h)')
        x = self.stem(x)
        x1 = self.tower1(x)
        x1 = x1.masked_fill(mask, -torch.inf)
        x1 = F.softmax(x1, dim=-1)
        x2 = self.tower2(x)
        x2 = torch.tanh(x2)     
        
#         x = rearrange(x, 'b w h -> b (w h)') #self.flatten(x)
#         x1 = self.head1(x)
#         x1 = x1.masked_fill(mask, -torch.inf)
#         x1 = F.softmax(x1, dim=-1)
#         x2 = self.head2(x)
#         x2 = F.tanh(x2)                
        return x1, x2

In [3]:
class Reservoir:
    def __init__(self):
        self.data = []
        
    def append(self, observation, old_value, reward, done, legal_moves, chosen_move):
        self.data.append((observation, old_value, reward, done, legal_moves, chosen_move))
    
    def sample_pop(self, size):
        if size > len(self.data):
            size = len(self.data)
        random.shuffle(self.data)
        sample = self.data[:size]
        self.data = self.data[size:]
        return self.stack(sample)
    
    def stack(self, data):
        observations = torch.stack([d[0] for d in data])
        old_values = torch.as_tensor([d[1] for d in data], dtype=torch.float32)
        rewards = torch.as_tensor([d[2] for d in data], dtype=torch.float32)
        dones = torch.as_tensor([d[3] for d in data], dtype=torch.float32)
        legal_moves = [d[4] for d in data]
        chosen_moves = [d[5] for d in data]
        return observations, old_values, rewards, dones, legal_moves, chosen_moves
    
    def __len__(self):
        return len(self.data)

In [2]:
def normalize(observation: np.ndarray) -> torch.Tensor:
    return torch.as_tensor(observation, dtype=torch.float32)

In [7]:
def play_episode(network, env):
    observations = []
    rewards = []
    observation, reward, done, info = env.reset()
    while not done:
        legal_moves = info['legal_moves']
        observation = normalize(observation)
#         network(, legal_moves = [legal_moves])
        network_move = choose_move(rearrange(observation, 'w h -> 1 w h'), [legal_moves], network)
        observation, reward, done, info = env.step(network_move)
    return reward

In [14]:
# v3 - loss_v = (old_values - reward - values*gamma*(1-dones))**2 
# this way value should not oscilate between 0.999 1 turn before end and 0 at the end

In [15]:
# set up network & env:
experiment_name = 'Baseline_batch2000_L_v3'
my_network = alpha_go_zero_batch()
opponent_choose_move = random_move
game_speed_multiplier=1000000
render=False
verbose=False
env = GoEnv(
    opponent_choose_move,
    verbose=verbose,
    render=render,
    game_speed_multiplier=game_speed_multiplier,
)


optimizer = torch.optim.AdamW(my_network.parameters(), lr=0.001)
metrics = []
test_eval_size = []

num_episodes = 10_000
num_test_episodes = 25
block_train_episodes = 100
gamma = 0.98
total_score = 0
total_played = 0
train_rewards = []
train_losses = {
    'policy': [],
    'value': []
}


reservoir = Reservoir()
reservoir_size = 2000
batch_size = 100
for episode in tqdm(range(num_episodes)):
    old_observation, reward, done, info = env.reset()
    old_observation = normalize(old_observation)
    old_value = 0 
    reward = 0
    done = 0
    
    observation = old_observation
    while not done:       
        legal_moves = info['legal_moves']
        probs, value = my_network(rearrange(observation, 'w h -> 1 w h'), legal_moves = [legal_moves])
        chosen_move = np.random.choice(range(0,82), p=probs[0].detach().numpy())
        observation, reward, done, info = env.step(chosen_move)
        observation = normalize(observation)

        reservoir.append(observation, old_value, reward, done, legal_moves, chosen_move)
        old_value = value.detach()
        if len(reservoir) == reservoir_size:
            observations, old_values, rewards, dones, lm, chosen_moves = reservoir.sample_pop(batch_size)
            probs, values = my_network(observations, lm)
            loss_v = (old_values - reward - values*gamma*(1-dones))**2
            loss_policy = -torch.log(probs[range(batch_size), chosen_moves])*(values+rewards)
            loss = (loss_v + loss_policy).sum()
            loss.backward()
            
#             train_losses['policy'].append(loss_policy.detach().numpy().mean())
#             train_losses['value'].append(loss_v.detach().numpy().mean())
        
        
    train_rewards.append(reward)
        
    if episode % block_train_episodes == 0:
        opponent_choose_move = random_move
        test_env = GoEnv(
            opponent_choose_move,
            verbose=verbose,
            render=render,
            game_speed_multiplier=game_speed_multiplier,
        )
        rewards = [play_episode(my_network, test_env) for _ in tqdm(range(num_test_episodes))]
        test_wr = sum([r == 1 for r in rewards])/num_test_episodes
        test_score = sum(rewards)/num_test_episodes
        test_ties = sum([r == 0 for r in rewards])/num_test_episodes
        
        train_wr = sum([r == 1 for r in train_rewards])/block_train_episodes
        train_score = sum(train_rewards)/block_train_episodes
        train_ties = sum([r == 0 for r in train_rewards])/block_train_episodes
        train_rewards = []
        metrics.append({'test_win_rate': test_wr,
                        'test_score': test_score,
                        'test_ties': test_ties,
                        'train_win_rate': train_wr,
                        'train_score': train_score,
                        'train_ties': train_ties,
                        'episode': episode,
                        'total_score': total_score,
                        'total_played': total_played,
#                         'train_loss_policy': sum(train_losses['policy'])/len(train_losses['policy']),
#                         'train_loss_value': sum(train_losses['value'])/len(train_losses['value'])
                       })

        pd.DataFrame(metrics).to_csv(f'logs/{experiment_name}_{episode}.csv')
        train_losses = {
            'policy': [],
            'value': []
        }

    total_score += reward
    total_played += 1
#     print(round(total_score/total_played, 2), total_score, total_played, loss_policy.detach())

            
#             optimizer.zero_grad()
#             if not done:
#                 loss_v = (old_value - reward - value*gamma)**2
#                 loss_policy = -torch.log(probs[chosen_move])*(reward + value.detach()*gamma)
#             if done:
#                 loss_v = (value - reward)**2
#                 loss_policy = -torch.log(probs[chosen_move])*(torch.Tensor([reward]))

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 