In [3]:
import gym
import numpy as np
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
from collections import deque
from gym.spaces.box import Box
from typing import List, Tuple
from deeprl.common.utils import get_gym_space_shape, net_gym_space_dims
from torch.distributions import Categorical
from deeprl.common.base import *
from deeprl.common.utils import *
from deeprl.common.buffers import *

from deeprl.algos.a2c.a2c import A2C

In [7]:
from collections import deque
import random
import torch
from abc import ABC

# TODO: Turn this into an abstract base class.
class Memory:
    """
    Memory provides storage and access for gym transitions.
    This class will store the data as is and then convert to tensors as needed when sampling.
    """

    def __init__(self, max_len, device):
        self.max_len = max_len
        self.buffer = deque(maxlen=self.max_len)
        self.device = device
        self.keys = ["states", "actions", "rewards", "dones", "next_states"]

    def __len__(self):
        return len(self.buffer)

    def to_torch(self, x):
        arr_x = np.array(x).astype(np.float32)
        out = torch.from_numpy(arr_x).to(self.device)
        return out
    
    def reset_buffer(self):
        """Deletes contents of memory where the buffer lives"""
        self.buffer.clear()

    def samples_to_batch(self, samples):
        states, actions, rewards, dones, next_states = [to_torch(x, self.device) for x in zip(*samples)]
        
        batch = {
            "states": states,
            "actions": actions,
            "rewards": rewards,
            "dones": dones,
            "next_states": next_states,
        }

        return batch

    def sample_batch(self, num_samples):
        """
        returns an iterable(states, actions, rewards, dones, next_states)
        """
        buffer_len = len(self.buffer)

        # If there aren't enough samples then take what there is
        if num_samples > buffer_len:
            num_samples = buffer_len

        samples = random.sample(self.buffer, num_samples)

        batch = self.samples_to_batch(samples)

        return batch

    def store(self, transition):
        self.buffer.append(transition)
        return self.buffer


class OnPolicyMemory(Memory):
    """Version of memory where order matters."""

    def __init__(self, max_len, device):
        super().__init__(max_len, device)

    def sample_batch(self):
        """Convert buffer to ordered stacks of different components instead of rows of transitions and convert these to float tensors.

        Returns:
            batch (dict[key: list(data)): Batch of experiences where keys correspond to each component recorded at a time step.
        """
        # separate lists for each part of transition
        states, actions, rewards, dones, next_states = [to_torch(data, self.device) for data in zip(*self.buffer)]
        
        batch = {
            "states": states,
            "actions": actions,
            "rewards": rewards,
            "dones": dones,
            "next_states": next_states,
        }

        return batch


In [8]:
env = gym.make("CartPole-v1")
policy_layers = [
    (nn.Linear,
        {"in_features": net_gym_space_dims(env.observation_space),
        "out_features": 32}),
    (nn.Tanh, {}),
    (nn.Linear,
        {"in_features": 32,
        "out_features": 32}),
    (nn.Tanh, {}),
    (nn.Linear,{"in_features": 32, "out_features": net_gym_space_dims(env.action_space)}),
]

critic_layers = [
    (nn.Linear, {"in_features": net_gym_space_dims(env.observation_space), "out_features": 32}),
    (nn.ReLU, {}),
    (nn.Linear,
        {"in_features": 32,
        "out_features": 32}),
    (nn.ReLU, {}),
    (nn.Linear, {"in_features": 32, "out_features": 1}),
]

a2c_args = {
    "gamma": 0.99,
    "env": env,
    "step_lim": 200,
    "policy": CategoricalPolicy(policy_layers),
    "policy_optimiser": optim.Adam,
    "policy_lr": 0.001,
    "critic": Network(critic_layers),
    "critic_lr": 0.001,
    "critic_optimiser": optim.Adam,
    "critic_criterion": nn.MSELoss(),
    "device": "cpu",
    "entropy_coef": 0.01,
    "batch_size": 100,
    "num_train_passes": 2,
    "lam": 0.2
}
agent = A2C(a2c_args)
buff = OnPolicyMemory(100, "cpu")
s = env.reset()
for _ in range(200):
    a = agent.choose_action(s)
    assert env.action_space.contains(a)
    s_, r, d, _ = env.step(a)
    buff.store((s,a,r,d,s_))
    if d:
        s = env.reset()
    else:
        s = s_
    

In [14]:
agent.generate_experience(agent.batch_size)
print(len(agent.buffer))
out = agent.buffer.sample_batch()
print("hi")
out

10
[(array([-0.08962838, -0.1913191 ,  0.04213118,  0.28329834], dtype=float32), array([-0.09345476, -0.38701585,  0.04779715,  0.5889659 ], dtype=float32), array([-0.10119507, -0.19259465,  0.05957646,  0.31171417], dtype=float32), array([-0.10504697, -0.38851252,  0.06581075,  0.6225747 ], dtype=float32), array([-0.11281722, -0.5844887 ,  0.07826224,  0.935237  ], dtype=float32), array([-0.124507  , -0.3905046 ,  0.09696698,  0.66813713], dtype=float32), array([-0.13231708, -0.19685525,  0.11032972,  0.4074913 ], dtype=float32), array([-0.13625419, -0.39335454,  0.11847955,  0.73281926], dtype=float32), array([-0.14412127, -0.5898969 ,  0.13313593,  1.0603176 ], dtype=float32), array([-0.15591922, -0.78650665,  0.1543423 ,  1.3916489 ], dtype=float32)), (array(0, dtype=int64), array(1, dtype=int64), array(0, dtype=int64), array(0, dtype=int64), array(1, dtype=int64), array(1, dtype=int64), array(0, dtype=int64), array(0, dtype=int64), array(0, dtype=int64), array(1, dtype=int64)), (1

{'states': tensor([[-0.0896, -0.1913,  0.0421,  0.2833],
         [-0.0935, -0.3870,  0.0478,  0.5890],
         [-0.1012, -0.1926,  0.0596,  0.3117],
         [-0.1050, -0.3885,  0.0658,  0.6226],
         [-0.1128, -0.5845,  0.0783,  0.9352],
         [-0.1245, -0.3905,  0.0970,  0.6681],
         [-0.1323, -0.1969,  0.1103,  0.4075],
         [-0.1363, -0.3934,  0.1185,  0.7328],
         [-0.1441, -0.5899,  0.1331,  1.0603],
         [-0.1559, -0.7865,  0.1543,  1.3916]]),
 'actions': tensor([0., 1., 0., 0., 1., 1., 0., 0., 0., 1.]),
 'rewards': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'dones': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'next_states': tensor([[-0.0935, -0.3870,  0.0478,  0.5890],
         [-0.1012, -0.1926,  0.0596,  0.3117],
         [-0.1050, -0.3885,  0.0658,  0.6226],
         [-0.1128, -0.5845,  0.0783,  0.9352],
         [-0.1245, -0.3905,  0.0970,  0.6681],
         [-0.1323, -0.1969,  0.1103,  0.4075],
         [-0.1363, -0.3934,  0.1185, 

In [None]:
batch = buff.sample()

In [None]:
batch.keys()

dict_keys(['states', 'actions', 'rewards', 'dones', 'next_states'])

In [33]:
batch['next_states'].shape

torch.Size([100, 4])

In [13]:
l = [1, 2, 3, 4]