Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightweight HF integration #220

Merged
merged 28 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/actions/setup-venv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ runs:
. .venv/bin/activate
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .[dev]
pip install -e hf_olmo
AkshitaB marked this conversation as resolved.
Show resolved Hide resolved

- if: steps.virtualenv-cache.outputs.cache-hit == 'true'
shell: bash
run: |
# Set up virtual environment from cache hit.
. .venv/bin/activate
pip install --no-deps -e .[dev]
pip install -e hf_olmo
AkshitaB marked this conversation as resolved.
Show resolved Hide resolved

- shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ jobs:
value: ":16:8"
- name: TOKENIZERS_PARALLELISM
value: "false"
command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/"]
command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/", "-k", "not hf_olmo"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add another job to run the HF tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline: since this requires updating the beaker image on which we run the GPU tests, and since we expect to reconfigure this at some point, it's not worth the effort now.

I've confirmed using instruct-eval that the HF integration runs on GPU.

result:
path: /unused
token: ${{ env.BEAKER_TOKEN }}
Expand Down
5 changes: 5 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,8 @@ def lorem_ipsum() -> str:
@pytest.fixture(scope="module")
def lorem_ipsum_docs() -> List[str]:
return [text.replace("\n", " ").strip() for text in (LOREM_IPSUM_1, LOREM_IPSUM_2)]


@pytest.fixture(scope="function")
def model_path() -> str:
return "test_fixtures/test-olmo-model"
3 changes: 3 additions & 0 deletions hf_olmo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .configuration_olmo import OLMoConfig
from .modeling_olmo import OLMoForCausalLM
from .tokenization_olmo_fast import OLMoTokenizerFast
38 changes: 38 additions & 0 deletions hf_olmo/add_hf_config_to_olmo_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse
import logging
import os

from hf_olmo.configuration_olmo import OLMoConfig
from olmo import Olmo

logger = logging.getLogger(__name__)


def write_config(checkpoint_dir: str):
# save config as HF config
logger.info(f"Loading checkpoint from {checkpoint_dir}")
model = Olmo.from_checkpoint(checkpoint_dir)

config_kwargs = model.config.asdict()
config_kwargs["use_cache"] = True
config = OLMoConfig(**config_kwargs)

logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
config.save_pretrained(checkpoint_dir)


def main():
parser = argparse.ArgumentParser(
description="Adds a config.json to the checkpoint directory, making it easier to load weights as HF models."
)
parser.add_argument(
"--checkpoint-dir",
help="Location of OLMo checkpoint.",
)

args = parser.parse_args()
write_config(checkpoint_dir=args.checkpoint_dir)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions hf_olmo/configuration_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
OLMo configuration
"""

from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging

from olmo.config import ModelConfig

logger = logging.get_logger(__name__)


class OLMoConfig(PretrainedConfig):
model_type = "olmo"
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm

def __init__(self, use_cache: bool = False, **kwargs):
model_config = ModelConfig()
all_kwargs = model_config.asdict()
all_kwargs.update(kwargs)
all_kwargs.update({"use_cache": use_cache})
super().__init__(**all_kwargs)


# Register the config class so that it is available for transformer pipelines, auto-loading etc.
AutoConfig.register("olmo", OLMoConfig)
211 changes: 211 additions & 0 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import os

# import warnings
from typing import List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel

# from transformers.generation.utils import ( # BaseStreamer,
# GenerateOutput,
# LogitsProcessorList,
# StoppingCriteriaList,
# )
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

from olmo.config import ModelConfig
from olmo.model import Olmo

from .configuration_olmo import OLMoConfig

# from typing import Callable, Sequence


def create_model_config_from_pretrained_config(config: OLMoConfig):
"""
Utility function
"""
model_config = ModelConfig(
d_model=config.d_model,
n_heads=config.n_heads,
n_layers=config.n_layers,
mlp_ratio=config.mlp_ratio,
activation_type=config.activation_type,
block_type=config.block_type,
alibi=config.alibi,
alibi_bias_max=config.alibi_bias_max,
rope=config.rope,
flash_attention=config.flash_attention,
attention_dropout=config.attention_dropout,
attention_layer_norm=config.attention_layer_norm,
multi_query_attention=config.multi_query_attention,
residual_dropout=config.residual_dropout,
embedding_dropout=config.embedding_dropout,
layer_norm_type=config.layer_norm_type,
max_sequence_length=config.max_sequence_length,
include_bias=config.include_bias,
vocab_size=config.vocab_size,
embedding_size=config.embedding_size,
eos_token_id=config.eos_token_id,
pad_token_id=config.pad_token_id,
init_device=config.init_device,
init_std=config.init_std,
precision=config.precision,
)
return model_config


class OLMoForCausalLM(PreTrainedModel):
"""
Extremely barebones HF model wrapper.
"""

config_class = OLMoConfig

def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None):
super().__init__(config)

if not model:
model_config = create_model_config_from_pretrained_config(config)
self.model = Olmo(model_config, init_params=True)
else:
self.model = model

# def forward(self, *args, **kwargs):
# # use_cache = self.config.use_cache or kwargs.pop("use_cache", False)
# kwargs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
# return self.model.forward(*args, **kwargs)

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if use_cache is None:
use_cache = self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
)

logits = outputs.logits

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.embedding_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.attn_key_values,
)

def can_generate(self) -> bool:
return True

# Note (akshitab): This model does not use OLMo's generate() function as it does not support all the
# bells and whistles that HF's generation-compatible models do, such as `StoppingCriteria` or top-p sampling, etc.
# Instead, the model sets `can_generate` to True, and relies on HF's default `.generate()`, and implements
# supporting functions like `prepare_inputs_for_generation()`. This allows us to use HF's various generation
# options.

# def generate(
# self,
# input_ids: Optional[torch.Tensor] = None,
# max_length: int = 20,
# max_new_tokens: Optional[int] = None,
# logits_processor: Optional[LogitsProcessorList] = None,
# stopping_criteria: Optional[StoppingCriteriaList] = None,
# prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
# synced_gpus: Optional[bool] = None,
# assistant_model: Optional["PreTrainedModel"] = None,
# streamer: Optional["BaseStreamer"] = None,
# **kwargs,
# ) -> Union[GenerateOutput, torch.LongTensor]:
#
# assert input_ids is not None
#
# # TODO: use stopping_criteria, since it's being used by instruct-eval
# if stopping_criteria is not None:
# warnings.warn(
# "OLMo's generate() function does not currently support `stopping_criteria`. "
# "This will likely result in worse performance on tasks."
# )
#
# max_steps = max_new_tokens or max_length - input_ids.shape[1]
# result = self.model.generate(
# input_ids,
# max_steps=max_steps,
# beam_size=1,
# **kwargs,
# )
#
# return torch.cat((input_ids, result.token_ids[:, 0]), dim=1)

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs
):
assert pretrained_model_name_or_path is not None
if kwargs.get("device_map", "auto") == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
else:
device = "cpu"
model = Olmo.from_checkpoint(pretrained_model_name_or_path, device=device)
try:
config = OLMoConfig.from_pretrained(pretrained_model_name_or_path)
except FileNotFoundError:
config = OLMoConfig(use_cache=True, **model.config.asdict())
return cls(config, model)

def prepare_inputs_for_generation(
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
):
if past_key_values:
# This is because we want the model to only process the last generated token.
input_ids = input_ids[:, -1:]
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}

model_inputs.update(kwargs)
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
return model_inputs

# TODO: these are required to make the implementation complete.
# def resize_position_embeddings(self, new_num_position_embeddings: int):
# pass
#
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
# pass
#
# def _reorder_cache(self, past_key_values, beam_idx):
# pass


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
Loading