Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightweight HF integration #220

Merged
merged 28 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/actions/setup-venv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ runs:
. .venv/bin/activate
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .[dev]
pip install -e hf_integration

- if: steps.virtualenv-cache.outputs.cache-hit == 'true'
shell: bash
run: |
# Set up virtual environment from cache hit.
. .venv/bin/activate
pip install --no-deps -e .[dev]
pip install -e hf_integration

- shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ jobs:
value: ":16:8"
- name: TOKENIZERS_PARALLELISM
value: "false"
command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/"]
command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/", "-k", "not hf_integration"]
result:
path: /unused
token: ${{ env.BEAKER_TOKEN }}
Expand Down
5 changes: 5 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,8 @@ def lorem_ipsum() -> str:
@pytest.fixture(scope="module")
def lorem_ipsum_docs() -> List[str]:
return [text.replace("\n", " ").strip() for text in (LOREM_IPSUM_1, LOREM_IPSUM_2)]


@pytest.fixture(scope="function")
def model_path() -> str:
return "test_fixtures/test-olmo-model"
3 changes: 3 additions & 0 deletions hf_integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .configuration_olmo import OLMoConfig
from .modeling_olmo import OLMoForCausalLM
from .tokenization_olmo_fast import OLMoTokenizerFast
38 changes: 38 additions & 0 deletions hf_integration/add_hf_config_to_olmo_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse
import logging
import os

from hf_integration.configuration_olmo import OLMoConfig
from olmo import Olmo

logger = logging.getLogger(__name__)


def write_config(checkpoint_dir: str):
# save config as HF config
logger.info(f"Loading checkpoint from {checkpoint_dir}")
model = Olmo.from_checkpoint(checkpoint_dir)

config = OLMoConfig(**model.config.asdict())

logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
config.save_pretrained(checkpoint_dir)


def main():
parser = argparse.ArgumentParser(
description="Adds a config.json to the checkpoint directory, making it easier to load weights as HF models."
)
parser.add_argument(
"--checkpoint-dir",
help="Location of OLMo checkpoint.",
)

args = parser.parse_args()
write_config(
checkpoint_dir=args.checkpoint_dir,
)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions hf_integration/configuration_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
OLMo configuration
"""

from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging

from olmo.config import ModelConfig

logger = logging.get_logger(__name__)


class OLMoConfig(PretrainedConfig):
model_type = "olmo"
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm

def __init__(self, use_cache: bool = False, **kwargs):
model_config = ModelConfig()
all_kwargs = model_config.asdict()
all_kwargs.update(kwargs)
all_kwargs.update({"use_cache": use_cache})
super().__init__(**all_kwargs)


# Register the config class so that it is available for transformer pipelines, auto-loading etc.
AutoConfig.register("olmo", OLMoConfig)
149 changes: 149 additions & 0 deletions hf_integration/modeling_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
from typing import List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

from olmo.config import ModelConfig
from olmo.model import Olmo

from .configuration_olmo import OLMoConfig


def create_model_config_from_pretrained_config(config: OLMoConfig):
"""
Utility function
"""
model_config = ModelConfig(
d_model=config.d_model,
n_heads=config.n_heads,
n_layers=config.n_layers,
mlp_ratio=config.mlp_ratio,
activation_type=config.activation_type,
block_type=config.block_type,
alibi=config.alibi,
alibi_bias_max=config.alibi_bias_max,
rope=config.rope,
flash_attention=config.flash_attention,
attention_dropout=config.attention_dropout,
attention_layer_norm=config.attention_layer_norm,
multi_query_attention=config.multi_query_attention,
residual_dropout=config.residual_dropout,
embedding_dropout=config.embedding_dropout,
layer_norm_type=config.layer_norm_type,
max_sequence_length=config.max_sequence_length,
include_bias=config.include_bias,
vocab_size=config.vocab_size,
embedding_size=config.embedding_size,
eos_token_id=config.eos_token_id,
pad_token_id=config.pad_token_id,
init_device=config.init_device,
init_std=config.init_std,
precision=config.precision,
)
return model_config


class OLMoForCausalLM(PreTrainedModel):
"""
Extremely barebones HF model wrapper.
"""

config_class = OLMoConfig

def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None):
super().__init__(config)

if not model:
model_config = create_model_config_from_pretrained_config(config)
self.model = Olmo(model_config, init_params=True)
else:
self.model = model

# def forward(self, *args, **kwargs):
# # use_cache = self.config.use_cache or kwargs.pop("use_cache", False)
# kwargs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
# return self.model.forward(*args, **kwargs)

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if use_cache is None:
use_cache = self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
)

logits = outputs.logits

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.embedding_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.attn_key_values,
)

def generate(self, input_ids, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to implement this or can we get this for free using HF built-in generate functionality?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, now. Needed a couple small methods implemented.

with torch.no_grad():
res = self.model.generate(input_ids, **kwargs)
# Add back input_ids to top beam output since this is what's expected for AutoModelForCausalLM
return torch.cat((input_ids, res.token_ids[:, 0]), dim=1)

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs
):
assert pretrained_model_name_or_path is not None
model = Olmo.from_checkpoint(pretrained_model_name_or_path)
config = OLMoConfig(**model.config.asdict())
return cls(config, model)

# TODO: these 4 are required to make the implementation complete.
# def resize_position_embeddings(self, new_num_position_embeddings: int):
# pass
#
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
# pass
#
# def prepare_inputs_for_generation(self, *args, **kwargs):
# pass
#
# def _reorder_cache(self, past_key_values, beam_idx):
# pass


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
100 changes: 100 additions & 0 deletions hf_integration/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
[project]
name = "olmo_hf_integration"
version = "0.0.1"
description = "Lightweight HF-OLMo"
authors = [
{name = "Akshita Bhagia", email = "akshitab@allenai.org" }
]
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"transformers>=4.27",
"tokenizers",
]
classifiers = [
"Development Status :: 3 - Alpha",
"Typing :: Typed",
]

[project.urls]
Homepage = "https://github.com/allenai/llm"


[project.optional-dependencies]
dev = [
"black>=22.6.0",
"isort>=5.10.1",
"mypy>=0.971",
"pytest>=5.2",
"ipython>=8.4.0",
"autopep8>=1.7.0",
"flake8>=5.0",
"ipdb>=0.13.0",
"flake8-pyi>=22.8.1",
"Flake8-pyproject>=1.1.0",
]

[build-system]
build-backend = "setuptools.build_meta"
requires = [
"setuptools >= 61.0.0",
"wheel"
]

[tool.setuptools.package-data]
olmo_hf_integration = ["py.typed"]

[tool.black]
line-length = 115

include = '\.pyi?$'

exclude = '''
(
__pycache__
| \.git
| \.mypy_cache
| \.pytest_cache
| \.vscode
| \.venv
| \bdist\b
| \bdoc\b
)
'''

[tool.setuptools]
py-modules = []

[tool.isort]
profile = "black"
multi_line_output = 3

[tool.autopep8]
max_line_length = 79
in-place = true
recursive = true
aggressive = 3

[tool.mypy]
python_version = 3.8
ignore_missing_imports = true
no_site_packages = true
allow_redefinition = false
warn_unused_configs = true
warn_unused_ignores = true
warn_no_return = true
warn_return_any = false
warn_unreachable = true
show_error_codes = true
pretty = true

[tool.mypy-tests]
strict_optional = false

[tool.flake8]
per-file-ignores = [
'__init__.py:F401',
'*.pyi:E302,E305',
'*.py:E203'
]
1 change: 1 addition & 0 deletions hf_integration/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers
31 changes: 31 additions & 0 deletions hf_integration/tokenization_olmo_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
from typing import Union

from transformers import AutoTokenizer, PreTrainedTokenizerFast

from hf_integration.configuration_olmo import OLMoConfig
from olmo import Tokenizer


class OLMoTokenizerFast(PreTrainedTokenizerFast):
# Note: Olmo's tokenizer is already a wrapper around huggingface. This is potentially unnecessary.

# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# # This is required to make the implementation complete.
# pass

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
tokenizer_raw = Tokenizer.from_checkpoint(pretrained_model_name_or_path)
tokenizer = cls(
tokenizer_object=tokenizer_raw.base_tokenizer,
truncation=tokenizer_raw.truncate_direction,
max_length=tokenizer_raw.truncate_to,
eos_token=tokenizer_raw.decode([tokenizer_raw.eos_token_id], skip_special_tokens=False),
)
tokenizer.model_input_names = ["input_ids", "attention_mask"]
return tokenizer


# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc.
AutoTokenizer.register(OLMoConfig, fast_tokenizer_class=OLMoTokenizerFast)
Loading