Skip to content

Commit

Permalink
Add H2O Danube2 Checkpoint (#1282)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Antiga <luca@lightning.ai>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Co-authored-by: Andrei-Aksionov <aksionau.andrei@gmail.com>
Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com>
  • Loading branch information
5 people committed May 3, 2024
1 parent d39b26a commit e441c65
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials
|----|----|----|----|
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
Expand Down
29 changes: 28 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def norm_class(self) -> Type:
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)


###############
# Meta LLaMA 3
Expand Down Expand Up @@ -964,6 +964,33 @@ def norm_class(self) -> Type:
]
configs.extend(codegemma)

################
# H2Oai Danube2
################
danube2 = [
# https://huggingface.co/h2oai/h2o-danube2-1.8b-chat/blob/main/config.json
dict(
name="Danube2-1.8b-chat",
hf_config=dict(org="h2oai", name="h2o-danube2-1.8b-chat"),
vocab_size=32000,
n_layer=24,
n_head=32,
n_embd=2560,
block_size=4096, # should be 8192 but sliding_window mechanism is not implemented
intermediate_size=6912,
padding_multiple=64,
norm_eps=1e-05,
rope_base=10000,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
)
]
configs.extend(danube2)


##########################
# Stability AI FreeWilly2
Expand Down
8 changes: 8 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ def apply(self, prompt: str, **kwargs: str) -> str:
return f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"


class H2Oai(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|prompt|>{prompt}</s><|answer|>"


# Maps prompt style names to PromptStyle classes
prompt_styles: Dict[str, Type[PromptStyle]] = {
# Dataset-specific prompt styles
Expand All @@ -312,6 +317,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"phi-2": Phi2,
"tinyllama": TinyLlama,
"gemma": Gemma,
"h2oai": H2Oai,
}


Expand Down Expand Up @@ -352,6 +358,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return TinyLlama()
if re.search(r"(Code)?Gemma.*-it", model_name):
return Gemma()
if re.search("Danube2.*-chat", model_name):
return H2Oai()
return Default()


Expand Down
58 changes: 58 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,64 @@ def test_against_hf_mixtral():
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_hf_h2o_danube(device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
"Danube2-1.8b-chat",
padded_vocab_size=10000,
n_layer=2,
n_embd=16,
n_head=8,
n_query_groups=2,
intermediate_size=43,
)
T = 5
theirs_config = MistralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = MistralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
Expand Down
24 changes: 15 additions & 9 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

LitGPT supports a variety of LLM architectures with publicly available weights. You can download model weights and access a list of supported models using the LitGPT `download.py` script.


| Model | Model size | Reference |
|----------------------------------------------|-----------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
| CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 by H2O.ai | 1.8B | [H2O.ai](https://h2o.ai/platform/danube-1-8b/)
| Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) |
| FreeWilly2 (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
Expand All @@ -28,11 +28,9 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| TinyLlama by Zhang et al. | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Vicuna by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |



&nbsp;
## General Instructions

## General Instructions

### 1. List Available Models

Expand Down Expand Up @@ -91,6 +89,7 @@ google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
h2oai/h2o-danube2-1.8b-chat
lmsys/longchat-13b-16k
lmsys/longchat-7b-16k
lmsys/vicuna-13b-v1.3
Expand Down Expand Up @@ -165,26 +164,28 @@ unsloth/Mistral-7B-v0.2

> [!NOTE]
> If you want to adopt a model variant that is not listed in the table above but has a similar architecture as one of the supported models, you can use this model by by using the `--model_name` argument as shown below:
>
> ```bash
> litgpt download \
> --repo_id NousResearch/Hermes-2-Pro-Mistral-7B \
> --model_name Mistral-7B-v0.1
> ```

&nbsp;

### 2. Download Model Weights

To download the weights for a specific model, use the `--repo_id` argument. Replace `<repo_id>` with the model's repository ID. For example:

```bash
litgpt download --repo_id <repo_id>
```

This command downloads the model checkpoint into the `checkpoints/` directory.

&nbsp;
### 3. Additional Help

### 3. Additional Help

For more options, add the `--help` flag when running the script:

Expand All @@ -193,6 +194,7 @@ litgpt download --help
```

&nbsp;

### 4. Run the Model

After conversion, run the model with the `--checkpoint_dir` flag, adjusting `repo_id` accordingly:
Expand All @@ -202,6 +204,7 @@ litgpt chat --checkpoint_dir checkpoints/<repo_id>
```

&nbsp;

## Tinyllama Example

This section shows a typical end-to-end example for downloading and using TinyLlama:
Expand Down Expand Up @@ -235,7 +238,7 @@ litgpt chat --checkpoint_dir checkpoints/$repo_id

Note that certain models require that you've been granted access to the weights on the Hugging Face Hub.

For example, to get access to the Gemma 2B model, you can do so by following the steps at https://huggingface.co/google/gemma-2b. After access is granted, you can find your HF hub token in https://huggingface.co/settings/tokens.
For example, to get access to the Gemma 2B model, you can do so by following the steps at <https://huggingface.co/google/gemma-2b>. After access is granted, you can find your HF hub token in <https://huggingface.co/settings/tokens>.

Once you've been granted access and obtained the access token you need to pass the additional `--access_token`:

Expand All @@ -246,7 +249,8 @@ litgpt download \
```

&nbsp;
## Finetunes and other model variants

## Finetunes and Other Model Variants

Sometimes you want to download the weights of a finetune of one of the models listed above. To do this, you need to manually specify the `model_name` associated to the config to use. For example:

Expand All @@ -257,11 +261,11 @@ litgpt download \
```

&nbsp;

## Tips for GPU Memory Limitations

The `download.py` script will automatically convert the downloaded model checkpoint into a LitGPT-compatible format. In case this conversion fails due to GPU memory constraints, you can try to reduce the memory requirements by passing the `--dtype bf16-true` flag to convert all parameters into this smaller precision (however, note that most model weights are already in a bfloat16 format, so it may not have any effect):


```bash
litgpt download \
--repo_id <repo_id>
Expand All @@ -271,6 +275,7 @@ litgpt download \
(If your GPU does not support the bfloat16 format, you can also try a regular 16-bit float format via `--dtype 16-true`.)

&nbsp;

## Converting Checkpoints Manually

For development purposes, for example, when adding or experimenting with new model configurations, it may be beneficial to split the weight download and model conversion into two separate steps.
Expand All @@ -291,6 +296,7 @@ litgpt convert to_litgpt \
```

&nbsp;

## Downloading Tokenizers Only

In some cases we don't need the model weight, for example, when we are pretraining a model from scratch instead of finetuning it. For cases like this, you can use the `--tokenizer_only` flag to only download a model's tokenizer, which can then be used in the pretraining scripts:
Expand Down

0 comments on commit e441c65

Please sign in to comment.