Skip to content

Commit

Permalink
Update the class names and comments on GPT2 to HF/transformer texts i…
Browse files Browse the repository at this point in the history
…n preparation for generalization to more models.

PiperOrigin-RevId: 615558614
  • Loading branch information
bdu91 authored and LIT team committed Mar 13, 2024
1 parent f0d07c5 commit 45887d3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
6 changes: 3 additions & 3 deletions lit_nlp/examples/lm_salience_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,14 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
model_name, path = model_string.split(":", 1)
logging.info("Loading model '%s' from '%s'", model_name, path)
if model_name.startswith("gpt2") or model_name in ["distilgpt2"]:
models[model_name] = pretrained_lms.GPT2GenerativeModel(path)
models[model_name] = pretrained_lms.HFGenerativeModel(path)
# Salience wrapper, using same underlying Keras models so as not to
# load the weights twice.
models[f"_{model_name}_salience"] = (
pretrained_lms.GPT2SalienceModel.from_loaded(models[model_name])
pretrained_lms.HFSalienceModel.from_loaded(models[model_name])
)
models[f"_{model_name}_tokenizer"] = (
pretrained_lms.GPT2TokenizerModel.from_loaded(models[model_name])
pretrained_lms.HFTokenizerModel.from_loaded(models[model_name])
)
elif model_name.startswith("gemma"):
path = file_cache.cached_path(
Expand Down
44 changes: 24 additions & 20 deletions lit_nlp/examples/models/pretrained_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def output_spec(self):
}


# TODO(lit-dev): merge with below, inherit from GPT2BaseModel.
# TODO(lit-dev): merge with below, inherit from HFBaseModel.
class GPT2LanguageModel(lit_model.BatchedModel):
"""Wrapper for a Huggingface Transformers GPT-2 model.
Expand Down Expand Up @@ -329,8 +329,8 @@ def output_spec(self):
return spec


class GPT2BaseModel(lit_model.BatchedModel):
"""Base class for GPT2 model wrappers."""
class HFBaseModel(lit_model.BatchedModel):
"""Base class for HF generative, salience, tokenizer model wrappers."""

@property
def num_layers(self):
Expand All @@ -350,17 +350,20 @@ def __init__(
model=None,
tokenizer=None,
):
"""Constructor for GPT2 model wrappers.
"""Constructor for HF base model wrappers.
Note: args "model" and "tokenizer" take priority if both are specified.
Otherwise, "model_name_or_path" is used to initialize the model and
tokenizer.
This class supports common HF transformer models such as GPT2, Llama,
Mistral, etc.
Args:
model_name_or_path: gpt2, gpt2-medium, gpt2-large, distilgpt2, etc.
batch_size: the number of items to process per `predict_minibatch` call.
model: an initialized transformers.TFGPT2LMHeadModel.
tokenizer: an initialized GPT2 tokenizer.
model: an initialized transformer model.
tokenizer: an initialized tokenizer.
"""
super().__init__()

Expand All @@ -377,7 +380,7 @@ def __init__(

# Note: we need to left-pad for generation to work properly.
# Other modes such as scoring and salience should handle this as well;
# see example in GPT2SalienceModel._postprocess().
# see example in HFSalienceModel._postprocess().
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path,
use_fast=False,
Expand All @@ -387,7 +390,7 @@ def __init__(
# AutoTokenizer.from_pretrained() above it will create a new token with
# with id = max_vocab_length and cause out-of-bounds errors in
# the embedding lookup.
self.model = transformers.TFGPT2LMHeadModel.from_pretrained(
self.model = transformers.TFAutoModelForCausalLM.from_pretrained(
model_name_or_path, output_hidden_states=True, output_attentions=False
)

Expand All @@ -399,11 +402,12 @@ def pad_left(self):
return self.tokenizer.padding_side == "left"

@classmethod
def from_loaded(cls, existing: "GPT2BaseModel", *args, **kw):
"""Share weights and underlying Keras model with another instance."""
def from_loaded(cls, existing: "HFBaseModel", *args, **kw):
"""Share weights and underlying HF model with another instance."""
return cls(model=existing.model, tokenizer=existing.tokenizer, *args, **kw)

def clean_bpe_token(self, tok):
# For GPT2 tokenizer.
tok = tok.replace("Ċ", "\n") # newlines
tok = tok.replace("Ġ", "▁") # start of word -> magic underscore
return tok
Expand All @@ -424,8 +428,8 @@ def input_spec(self):
}


class GPT2GenerativeModel(GPT2BaseModel):
"""Wrapper for a Huggingface Transformers GPT-2 model.
class HFGenerativeModel(HFBaseModel):
"""Wrapper for a HF Transformer model that generates texts.
This class loads a tokenizer and model using the Huggingface library and
provides the LIT-required functions to generate text responses given input
Expand All @@ -442,12 +446,12 @@ def init_spec(cls) -> lit_model.Spec:
}

def __init__(self, *args, max_new_tokens=50, **kw):
"""Constructor for GPT2LanguageModel.
"""Constructor for HFGenerativeModel.
Args:
*args: as to GPT2BaseModel.__init__
*args: as to HFBaseModel.__init__
max_new_tokens: the maximum number of new tokens to generate.
**kw: as to GPT2BaseModel.__init__
**kw: as to HFBaseModel.__init__
"""
super().__init__(*args, **kw)
self.max_new_tokens = max_new_tokens
Expand Down Expand Up @@ -513,8 +517,8 @@ def output_spec(self) -> lit_types.Spec:
}


class GPT2SalienceModel(GPT2BaseModel):
"""Wrapper for GPT-2 input (token) salience."""
class HFSalienceModel(HFBaseModel):
"""Wrapper for a HF Transformer model that computes input (token) salience."""

def _pred(self, encoded_inputs, target_masks):
"""Predicts one batch of tokenized text.
Expand Down Expand Up @@ -601,7 +605,7 @@ def _pred(self, encoded_inputs, target_masks):

def _postprocess(self, preds):
"""Post-process single-example preds. Operates on numpy arrays."""
# Be sure to cast to bool, otherwise this will select intger positions 0, 1
# Be sure to cast to bool, otherwise this will select integer positions 0, 1
# rather than acting as a boolean mask.
mask = preds.pop("attention_mask").astype(bool)
ids = preds.pop("input_ids")[mask]
Expand Down Expand Up @@ -649,10 +653,10 @@ def output_spec(self) -> lit_types.Spec:
}


class GPT2TokenizerModel(GPT2BaseModel):
class HFTokenizerModel(HFBaseModel):
"""Wrapper to run only the tokenizer.
Should exactly match tokens from GPT2SalienceModel.
Should exactly match tokens from HFSalienceModel.
"""

def _postprocess(self, preds):
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/examples/models/pretrained_lms_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_gpt2(self):
def test_gpt2_generation(self):
# Run prediction to ensure no failure.
model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz"
model = pretrained_lms.GPT2GenerativeModel(model_name_or_path=model_path)
model = pretrained_lms.HFGenerativeModel(model_name_or_path=model_path)
model_in = [{"prompt": "Today is"}, {"prompt": "What is the color of"}]
model_out = list(model.predict(model_in))

Expand Down

0 comments on commit 45887d3

Please sign in to comment.