In [1]:
import torch as th

In [2]:
class Memory():
    """Storage for observation of a DQN agent.

    Observations are stored large continuous tensor.
    The tensor are automatically initialized upon the first call of store().
    Important: all tensors to be stored need to be passed at the first call of
    the store. Also the shape of tensors to be stored needs to be consistent.


    Typical usage:
        mem = Memory(...)
        for episode in range(n_episodes):
            obs = env.init()
            for round in range(n_rounds):
                action = agent(obs)
                next_obs, reward = env.step()
                mem.store(**obs, reward=reward, action=action)
                obs = next_obs
            mem.finish_episode()

            sample = mem.sample()
            update_agents(sample)
    """

    def __init__(
            self, device, size, n_rounds):
        """
            Args:
                device: device for the memory
                size: number of episode to store
                n_rounds; number of rounds to store per episode
        """
        self.memory = None
        self.size = size
        self.n_rounds = n_rounds
        self.device = device
        self.current_row = 0
        self.episodes_stored = 0

    def init_store(self, state):
        """Initialize the memory tensor.
        """
        self.memory = {
            k: th.zeros((self.size, self.n_rounds, *t.shape),
                        dtype=t.dtype, device=self.device)
            for k, t in state.items() if t is not None
        }

    def finish_episode(self):
        """Moves the currently active slice in memory to the next episode.
        """
        self.episodes_stored += 1
        self.current_row = (self.current_row + 1) % self.size

    def store(self, round, **state):
        """Stores multiple tensor in the memory.
        """
        if self.memory is None:
            self.init_store(state)
        for k, t in state.items():
            if t is not None:
                self.memory[k][self.current_row, round] = t.to(self.device)

    def sample(self, batch_size, device, **kwargs):
        """Samples form the memory.

        Returns:
            dict | None: Dict being stored. If the batch size is larger than the number
            of episodes stored 'None' is returned.
        """
        if len(self) < batch_size:
            return None
        random_memory_idx = th.randperm(len(self))[:batch_size]
        print(f'random_memory_idx', random_memory_idx)
        sample = {k: v[random_memory_idx].to(device) for k, v in self.memory.items()}
        return sample

    def __len__(self):
        """The current memory usage, i.e. the number of valid episodes in
        the memory.This increases as episodes are added to the memory until the
        maximum size of the memory is reached.
        """
        return min(self.episodes_stored, self.size)

In [6]:
n_episodes = 6
n_rounds = 3
n_networks = 2
n_nodes = 5
obs_shape = 7
batch_size = 4
device = th.device('cpu')

mem = Memory(device=device, size=5, n_rounds=n_rounds)
for epsiode in range(n_episodes):
    for round in range(n_rounds):
        # mock of the environment
        obs = {
            'mask': th.rand((n_networks,n_nodes)) > 0.8,
            'obs': (th.rand((n_networks,n_nodes, obs_shape)) > 0.9).type(th.int64)
        }
        mem.store(round, **obs)
    mem.finish_episode()
    sample = mem.sample(batch_size, device=device)
    if sample is not None:
        for k,v in sample.items():
            print(k, v.shape)
    else:
        print(f"Skip epsiode {epsiode}")

Skip epsiode 0
Skip epsiode 1
Skip epsiode 2
random_memory_idx tensor([3, 0, 2, 1])
mask torch.Size([4, 3, 2, 5])
obs torch.Size([4, 3, 2, 5, 7])
random_memory_idx tensor([3, 0, 2, 4])
mask torch.Size([4, 3, 2, 5])
obs torch.Size([4, 3, 2, 5, 7])
random_memory_idx tensor([3, 1, 4, 0])
mask torch.Size([4, 3, 2, 5])
obs torch.Size([4, 3, 2, 5, 7])
