In [1]:
import torch
import gym as g
from gym import spaces
from connect4 import *
from envs import ConnectNEnv
from networks.architecture import RepresentationNetwork, DynamicsNetwork, PredictionNetwork
import numpy as np

In [2]:
from architecture.engine import *
from architecture.game import *
from architecture.network import MuZeroNetwork

In [3]:
class ReplayBuffer:

    def __init__(self, max_buffer_size, batch_size):
        self.max_buffer_size = max_buffer_size
        self.batch_size = batch_size
        self.buffer = [] # <- A list of games

    def sample_game(self):
        return np.random.choice(self.buffer)
    
    def sample_position(self, game: Game):
        return game.sample_position()

    def sample_batch(self, unroll_steps: int, td_steps: int, ):
        # Get all games
        games : List[Game] = [self.sample_game() for _ in range(self.batch_size)]
        game_position : List[Tuple[Game, int]] = [(game, self.sample_position(game)) for game in games]
        return [(game.make_image(position), 
                 game.action_history[position:position+unroll_steps],
                 game.make_target(position, unroll_steps, td_steps, 0.99))
                 for game, position in game_position]


In [4]:
def create_random_games():
    gs = []
    for _ in range(100):
        g = Game(ConnectNEnv())
        for _ in range(10):
            g.step(np.random.choice(g.legal_actions()))
            if g.terminal(): break
        gs.append(g)
    return gs

rb = ReplayBuffer(10, 10)
rb.buffer = create_random_games()

In [5]:
rb.sample_batch(10, 5)

[(tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  1.,  0.,  0.,  0.]],
  
           [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  1.,  0.,  0.,  0.]],
  
           [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]],
  
           [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
        