diff --git a/hf_olmo/add_hf_config_to_olmo_checkpoint.py b/hf_olmo/convert_olmo_to_hf.py similarity index 54% rename from hf_olmo/add_hf_config_to_olmo_checkpoint.py rename to hf_olmo/convert_olmo_to_hf.py index a7663058a..b4600794d 100644 --- a/hf_olmo/add_hf_config_to_olmo_checkpoint.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -5,7 +5,7 @@ from hf_olmo.configuration_olmo import OLMoConfig from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast -from olmo import ModelConfig +from olmo import ModelConfig, Tokenizer logger = logging.getLogger(__name__) @@ -25,11 +25,34 @@ def write_config(checkpoint_dir: str): logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}") config.save_pretrained(checkpoint_dir) - tokenizer = OLMoTokenizerFast.from_pretrained(checkpoint_dir) + +def write_model(checkpoint_dir: str, soft_link: bool = True): + if soft_link: + try: + os.symlink("model.pt", os.path.join(checkpoint_dir, "pytorch_model.bin")) + except FileExistsError: + pass + else: + if not os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")): + os.rename(os.path.join(checkpoint_dir, "model.pt"), os.path.join(checkpoint_dir, "pytorch_model.bin")) + + +def write_tokenizer(checkpoint_dir: str): + tokenizer_raw = Tokenizer.from_checkpoint(checkpoint_dir) + tokenizer = OLMoTokenizerFast( + tokenizer_object=tokenizer_raw.base_tokenizer, + truncation=tokenizer_raw.truncate_direction, + max_length=tokenizer_raw.truncate_to, + eos_token=tokenizer_raw.decode([tokenizer_raw.eos_token_id], skip_special_tokens=False), + ) + tokenizer.model_input_names = ["input_ids", "attention_mask"] + tokenizer.pad_token_id = tokenizer_raw.pad_token_id + tokenizer.eos_token_id = tokenizer_raw.eos_token_id + tokenizer.save_pretrained(checkpoint_dir) -def download_remote_checkpoint_and_add_hf_config(checkpoint_dir: str, local_dir: str): +def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir: str): from cached_path import cached_path model_name = os.path.basename(checkpoint_dir) @@ -49,20 +72,32 @@ def download_remote_checkpoint_and_add_hf_config(checkpoint_dir: str, local_dir: logger.info(f"File already present at {final_location}") write_config(local_model_path) + write_model(local_model_path, soft_link=False) + write_tokenizer(local_model_path) return local_model_path def main(): parser = argparse.ArgumentParser( - description="Adds a config.json to the checkpoint directory, making it easier to load weights as HF models." + description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, " + "making it easier to load weights as HF models." ) parser.add_argument( "--checkpoint-dir", help="Location of OLMo checkpoint.", ) + parser.add_argument( + "--ignore-olmo-compatibility", + action="store_true", + help="Ignore compatibility with the olmo codebase. " + "This will rename model.pt --> pytorch_model.bin instead of creating a symlink.", + ) + args = parser.parse_args() write_config(checkpoint_dir=args.checkpoint_dir) + write_model(checkpoint_dir=args.checkpoint_dir, soft_link=not args.ignore_olmo_compatibility) + write_tokenizer(checkpoint_dir=args.checkpoint_dir) if __name__ == "__main__": diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index da1ea9f44..33512c9b8 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -1,16 +1,7 @@ -import os - -# import warnings from typing import List, Optional, Tuple, Union import torch from transformers import PreTrainedModel - -# from transformers.generation.utils import ( # BaseStreamer, -# GenerateOutput, -# LogitsProcessorList, -# StoppingCriteriaList, -# ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.auto import AutoModelForCausalLM @@ -19,8 +10,6 @@ from .configuration_olmo import OLMoConfig -# from typing import Callable, Sequence - def create_model_config_from_pretrained_config(config: OLMoConfig): """ @@ -62,6 +51,7 @@ class OLMoForCausalLM(PreTrainedModel): """ config_class = OLMoConfig + base_model_prefix = "model" def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None): super().__init__(config) @@ -72,11 +62,6 @@ def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None): else: self.model = model - # def forward(self, *args, **kwargs): - # # use_cache = self.config.use_cache or kwargs.pop("use_cache", False) - # kwargs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache) - # return self.model.forward(*args, **kwargs) - def forward( self, input_ids: torch.LongTensor = None, @@ -129,61 +114,6 @@ def forward( def can_generate(self) -> bool: return True - # Note (akshitab): This model does not use OLMo's generate() function as it does not support all the - # bells and whistles that HF's generation-compatible models do, such as `StoppingCriteria` or top-p sampling, etc. - # Instead, the model sets `can_generate` to True, and relies on HF's default `.generate()`, and implements - # supporting functions like `prepare_inputs_for_generation()`. This allows us to use HF's various generation - # options. - - # def generate( - # self, - # input_ids: Optional[torch.Tensor] = None, - # max_length: int = 20, - # max_new_tokens: Optional[int] = None, - # logits_processor: Optional[LogitsProcessorList] = None, - # stopping_criteria: Optional[StoppingCriteriaList] = None, - # prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - # synced_gpus: Optional[bool] = None, - # assistant_model: Optional["PreTrainedModel"] = None, - # streamer: Optional["BaseStreamer"] = None, - # **kwargs, - # ) -> Union[GenerateOutput, torch.LongTensor]: - # - # assert input_ids is not None - # - # # TODO: use stopping_criteria, since it's being used by instruct-eval - # if stopping_criteria is not None: - # warnings.warn( - # "OLMo's generate() function does not currently support `stopping_criteria`. " - # "This will likely result in worse performance on tasks." - # ) - # - # max_steps = max_new_tokens or max_length - input_ids.shape[1] - # result = self.model.generate( - # input_ids, - # max_steps=max_steps, - # beam_size=1, - # **kwargs, - # ) - # - # return torch.cat((input_ids, result.token_ids[:, 0]), dim=1) - - @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs - ): - assert pretrained_model_name_or_path is not None - if kwargs.get("device_map", "auto") == "auto": - device = "cuda:0" if torch.cuda.is_available() else "cpu" - else: - device = "cpu" - model = Olmo.from_checkpoint(pretrained_model_name_or_path, device=device) - try: - config = OLMoConfig.from_pretrained(pretrained_model_name_or_path) - except FileNotFoundError: - config = OLMoConfig(use_cache=True, **model.config.asdict()) - return cls(config, model) - def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs ): @@ -206,6 +136,28 @@ def prepare_inputs_for_generation( # def _reorder_cache(self, past_key_values, beam_idx): # pass + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.transformer.wte + + def set_input_embeddings(self, value: torch.nn.Module): + self.model.transformer.wte = value + + def get_output_embeddings(self): + if self.config.weight_tying: + return self.model.transformer.wte + else: + return self.model.transformer.ff_out + + def set_output_embeddings(self, value: torch.nn.Module): + if self.config.weight_tying: + self.model.transformer.wte = value + else: + self.model.transformer.ff_out = value + + def tie_weights(self): + if self.config.weight_tying: + self.model.transformer.ff_out = self.model.transformer.wte + # Register the model so that it is available for transformer pipelines, auto-loading, etc. AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM) diff --git a/hf_olmo/tokenization_olmo_fast.py b/hf_olmo/tokenization_olmo_fast.py index 948192ab8..e2bd665d1 100644 --- a/hf_olmo/tokenization_olmo_fast.py +++ b/hf_olmo/tokenization_olmo_fast.py @@ -1,31 +1,16 @@ -import os -from typing import Union - from transformers import AutoTokenizer, PreTrainedTokenizerFast from hf_olmo.configuration_olmo import OLMoConfig -from olmo import Tokenizer class OLMoTokenizerFast(PreTrainedTokenizerFast): # Note: Olmo's tokenizer is already a wrapper around huggingface. This is potentially unnecessary. + pass # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: # # This is required to make the implementation complete. # pass - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): - tokenizer_raw = Tokenizer.from_checkpoint(pretrained_model_name_or_path) - tokenizer = cls( - tokenizer_object=tokenizer_raw.base_tokenizer, - truncation=tokenizer_raw.truncate_direction, - max_length=tokenizer_raw.truncate_to, - eos_token=tokenizer_raw.decode([tokenizer_raw.eos_token_id], skip_special_tokens=False), - ) - tokenizer.model_input_names = ["input_ids", "attention_mask"] - return tokenizer - # Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc. AutoTokenizer.register(OLMoConfig, fast_tokenizer_class=OLMoTokenizerFast) diff --git a/test_fixtures/test-olmo-model/config.json b/test_fixtures/test-olmo-model/config.json index 10aca15c7..71e7b981e 100644 --- a/test_fixtures/test-olmo-model/config.json +++ b/test_fixtures/test-olmo-model/config.json @@ -2,8 +2,14 @@ "activation_type": "swiglu", "alibi": false, "alibi_bias_max": 8.0, + "architectures": [ + "OlmoModelForCausalLM" + ], "attention_dropout": 0.1, "attention_layer_norm": false, + "attention_layer_norm_with_affine": true, + "bias_for_layer_norm": null, + "block_group_size": 1, "block_type": "sequential", "d_model": 32, "embedding_dropout": 0.1, @@ -11,10 +17,14 @@ "eos_token_id": 50256, "flash_attention": false, "include_bias": true, + "init_cutoff_factor": null, "init_device": null, + "init_fn": "normal", "init_std": 0.02, "layer_norm_type": "default", + "layer_norm_with_affine": true, "max_sequence_length": 1024, + "mlp_hidden_size": null, "mlp_ratio": 4, "model_type": "olmo", "multi_query_attention": false, @@ -24,7 +34,10 @@ "precision": null, "residual_dropout": 0.1, "rope": false, + "rope_full_precision": true, + "scale_logits": false, "transformers_version": "4.29.0", "use_cache": true, - "vocab_size": 50257 + "vocab_size": 50257, + "weight_tying": true } diff --git a/test_fixtures/test-olmo-model/pytorch_model.bin b/test_fixtures/test-olmo-model/pytorch_model.bin new file mode 120000 index 000000000..6702297cd --- /dev/null +++ b/test_fixtures/test-olmo-model/pytorch_model.bin @@ -0,0 +1 @@ +model.pt \ No newline at end of file diff --git a/test_fixtures/test-olmo-model/special_tokens_map.json b/test_fixtures/test-olmo-model/special_tokens_map.json index a84b18f72..2a4ce0a7e 100644 --- a/test_fixtures/test-olmo-model/special_tokens_map.json +++ b/test_fixtures/test-olmo-model/special_tokens_map.json @@ -1,3 +1,4 @@ { - "eos_token": "<|endoftext|>" + "eos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>" } diff --git a/tests/hf_olmo/hf_olmo_test.py b/tests/hf_olmo/hf_olmo_test.py index 947a63b36..8ec9499a6 100644 --- a/tests/hf_olmo/hf_olmo_test.py +++ b/tests/hf_olmo/hf_olmo_test.py @@ -4,17 +4,20 @@ from olmo import BlockType, Tokenizer, TrainConfig from olmo.data import DataCollator from olmo.model import Olmo +from olmo.torch_util import seed_all def test_auto_hf_classes(model_path: str): from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from hf_olmo import OLMoConfig, OLMoForCausalLM, OLMoTokenizerFast - from hf_olmo.add_hf_config_to_olmo_checkpoint import write_config + from hf_olmo.convert_olmo_to_hf import write_config, write_model, write_tokenizer # model_path is an OLMo checkpoint. # Creates HF-compatible config.json write_config(model_path) + write_tokenizer(model_path) + write_model(model_path) config = AutoConfig.from_pretrained(model_path) assert isinstance(config, OLMoConfig) @@ -184,10 +187,13 @@ def test_forward( use_amp = dtype in {torch.float16, torch.bfloat16} + seed_all(1234) model = Olmo(train_config.model).eval() hf_config = OLMoConfig(**model.config.asdict()) - hf_model = OLMoForCausalLM(hf_config, model=model) + + seed_all(1234) + hf_model = OLMoForCausalLM(hf_config).eval() input1 = tokenizer.encode("My name is OLMo!") input2 = tokenizer.encode("I'm a delightful large open language model :)") diff --git a/tests/hf_olmo/modeling_olmo_test.py b/tests/hf_olmo/modeling_olmo_test.py index e3bdf0b93..5ae19f6a7 100644 --- a/tests/hf_olmo/modeling_olmo_test.py +++ b/tests/hf_olmo/modeling_olmo_test.py @@ -1,3 +1,5 @@ +import tempfile + import torch from olmo.model import Olmo @@ -6,7 +8,7 @@ def test_olmo_model(model_path: str): from transformers import AutoModelForCausalLM, AutoTokenizer - from hf_olmo import OLMoForCausalLM # noqa: F401 + from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast # noqa: F401 model = Olmo.from_checkpoint(model_path) hf_model = AutoModelForCausalLM.from_pretrained(model_path) @@ -19,3 +21,24 @@ def test_olmo_model(model_path: str): hf_output = hf_model(input_tensor) torch.testing.assert_allclose(output.logits, hf_output.logits) + + +def test_save_pretrained(model_path: str): + from transformers import AutoModelForCausalLM, AutoTokenizer + + from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast # noqa: F401 + + tokenizer = AutoTokenizer.from_pretrained(model_path) + input = tokenizer.encode("My name is OLMo!") + input_tensor = torch.tensor(input).unsqueeze(0) + + hf_model = AutoModelForCausalLM.from_pretrained(model_path) + hf_output = hf_model(input_tensor) + + with tempfile.TemporaryDirectory() as tmp_dir: + hf_model.save_pretrained(tmp_dir) + + saved_hf_model = AutoModelForCausalLM.from_pretrained(tmp_dir) + saved_hf_output = saved_hf_model(input_tensor) + + torch.testing.assert_allclose(saved_hf_output.logits, hf_output.logits) diff --git a/tests/hf_olmo/tokenization_olmo_fast_test.py b/tests/hf_olmo/tokenization_olmo_fast_test.py index 0c6c8c78a..10bb4f7dd 100644 --- a/tests/hf_olmo/tokenization_olmo_fast_test.py +++ b/tests/hf_olmo/tokenization_olmo_fast_test.py @@ -1,3 +1,5 @@ +import tempfile + from olmo.tokenizer import Tokenizer @@ -16,3 +18,29 @@ def test_olmo_tokenizer(model_path: str): hf_tokenized = hf_tok.encode(input_str) assert tokenized == hf_tokenized + + # tokenized = tok([input_str], return_tensors="pt", max_length=5, truncation=True) + hf_tokenized = hf_tok([input_str], return_tensors="pt", max_length=5, truncation=True) + + print(hf_tokenized) + + +def test_save_pretrained(model_path: str): + from transformers import AutoTokenizer + + from hf_olmo import OLMoTokenizerFast # noqa: F401 + + hf_tok = AutoTokenizer.from_pretrained(model_path) + + input_str = "Hello, this is a test!" + + # Note: our tokenizer adds eos token by default, HF doesn't. + hf_tokenized = hf_tok.encode(input_str) + + with tempfile.TemporaryDirectory() as tmp_dir: + hf_tok.save_pretrained(tmp_dir) + + saved_hf_tok = AutoTokenizer.from_pretrained(tmp_dir) + saved_hf_tokenized = saved_hf_tok.encode(input_str) + + assert hf_tokenized == saved_hf_tokenized