In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import random
from einops.layers.torch import Rearrange
from einops import rearrange

from typing import Any, Dict, Tuple, Optional
from game_mechanics import GoEnv, choose_move_randomly, load_pkl, play_go, save_pkl
from tqdm.notebook import tqdm

from functools import partial
import pandas as pd
from datetime import datetime

In [2]:
def normalize(observation: np.ndarray) -> torch.Tensor:
    return torch.as_tensor(observation, dtype=torch.float32)

def random_move(observation, legal_moves):
    return random.choice(legal_moves)

def choose_move(observation, legal_moves, network: nn.Module) -> int:
    probs, value = network(observation, legal_moves)
    probs = probs[0].cpu().detach().numpy()
    move = np.random.choice(range(82), p=probs)
    return move


In [3]:
class alpha_go_zero_batch(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Linear(81,100),
            nn.ReLU(),
            nn.Linear(100,100),
            nn.ReLU()
        )
        
        self.tower1 = nn.Sequential(
            nn.Linear(100,100),
            nn.ReLU(),
            nn.Linear(100,82)
        )
        
        self.tower2 = nn.Sequential(
            nn.Linear(100,100),
            nn.ReLU(),
            nn.Linear(100,1),
            nn.ReLU(),
        )


    def forward(self, x, legal_moves):
        illegal = lambda legal: [move not in legal for move in range(82)]
        mask = torch.stack([torch.as_tensor(illegal(lm)) for lm in legal_moves])        

        
        x = rearrange(x, 'b w h -> b (w h)')
        x = self.stem(x)
        x1 = self.tower1(x)
        x1 = x1.masked_fill(mask, -torch.inf)
        x1 = F.softmax(x1, dim=-1)
        x2 = self.tower2(x)
        x2 = torch.tanh(x2)     
            
        return x1, x2

In [4]:
agzb = alpha_go_zero_batch()

In [5]:
def play_episode(network, env):
    observations = []
    rewards = []
    observation, reward, done, info = env.reset()
    while not done:
        legal_moves = info['legal_moves']
        observation = normalize(observation)
        network_move = choose_move(rearrange(observation, 'w h -> 1 w h'), [legal_moves], network)
        observation, reward, done, info = env.step(network_move)
    return reward

In [6]:
game_speed_multiplier=1000000
render=False
verbose=False
env = GoEnv(
    random_move,
    verbose=verbose,
    render=render,
    game_speed_multiplier=game_speed_multiplier,
)

In [31]:
observations = []
rewards = []
moves = []
values = []
successor_values = []
gamma = 1.0
lamda = 0.5

observation, reward, done, info = env.reset()
while not  done:
    legal_moves = info['legal_moves']
    observation = normalize(observation)
    
    probs, value = agzb(rearrange(observation, 'w h -> 1 w h'), [legal_moves])
    probs = probs[0].cpu().detach().numpy()
    move = np.random.choice(range(82), p=probs)
    
    observation, reward, done, info = env.step(network_move)
    
    observations.append(observation)
    moves.append(network_move)
    rewards.append(reward)
    values.append(value.item())
    
successor_values = values[1:] + [0]

values = torch.as_tensor(values, dtype=torch.float32)
successor_values = torch.as_tensor(successor_values, dtype=torch.float32)
rewards = torch.as_tensor(rewards, dtype=torch.float32)

In [29]:
torch.roll?

In [27]:
deltas

tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -1.])

In [35]:
torch.roll(gamlam_geo_series, shifts=10)

tensor([9.3132e-10, 4.6566e-10, 2.3283e-10, 1.1642e-10, 5.8208e-11, 2.9104e-11,
        1.4552e-11, 7.2760e-12, 3.6380e-12, 1.8190e-12, 1.0000e+00, 5.0000e-01,
        2.5000e-01, 1.2500e-01, 6.2500e-02, 3.1250e-02, 1.5625e-02, 7.8125e-03,
        3.9062e-03, 1.9531e-03, 9.7656e-04, 4.8828e-04, 2.4414e-04, 1.2207e-04,
        6.1035e-05, 3.0518e-05, 1.5259e-05, 7.6294e-06, 3.8147e-06, 1.9073e-06,
        9.5367e-07, 4.7684e-07, 2.3842e-07, 1.1921e-07, 5.9605e-08, 2.9802e-08,
        1.4901e-08, 7.4506e-09, 3.7253e-09, 1.8626e-09])

In [None]:
def calculate_gae(
        rewards: torch.Tensor,
        values: torch.Tensor,
        successor_values: torch.Tensor,
        gamma: float,
        lamda: float,
):
    N = len(rewards)

    delta_terms = rewards + gamma * successor_values - values

    gamlam = gamma * lamda

    gamlam_geo_series = torch.tensor([gamlam ** n for n in range(N)])

    full_gamlam_matrix = torch.stack([torch.roll(gamlam_geo_series, shifts=n) for n in range(N)])

    gamlam_matrix = torch.triu(full_gamlam_matrix)

    return torch.matmul(gamlam_matrix, delta_terms)

In [20]:
lamda = 0.9

In [26]:
sum([(lamda**10)*((1-lamda)*lamda**i) for i in range(1000)])

0.34867844010000004

In [28]:
(lamda**10)

0.3486784401000001

In [36]:
len_episode = 10
normalization_factor = (1-lamda)*(1/(1-lamda**len_episode))
[lamda**i*normalization_factor for i in range(len_episode)]

[0.15353399327876294,
 0.13818059395088664,
 0.12436253455579799,
 0.1119262811002182,
 0.10073365299019636,
 0.09066028769117673,
 0.08159425892205906,
 0.07343483302985317,
 0.06609134972686785,
 0.05948221475418106]

In [22]:
sum([(1 - lamda)*lamda**i for i in range(10)])

0.6513215598999998