Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "kithara/model/maxtext/maxtext"]
path = kithara/model/maxtext/maxtext
url = https://github.com/google/maxtext
[submodule "kithara/model/maxtext/JetStream"]
path = kithara/model/maxtext/JetStream
url = https://github.com/AI-Hypercomputer/JetStream.git
13 changes: 8 additions & 5 deletions docs/source/api/kithara.model_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@ save_in_hf_format
generate
^^^^^^^

.. py:method:: generate(inputs: Union[str | List[str] | Dict[str, np.ndarray]], max_length: int = 100, stop_token_ids: Union[str | List[int]] = "auto", strip_prompt: bool = False, tokenizer: Optional[AutoTokenizer] = None, tokenizer_handle: Optional[str] = None, return_decoded: bool = True, skip_special_tokens: bool = True, **kwargs) -> Union[List[str] | Dict[str, np.ndarray]]
.. py:method:: generate(inputs: Union[str | List[str] | List[int] | np.ndarray | List[np.ndarray] | List[List[int]]], max_length: int = 100, stop_token_ids: Union[str | List[int]] = "auto", strip_prompt: bool = False, tokenizer: Optional[AutoTokenizer] = None, tokenizer_handle: Optional[str] = None, return_decoded: bool = True, skip_special_tokens: bool = True, **kwargs) -> Union[List[str] | Dict[str, np.ndarray]]

Generate text tokens using the model.

:param inputs: A single string, a list of strings, or a dictionary with tokens as expected by the underlying model during the forward pass. If strings are provided, one of `tokenizer` and `tokenizer_handle` must be provided.
:param inputs: Inputs can be either string or integer tokens. String inputs can be a single string,
or a list of strings. Token inputs can be a numpy array, a list of numpy arrays, an integer array,
or a list of integer arrays. If strings are provided, one of `tokenizer` and `tokenizer_handle` must be provided.
:param max_length: Maximum total sequence length (prompt + generated tokens). If `tokenizer` and `tokenizer_handle` are `None`, `inputs` should be padded to the desired maximum length and this argument will be ignored. When `inputs` is string, this value must be provided. (default: 100)
:param stop_token_ids: List of token IDs that stop generation. Defaults to "auto", which extracts the end token id from the tokenizer.
:param strip_prompt: If True, returns only the generated tokens without the input prompt. If False, returns the full sequence including the prompt. (default: False)
Expand Down Expand Up @@ -183,12 +185,13 @@ save_in_hf_format

generate
^^^^^^^

.. py:method:: generate(inputs: Union[str | List[str] | Dict[str, np.ndarray]], max_length: int = 100, stop_token_ids: Union[str | List[int]] = "auto", strip_prompt: bool = False, tokenizer: Optional[AutoTokenizer] = None, tokenizer_handle: Optional[str] = None, return_decoded: bool = True, skip_special_tokens: bool = True, **kwargs) -> Union[List[str] | Dict[str, np.ndarray]]
.. py:method:: generate(inputs: Union[str | List[str] | List[int] | np.ndarray | List[np.ndarray] | List[List[int]]], max_length: int = 100, stop_token_ids: Union[str | List[int]] = "auto", strip_prompt: bool = False, tokenizer: Optional[AutoTokenizer] = None, tokenizer_handle: Optional[str] = None, return_decoded: bool = True, skip_special_tokens: bool = True, **kwargs) -> Union[List[str] | Dict[str, np.ndarray]]

Generate text tokens using the model.

:param inputs: A single string, a list of strings, or a dictionary with tokens as expected by the underlying model during the forward pass. If strings are provided, one of `tokenizer` and `tokenizer_handle` must be provided.
:param inputs: Inputs can be either string or integer tokens. String inputs can be a single string,
or a list of strings. Token inputs can be a numpy array, a list of numpy arrays, an integer array,
or a list of integer arrays. If strings are provided, one of `tokenizer` and `tokenizer_handle` must be provided.
:param max_length: Maximum total sequence length (prompt + generated tokens). If `tokenizer` and `tokenizer_handle` are `None`, `inputs` should be padded to the desired maximum length and this argument will be ignored. When `inputs` is string, this value must be provided. (default: 100)
:param stop_token_ids: List of token IDs that stop generation. Defaults to "auto", which extracts the end token id from the tokenizer.
:param strip_prompt: If True, returns only the generated tokens without the input prompt. If False, returns the full sequence including the prompt. (default: False)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ We support all safetensor formatted models on `HuggingFace Hub <https://huggingf
- 2B, 9B, 27B
- google/gemma-2-2b
* - Llama 3.1
- 8B, 27B, 405B
- 8B, 70B, 405B
- meta-llama/Llama-3.1-8B
34 changes: 31 additions & 3 deletions kithara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,41 @@ def _install_maxtext():
except Exception as e:
print(f"Failed to install maxtext: {e}")

maxtext_dir = Path(
os.path.join(os.path.dirname(Path(__file__)), "model/maxtext/maxtext/MaxText")
)
maxtext_dir = Path(__file__).parent / "model/maxtext/maxtext/MaxText"
sys.path.append(str(maxtext_dir))

def _install_jetstream():
try:
importlib.metadata.version("google-jetstream")
except importlib.metadata.PackageNotFoundError:
try:
print(
"Installing JetStream... This should only happen once when Kithara is first initiated."
)
jetstream_path = Path(
os.path.join(os.path.dirname(Path(__file__)), "model/maxtext/JetStream")
)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-e",
str(jetstream_path),
"--no-deps",
]
)
print("JetStream installed successfully")
except Exception as e:
print(f"Failed to install JetStream: {e}")

jetstream_dir = Path(__file__).parent / "model/maxtext/JetStream/jetstream"
sys.path.append(str(jetstream_dir))

_install_maxtext()
_install_jetstream()


from kithara.dataset import Dataloader, SFTDataset, TextCompletionDataset
from kithara.trainer import Trainer
Expand Down
109 changes: 94 additions & 15 deletions kithara/model/kerashub/keras_hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from typing import Optional, Dict
from typing import Optional, Dict, Union, List
import numpy as np
from keras_hub.models import CausalLM
from kithara.distributed.sharding import ShardingStrategy, PredefinedShardingStrategy
Expand Down Expand Up @@ -110,31 +110,110 @@ def from_preset(
lora_rank=lora_rank,
)

def _pad_tokens_to_max_length(self, tokens, max_length):
"""
Pad each sequence in the list of token sequences to max_length.

Args:
tokens: List of numpy arrays, where each array is a sequence of token IDs
max_length: The target length to pad sequences to

Returns:
Dict containing padded token_ids and corresponding padding_mask
"""
# Initialize arrays for padded tokens and attention masks
batch_size = len(tokens)
padded_tokens = np.zeros((batch_size, max_length), dtype=np.int64)
padding_mask = np.zeros((batch_size, max_length), dtype=np.int64)

# Fill the arrays with the tokens and create corresponding masks
for i, seq in enumerate(tokens):
seq_len = min(len(seq), max_length)
padded_tokens[i, :seq_len] = seq[:seq_len]
padding_mask[i, :seq_len] = 1

return {
"token_ids": padded_tokens,
"padding_mask": padding_mask,
}

def _convert_text_input_to_model_input(
self,
prompts: Union[str | List[str]],
tokenizer:"AutoTokenizer",
max_length: int,
):
assert (
max_length is not None
), "max_length must be provided to generate() when inputs are strings."

tokens: Dict[str, np.ndarray] = tokenizer(
prompts,
max_length=max_length,
padding="max_length",
padding_side="right",
truncation=True,
return_tensors="np",
)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
return {
"token_ids": input_ids,
"padding_mask": attention_mask,
}

def _generate(
self,
model_input,
stop_token_ids=None,
strip_prompt=False,
inputs: Union[List[str], List[np.ndarray]],
max_length: int = None,
stop_token_ids: Optional[List] = None,
strip_prompt: str = False,
tokenizer: Optional["AutoTokenizer"] = None,
**kwargs,
) -> Dict[str, np.ndarray]:
"""Fall back to https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/causal_lm.py"""
) -> List[List[int]]:
"""Generate tokens using the model. This function falls back to KerasHub model's
native generation function:
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/causal_lm.py

Args:
inputs (list[str]|list[np.ndarray]): A list of strings, or a list
of numpy arrays containing integer token ids.
max_length (int, optional): Maximum total sequence length
(prompt + generated tokens).
stop_token_ids (List[int], optional): List of token IDs that stop
generation.
strip_prompt (bool, optional): If True, returns only the generated
tokens without the input prompt tokens. If False, returns all
tokens, including the prompt tokens. Defaults to False.
tokenizer (AutoTokenizer, optional): A HuggingFace AutoTokenizer instance.
This is guaranteed to be provided when inputs are strings.

Returns:
list[np.ndarray]: Generated token IDs (numpy.ndarray) for each prompt
"""

if isinstance(inputs[0], str):
inputs = self._convert_text_input_to_model_input(
inputs, tokenizer, max_length
)
else:
inputs = self._pad_tokens_to_max_length(inputs, max_length)

# stop_token_ids cannot be an empty list
stop_token_ids = stop_token_ids if stop_token_ids else None

tokens = self.model.generate(
model_input,
inputs,
stop_token_ids=stop_token_ids,
strip_prompt=strip_prompt,
)

# Return output but first stripped prompt
is_token = tokens["padding_mask"] == True
B, _ = tokens["token_ids"].shape
return {
"token_ids": tokens["token_ids"][is_token][None, :].reshape(B, -1),
"padding_mask": tokens["padding_mask"][is_token][None, :].reshape(B, -1),
}

results = []
for idx, _ in enumerate(inputs["token_ids"]):
is_token = tokens["padding_mask"][idx, :] == True
generated_tokens = tokens["token_ids"][idx, :][is_token]
results.append(generated_tokens.tolist())
return results

def save_in_hf_format(
self,
Expand Down
1 change: 1 addition & 0 deletions kithara/model/maxtext/JetStream
Submodule JetStream added at 982569
35 changes: 29 additions & 6 deletions kithara/model/maxtext/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def get_maxtext_pyconfig(
argv += [f"model_name={model_name}"]
if maxtext_config is not None:
argv += maxtext_config.split(" ")
# pyconfig.initialize must be called before
# any JAX computations are executed.
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
return config

@staticmethod
Expand All @@ -178,6 +175,7 @@ def initialize_random_maxtext_model(
weight_dtype: str,
activation_dtype: str,
scan_layers: bool,
max_prefill_predict_length: int,
maxtext_config_args: Optional[str] = None,
) -> tuple[ShardingStrategy, keras.Model]:
"""Initialize a random MaxText model with the input configuration.
Expand Down Expand Up @@ -221,7 +219,11 @@ def initialize_random_maxtext_model(
maxtext_config_args["per_device_batch_size"] = per_device_batch_size
assert "max_target_length" not in maxtext_config_args
maxtext_config_args["max_target_length"] = seq_len

assert "max_prefill_predict_length" not in maxtext_config_args
maxtext_config_args["max_prefill_predict_length"] = max_prefill_predict_length
if "enable_checkpointing" not in maxtext_config_args:
maxtext_config_args["enable_checkpointing"] = False

maxtext_config_args = " ".join(
[f"{key}={value}" for key, value in maxtext_config_args.items()]
)
Expand Down Expand Up @@ -263,4 +265,25 @@ def init_initial_state(model, rng):
)

print(f"✅ Successfully initialized a MaxText {model_name} model in {time.time() - start_time:.3f}s...")
return sharding_strategy, model
return maxtext_config, sharding_strategy, model

def get_maxtext_params(model: "kithara.MaxTextModel") -> dict:
"""Convert Kithara variables (flat list format) into
MaxText variables (nested dict format). This function
is a simple format converter. It is used for inserting
the current model parameters into MaxText's inference
engine.
"""
params = {}
for v in model.variables:
variable_name = v.path
nest_keys = variable_name.split("/")[-1].split("-")
if nest_keys[0] != "params":
continue

current = params
for key in nest_keys[:-1]: # All keys except the last one
current = current.setdefault(key, {})
current[nest_keys[-1]] = v.value

return params
68 changes: 68 additions & 0 deletions kithara/model/maxtext/inference_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kithara.model.maxtext.maxtext.MaxText.maxengine import MaxEngine
import jax
import max_utils

class MaxtextInferenceEngine(MaxEngine):
"""This is a patched version of MaxEngine

Changes:
- Added the `load_existing_params` function to allow
the engine to run with an exisitng model instead of initiating
a new one.
"""
def load_existing_params(self, params, rng=None) -> "Params":

if rng is None:
rng = jax.random.PRNGKey(0)

self.abstract_params = jax.tree_util.tree_map(
lambda x: (
jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
if isinstance(x, jax.Array)
else None
),
params,
)

self.prefill_kv_cache_annotations = max_utils.get_prefill_kv_cache_annotations(
self.model, self.config, rng, self._mesh
)
self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, x),
self.prefill_kv_cache_annotations,
)

if self.config.stack_prefill_result_cache:
# Add extra axis for the axis generated by the stack.
self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(
self._mesh, jax.sharding.PartitionSpec(None, *x.spec)
),
self.prefill_kv_cache_shardings,
)
self.prefill_kv_cache_shardings = self.prefill_kv_cache_shardings[
"decoder"
]["layers_0"]

self.kv_cache_annotations = max_utils.get_kv_cache_annotations(
self.model, self.config, rng, self._mesh
)
self.kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, x),
self.kv_cache_annotations,
)
return params
2 changes: 1 addition & 1 deletion kithara/model/maxtext/maxtext
Loading