-
Notifications
You must be signed in to change notification settings - Fork 36
Add tensor parallelism support for HF wrapper forward and lm_eval integration #340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…aster according to torch src
…e to right gpu after scatter
" If not set, it is inferred from the Fast-LLM model config or tokenizer.", | ||
) | ||
|
||
communication_timeout_sec: float = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
timeout
. Unnecessary long timeouts are often bad, so I recommend making it optional (default none) and enabling only as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Context
Conceptually, places like worker_forward
or data-parallel_worker
wait primitives should only exit under three conditions:
- They receive work
- They receive a finish message
- The connection with peers/coordinator is lost (after some timeout)
However, this is not how torch.distributed
works. It is designed for more or less synchronous communication, while here we are trying to adapt it for asynchronous communication.
Problem
If we set the default timeout to None
, users will end up seeing random timeouts in different places.
Discussion
A better long-term solution would be to use a distributed messaging framework that is more appropriate for sending work and finish messages. However, introducing another communication layer into fast_llm
is likely outside the scope of this PR.
Proposal
- Keep the default timeout as it is, applied only to these entry points. reset timeout after wait operation to default of 60 sec.
- Clarify the naming/description to avoid confusion.
- Add a TODO to revisit this later with a more suitable communication framework.
# Meant to be overridden in derived classes | ||
raise NotImplementedError() | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem relevant outside lm_eval. Any way to move it there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially thought about handling this differently, but since each subclass of the model has its own class, the only practical way I found was to use a dynamic class that constructs itself on the fly with type
.
This lets us encapsulate forward
of the fast_llm Hugging Face class and then pass it to generate
.
Something like:
def wrap_hf_model(model):
inner_forward = get_bounded_method(model.forward, model)
wrapper_class = get_new_type(
model.__class__,
{
"inner_forward": inner_forward,
"forward": cordinator_forward,
"worker_forward": worker_forward,
},
)
model.__class__ = wrapper_class
return model
Another option would be to create a static wrapper class, but that would require exposing and forwarding a lot of functionality that generate
expects.
So instead, I decided to implement this in our HF wrapper, since it is implemented before any class specialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I'm following here. From what I understand, these methods are called by FastLLMLmEvalWrapper
above which we are free to adjust as we want, and there isn't any dependency on the HF model so moving should be easy. Or are there some call in the. LM eval base class that absolutely enforce this structure?
This isn't present in typical HF models, so I'd prefer to avoid it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, lm_eval
calls FastLLMLmEvalWrapper
, which then calls our model.generate
which in turn makes multiple calls to forward
, all runs entirely on TP rank 0. Our forward
must be overridden to handle data distribution across all TP ranks.
----------------------------------------------------------------------------------------------------
| Tensor Parallel Setup |
----------------------------------------------------------------------------------------------------
[ lm_eval ] [ forward_worker ] [ forward_worker ] ... more
| | |
v v v
+-------------------------+ | |
| HF generate mixin | | |
| model must be HF | | |
| (model.generate) | | |
| - runs only on TP rank0 | | |
| - does multiple forward | | |
| calls + sampling | | |
| beam search, etc. | | |
+-------------------------+ | |
| | |
v | |
+-------------------------+ | |
| model.forward() |---------------+----------------------------------+--> [wait for data]
| must be overridden | | | [long timeout]
| - orchestrates TP calls | | |
+-------------------------+ | |
| | |
v v v
+---------------------+ +---------------------+ +---------------------+ ... TP N-1
| TP rank 0 | | TP rank 1 | | TP rank 2 |
| model.forward_inner | | model.forward_inner | | model.forward_inner |
+---------------------+ +---------------------+ +---------------------+
Alternatives i have considered:
- Composition (wrapper-model around HF specialization)
- Define a class
OrchestratorModel
that looks like an HF model (hasforward
,generate
, etc.). - It contains an inner
HuggingfaceBaseModelForCausalLM
(or subclass) that runs on workers. - Orchestrator
forward
does the TP dispatch/gathering, then calls into the inner worker models as needed. generate
(from HF) runs on this outer orchestrator class, which works because it just callsself.forward
.
This is clean, explicit, stable — but involves boilerplate to replicate the HF interface.
- Dynamic class injection (multiple inheritance / runtime patching)
- Build a class at runtime that inherits both:
- HF specialization (
HuggingfaceGPTModelForCausalLM
, etc.) OrchestratorMixin
(overridesforward
).
- HF specialization (
- Register that as the actual model class for the model object.
generate
is inherited unmodified from HF mixin, but callsOrchestratorMixin.forward
.
This avoids extra wrapper code, but is “hackier” and could break with HF updates.
That’s why I dismissed the other options, but if we really want to keep HuggingfaceBaseModelForCausalLM
completely unmodified, option 1 (composition) is likely the safer and more maintainable approach for our use case.
✨ Description
Add tensor parallelism support for HF wrapper forward and lm_eval integration
Closes #334
🔍 Type of change
Select all that apply:
📝 Changes
Key updates introduced in this PR:
_object_to_tensor
for faster performance (following PyTorch sources).forward
.generate
to run only on data-parallel leader ranks while tensor parallel workers participate throughworker_forward
.lm_eval
wrapper.lm_eval
tasks or when batches are incomplete and some data-parallel ranks have no data.🗒️ Notes and Known Issues