Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kithara/model/hf_compatibility/shape_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def LLAMA31_HF_WEIGHTS_TO_SHAPE_MAPPING(config):
mapping = {
"model.embed_tokens.weight": [config["vocab_size"], config["hidden_size"]],
"model.norm.weight": [config["hidden_size"]],
"lm_head.weight": [config["hidden_size"]],
"lm_head.weight": [config["vocab_size"], config["hidden_size"]]
}
for layer_idx in range(config["num_hidden_layers"]):
layer_mapping = {
Expand Down
2 changes: 1 addition & 1 deletion kithara/model/hf_compatibility/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def apply_hook_fns(weight, target_shape, hook_fns):
return weight
if not isinstance(hook_fns, list):
hook_fns = [hook_fns]
for hook_fn in hook_fns:
for hook_fn in hook_fns[::-1]:
weight = hook_fn(weight, target_shape)
return weight

Expand Down
24 changes: 10 additions & 14 deletions kithara/model/maxtext/ckpt_compatibility/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=
- Keys: MaxText parameter names (str)
- Values: Either:
- callable: Single transformation function
- list[callable]: List of transformation functions to be applied in sequence
- list[callable]: List of transformation functions to be applied in sequence.
The order of the functions matters. The ordering specified is applied during
model loading, and the reverse order if applied during saving.

Transformation Details:
The function handles reshaping and Transpose 2d:
Expand Down Expand Up @@ -521,29 +523,23 @@ def from_hf():
return from_hf()

def adjust_rope(input_tensor, target_shape):
def unpermute_from_match_maxtext_rope(arr):
def from_hf(arr):
"""Convert from HF's concatenated layout to MaxText's interleaved layout"""
half_dim = arr.shape[-1] // 2
first_half = arr[..., :half_dim]
second_half = arr[..., half_dim:]
return jax.numpy.stack([first_half, second_half], axis=-1).reshape(arr.shape)

def permute_to_match_maxtext_rope(arr):
def to_hf(arr):
"""Convert from MaxText's interleaved layout to HF's concatenated layout"""
shape = arr.shape
arr = arr.reshape(shape[:-1] + (-1, 2))
return np.concatenate([arr[..., 0], arr[..., 1]], axis=-1)

def to_hf():
return permute_to_match_maxtext_rope(input_tensor)

def from_hf():
return unpermute_from_match_maxtext_rope(input_tensor)
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)

if saving_to_hf:
return to_hf()
return to_hf(input_tensor)
else:
return from_hf()
return from_hf(input_tensor)


def reshape_kernel(input_tensor, target_shape):
Expand Down
66 changes: 0 additions & 66 deletions tests/model/maxtext/ckpt_compatibility/loading_llama31.py

This file was deleted.

108 changes: 0 additions & 108 deletions tests/model/maxtext/ckpt_compatibility/saving_llama31.py

This file was deleted.

6 changes: 3 additions & 3 deletions tests/model/maxtext/ckpt_compatibility/test_saving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def test_llama_31_8b_conversion(self):
"""Test conversion for Llama 3.1 8B model."""
self._run_conversion_test(
model_id="meta-llama/Llama-3.1-8B",
weight_tol=0.0001,
logits_tol=2.0,
top1_token_tol=0.1
weight_tol=0.01,
logits_tol=0.1,
top1_token_tol=0.01
)

if __name__ == '__main__':
Expand Down