-
Notifications
You must be signed in to change notification settings - Fork 400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lightweight HF integration #220
Merged
Merged
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
442f86c
first commit
AkshitaB c7c170d
second pass, textgen pipeline works
AkshitaB 2befdc5
add tests
AkshitaB b5da6b4
tokenizer
AkshitaB 6935e5f
make auto compatible
AkshitaB 5b623fc
tests, cleanup
AkshitaB bdbf041
pipeline test
AkshitaB e6df083
add requirement file
AkshitaB 6db6980
pyproject
AkshitaB c3611d7
get tests to work
AkshitaB 71abd14
rename for consistency
AkshitaB c92f25e
add test fixture
AkshitaB e0fd0a7
ignore hf integration tests on gpu
AkshitaB bfa379f
fix
AkshitaB 754eff1
move imports
AkshitaB 9e7b57c
use_cache config to arg
AkshitaB 6159097
update forward
AkshitaB 126e5a5
style
AkshitaB 7c8243d
add missing kwargs, fix from_pretrained to use device_map
AkshitaB 7abd90f
use HF's default generation tools
AkshitaB e4c14a6
style
AkshitaB e5f91a3
ensure that HF generation uses cache
AkshitaB d5006a5
update comment
AkshitaB 72248f0
rename to hf_olmo
AkshitaB 89aaf97
fix
AkshitaB 9659955
fix github actions
AkshitaB 833e4f5
fix finally
AkshitaB 1a4ba1b
add reqs to cache
AkshitaB File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .configuration_olmo import OLMoConfig | ||
from .modeling_olmo import OLMoForCausalLM | ||
from .tokenization_olmo_fast import OLMoTokenizerFast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import argparse | ||
import logging | ||
import os | ||
|
||
from hf_integration.configuration_olmo import OLMoConfig | ||
from olmo import Olmo | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def write_config(checkpoint_dir: str): | ||
# save config as HF config | ||
logger.info(f"Loading checkpoint from {checkpoint_dir}") | ||
model = Olmo.from_checkpoint(checkpoint_dir) | ||
|
||
config = OLMoConfig(**model.config.asdict()) | ||
|
||
logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}") | ||
config.save_pretrained(checkpoint_dir) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Adds a config.json to the checkpoint directory, making it easier to load weights as HF models." | ||
) | ||
parser.add_argument( | ||
"--checkpoint-dir", | ||
help="Location of OLMo checkpoint.", | ||
) | ||
|
||
args = parser.parse_args() | ||
write_config( | ||
checkpoint_dir=args.checkpoint_dir, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
OLMo configuration | ||
""" | ||
|
||
from transformers import AutoConfig, PretrainedConfig | ||
from transformers.utils import logging | ||
|
||
from olmo.config import ModelConfig | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class OLMoConfig(PretrainedConfig): | ||
model_type = "olmo" | ||
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm | ||
|
||
def __init__(self, use_cache: bool = False, **kwargs): | ||
model_config = ModelConfig() | ||
all_kwargs = model_config.asdict() | ||
all_kwargs.update(kwargs) | ||
all_kwargs.update({"use_cache": use_cache}) | ||
super().__init__(**all_kwargs) | ||
|
||
|
||
# Register the config class so that it is available for transformer pipelines, auto-loading etc. | ||
AutoConfig.register("olmo", OLMoConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import os | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
from transformers import PreTrainedModel | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from transformers.models.auto import AutoModelForCausalLM | ||
|
||
from olmo.config import ModelConfig | ||
from olmo.model import Olmo | ||
|
||
from .configuration_olmo import OLMoConfig | ||
|
||
|
||
def create_model_config_from_pretrained_config(config: OLMoConfig): | ||
""" | ||
Utility function | ||
""" | ||
model_config = ModelConfig( | ||
d_model=config.d_model, | ||
n_heads=config.n_heads, | ||
n_layers=config.n_layers, | ||
mlp_ratio=config.mlp_ratio, | ||
activation_type=config.activation_type, | ||
block_type=config.block_type, | ||
alibi=config.alibi, | ||
alibi_bias_max=config.alibi_bias_max, | ||
rope=config.rope, | ||
flash_attention=config.flash_attention, | ||
attention_dropout=config.attention_dropout, | ||
attention_layer_norm=config.attention_layer_norm, | ||
multi_query_attention=config.multi_query_attention, | ||
residual_dropout=config.residual_dropout, | ||
embedding_dropout=config.embedding_dropout, | ||
layer_norm_type=config.layer_norm_type, | ||
max_sequence_length=config.max_sequence_length, | ||
include_bias=config.include_bias, | ||
vocab_size=config.vocab_size, | ||
embedding_size=config.embedding_size, | ||
eos_token_id=config.eos_token_id, | ||
pad_token_id=config.pad_token_id, | ||
init_device=config.init_device, | ||
init_std=config.init_std, | ||
precision=config.precision, | ||
) | ||
return model_config | ||
|
||
|
||
class OLMoForCausalLM(PreTrainedModel): | ||
""" | ||
Extremely barebones HF model wrapper. | ||
""" | ||
|
||
config_class = OLMoConfig | ||
|
||
def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None): | ||
super().__init__(config) | ||
|
||
if not model: | ||
model_config = create_model_config_from_pretrained_config(config) | ||
self.model = Olmo(model_config, init_params=True) | ||
else: | ||
self.model = model | ||
|
||
# def forward(self, *args, **kwargs): | ||
# # use_cache = self.config.use_cache or kwargs.pop("use_cache", False) | ||
# kwargs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache) | ||
# return self.model.forward(*args, **kwargs) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, CausalLMOutputWithPast]: | ||
if use_cache is None: | ||
use_cache = self.config.use_cache | ||
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
outputs = self.model.forward( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
past_key_values=past_key_values, | ||
use_cache=use_cache, | ||
) | ||
|
||
logits = outputs.logits | ||
|
||
loss = None | ||
if labels is not None: | ||
# Shift so that tokens < n predict n | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
loss_fct = torch.nn.CrossEntropyLoss() | ||
shift_logits = shift_logits.view(-1, self.config.embedding_size) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
shift_labels = shift_labels.to(shift_logits.device) | ||
loss = loss_fct(shift_logits, shift_labels) | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
return (loss,) + output if loss is not None else output | ||
|
||
return CausalLMOutputWithPast( | ||
loss=loss, | ||
logits=logits, | ||
past_key_values=outputs.attn_key_values, | ||
) | ||
|
||
def generate(self, input_ids, *args, **kwargs): | ||
with torch.no_grad(): | ||
res = self.model.generate(input_ids, **kwargs) | ||
# Add back input_ids to top beam output since this is what's expected for AutoModelForCausalLM | ||
return torch.cat((input_ids, res.token_ids[:, 0]), dim=1) | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs | ||
): | ||
assert pretrained_model_name_or_path is not None | ||
model = Olmo.from_checkpoint(pretrained_model_name_or_path) | ||
config = OLMoConfig(**model.config.asdict()) | ||
return cls(config, model) | ||
|
||
# TODO: these 4 are required to make the implementation complete. | ||
# def resize_position_embeddings(self, new_num_position_embeddings: int): | ||
# pass | ||
# | ||
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]: | ||
# pass | ||
# | ||
# def prepare_inputs_for_generation(self, *args, **kwargs): | ||
# pass | ||
# | ||
# def _reorder_cache(self, past_key_values, beam_idx): | ||
# pass | ||
|
||
|
||
# Register the model so that it is available for transformer pipelines, auto-loading, etc. | ||
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
[project] | ||
name = "olmo_hf_integration" | ||
version = "0.0.1" | ||
description = "Lightweight HF-OLMo" | ||
authors = [ | ||
{name = "Akshita Bhagia", email = "akshitab@allenai.org" } | ||
] | ||
license = {text = "Apache-2.0"} | ||
readme = "README.md" | ||
requires-python = ">=3.8" | ||
dependencies = [ | ||
"transformers>=4.27", | ||
"tokenizers", | ||
] | ||
classifiers = [ | ||
"Development Status :: 3 - Alpha", | ||
"Typing :: Typed", | ||
] | ||
|
||
[project.urls] | ||
Homepage = "https://github.com/allenai/llm" | ||
|
||
|
||
[project.optional-dependencies] | ||
dev = [ | ||
"black>=22.6.0", | ||
"isort>=5.10.1", | ||
"mypy>=0.971", | ||
"pytest>=5.2", | ||
"ipython>=8.4.0", | ||
"autopep8>=1.7.0", | ||
"flake8>=5.0", | ||
"ipdb>=0.13.0", | ||
"flake8-pyi>=22.8.1", | ||
"Flake8-pyproject>=1.1.0", | ||
] | ||
|
||
[build-system] | ||
build-backend = "setuptools.build_meta" | ||
requires = [ | ||
"setuptools >= 61.0.0", | ||
"wheel" | ||
] | ||
|
||
[tool.setuptools.package-data] | ||
olmo_hf_integration = ["py.typed"] | ||
|
||
[tool.black] | ||
line-length = 115 | ||
|
||
include = '\.pyi?$' | ||
|
||
exclude = ''' | ||
( | ||
__pycache__ | ||
| \.git | ||
| \.mypy_cache | ||
| \.pytest_cache | ||
| \.vscode | ||
| \.venv | ||
| \bdist\b | ||
| \bdoc\b | ||
) | ||
''' | ||
|
||
[tool.setuptools] | ||
py-modules = [] | ||
|
||
[tool.isort] | ||
profile = "black" | ||
multi_line_output = 3 | ||
|
||
[tool.autopep8] | ||
max_line_length = 79 | ||
in-place = true | ||
recursive = true | ||
aggressive = 3 | ||
|
||
[tool.mypy] | ||
python_version = 3.8 | ||
ignore_missing_imports = true | ||
no_site_packages = true | ||
allow_redefinition = false | ||
warn_unused_configs = true | ||
warn_unused_ignores = true | ||
warn_no_return = true | ||
warn_return_any = false | ||
warn_unreachable = true | ||
show_error_codes = true | ||
pretty = true | ||
|
||
[tool.mypy-tests] | ||
strict_optional = false | ||
|
||
[tool.flake8] | ||
per-file-ignores = [ | ||
'__init__.py:F401', | ||
'*.pyi:E302,E305', | ||
'*.py:E203' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import os | ||
from typing import Union | ||
|
||
from transformers import AutoTokenizer, PreTrainedTokenizerFast | ||
|
||
from hf_integration.configuration_olmo import OLMoConfig | ||
from olmo import Tokenizer | ||
|
||
|
||
class OLMoTokenizerFast(PreTrainedTokenizerFast): | ||
# Note: Olmo's tokenizer is already a wrapper around huggingface. This is potentially unnecessary. | ||
|
||
# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: | ||
# # This is required to make the implementation complete. | ||
# pass | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): | ||
tokenizer_raw = Tokenizer.from_checkpoint(pretrained_model_name_or_path) | ||
tokenizer = cls( | ||
tokenizer_object=tokenizer_raw.base_tokenizer, | ||
truncation=tokenizer_raw.truncate_direction, | ||
max_length=tokenizer_raw.truncate_to, | ||
eos_token=tokenizer_raw.decode([tokenizer_raw.eos_token_id], skip_special_tokens=False), | ||
) | ||
tokenizer.model_input_names = ["input_ids", "attention_mask"] | ||
return tokenizer | ||
|
||
|
||
# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc. | ||
AutoTokenizer.register(OLMoConfig, fast_tokenizer_class=OLMoTokenizerFast) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to implement this or can we get this for free using HF built-in
generate
functionality?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can, now. Needed a couple small methods implemented.