Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ the tokenizer that we will use.
Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.

```bash
# Install huggingface-cli and login if it's not set up.
pip install -U "huggingface_hub[cli]"
huggingface-cli login
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
```

Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object)

## Mixtral
### Get Mixtral Checkpoint from HuggingFace

Expand Down
9 changes: 1 addition & 8 deletions convert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,21 +428,14 @@ def _get_llama_state_dict(input_ckpt_dir):
return state_dict, params


def fix_json(text):
text = text.replace("'", '"')
lines = text.split("\n")
lines[-3] = lines[-3].replace(",", "")
return "\n".join(lines)


def _get_gemma_state_dict(input_ckpt_dir):
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
ckpt_file = ckpt_file[0]
state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[
"model_state_dict"
]
config_text = fix_json((input_ckpt_dir / "config.json").read_text())
config_text = (input_ckpt_dir / "config.json").read_text()
model_config = json.loads(config_text)
for key in list(state_dict.keys()):
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
Expand Down