Skip to content

Commit

Permalink
GPT2 Generative model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604654297
  • Loading branch information
bdu91 authored and LIT team committed Feb 6, 2024
1 parent 2138bd9 commit 27e6901
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 0 deletions.
114 changes: 114 additions & 0 deletions lit_nlp/examples/models/pretrained_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,117 @@ def output_spec(self):
align_in="tokens", align_out="tokens")
spec[f"layer_{i:d}_avg_embedding"] = lit_types.Embeddings()
return spec


class GPT2GenerativeModel(lit_model.BatchedModel):
"""Wrapper for a Huggingface Transformers GPT-2 model.
This class loads a tokenizer and model using the Huggingface library and
provides the LIT-required functions to generate text responses given input
prompts.
Note that the default model generation config is used such that the response
is produced using multinomial sampling.
"""

@classmethod
def init_spec(cls) -> lit_model.Spec:
return {
"model_name_or_path": lit_types.String(default="gpt2"),
"max_new_tokens": lit_types.Integer(default=50, min_val=1, max_val=500),
"batch_size": lit_types.Integer(default=6, min_val=1, max_val=25),
}

def __init__(
self,
model=None,
tokenizer=None,
model_name_or_path="gpt2",
max_new_tokens=50,
batch_size=6,
):
"""Constructor for GPT2LanguageModel.
Note: args "model" and "tokenizer" take priority if both are specified.
Otherwise, "model_name_or_path" is used to initialize the model and
tokenizer.
Args:
model: an initialized GPT2 model compatible with Tensorflow.
tokenizer: an initialized GPT2 tokenizer.
model_name_or_path: gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2,
etc.
max_new_tokens: the maximum number of new tokens to generate.
batch_size: the number of items to process per `predict_minibatch` call.
"""
super().__init__()

if model is not None and tokenizer is not None:
self.model = model
self.tokenizer = tokenizer
else:
# Normally path is a directory; if it's an archive file, download and
# extract to the transformers cache.
if model_name_or_path.endswith(".tar.gz"):
model_name_or_path = file_cache.cached_path(
model_name_or_path, extract_compressed_file=True
)

self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path, use_fast=False
)
# Set this after init, as if pad_token= is passed to
# AutoTokenizer.from_pretrained() above it will create a new token with
# with id = max_vocab_length and cause out-of-bounds errors in
# the embedding lookup.
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = transformers.TFAutoModelForCausalLM.from_pretrained(
model_name_or_path
)

self.max_new_tokens = max_new_tokens
self.batch_size = batch_size

##
# LIT API implementations
def max_minibatch_size(self) -> int:
# The BatchedModel base class handles batching automatically in the
# implementation of predict(), and uses this value as the batch size.
return self.batch_size

def predict_minibatch(self, inputs):
prompts = [ex["prompt"] for ex in inputs]
encoded_inputs = self.tokenizer.batch_encode_plus(
prompts,
return_tensors="tf",
add_special_tokens=True,
padding="longest",
truncation="longest_first",
)
outputs = self.model.generate(
encoded_inputs["input_ids"],
max_new_tokens=self.max_new_tokens,
)
responses = self.tokenizer.batch_decode(
outputs[:, -self.max_new_tokens :], skip_special_tokens=True
)
embeddings = self.model.transformer.wte(outputs)
return [
{
"response": responses[i],
"prompt_embeddings": embeddings[i, : -self.max_new_tokens],
"response_embeddings": embeddings[i, -self.max_new_tokens :]
} for i in range(len(outputs))
]

def input_spec(self):
return {
"prompt": lit_types.TextSegment(),
}

def output_spec(self) -> lit_types.Spec:
return {
"response": lit_types.GeneratedTextCandidates(),
"prompt_embeddings": lit_types.Embeddings(required=False),
"response_embeddings": lit_types.Embeddings(required=False)
}
17 changes: 17 additions & 0 deletions lit_nlp/examples/models/pretrained_lms_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,22 @@ def test_gpt2(self):
for key in model.output_spec().keys():
self.assertIn(key, model_out[0].keys())

def test_gpt2_generation(self):
# Run prediction to ensure no failure.
model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz"
model = pretrained_lms.GPT2GenerativeModel(model_name_or_path=model_path)
model_in = [{"prompt": "Today is"}, {"prompt": "What is the color of"}]
model_out = list(model.predict(model_in))

# Sanity-check output vs output spec.
self.assertLen(model_out, 2)
for key in model.output_spec().keys():
self.assertIn(key, model_out[0].keys())

# Check that the embedding dimension is the same for prompt and response.
self.assertEqual(model_out[0]["prompt_embeddings"].shape[1],
model_out[0]["response_embeddings"].shape[1])


if __name__ == "__main__":
absltest.main()

0 comments on commit 27e6901

Please sign in to comment.