Skip to content

Commit

Permalink
Merge pull request #220 from allenai/hf-integration
Browse files Browse the repository at this point in the history
Lightweight HF integration
  • Loading branch information
AkshitaB committed Jun 26, 2023
2 parents 43c29d9 + 1a4ba1b commit acf372e
Show file tree
Hide file tree
Showing 19 changed files with 835 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .github/actions/setup-venv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ runs:
id: virtualenv-cache
with:
path: .venv
key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('requirements.txt', 'dev-requirements.txt') }}
key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('requirements.txt', 'dev-requirements.txt', 'hf_olmo/requirements.txt') }}

- if: steps.virtualenv-cache.outputs.cache-hit != 'true'
shell: bash
Expand All @@ -40,13 +40,15 @@ runs:
. .venv/bin/activate
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .[dev]
pip install -e hf_olmo
- if: steps.virtualenv-cache.outputs.cache-hit == 'true'
shell: bash
run: |
# Set up virtual environment from cache hit.
. .venv/bin/activate
pip install --no-deps -e .[dev]
pip install --no-deps -e hf_olmo
- shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ jobs:
value: ":16:8"
- name: TOKENIZERS_PARALLELISM
value: "false"
command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/"]
command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/", "-k", "not hf_olmo"]
result:
path: /unused
token: ${{ env.BEAKER_TOKEN }}
Expand Down
5 changes: 5 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,8 @@ def lorem_ipsum() -> str:
@pytest.fixture(scope="module")
def lorem_ipsum_docs() -> List[str]:
return [text.replace("\n", " ").strip() for text in (LOREM_IPSUM_1, LOREM_IPSUM_2)]


@pytest.fixture(scope="function")
def model_path() -> str:
return "test_fixtures/test-olmo-model"
3 changes: 3 additions & 0 deletions hf_olmo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .configuration_olmo import OLMoConfig
from .modeling_olmo import OLMoForCausalLM
from .tokenization_olmo_fast import OLMoTokenizerFast
38 changes: 38 additions & 0 deletions hf_olmo/add_hf_config_to_olmo_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse
import logging
import os

from hf_olmo.configuration_olmo import OLMoConfig
from olmo import Olmo

logger = logging.getLogger(__name__)


def write_config(checkpoint_dir: str):
# save config as HF config
logger.info(f"Loading checkpoint from {checkpoint_dir}")
model = Olmo.from_checkpoint(checkpoint_dir)

config_kwargs = model.config.asdict()
config_kwargs["use_cache"] = True
config = OLMoConfig(**config_kwargs)

logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
config.save_pretrained(checkpoint_dir)


def main():
parser = argparse.ArgumentParser(
description="Adds a config.json to the checkpoint directory, making it easier to load weights as HF models."
)
parser.add_argument(
"--checkpoint-dir",
help="Location of OLMo checkpoint.",
)

args = parser.parse_args()
write_config(checkpoint_dir=args.checkpoint_dir)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions hf_olmo/configuration_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
OLMo configuration
"""

from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging

from olmo.config import ModelConfig

logger = logging.get_logger(__name__)


class OLMoConfig(PretrainedConfig):
model_type = "olmo"
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm

def __init__(self, use_cache: bool = False, **kwargs):
model_config = ModelConfig()
all_kwargs = model_config.asdict()
all_kwargs.update(kwargs)
all_kwargs.update({"use_cache": use_cache})
super().__init__(**all_kwargs)


# Register the config class so that it is available for transformer pipelines, auto-loading etc.
AutoConfig.register("olmo", OLMoConfig)
211 changes: 211 additions & 0 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import os

# import warnings
from typing import List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel

# from transformers.generation.utils import ( # BaseStreamer,
# GenerateOutput,
# LogitsProcessorList,
# StoppingCriteriaList,
# )
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

from olmo.config import ModelConfig
from olmo.model import Olmo

from .configuration_olmo import OLMoConfig

# from typing import Callable, Sequence


def create_model_config_from_pretrained_config(config: OLMoConfig):
"""
Utility function
"""
model_config = ModelConfig(
d_model=config.d_model,
n_heads=config.n_heads,
n_layers=config.n_layers,
mlp_ratio=config.mlp_ratio,
activation_type=config.activation_type,
block_type=config.block_type,
alibi=config.alibi,
alibi_bias_max=config.alibi_bias_max,
rope=config.rope,
flash_attention=config.flash_attention,
attention_dropout=config.attention_dropout,
attention_layer_norm=config.attention_layer_norm,
multi_query_attention=config.multi_query_attention,
residual_dropout=config.residual_dropout,
embedding_dropout=config.embedding_dropout,
layer_norm_type=config.layer_norm_type,
max_sequence_length=config.max_sequence_length,
include_bias=config.include_bias,
vocab_size=config.vocab_size,
embedding_size=config.embedding_size,
eos_token_id=config.eos_token_id,
pad_token_id=config.pad_token_id,
init_device=config.init_device,
init_std=config.init_std,
precision=config.precision,
)
return model_config


class OLMoForCausalLM(PreTrainedModel):
"""
Extremely barebones HF model wrapper.
"""

config_class = OLMoConfig

def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None):
super().__init__(config)

if not model:
model_config = create_model_config_from_pretrained_config(config)
self.model = Olmo(model_config, init_params=True)
else:
self.model = model

# def forward(self, *args, **kwargs):
# # use_cache = self.config.use_cache or kwargs.pop("use_cache", False)
# kwargs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
# return self.model.forward(*args, **kwargs)

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if use_cache is None:
use_cache = self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
)

logits = outputs.logits

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.embedding_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.attn_key_values,
)

def can_generate(self) -> bool:
return True

# Note (akshitab): This model does not use OLMo's generate() function as it does not support all the
# bells and whistles that HF's generation-compatible models do, such as `StoppingCriteria` or top-p sampling, etc.
# Instead, the model sets `can_generate` to True, and relies on HF's default `.generate()`, and implements
# supporting functions like `prepare_inputs_for_generation()`. This allows us to use HF's various generation
# options.

# def generate(
# self,
# input_ids: Optional[torch.Tensor] = None,
# max_length: int = 20,
# max_new_tokens: Optional[int] = None,
# logits_processor: Optional[LogitsProcessorList] = None,
# stopping_criteria: Optional[StoppingCriteriaList] = None,
# prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
# synced_gpus: Optional[bool] = None,
# assistant_model: Optional["PreTrainedModel"] = None,
# streamer: Optional["BaseStreamer"] = None,
# **kwargs,
# ) -> Union[GenerateOutput, torch.LongTensor]:
#
# assert input_ids is not None
#
# # TODO: use stopping_criteria, since it's being used by instruct-eval
# if stopping_criteria is not None:
# warnings.warn(
# "OLMo's generate() function does not currently support `stopping_criteria`. "
# "This will likely result in worse performance on tasks."
# )
#
# max_steps = max_new_tokens or max_length - input_ids.shape[1]
# result = self.model.generate(
# input_ids,
# max_steps=max_steps,
# beam_size=1,
# **kwargs,
# )
#
# return torch.cat((input_ids, result.token_ids[:, 0]), dim=1)

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs
):
assert pretrained_model_name_or_path is not None
if kwargs.get("device_map", "auto") == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
else:
device = "cpu"
model = Olmo.from_checkpoint(pretrained_model_name_or_path, device=device)
try:
config = OLMoConfig.from_pretrained(pretrained_model_name_or_path)
except FileNotFoundError:
config = OLMoConfig(use_cache=True, **model.config.asdict())
return cls(config, model)

def prepare_inputs_for_generation(
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
):
if past_key_values:
# This is because we want the model to only process the last generated token.
input_ids = input_ids[:, -1:]
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}

model_inputs.update(kwargs)
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
return model_inputs

# TODO: these are required to make the implementation complete.
# def resize_position_embeddings(self, new_num_position_embeddings: int):
# pass
#
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
# pass
#
# def _reorder_cache(self, past_key_values, beam_idx):
# pass


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
Loading

0 comments on commit acf372e

Please sign in to comment.