-
Notifications
You must be signed in to change notification settings - Fork 75
feat: reuse pin_memory when registering checkpoint #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a reusable pin memory buffer mechanism for checkpoints to reduce memory allocation overhead. The shared memory pool allows checkpoints to reuse the same pinned memory buffers sequentially, with the constraint that only one checkpoint can use the shared pool at a time and the buffer shape is fixed on first allocation.
Key changes:
- Added
use_shared_memory_poolparameter toregister_checkpoint()for opt-in shared memory usage - Introduced tracking of current shared memory pool user via
_current_shared_memory_pool_user - Modified
_register_checkpoint()to accept and reuse existing pin memory buffers when provided
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| checkpoint_engine/ps.py | Implements shared pin memory pool infrastructure with registration/unregistration logic and helper method for memory pool access |
| tests/test_pin_memory.py | Adds test coverage for shared memory pool registration, unregistration, and conflict handling scenarios |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
6071dc8 to
71af910
Compare
014a03f to
302bafd
Compare
fc3f7fb to
9dd7e3f
Compare
9dd7e3f to
b976022
Compare
|
cc @weixiao-huang I'm not sure if this is necessary. Instead, maybe making use of torch's own cached allocator would be sufficient. @specture724 and I discussed about this, he would experiment about simply removing the |
I've done testing registering and unregistering 10 times with "baseline", "simply remove
Noticed that in "remove |
blahgeek
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他没问题,我先merge了,上面那个comment明天可以顺便改一下
| metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore | ||
| try: | ||
| memory_pool = self._get_memory_pool(checkpoint_name) | ||
| except RuntimeError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接忽略RuntimeError不太好,因为这是个比较general的异常。最好是类似dict.get的做法,参数里加个allow_not_found=True或者default=None什么的,然后返回None来判断
resolve #23
Add a reusable pin memory buffer. In the current implementation, Only one checkpoint is able to use the shared pin memory at the same time. And the pin memory buffer shape is fixed to the shape when it is used for the first time, which cannot be modified latter.