-
Notifications
You must be signed in to change notification settings - Fork 257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor loading weights #1603
Refactor loading weights #1603
Conversation
53fe863
to
d7138f1
Compare
@zhulinJulia24 hi, could you start a full-scope test of all pytorch engine models using |
rank=rank, | ||
world_size=world_size, | ||
prefix='query_key_value') | ||
rowwise_parallelize_linear(self.dense, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
much better than previous version
logger = get_logger('lmdeploy') | ||
|
||
|
||
def _get_weight_type(model_path: str, use_safetensors: bool = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_safetensors
can be {True, False, None}. Why not True
or False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for name, param in mod.named_parameters(recurse=False): | ||
dtype = param.dtype | ||
if not loader.has(name): | ||
logger.debug(f'rank [{rank}]' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to invoke this condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some model might shared weight of token embedding, they do not safe redundant weight in checkpoint.
@@ -160,204 +157,3 @@ def sync_qparam_to_context(context: Any, layer_id: str, qparams: dict): | |||
context.set_output(layer_id, last_qparam) | |||
else: | |||
context.set_output(layer_id, qparams) | |||
|
|||
|
|||
@torch.no_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it used before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost never.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
optimization tp model loading.
requirement