Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions docs/source/api/kithara.model_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ save_in_hf_format
- Model weights file (model.safetensors for models smaller than DEFAULT_MAX_SHARD_SIZE, model-x-of-x.safetensors for larger models)
- Safe tensors index file (model.safetensors.index.json)

:param output_dir: Directory path where the model should be saved. Can be local or Google cloud storage path.
Will be created if it doesn't exist.
:param output_dir: Directory path where the model should be saved. Can be a local folder (e.g. "foldername/"),
HuggingFaceHub repo prefixed with "hf://" (e.g. "hf://your_hf_id/repo_name") or a
Google cloud storage path prefixed with "gs://" (e.g. "gs://your_bucket/folder_name),
and will be created if it doesn't exist.
:param dtype: Data type for saved weights. Defaults to "auto" which saves the model in its current precision type. (default: "auto")
:param parallel_threads: Number of parallel threads to use for saving (default: 8).
Note: Local system must have at least parallel_threads * DEFAULT_MAX_SHARD_SIZE free disk space,
Expand Down Expand Up @@ -155,8 +157,10 @@ save_in_hf_format

Save the model in HuggingFace format, including configuration and weights files.

:param output_dir: Directory path where the model should be saved. Can be local or Google cloud storage path.
Will be created if it doesn't exist.
:param output_dir: Directory path where the model should be saved. Can be a local folder (e.g. "foldername/"),
HuggingFaceHub repo prefixed with "hf://" (e.g. "hf://your_hf_id/repo_name") or a
Google cloud storage path prefixed with "gs://" (e.g. "gs://your_bucket/folder_name),
and will be created if it doesn't exist.
:param dtype: Data type for saved weights. Defaults to "auto" which saves the model in its current precision type.
:param only_save_adapters: If True, only adapter weights will be saved. If False, both base model weights and adapter weights will be saved. (default: False)
:param save_adapters_separately: If False, adapter weights will be merged with base model. If True, adapter weights will be saved separately in HuggingFace's peft format. (default: False)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/finetuning_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Quick tips:
- Always start model handles with ``hf://`` when loading from HuggingFace - so we know you are not loading from local directory 😀
- The default precision ``mixed_bfloat16`` is your friend - it's memory-friendly! It loads model weights in full precision and casts activations to bfloat16.
- Check out our :doc:`model garden <models>` for supported architectures
- Want to save your model? Simply do ``model.save_in_hf_format(local_dir_or_gs_bucket)``
- Want to save your model? Simply do ``model.save_in_hf_format(destination)`` to either save it locally, to GCS, or to HuggingFace.
- Check out :doc:`Model API <api/kithara.model_api>` documentation

2. Prepare Your Data
Expand Down
6 changes: 3 additions & 3 deletions docs/source/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ You have three options for saving models trained with LoRA:
Since the base model is left unchanged, you can save just the LoRA Adapters::

model.save_in_hf_format(
local_dir_or_gs_bucket,
destination,
only_save_adapters=True
)

Expand All @@ -44,7 +44,7 @@ Since the base model is left unchanged, you can save just the LoRA Adapters::
In case you want to save the base model as well. ::

model.save_in_hf_format(
local_dir_or_gs_bucket,
destination,
only_save_adapters=False,
save_adapters_separately=True
)
Expand All @@ -54,7 +54,7 @@ In case you want to save the base model as well. ::
Creates a single model combining base weights and adaptations::

model.save_in_hf_format(
local_dir_or_gs_bucket,
destination,
save_adapters_separately=False
)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ We support all safetensor formatted models on `HuggingFace Hub <https://huggingf
* - Gemma 2
- 2B, 9B, 27B
- google/gemma-2-2b
* - Llama 3
* - Llama 3.1
- 8B, 27B, 405B
- meta-llama/Llama-3.1-8B
2 changes: 2 additions & 0 deletions docs/source/pretraining.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ Step 7: Save Model
Save the model in the HuggingFace format::

model.save_in_hf_format("gs://my-bucket/models")
# Or, if you prefer saving to HuggingFace Hub
# model.save_in_hf_format("hf://my-hf-id/repo-name")


Notes
Expand Down
11 changes: 7 additions & 4 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ The script can also be found on `Github <https://github.com/AI-Hypercomputer/kit

Setup
-----
Import required packages::
Log into HuggingFace and import required packages::

from huggingface_hub import login
login(token="your_hf_token", add_to_git_credential=False)

import os
os.environ["KERAS_BACKEND"] = "jax"
Expand All @@ -23,6 +26,9 @@ Import required packages::
SFTDataset,
)

.. tip::
New to HuggingFace? First create an access token, `apply access <https://huggingface.co/google/gemma-2-2b>`_ to the Gemma2 HuggingFace model which will be used in this example.

Quick Usage
----------

Expand All @@ -33,9 +39,6 @@ Quick Usage
precision="mixed_bfloat16",
lora_rank=4,
)

.. tip::
New to HuggingFace? First create an access token, `apply access <https://huggingface.co/google/gemma-2-2b>`_ to the HuggingFace model, and set the ``HF_TOKEN`` environment variable.

2. Prepare Dataset::

Expand Down
2 changes: 1 addition & 1 deletion docs/source/sft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Step 6: Save Model
Save the model in the Hugging Face format::

model.save_in_hf_format(
"model_output/", # You can also save the model to a Google Cloud Storage bucket
"model_output/", # You can also save the model to Google Cloud Storage, or directly to HuggingFace Hub
only_save_adapters=True, # You can also save the base model, or merge the base model with the adapters
save_adapters_separately=True
)
Expand Down
92 changes: 77 additions & 15 deletions kithara/model/hf_compatibility/to_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
Copyright 2025 Google LLC
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0
https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
Utilities to convert Kithara model weights to HuggingFace format.
Expand Down Expand Up @@ -42,6 +42,8 @@
from concurrent.futures import ThreadPoolExecutor
from peft import LoraConfig, PeftConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from huggingface_hub import HfApi, repo_exists


def apply_hook_fns(weight, target_shape, hook_fns):
if hook_fns is None:
Expand Down Expand Up @@ -88,7 +90,7 @@ def process_weight(variable, mappings, debug=False):
variable_path = variable.path
if variable.path.startswith("max_text_layer"):
variable_path = variable.path.split("/")[-1]

hf_paths = mappings["param_mapping"][variable_path]
if isinstance(hf_paths, str):
hf_paths = [hf_paths]
Expand Down Expand Up @@ -127,6 +129,12 @@ def save_lora_files(
if lora_config == None:
print("WARNING: There is no LoRA adapter to be saved. ")
return

if output_dir.startswith("hf://"):
create_huggingface_hub_repo_if_not_exist(
repo_id=output_dir.lstrip("hf://"), repo_type="model"
)

local_dir = _get_local_directory(output_dir)
# Save adapter_config.json
save_peft_config_file(lora_config, local_dir, output_dir)
Expand All @@ -136,9 +144,36 @@ def save_lora_files(
)


def create_huggingface_hub_repo_if_not_exist(repo_id, repo_type):
if not repo_exists(repo_id, repo_type=repo_type):
api = HfApi()
api.create_repo(
repo_id=repo_id,
repo_type=repo_type,
exist_ok=True,
private=True,
)
print(f"\n Created new HuggingFace Hub {repo_type} repo: {repo_id}.")


def upload_file_to_huggingface_hub(local_path, file_name, repo_id, repo_type):
api = HfApi()
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=file_name,
repo_id=repo_id,
repo_type=repo_type,
)


def save_model_files(weight_arrays: Dict, config, output_dir: str, parallel_threads=8):
"""Saves model files (config and weights) to the specified directory."""
start_time = time.time()

if output_dir.startswith("hf://"):
create_huggingface_hub_repo_if_not_exist(
repo_id=output_dir.lstrip("hf://"), repo_type="model"
)

print(f"\n-> Saving weights to {output_dir}...")

local_dir = _get_local_directory(output_dir)
Expand All @@ -151,11 +186,10 @@ def save_model_files(weight_arrays: Dict, config, output_dir: str, parallel_thre
save_weight_files(shards, index, local_dir, output_dir, parallel_threads)



def _get_local_directory(output_dir: str) -> str:
"""Determines the local directory for saving files."""
local_dir = output_dir
if local_dir.startswith("gs://"):
if local_dir.startswith("gs://") or local_dir.startswith("hf://"):
local_dir = os.path.join(find_cache_root_dir(), "temp_ckpt")
os.makedirs(local_dir, exist_ok=True)
return local_dir
Expand All @@ -173,6 +207,13 @@ def save_index_file(index: dict, local_dir: str, output_dir: str, file_name: str
os.path.join(output_dir, file_name),
remove_local_file_after_upload=True,
)
elif output_dir.startswith("hf://"):
upload_file_to_huggingface_hub(
local_path=local_path,
file_name=file_name,
repo_id=output_dir.lstrip("hf://"),
repo_type="model",
)


def save_config_file(config, local_dir: str, output_dir: str, file_name: str):
Expand All @@ -187,6 +228,13 @@ def save_config_file(config, local_dir: str, output_dir: str, file_name: str):
os.path.join(output_dir, file_name),
remove_local_file_after_upload=True,
)
elif output_dir.startswith("hf://"):
upload_file_to_huggingface_hub(
local_path=local_path,
file_name=file_name,
repo_id=output_dir.lstrip("hf://"),
repo_type="model",
)


def save_peft_config_file(config: PeftConfig, local_dir: str, output_dir: str):
Expand All @@ -200,6 +248,13 @@ def save_peft_config_file(config: PeftConfig, local_dir: str, output_dir: str):
os.path.join(output_dir, SAFE_TENSORS_PEFT_CONFIG_FILE),
remove_local_file_after_upload=True,
)
elif output_dir.startswith("hf://"):
upload_file_to_huggingface_hub(
local_path=local_path,
file_name=SAFE_TENSORS_PEFT_CONFIG_FILE,
repo_id=output_dir.lstrip("hf://"),
repo_type="model",
)


def save_safetensor_file(state_dict, local_dir, output_dir, file_name):
Expand All @@ -213,6 +268,13 @@ def save_safetensor_file(state_dict, local_dir, output_dir, file_name):
upload_file_to_gcs(
local_path, cloud_path, remove_local_file_after_upload=True
)
elif output_dir.startswith("hf://"):
upload_file_to_huggingface_hub(
local_path=local_path,
file_name=file_name,
repo_id=output_dir.lstrip("hf://"),
repo_type="model",
)


def save_weight_files(
Expand Down
6 changes: 4 additions & 2 deletions kithara/model/kerashub/keras_hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ def save_in_hf_format(

Args:
output_dir (str): Directory path where the model should be saved.
Directory could be local or a Google cloud storage path, and will be created if
it doesn't exist.
Directory could be a local folder (e.g. "foldername/"),
HuggingFaceHub repo (e.g. "hf://your_hf_id/repo_name") or a
Google cloud storage path (e.g. "gs://your_bucket/folder_name),
and will be created if it doesn't exist.
dtype (str, optional): Data type for saved weights. Defaults to "auto".
only_save_adapters (bool): If set to True, only adapter weights will be saved. If
set to False, both base model weights and adapter weights will be saved. Default
Expand Down
6 changes: 4 additions & 2 deletions kithara/model/maxtext/maxtext_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,10 @@ def save_in_hf_format(

Args:
output_dir (str): Directory path where the model should be saved.
Directory could be local or a Google cloud storage path, and
will be created if it doesn't exist.
Directory could be a local folder (e.g. "foldername/"),
HuggingFaceHub repo (e.g. "hf://your_hf_id/repo_name") or a
Google cloud storage path (e.g. "gs://your_bucket/folder_name),
and will be created if it doesn't exist.
dtype (str, optional): Data type for saved weights. Defaults to "auto".
parallel_threads (int, optional): Number of parallel threads to use for saving.
Defaults to 8. Make sure the local system has at least
Expand Down
Loading