From 5255098b5c1c8804b726aab450c460f531564d7f Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Sat, 9 Dec 2023 23:15:36 -0800 Subject: [PATCH] fix issue of changing config --- hf_olmo/modeling_olmo.py | 33 ++++++--------------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index 33512c9b8..5abe50325 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -15,33 +15,12 @@ def create_model_config_from_pretrained_config(config: OLMoConfig): """ Utility function """ - model_config = ModelConfig( - d_model=config.d_model, - n_heads=config.n_heads, - n_layers=config.n_layers, - mlp_ratio=config.mlp_ratio, - activation_type=config.activation_type, - block_type=config.block_type, - alibi=config.alibi, - alibi_bias_max=config.alibi_bias_max, - rope=config.rope, - flash_attention=config.flash_attention, - attention_dropout=config.attention_dropout, - attention_layer_norm=config.attention_layer_norm, - multi_query_attention=config.multi_query_attention, - residual_dropout=config.residual_dropout, - embedding_dropout=config.embedding_dropout, - layer_norm_type=config.layer_norm_type, - max_sequence_length=config.max_sequence_length, - include_bias=config.include_bias, - vocab_size=config.vocab_size, - embedding_size=config.embedding_size, - eos_token_id=config.eos_token_id, - pad_token_id=config.pad_token_id, - init_device=config.init_device, - init_std=config.init_std, - precision=config.precision, - ) + + kwargs = {} + for key in ModelConfig.__match_args__: + kwargs[key] = getattr(config, key) + + model_config = ModelConfig(**kwargs) return model_config