In [1]:
import sys
sys.path.append("../src/")

In [2]:
from plugin_write_and_run import *

In [3]:
%%write_and_run ../src/replay_buffer.py
import numpy as np
import json
import os
import sys
sys.path.append("../src")
from config import *

In [4]:
%%write_and_run -a ../src/replay_buffer.py

class ReplayBuffer():
    def __init__(self, env, buffer_capacity=BUFFER_CAPACITY, batch_size=BATCH_SIZE, min_size_buffer=MIN_SIZE_BUFFER):
        self.buffer_capacity = buffer_capacity
        self.batch_size = batch_size
        self.min_size_buffer = min_size_buffer
        self.buffer_counter = 0
        self.n_games = 0
        self.n_agents = env.n
        self.list_actors_dimension = [env.observation_space[index].shape[0] for index in range(self.n_agents)]
        self.critic_dimension = sum(self.list_actors_dimension)        
        self.list_actor_n_actions = [env.action_space[index].n for index in range(self.n_agents)]
        
        self.states = np.zeros((self.buffer_capacity, self.critic_dimension))
        self.rewards = np.zeros((self.buffer_capacity, self.n_agents))
        self.next_states = np.zeros((self.buffer_capacity, self.critic_dimension))
        self.dones = np.zeros((self.buffer_capacity, self.n_agents), dtype=bool)

        self.list_actors_states = []
        self.list_actors_next_states = []
        self.list_actors_actions = []
        
        for n in range(self.n_agents):
            self.list_actors_states.append(np.zeros((self.buffer_capacity, self.list_actors_dimension[n])))
            self.list_actors_next_states.append(np.zeros((self.buffer_capacity, self.list_actors_dimension[n])))
            self.list_actors_actions.append(np.zeros((self.buffer_capacity, self.list_actor_n_actions[n])))
            
    def __len__(self):
        return self.buffer_counter
        
    def check_buffer_size(self):
        return self.buffer_counter >= self.batch_size and self.buffer_counter >= self.min_size_buffer
    
    def update_n_games(self):
        self.n_games += 1
          
    def add_record(self, actor_states, actor_next_states, actions, state, next_state, reward, done):
        
        index = self.buffer_counter % self.buffer_capacity

        for agent_index in range(self.n_agents):
            self.list_actors_states[agent_index][index] = actor_states[agent_index]
            self.list_actors_next_states[agent_index][index] = actor_next_states[agent_index]
            self.list_actors_actions[agent_index][index] = actions[agent_index]

        self.states[index] = state
        self.next_states[index] = next_state
        self.rewards[index] = reward
        self.dones[index] = done
            
        self.buffer_counter += 1
            
    def get_minibatch(self):
        # If the counter is less than the capacity we don't want to take zeros records, 
        # if the cunter is higher we don't access the record using the counter 
        # because older records are deleted to make space for new one
        buffer_range = min(self.buffer_counter, self.buffer_capacity)

        batch_index = np.random.choice(buffer_range, self.batch_size, replace=False)

        # Take indices
        state = self.states[batch_index]
        reward = self.rewards[batch_index]
        next_state = self.next_states[batch_index]
        done = self.dones[batch_index]
            
        actors_state = [self.list_actors_states[index][batch_index] for index in range(self.n_agents)]
        actors_next_state = [self.list_actors_next_states[index][batch_index] for index in range(self.n_agents)]
        actors_action = [self.list_actors_actions[index][batch_index] for index in range(self.n_agents)]

        return state, reward, next_state, done, actors_state, actors_next_state, actors_action
    
    def save(self, folder_path):
        """
        Save the replay buffer
        """
        if not os.path.isdir(folder_path):
            os.mkdir(folder_path)
        
        np.save(folder_path + '/states.npy', self.states)
        np.save(folder_path + '/rewards.npy', self.rewards)
        np.save(folder_path + '/next_states.npy', self.next_states)
        np.save(folder_path + '/dones.npy', self.dones)
        
        for index in range(self.n_agents):
            np.save(folder_path + '/states_actor_{}.npy'.format(index), self.list_actors_states[index])
            np.save(folder_path + '/next_states_actor_{}.npy'.format(index), self.list_actors_next_states[index])
            np.save(folder_path + '/actions_actor_{}.npy'.format(index), self.list_actors_actions[index])
            
        dict_info = {"buffer_counter": self.buffer_counter, "n_games": self.n_games}
        
        with open(folder_path + '/dict_info.json', 'w') as f:
            json.dump(dict_info, f)
            
    def load(self, folder_path):
        self.states = np.load(folder_path + '/states.npy')
        self.rewards = np.load(folder_path + '/rewards.npy')
        self.next_states = np.load(folder_path + '/next_states.npy')
        self.dones = np.load(folder_path + '/dones.npy')
        
        self.list_actors_states = [np.load(folder_path + '/states_actor_{}.npy'.format(index)) for index in range(self.n_agents)]
        self.list_actors_next_states = [np.load(folder_path + '/next_states_actor_{}.npy'.format(index)) for index in range(self.n_agents)]
        self.list_actors_actions = [np.load(folder_path + '/actions_actor_{}.npy'.format(index)) for index in range(self.n_agents)]
        
        with open(folder_path + '/dict_info.json', 'r') as f:
            dict_info = json.load(f)
        self.buffer_counter = dict_info["buffer_counter"]
        self.n_games = dict_info["n_games"]

In [5]:
from make_env import *

In [6]:
env = make_env(ENV_NAME)

In [7]:
rb = ReplayBuffer(env)

In [8]:
rb.list_actors_actions

[array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]),
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]),
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])]

In [9]:
rb.list_actors_actions[0].shape

(1000000, 5)

In [10]:
env.agents

[<multiagent.core.Agent at 0x7f99c5180640>,
 <multiagent.core.Agent at 0x7f99c52572e0>,
 <multiagent.core.Agent at 0x7f99c5257af0>]

In [11]:
rb.list_actors_states[0].shape

(1000000, 8)

In [12]:
rb.list_actors_next_states[0].shape

(1000000, 8)

In [13]:
state, reward, done, info = env.step([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])

In [14]:
state

[array([-0.89836314, -0.904225  ,  0.43635738, -1.30697775,  0.64801724,
        -1.32998244,  0.22531173, -0.84901717]),
 array([-0.21165986,  0.02300469, -1.54638038,  0.42575743, -0.21165986,
         0.02300469, -0.64801724,  1.32998244, -0.42270551,  0.48096527]),
 array([ 0.21104565, -0.45796058, -1.12367487, -0.05520783,  0.21104565,
        -0.45796058, -0.22531173,  0.84901717,  0.42270551, -0.48096527])]

In [15]:
np.concatenate(state)

array([-0.89836314, -0.904225  ,  0.43635738, -1.30697775,  0.64801724,
       -1.32998244,  0.22531173, -0.84901717, -0.21165986,  0.02300469,
       -1.54638038,  0.42575743, -0.21165986,  0.02300469, -0.64801724,
        1.32998244, -0.42270551,  0.48096527,  0.21104565, -0.45796058,
       -1.12367487, -0.05520783,  0.21104565, -0.45796058, -0.22531173,
        0.84901717,  0.42270551, -0.48096527])

In [16]:
state[0].shape

(8,)

In [17]:
state[1].shape

(10,)

In [18]:
state[2].shape

(10,)

In [19]:
[env.observation_space[i].shape[0] for i in range(3)]

[8, 10, 10]

In [20]:
for i in range(100):
    actors_state, reward, done, info = env.step([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
    state = np.concatenate(actors_state)
    rb.add_record(actors_state, actors_state, [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]], state, state, reward, done)

In [21]:
rb.get_minibatch()

(array([[ 6.05164759,  6.04578572,  7.3863681 , ...,  0.84901717,
          0.42270551, -0.48096527],
        [ 4.45174401,  4.44588215,  5.78646453, ...,  0.84901717,
          0.42270551, -0.48096527],
        [ 7.65163793,  7.64577607,  8.98635845, ...,  0.84901717,
          0.42270551, -0.48096527],
        ...,
        [18.25163686, 18.245775  , 19.58635738, ...,  0.84901717,
          0.42270551, -0.48096527],
        [16.85163686, 16.845775  , 18.18635738, ...,  0.84901717,
          0.42270551, -0.48096527],
        [ 2.65306391,  2.64720204,  3.98778442, ...,  0.84901717,
          0.42270551, -0.48096527]]),
 array([[-8.64022550e+01, -3.70505943e-01, -3.70505943e-01],
        [-4.98300671e+01, -3.47679661e-01, -3.47679661e-01],
        [-1.33216025e+02, -3.84393248e-01, -3.84393248e-01],
        [-4.70599057e+02, -4.11035771e-01, -4.11035771e-01],
        [-4.22878970e+01, -3.39489176e-01, -3.39489176e-01],
        [-1.61771274e+00,  4.59332503e-01,  4.59332503e-01],
       

In [22]:
rb.save("../model/test")