In [98]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchinfo import summary, torchinfo

import random
from einops.layers.torch import Rearrange
from einops import rearrange

from typing import Any, Dict, Tuple, Optional
from game_mechanics import GoEnv, choose_move_randomly, load_pkl, play_go, save_pkl
from tqdm import notebook

from functools import partial
import pandas as pd
from datetime import datetime
from tqdm.notebook import tqdm

In [12]:
def choose_move(observation: np.ndarray, legal_moves: np.ndarray, neural_network: nn.Module) -> int:
    observation = normalize(observation)
    with torch.no_grad():
        probs, value = my_network(observation, legal_moves)
    probs = probs[0].cpu().detach().numpy()
    move = np.random.choice(range(82), p=probs)
    return move


def choose_move_human(observation: np.ndarray, legal_moves:np.ndarray, neural_network: nn.Module) -> int:
    i, j = [int(_) for _ in input().split(" ")]

    return (i-1)*9 + j-1

def random_move(observation, legal_moves):
    return random.choice(legal_moves)

In [102]:
class alpha_go_zero_batch(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(
            Rearrange('b w h -> b 1 w h'),
            nn.Conv2d(1, 100, kernel_size=3,padding=1, stride=1, bias=False),
            nn.ReLU(),
            nn.Conv2d(100, 100, kernel_size=3,padding=1, stride=1, bias=False),
            nn.ReLU(),
            nn.Conv2d(100, 1, kernel_size=1,padding=0, stride=1, bias=False),
            nn.ReLU(),
            Rearrange('b c w h -> b (c w h)')
        )
        
        self.tower1 = nn.Sequential(
            nn.Linear(81,100),
            nn.ReLU(),
            nn.Linear(100,82)
        )
        
        self.tower2 = nn.Sequential(
            nn.Linear(81,100),
            nn.ReLU(),
            nn.Linear(100,1),
            nn.ReLU(),
        )

    def forward(self, x, legal_moves):
        illegal = lambda legal: [move not in legal for move in range(82)]
        mask = torch.stack([torch.as_tensor(illegal(lm)) for lm in legal_moves])        

        
#         x = rearrange(x, 'b w h -> b (w h)')
        x = self.stem(x)
        x1 = self.tower1(x)
        x1 = x1.masked_fill(mask, -torch.inf)
        x1 = F.softmax(x1, dim=-1)
        x2 = self.tower2(x)
        x2 = torch.tanh(x2)     
                     
        return x1, x2

In [22]:
class Reservoir:
    def __init__(self):
        self.data = []
        
    def append(self, observation, next_observation, reward, done, legal_moves, next_legal_moves, chosen_move):
        self.data.append((observation, next_observation, reward, done, legal_moves, next_legal_moves, chosen_move))
    
    def pop(self):
        self.data.pop()
    
    def sample(self, size):
        if size > len(self.data):
            size = len(self.data)
#         random.shuffle(self.data)
        sample = random.sample(self.data, size)
        return self.stack(sample)
    
    def stack(self, data):
        observations = torch.stack([d[0] for d in data])
        next_observations = torch.stack([d[0] for d in data])
        rewards = torch.as_tensor([d[2] for d in data], dtype=torch.float32)
        dones = torch.as_tensor([d[3] for d in data], dtype=torch.float32)
        legal_moves = [d[4] for d in data]
        next_legal_moves = [d[5] for d in data]
        chosen_moves = [d[6] for d in data]
        return observations, next_observations, rewards, dones, legal_moves, next_legal_moves, chosen_moves
    
    def __len__(self):
        return len(self.data)
    
def normalize(observation: np.ndarray) -> torch.Tensor:
    return torch.as_tensor(observation, dtype=torch.float32)

In [10]:
def play_episode(network, env):
    observations = []
    rewards = []
    observation, reward, done, info = env.reset()
    while not done:
        legal_moves = info['legal_moves']
        observation = normalize(observation)
        network_move = choose_move(rearrange(observation, 'w h -> 1 w h'), [legal_moves], network)
        observation, reward, done, info = env.step(network_move)
    return reward

In [97]:
my_network = alpha_go_zero_batch()
summary(my_network)

Layer (type:depth-idx)                   Param #
alpha_go_zero_batch                      --
├─Sequential: 1-1                        --
│    └─Rearrange: 2-1                    --
│    └─Conv2d: 2-2                       900
│    └─ReLU: 2-3                         --
│    └─Conv2d: 2-4                       90,000
│    └─ReLU: 2-5                         --
│    └─Conv2d: 2-6                       100
│    └─ReLU: 2-7                         --
│    └─Rearrange: 2-8                    --
├─Sequential: 1-2                        --
│    └─Linear: 2-9                       10,100
│    └─ReLU: 2-10                        --
│    └─Linear: 2-11                      8,282
├─Sequential: 1-3                        --
│    └─Linear: 2-12                      10,100
│    └─ReLU: 2-13                        --
│    └─Linear: 2-14                      101
│    └─ReLU: 2-15                        --
Total params: 119,583
Trainable params: 119,583
Non-trainable params: 0

In [None]:
## v5_L_conv - added conv layers

## v6 - there was no zero_grad()..., no optimizer.step either.... O M G

## v7: change value loss to smooth huber loss

## v8: PPO

## v9: grad clipping

    INITIALISE policy, value

    INITIALISE memory as empty

    FOR episode_number in number of episodes you'd like to run:	

        Sample starting state

        WHILE episode hasn't terminated

            Sample action from policy(state)

            Sample next_state from transition_function(state, action)

            Sample reward from reward_function(state, action, next_state)

            APPEND state, action, reward, next_state, is_done to memory

                    IF length of memory == N
                        UPDATE value_net using TD-lambda estimation:
                            values = value(states)
                            lambda_returns = calculate_lambda_returns(rewards, values, lambda, is_dones)
                      val_loss = MSE(value_estimates, lambda_returns)
                            UPDATE val_net to minimize val_loss

                    UPDATE policy_net using policy gradient update step:
                      values = value(states)
                            old_pol_probs = policy(states)[actions]
                            lambda_returns = calculate_lambda_returns(rewards, values, lambda, is_dones)
                            gae = lambda_returns - value_estimates

                            FOR epoch in {1, 2, ..., K}:
                                prob_ratio = policy(states)[actions] / old_pol_probs
                                pol_loss = -sum(gae * min(prob_ratio, clip(prob_ratio, 1-eps, 1+eps)))
                                UPDATE policy_net to minimize pol_loss
                            END FOR

                        EMPTY memory

            UPDATE state to next_state

        END WHILE

    END FOR

In [107]:
# set up network & env:
experiment_name = 'Baseline_reservoir_v8'
my_network = alpha_go_zero_batch()
opponent_choose_move = random_move
game_speed_multiplier=1000000
render=False
verbose=False
env = GoEnv(
    opponent_choose_move,
    verbose=verbose,
    render=render,
    game_speed_multiplier=game_speed_multiplier,
)


optimizer = torch.optim.AdamW(my_network.parameters(), lr=0.001)
metrics = []
test_eval_size = []

num_episodes = 10_000
num_test_episodes = 25
block_train_episodes = 100
gamma = 0.98
total_score = 0
total_played = 0
train_rewards = []
train_losses = {
    'policy': [],
    'value': []
}



reservoir = Reservoir()
reservoir_size = 10
batch_size = 3
for episode in tqdm(range(num_episodes)):
    observation, reward, done, info = env.reset()
    observation = normalize(observation)
    old_value = 0 
    reward = 0
    done = 0
    
    while not done:       
        legal_moves = info['legal_moves']
        probs, value = my_network(rearrange(observation, 'w h -> 1 w h'), legal_moves = [legal_moves])
        chosen_move = np.random.choice(range(0,82), p=probs[0].detach().numpy())
        next_observation, reward, done, info = env.step(chosen_move)
        next_observation = normalize(next_observation)
        next_legal_moves = info['legal_moves']

        reservoir.append(observation, next_observation, reward, done, legal_moves, next_legal_moves, chosen_move)
        if len(reservoir) == reservoir_size:
            observations, next_observations, rewards, dones, lm, next_lm, chosen_moves = reservoir.sample(batch_size)
            probs, values = my_network(observations, lm)
            with torch.no_grad():
                next_probs, next_values = my_network(next_observations, next_lm)
            loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))
#             loss_v = (values - reward - next_values*gamma*(1-dones))**2
            loss_policy = -torch.log(probs[range(batch_size), chosen_moves])*values
            loss = (loss_v + loss_policy).sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            reservoir.pop()
        observation = next_observation
        legal_moves = next_legal_moves

        
    train_rewards.append(reward)
        
    if episode % block_train_episodes == 0:
        opponent_choose_move = random_move
        test_env = GoEnv(
            opponent_choose_move,
            verbose=verbose,
            render=render,
            game_speed_multiplier=game_speed_multiplier,
        )
        rewards = [play_episode(my_network, test_env) for _ in tqdm(range(num_test_episodes))]
        test_wr = sum([r == 1 for r in rewards])/num_test_episodes
        test_score = sum(rewards)/num_test_episodes
        test_ties = sum([r == 0 for r in rewards])/num_test_episodes
        
        train_wr = sum([r == 1 for r in train_rewards])/block_train_episodes
        train_score = sum(train_rewards)/block_train_episodes
        train_ties = sum([r == 0 for r in train_rewards])/block_train_episodes
        train_rewards = []
        metrics.append({'test_win_rate': test_wr,
                        'test_score': test_score,
                        'test_ties': test_ties,
                        'train_win_rate': train_wr,
                        'train_score': train_score,
                        'train_ties': train_ties,
                        'episode': episode,
                        'total_score': total_score,
                        'total_played': total_played,
                       })

        pd.DataFrame(metrics).to_csv(f'logs/{experiment_name}_{episode}.csv')
        train_losses = {
            'policy': [],
            'value': []
        }

    total_score += reward
    total_played += 1

  0%|          | 0/10000 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

  loss_v = F.smooth_l1_loss(values, reward + next_values*gamma*(1-dones))


  0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 