Skip to content

Commit

Permalink
Dynamic sequence length for Keras LM wrappers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607214308
  • Loading branch information
iftenney authored and LIT team committed Feb 15, 2024
1 parent b6ab352 commit c97a710
Showing 1 changed file with 49 additions and 11 deletions.
60 changes: 49 additions & 11 deletions lit_nlp/examples/models/instrumented_keras_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import types
from typing import Sequence

from absl import logging
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.lib import utils as lit_utils
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
self,
model,
max_length: int = _DEFAULT_MAX_LENGTH,
dynamic_sequence_length: bool = True,
batch_size: int = 16,
):
"""Base wrapper for a Keras/TF2 LM supporting the layer_intercept_fn API.
Expand All @@ -54,15 +56,17 @@ def __init__(
Args:
model: pre-loaded Keras LM using the TF backend
max_length: max sequence length
dynamic_sequence_length: if true, will trim padding to the length of the
longest sequence in a batch. Recommended for CPU and GPU usage, but may
be disabled for compilation where a fixed shape is required.
batch_size: batch size
"""
super().__init__()

self.model = model
self.batch_size = batch_size
self.max_length = max_length

self.encode_inputs = self.model.preprocessor.generate_preprocess
self.dynamic_sequence_length = dynamic_sequence_length

self.ids_to_tokens = np.vectorize(
self.model.preprocessor.tokenizer.id_to_token
Expand All @@ -72,6 +76,46 @@ def __init__(
# to embs: <tf.float>[batch_size, num_tokens, emb_dim]
self.embedder = self.model.backbone.token_embedding

def encode_inputs(self, texts: Sequence[str]):
"""Encode inputs, with optional dynamic trimming.
By default, the model's generate_preprocess() pads to a fixed sequence
length, either specified as sequence_length= or using an internal default.
Here, we optionally trim this to remove extraneous padding positions based
on the actual contents of the minibatch. This can greatly speed up
performance when running on CPU or GPU.
Args:
texts: list of input strings
Returns:
encoded_inputs compatible with model.score() or other functions
"""
# First: pack to max_length
encoded_inputs = self.model.preprocessor.generate_preprocess(
texts, sequence_length=self.max_length
)
if not self.dynamic_sequence_length:
return encoded_inputs

# Trim to the maximum length needed to contain any non-padding tokens.
mask = encoded_inputs["padding_mask"]
# Find position of last 'True' in each row.
seq_ends: Sequence[int] = [
1 + tf.reduce_max(tf.where(mask[i])).numpy().tolist()
for i in range(mask.shape[0])
]
trimmed_length = max(seq_ends)
# TODO(lit-dev): remove this line, or make it logging.debug ?
logging.info(
"Trimming batch to trimmed_length = %d based on sequence ends %s",
trimmed_length,
seq_ends,
)
# Actually trim the input tensors.
return {k: v[:, :trimmed_length] for k, v in encoded_inputs.items()}

@classmethod
def from_loaded(cls, existing: "_KerasBaseModel", *args, **kw):
"""Share weights and underlying Keras model with another instance."""
Expand Down Expand Up @@ -111,9 +155,7 @@ def __init__(self, *args, output_embeddings=True, **kw):
self.output_embeddings = output_embeddings

def embed_texts(self, texts: Sequence[str]):
processed_inputs = self.encode_inputs(
texts, sequence_length=self.max_length
)
processed_inputs = self.encode_inputs(texts)
# <tf.float>[batch_size, num_tokens, emb_dim]
embs = self.embedder(processed_inputs["token_ids"])
# <tf.bool>[batch_size, num_tokens]
Expand Down Expand Up @@ -289,9 +331,7 @@ def predict_minibatch(self, inputs):
texts: Sequence[str] = [
ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs
]
preprocessed_texts = self.encode_inputs(
texts, sequence_length=self.max_length
)
preprocessed_texts = self.encode_inputs(texts)
sequence_ids = preprocessed_texts["token_ids"]
padding_mask = preprocessed_texts["padding_mask"]

Expand Down Expand Up @@ -342,9 +382,7 @@ def predict_minibatch(self, inputs):
texts: Sequence[str] = [
ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs
]
preprocessed_texts = self.encode_inputs(
texts, sequence_length=self.max_length
)
preprocessed_texts = self.encode_inputs(texts)
batched_outputs = {
"token_ids": preprocessed_texts["token_ids"],
"padding_mask": preprocessed_texts["padding_mask"],
Expand Down

0 comments on commit c97a710

Please sign in to comment.