Skip to content

Commit

Permalink
Merge pull request #391 from allenai/hf-olmo-new
Browse files Browse the repository at this point in the history
hf_olmo modeling class should be a true `PreTrainedModel`
  • Loading branch information
AkshitaB committed Dec 9, 2023
2 parents a120ab2 + 68ff059 commit e99dbe5
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from hf_olmo.configuration_olmo import OLMoConfig
from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast
from olmo import ModelConfig
from olmo import ModelConfig, Tokenizer

logger = logging.getLogger(__name__)

Expand All @@ -25,11 +25,34 @@ def write_config(checkpoint_dir: str):
logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
config.save_pretrained(checkpoint_dir)

tokenizer = OLMoTokenizerFast.from_pretrained(checkpoint_dir)

def write_model(checkpoint_dir: str, soft_link: bool = True):
if soft_link:
try:
os.symlink("model.pt", os.path.join(checkpoint_dir, "pytorch_model.bin"))
except FileExistsError:
pass
else:
if not os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")):
os.rename(os.path.join(checkpoint_dir, "model.pt"), os.path.join(checkpoint_dir, "pytorch_model.bin"))


def write_tokenizer(checkpoint_dir: str):
tokenizer_raw = Tokenizer.from_checkpoint(checkpoint_dir)
tokenizer = OLMoTokenizerFast(
tokenizer_object=tokenizer_raw.base_tokenizer,
truncation=tokenizer_raw.truncate_direction,
max_length=tokenizer_raw.truncate_to,
eos_token=tokenizer_raw.decode([tokenizer_raw.eos_token_id], skip_special_tokens=False),
)
tokenizer.model_input_names = ["input_ids", "attention_mask"]
tokenizer.pad_token_id = tokenizer_raw.pad_token_id
tokenizer.eos_token_id = tokenizer_raw.eos_token_id

tokenizer.save_pretrained(checkpoint_dir)


def download_remote_checkpoint_and_add_hf_config(checkpoint_dir: str, local_dir: str):
def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir: str):
from cached_path import cached_path

model_name = os.path.basename(checkpoint_dir)
Expand All @@ -49,20 +72,32 @@ def download_remote_checkpoint_and_add_hf_config(checkpoint_dir: str, local_dir:
logger.info(f"File already present at {final_location}")

write_config(local_model_path)
write_model(local_model_path, soft_link=False)
write_tokenizer(local_model_path)
return local_model_path


def main():
parser = argparse.ArgumentParser(
description="Adds a config.json to the checkpoint directory, making it easier to load weights as HF models."
description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, "
"making it easier to load weights as HF models."
)
parser.add_argument(
"--checkpoint-dir",
help="Location of OLMo checkpoint.",
)

parser.add_argument(
"--ignore-olmo-compatibility",
action="store_true",
help="Ignore compatibility with the olmo codebase. "
"This will rename model.pt --> pytorch_model.bin instead of creating a symlink.",
)

args = parser.parse_args()
write_config(checkpoint_dir=args.checkpoint_dir)
write_model(checkpoint_dir=args.checkpoint_dir, soft_link=not args.ignore_olmo_compatibility)
write_tokenizer(checkpoint_dir=args.checkpoint_dir)


if __name__ == "__main__":
Expand Down
94 changes: 23 additions & 71 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
import os

# import warnings
from typing import List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel

# from transformers.generation.utils import ( # BaseStreamer,
# GenerateOutput,
# LogitsProcessorList,
# StoppingCriteriaList,
# )
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

Expand All @@ -19,8 +10,6 @@

from .configuration_olmo import OLMoConfig

# from typing import Callable, Sequence


def create_model_config_from_pretrained_config(config: OLMoConfig):
"""
Expand Down Expand Up @@ -62,6 +51,7 @@ class OLMoForCausalLM(PreTrainedModel):
"""

config_class = OLMoConfig
base_model_prefix = "model"

def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None):
super().__init__(config)
Expand All @@ -72,11 +62,6 @@ def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None):
else:
self.model = model

# def forward(self, *args, **kwargs):
# # use_cache = self.config.use_cache or kwargs.pop("use_cache", False)
# kwargs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
# return self.model.forward(*args, **kwargs)

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -129,61 +114,6 @@ def forward(
def can_generate(self) -> bool:
return True

# Note (akshitab): This model does not use OLMo's generate() function as it does not support all the
# bells and whistles that HF's generation-compatible models do, such as `StoppingCriteria` or top-p sampling, etc.
# Instead, the model sets `can_generate` to True, and relies on HF's default `.generate()`, and implements
# supporting functions like `prepare_inputs_for_generation()`. This allows us to use HF's various generation
# options.

# def generate(
# self,
# input_ids: Optional[torch.Tensor] = None,
# max_length: int = 20,
# max_new_tokens: Optional[int] = None,
# logits_processor: Optional[LogitsProcessorList] = None,
# stopping_criteria: Optional[StoppingCriteriaList] = None,
# prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
# synced_gpus: Optional[bool] = None,
# assistant_model: Optional["PreTrainedModel"] = None,
# streamer: Optional["BaseStreamer"] = None,
# **kwargs,
# ) -> Union[GenerateOutput, torch.LongTensor]:
#
# assert input_ids is not None
#
# # TODO: use stopping_criteria, since it's being used by instruct-eval
# if stopping_criteria is not None:
# warnings.warn(
# "OLMo's generate() function does not currently support `stopping_criteria`. "
# "This will likely result in worse performance on tasks."
# )
#
# max_steps = max_new_tokens or max_length - input_ids.shape[1]
# result = self.model.generate(
# input_ids,
# max_steps=max_steps,
# beam_size=1,
# **kwargs,
# )
#
# return torch.cat((input_ids, result.token_ids[:, 0]), dim=1)

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs
):
assert pretrained_model_name_or_path is not None
if kwargs.get("device_map", "auto") == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
else:
device = "cpu"
model = Olmo.from_checkpoint(pretrained_model_name_or_path, device=device)
try:
config = OLMoConfig.from_pretrained(pretrained_model_name_or_path)
except FileNotFoundError:
config = OLMoConfig(use_cache=True, **model.config.asdict())
return cls(config, model)

def prepare_inputs_for_generation(
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
):
Expand All @@ -206,6 +136,28 @@ def prepare_inputs_for_generation(
# def _reorder_cache(self, past_key_values, beam_idx):
# pass

def get_input_embeddings(self) -> torch.nn.Module:
return self.model.transformer.wte

def set_input_embeddings(self, value: torch.nn.Module):
self.model.transformer.wte = value

def get_output_embeddings(self):
if self.config.weight_tying:
return self.model.transformer.wte
else:
return self.model.transformer.ff_out

def set_output_embeddings(self, value: torch.nn.Module):
if self.config.weight_tying:
self.model.transformer.wte = value
else:
self.model.transformer.ff_out = value

def tie_weights(self):
if self.config.weight_tying:
self.model.transformer.ff_out = self.model.transformer.wte


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
17 changes: 1 addition & 16 deletions hf_olmo/tokenization_olmo_fast.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
import os
from typing import Union

from transformers import AutoTokenizer, PreTrainedTokenizerFast

from hf_olmo.configuration_olmo import OLMoConfig
from olmo import Tokenizer


class OLMoTokenizerFast(PreTrainedTokenizerFast):
# Note: Olmo's tokenizer is already a wrapper around huggingface. This is potentially unnecessary.
pass

# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# # This is required to make the implementation complete.
# pass

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
tokenizer_raw = Tokenizer.from_checkpoint(pretrained_model_name_or_path)
tokenizer = cls(
tokenizer_object=tokenizer_raw.base_tokenizer,
truncation=tokenizer_raw.truncate_direction,
max_length=tokenizer_raw.truncate_to,
eos_token=tokenizer_raw.decode([tokenizer_raw.eos_token_id], skip_special_tokens=False),
)
tokenizer.model_input_names = ["input_ids", "attention_mask"]
return tokenizer


# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc.
AutoTokenizer.register(OLMoConfig, fast_tokenizer_class=OLMoTokenizerFast)
15 changes: 14 additions & 1 deletion test_fixtures/test-olmo-model/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,29 @@
"activation_type": "swiglu",
"alibi": false,
"alibi_bias_max": 8.0,
"architectures": [
"OlmoModelForCausalLM"
],
"attention_dropout": 0.1,
"attention_layer_norm": false,
"attention_layer_norm_with_affine": true,
"bias_for_layer_norm": null,
"block_group_size": 1,
"block_type": "sequential",
"d_model": 32,
"embedding_dropout": 0.1,
"embedding_size": 50304,
"eos_token_id": 50256,
"flash_attention": false,
"include_bias": true,
"init_cutoff_factor": null,
"init_device": null,
"init_fn": "normal",
"init_std": 0.02,
"layer_norm_type": "default",
"layer_norm_with_affine": true,
"max_sequence_length": 1024,
"mlp_hidden_size": null,
"mlp_ratio": 4,
"model_type": "olmo",
"multi_query_attention": false,
Expand All @@ -24,7 +34,10 @@
"precision": null,
"residual_dropout": 0.1,
"rope": false,
"rope_full_precision": true,
"scale_logits": false,
"transformers_version": "4.29.0",
"use_cache": true,
"vocab_size": 50257
"vocab_size": 50257,
"weight_tying": true
}
1 change: 1 addition & 0 deletions test_fixtures/test-olmo-model/pytorch_model.bin
3 changes: 2 additions & 1 deletion test_fixtures/test-olmo-model/special_tokens_map.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"eos_token": "<|endoftext|>"
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>"
}
10 changes: 8 additions & 2 deletions tests/hf_olmo/hf_olmo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
from olmo import BlockType, Tokenizer, TrainConfig
from olmo.data import DataCollator
from olmo.model import Olmo
from olmo.torch_util import seed_all


def test_auto_hf_classes(model_path: str):
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from hf_olmo import OLMoConfig, OLMoForCausalLM, OLMoTokenizerFast
from hf_olmo.add_hf_config_to_olmo_checkpoint import write_config
from hf_olmo.convert_olmo_to_hf import write_config, write_model, write_tokenizer

# model_path is an OLMo checkpoint.
# Creates HF-compatible config.json
write_config(model_path)
write_tokenizer(model_path)
write_model(model_path)

config = AutoConfig.from_pretrained(model_path)
assert isinstance(config, OLMoConfig)
Expand Down Expand Up @@ -184,10 +187,13 @@ def test_forward(

use_amp = dtype in {torch.float16, torch.bfloat16}

seed_all(1234)
model = Olmo(train_config.model).eval()

hf_config = OLMoConfig(**model.config.asdict())
hf_model = OLMoForCausalLM(hf_config, model=model)

seed_all(1234)
hf_model = OLMoForCausalLM(hf_config).eval()

input1 = tokenizer.encode("My name is OLMo!")
input2 = tokenizer.encode("I'm a delightful large open language model :)")
Expand Down
25 changes: 24 additions & 1 deletion tests/hf_olmo/modeling_olmo_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import tempfile

import torch

from olmo.model import Olmo
Expand All @@ -6,7 +8,7 @@
def test_olmo_model(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer

from hf_olmo import OLMoForCausalLM # noqa: F401
from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast # noqa: F401

model = Olmo.from_checkpoint(model_path)
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
Expand All @@ -19,3 +21,24 @@ def test_olmo_model(model_path: str):
hf_output = hf_model(input_tensor)

torch.testing.assert_allclose(output.logits, hf_output.logits)


def test_save_pretrained(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer

from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast # noqa: F401

tokenizer = AutoTokenizer.from_pretrained(model_path)
input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(input).unsqueeze(0)

hf_model = AutoModelForCausalLM.from_pretrained(model_path)
hf_output = hf_model(input_tensor)

with tempfile.TemporaryDirectory() as tmp_dir:
hf_model.save_pretrained(tmp_dir)

saved_hf_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
saved_hf_output = saved_hf_model(input_tensor)

torch.testing.assert_allclose(saved_hf_output.logits, hf_output.logits)

0 comments on commit e99dbe5

Please sign in to comment.