Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write custom serialisation for experience replay #8

Open
Kaixhin opened this issue Apr 26, 2019 · 9 comments
Open

Write custom serialisation for experience replay #8

Kaixhin opened this issue Apr 26, 2019 · 9 comments
Labels
enhancement New feature or request

Comments

@Kaixhin
Copy link
Owner

Kaixhin commented Apr 26, 2019

Although it is now possible to save/load experience replay memories (#3), naively using torch.save fails with large memories. Dealing with this would require custom serialisation code.

@Kaixhin Kaixhin added the enhancement New feature or request label Apr 27, 2019
@alec-tschantz
Copy link

alec-tschantz commented Dec 16, 2019

The way I have been getting around this is saving each part of the memory as .npz

D = np.empty((10 ** 5, 3, 64, 64), dtype=np.uint8)

# previous method
torch.save(D, "torch_save.pth")

# npz method
np.savez_compressed("numpy_save.npz", data=D)

# size: 1.2 MB
x = np.load("numpy_save.npz")
print(x["data"].shape)

# size: 1.23 GB
y = torch.load("torch_save.pth")
print(y.shape)

As you can see, the size is reduced by an order of magnitude. If you like, I can provide a pull request that adds this save & load functionality to the memory class.

@alec-tschantz
Copy link

The memory savings are far less when the arrays are full:

# make sure array is full
D = np.random.rand(10 ** 4, 3, 64, 64)

# torch method
start_time = time.time()
torch.save(D, "torch_save.pth")
print("Torch save time: {:.2f}".format(time.time() - start_time))

# numpy method
start_time = time.time()
np.savez_compressed("numpy_save.npz", data=D)
print("Numpy save time: {:.2f}".format(time.time() - start_time))

# torch method (1.4 GB)
start_time = time.time()
x = torch.load("torch_save.pth")
print("Torch load time: {:.2f}".format(time.time() - start_time))


# numpy method (0.9 GB)
start_time = time.time()
x = np.load("numpy_save.npz")
print("Numpy load time: {:.2f}".format(time.time() - start_time))

Output:

Torch save time: 10.19
Numpy save time: 38.53
Torch load time: 6.40
Numpy load time: 0.05

So you get a reasonable memory decrease, but a large increase in the time taken to write.

Any further suggestions most welcome, as this makes it hard to do runs across sessions.

@Kaixhin
Copy link
Owner Author

Kaixhin commented Dec 16, 2019

I suspect one solution might be to use memory-mapped numpy arrays? Never looked into them though.

@alec-tschantz
Copy link

alec-tschantz commented Dec 16, 2019

Oh nice, this absolutely seems viable. (10**6, 3, 64, 64) is 12 GB on the hard-drive but reading and writing are mostly instant. As a minimal working example:

class Buffer(object):
    def __init__(self, buffer_size=10 ** 6):
        self.buffer_size = buffer_size
        self.obs_shape = (buffer_size, 3, 64, 64)
        self.obs_path = "obs.dat"
        self.idx = 0

    def create_new_file(self):
        obs_f = np.memmap(
            self.obs_path, dtype=np.uint8, mode="w+", shape=self.obs_shape
        )
        del obs_f

    def add(self, obs):
        obs_f = np.memmap(
            self.obs_path, dtype=np.uint8, mode="w+", shape=self.obs_shape
        )
        obs_f[self.idx] = obs
        del obs_f
        self.idx += 1

    def sample(self, idxs):
        obs_f = np.memmap(self.obs_path, dtype=np.uint8, mode="r", shape=self.obs_shape)
        data = obs_f[idxs]
        del obs_f
        return data

if __name__ == "__main__":
    buffer = Buffer()
    for i in range(10 ** 6):
        obs = (np.random.rand(3, 64, 64) * 255).astype(np.uint8)
        buffer.add(obs)
    idxs = np.random.randint(0, 10 ** 4, size=100)
    data = buffer.sample(idxs)

I'll get round to integrating it with your buffer at some point - if you think it'd be worthwhile.

@Kaixhin
Copy link
Owner Author

Kaixhin commented Dec 16, 2019

Nice! I guess it'd be good to have this as a sort of helper class, so that the serialisation of the buffer can include other bits, like the current index.

@alec-tschantz
Copy link

I think I follow.

I was imagining that the buffer was serialised in a standard fashion and contained metadata about things like .dat file locations, index, data shapes, etc. The problem with memmap is you need both the data type and array shape whenever you invoke it - these can't be serialised into the .dat file itself.

Unless I've misunderstood your suggestion.

@Kaixhin
Copy link
Owner Author

Kaixhin commented Dec 16, 2019

I was just talking about https://github.com/Kaixhin/PlaNet/blob/master/memory.py#L15-L18 , but yes, anything that needs to be stored to recover the whole ExperienceReplay class.

@alec-tschantz
Copy link

Yeah, that's how I have it set up - you can serialise the Replay class, and as long as the .dat files have not changed location, it works just fine (and has no computational overhead, as far as I can tell):

https://github.com/alec-tschantz/planet/blob/master/planet/training/_buffer.py

I can make this consistent with your repo if you'd accept a pull request

@Kaixhin
Copy link
Owner Author

Kaixhin commented Dec 16, 2019

At a quick glance looks good, but please put in a PR so I can have a proper review. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants