Skip to content

Conversation

@specture724
Copy link
Collaborator

resolve #23
Add a reusable pin memory buffer. In the current implementation, Only one checkpoint is able to use the shared pin memory at the same time. And the pin memory buffer shape is fixed to the shape when it is used for the first time, which cannot be modified latter.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a reusable pin memory buffer mechanism for checkpoints to reduce memory allocation overhead. The shared memory pool allows checkpoints to reuse the same pinned memory buffers sequentially, with the constraint that only one checkpoint can use the shared pool at a time and the buffer shape is fixed on first allocation.

Key changes:

  • Added use_shared_memory_pool parameter to register_checkpoint() for opt-in shared memory usage
  • Introduced tracking of current shared memory pool user via _current_shared_memory_pool_user
  • Modified _register_checkpoint() to accept and reuse existing pin memory buffers when provided

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
checkpoint_engine/ps.py Implements shared pin memory pool infrastructure with registration/unregistration logic and helper method for memory pool access
tests/test_pin_memory.py Adds test coverage for shared memory pool registration, unregistration, and conflict handling scenarios

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@specture724 specture724 force-pushed the feat/reuse_pin_memory branch from 6071dc8 to 71af910 Compare November 19, 2025 08:21
@specture724 specture724 force-pushed the feat/reuse_pin_memory branch from 014a03f to 302bafd Compare November 21, 2025 06:10
@specture724 specture724 force-pushed the feat/reuse_pin_memory branch from fc3f7fb to 9dd7e3f Compare December 1, 2025 05:44
@specture724 specture724 force-pushed the feat/reuse_pin_memory branch from 9dd7e3f to b976022 Compare December 5, 2025 07:38
@blahgeek
Copy link
Collaborator

blahgeek commented Dec 8, 2025

cc @weixiao-huang I'm not sure if this is necessary. Instead, maybe making use of torch's own cached allocator would be sufficient. @specture724 and I discussed about this, he would experiment about simply removing the _host_emptyCache call in some cases

@specture724
Copy link
Collaborator Author

specture724 commented Dec 10, 2025

cc @weixiao-huang I'm not sure if this is necessary. Instead, maybe making use of torch's own cached allocator would be sufficient. @specture724 and I discussed about this, he would experiment about simply removing the _host_emptyCache call in some cases

I've done testing registering and unregistering 10 times with "baseline", "simply remove _host_emptyCache" and "reusing pin memory strategy implemented in this PR". I used gen_test_tensors in test_update.py to generate test tensors . The test tensors only consume 1 bucket for each rank. The result lays as follow: (ps: gen_test_tensors generates tensors with random dtype and size, so the following data might be just approximation)

baseline remove _host_emptyCache This PR's implementation
20s 4s 2s

Noticed that in "remove _host_emptyCache", a few registrations were super slow (about 1.5s, while others 0.2s), delaying the total process. This PR's implementation didn't cause any registration slow.

Copy link
Collaborator

@blahgeek blahgeek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他没问题,我先merge了,上面那个comment明天可以顺便改一下

metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
try:
memory_pool = self._get_memory_pool(checkpoint_name)
except RuntimeError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接忽略RuntimeError不太好,因为这是个比较general的异常。最好是类似dict.get的做法,参数里加个allow_not_found=True或者default=None什么的,然后返回None来判断

@blahgeek blahgeek merged commit 6b9ffc7 into MoonshotAI:main Dec 10, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support reusing pin_memory when register_checkpoint

3 participants