In [1]:
import sys
sys.path.append("../src/")

In [2]:
from plugin_write_and_run import *

In [3]:
%%write_and_run ../src/replay_buffer.py
import numpy as np
import json
import os
import sys
sys.path.append("../src")
from config import *
from make_env import *

In [4]:
%%write_and_run -a ../src/replay_buffer.py

class ReplayBuffer():
    def __init__(self, env, buffer_capacity=BUFFER_CAPACITY, batch_size=BATCH_SIZE, min_size_buffer=MIN_SIZE_BUFFER):
        self.buffer_capacity = buffer_capacity
        self.batch_size = batch_size
        self.min_size_buffer = min_size_buffer
        self.buffer_counter = 0
        self.n_games = 0
        self.n_agents = env.n
        self.list_actors_dimension = [env.observation_space[index].shape[0] for index in range(self.n_agents)]
        self.critic_dimension = sum(self.list_actors_dimension)        
        self.list_actor_n_actions = [env.action_space[index].n for index in range(self.n_agents)]
        
        self.states = np.zeros((self.buffer_capacity, self.critic_dimension))
        self.rewards = np.zeros((self.buffer_capacity, self.n_agents))
        self.next_states = np.zeros((self.buffer_capacity, self.critic_dimension))
        self.dones = np.zeros((self.buffer_capacity, self.n_agents), dtype=bool)

        self.list_actors_states = []
        self.list_actors_next_states = []
        self.list_actors_actions = []
        
        for n in range(self.n_agents):
            self.list_actors_states.append(np.zeros((self.buffer_capacity, self.list_actors_dimension[n])))
            self.list_actors_next_states.append(np.zeros((self.buffer_capacity, self.list_actors_dimension[n])))
            self.list_actors_actions.append(np.zeros((self.buffer_capacity, self.list_actor_n_actions[n])))
            
    def __len__(self):
        return self.buffer_counter
        
    def check_buffer_size(self):
        return self.buffer_counter >= self.batch_size and self.buffer_counter >= self.min_size_buffer
    
    def update_n_games(self):
        self.n_games += 1
          
    def add_record(self, actor_states, actor_next_states, actions, state, next_state, reward, done):
        
        index = self.buffer_counter % self.buffer_capacity

        for agent_index in range(self.n_agents):
            self.list_actors_states[agent_index][index] = actor_states[agent_index]
            self.list_actors_next_states[agent_index][index] = actor_next_states[agent_index]
            self.list_actors_actions[agent_index][index] = actions[agent_index]

        self.states[index] = state
        self.next_states[index] = next_state
        self.rewards[index] = reward
        self.dones[index] = done
            
        self.buffer_counter += 1
            
    def get_minibatch(self):
        # If the counter is less than the capacity we don't want to take zeros records, 
        # if the cunter is higher we don't access the record using the counter 
        # because older records are deleted to make space for new one
        buffer_range = min(self.buffer_counter, self.buffer_capacity)

        batch_index = np.random.choice(buffer_range, self.batch_size, replace=False)

        # Take indices
        state = self.states[batch_index]
        reward = self.rewards[batch_index]
        next_state = self.next_states[batch_index]
        done = self.dones[batch_index]
            
        actors_state = [self.list_actors_states[index][batch_index] for index in range(self.n_agents)]
        actors_new_state = [self.list_actors_next_states[index][batch_index] for index in range(self.n_agents)]
        actors_action = [self.list_actors_actions[index][batch_index] for index in range(self.n_agents)]

        return state, reward, next_state, done, actors_state, actors_new_state, actors_action
    
    def save(self, folder_path):
        """
        Save the replay buffer
        """
        if not os.path.isdir(folder_path):
            os.mkdir(folder_path)
        
        np.save(folder_path + '/states.npy', self.states)
        np.save(folder_path + '/rewards.npy', self.rewards)
        np.save(folder_path + '/next_states.npy', self.next_states)
        np.save(folder_path + '/dones.npy', self.dones)
        
        for index in range(self.n_agents):
            np.save(folder_path + '/states_actor_{}.npy'.format(index), self.list_actors_states[index])
            np.save(folder_path + '/next_states_actor_{}.npy'.format(index), self.list_actors_next_states[index])
            np.save(folder_path + '/actions_actor_{}.npy'.format(index), self.list_actors_actions[index])
            
        dict_info = {"buffer_counter": self.buffer_counter, "n_games": self.n_games}
        
        with open(folder_path + '/dict_info.json', 'w') as f:
            json.dump(dict_info, f)
            
    def load(self, folder_path):
        self.states = np.load(folder_path + '/states.npy')
        self.rewards = np.load(folder_path + '/rewards.npy')
        self.next_states = np.load(folder_path + '/next_states.npy')
        self.dones = np.load(folder_path + '/dones.npy')
        
        self.list_actors_states = [np.load(folder_path + '/states_actor_{}.npy'.format(index)) for index in range(self.n_agents)]
        self.list_actors_next_states = [np.load(folder_path + '/next_states_actor_{}.npy'.format(index)) for index in range(self.n_agents)]
        self.list_actors_actions = [np.load(folder_path + '/actions_actor_{}.npy'.format(index)) for index in range(self.n_agents)]
        
        with open(folder_path + '/dict_info.json', 'r') as f:
            dict_info = json.load(f)
        self.buffer_counter = dict_info["buffer_counter"]
        self.n_games = dict_info["n_games"]

In [5]:
env = make_env(ENV_NAME)

In [6]:
rb = ReplayBuffer(env)

In [7]:
rb.list_actors_actions

[array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]),
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]),
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])]

In [8]:
rb.list_actors_actions[0].shape

(1000000, 5)

In [9]:
env.agents

[<multiagent.core.Agent at 0x7ff9df1802e0>,
 <multiagent.core.Agent at 0x7ff9df24af40>,
 <multiagent.core.Agent at 0x7ff9df256790>]

In [10]:
rb.list_actors_states[0].shape

(1000000, 8)

In [11]:
rb.list_actors_next_states[0].shape

(1000000, 8)

In [12]:
state, reward, done, info = env.step([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])

In [13]:
state

[array([-3.38370218e-01, -4.30901985e-01, -1.12620986e+00,  1.11595831e+00,
        -9.55481375e-04,  4.63488114e-01,  2.00651676e-01,  7.81533551e-02]),
 array([-3.37414737e-01, -8.94390099e-01, -3.37414737e-01, -8.94390099e-01,
        -1.12525438e+00,  6.52470194e-01,  9.55481375e-04, -4.63488114e-01,
         2.01607157e-01, -3.85334758e-01]),
 array([-0.53902189, -0.50905534, -0.53902189, -0.50905534, -1.32686153,
         1.03780495, -0.20065168, -0.07815336, -0.20160716,  0.38533476])]

In [14]:
np.concatenate(state)

array([-3.38370218e-01, -4.30901985e-01, -1.12620986e+00,  1.11595831e+00,
       -9.55481375e-04,  4.63488114e-01,  2.00651676e-01,  7.81533551e-02,
       -3.37414737e-01, -8.94390099e-01, -3.37414737e-01, -8.94390099e-01,
       -1.12525438e+00,  6.52470194e-01,  9.55481375e-04, -4.63488114e-01,
        2.01607157e-01, -3.85334758e-01, -5.39021894e-01, -5.09055340e-01,
       -5.39021894e-01, -5.09055340e-01, -1.32686153e+00,  1.03780495e+00,
       -2.00651676e-01, -7.81533551e-02, -2.01607157e-01,  3.85334758e-01])

In [15]:
state[0].shape

(8,)

In [16]:
state[1].shape

(10,)

In [17]:
state[2].shape

(10,)

In [18]:
[env.observation_space[i].shape[0] for i in range(3)]

[8, 10, 10]

In [20]:
for i in range(100):
    actors_state, reward, done, info = env.step([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
    state = np.concatenate(actors_state)
    rb.add_record(actors_state, actors_state, [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]], state, state, reward, done)

In [21]:
rb.get_minibatch()

(array([[ 1.82232055,  1.72978878,  1.03448091, ..., -0.07815336,
         -0.20160716,  0.38533476],
        [ 6.61164051,  6.51910874,  5.82380087, ..., -0.07815336,
         -0.20160716,  0.38533476],
        [ 8.21163086,  8.11909909,  7.42379122, ..., -0.07815336,
         -0.20160716,  0.38533476],
        ...,
        [ 8.01163121,  7.91909945,  7.22379157, ..., -0.07815336,
         -0.20160716,  0.38533476],
        [ 0.49172011,  0.39918835, -0.29611952, ..., -0.07815336,
         -0.20160716,  0.38533476],
        [ 6.21164885,  6.11911708,  5.42380921, ..., -0.07815336,
         -0.20160716,  0.38533476]]),
 array([[-6.31302142e+00,  2.92695900e-01,  2.92695900e-01],
        [-8.62125690e+01,  3.18636800e-01,  3.18636800e-01],
        [-1.33350651e+02,  3.20335522e-01,  3.20335522e-01],
        [-1.53268315e+01,  3.06052339e-01,  3.06052339e-01],
        [-7.86610555e-01,  2.04383469e-01,  2.04383469e-01],
        [-4.36036223e+02,  3.23389624e-01,  3.23389624e-01],
       

In [22]:
rb.save("../model/test")