Skip to content

Commit

Permalink
Generic instrumented Keras LM wrapper for LM salience
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606810304
  • Loading branch information
iftenney authored and LIT team committed Feb 14, 2024
1 parent 80cf699 commit 1df3ba8
Showing 1 changed file with 375 additions and 0 deletions.
375 changes: 375 additions & 0 deletions lit_nlp/examples/models/instrumented_keras_lms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
"""LIT model wrappers for generic instrumented Keras LMs."""

import functools
import inspect
import types
from typing import Sequence

from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.lib import utils as lit_utils
import numpy as np
import tensorflow as tf


_DEFAULT_MAX_LENGTH = 1024


class FieldNames(types.SimpleNamespace):
PROMPT = "prompt"
RESPONSE = "response"
PROMPT_EMBEDDINGS = "prompt_embeddings"
RESPONSE_EMBEDDINGS = "response_embeddings"
TARGET = "target"
TOKENS = "tokens"
TARGET_MASK = "target_mask"
GRAD_DOT_INPUT = "grad_dot_input"
GRAD_NORM = "grad_l2"
TOKEN_LOSS = "token_loss"


class _KerasBaseModel(lit_model.BatchedModel):
"""Base LIT model wrapper class for Keras on TensorFlow."""

# TODO(lit-dev): pytype annotations for model= ?
# Should be keras_nlp.models.generative_task.GenerativeTask
def __init__(
self,
model,
max_length: int = _DEFAULT_MAX_LENGTH,
batch_size: int = 16,
):
"""Base wrapper for a Keras/TF2 LM supporting the layer_intercept_fn API.
Model should support the following methods:
- .generate()
- .score()*
- .preprocessor.generate_preprocess()
. .preprocessor.tokenizer.id_to_token()
. .backbone.token_embedding()
* The score function should accept layer_intercept_fn= as a way to intercept
and manipulate activations between layers. We use this for salience, below.
Args:
model: pre-loaded Keras LM using the TF backend
max_length: max sequence length
batch_size: batch size
"""
super().__init__()

self.model = model
self.batch_size = batch_size
self.max_length = max_length

self.encode_inputs = self.model.preprocessor.generate_preprocess

self.ids_to_tokens = np.vectorize(
self.model.preprocessor.tokenizer.id_to_token
)

# map ids: <tf.int64>[batch_size, num_tokens]
# to embs: <tf.float32>[batch_size, num_tokens, emb_dim]
self.embedder = self.model.backbone.token_embedding

@classmethod
def from_loaded(cls, existing: "_KerasBaseModel", *args, **kw):
"""Share weights and underlying Keras model with another instance."""
return cls(model=existing.model, *args, **kw)

def max_minibatch_size(self) -> int:
return self.batch_size

@classmethod
def init_spec(cls):
# Cannot initialize from spec, because we need a Keras model object.
return None

def input_spec(self):
return {
FieldNames.PROMPT: lit_types.TextSegment(),
FieldNames.TARGET: lit_types.TextSegment(required=False),
}


class KerasGenerationModel(_KerasBaseModel):
"""LIT model wrapper for generating text with Keras on TensorFlow.
This class accepts a loaded model and provides the LIT-required functions plus
additional helper functions for generation tasks.
This class supports generation and pass-through modes. If a dataset provides a
pre-populated 'response' column then this model will return that text instead
of generating new text from the 'prompt'. This allows the same model wrapper
to be efficiently used to examine saved results from bulk-inference pipelines
and new generations from, e.g., counterfactually generated examples, or novel
evaluation datasets.
"""

def __init__(self, *args, output_embeddings=True, **kw):
super().__init__(*args, **kw)
self.output_embeddings = output_embeddings

def embed_texts(self, texts: Sequence[str]):
processed_inputs = self.encode_inputs(
texts, sequence_length=self.max_length
)
# <tf.float32>[batch_size, num_tokens, emb_dim]
embs = self.embedder(processed_inputs["token_ids"])
# <tf.bool>[batch_size, num_tokens]
mask = processed_inputs["padding_mask"]
return embs, mask

def embed_and_mean_pool(self, texts: Sequence[str]):
"""Return a single vector for each text."""
embs, mask = self.embed_texts(texts)
# <tf.float32>[batch_size, num_tokens, 1]
mask = tf.expand_dims(tf.cast(mask, dtype=tf.float32), axis=2)
# <tf.float32>[batch_size, 1, emb_dim]
pooled_embs = tf.reduce_sum(
mask * embs, axis=1, keepdims=True
) / tf.reduce_sum(mask, axis=1, keepdims=True)
# <tf.float32>[batch_size, emb_dim]
return tf.squeeze(pooled_embs, axis=1)

def predict_minibatch(
self,
inputs: list[lit_types.JsonDict],
) -> list[lit_types.JsonDict]:
prompts: Sequence[str] = [ex[FieldNames.PROMPT] for ex in inputs]

# TODO(lit-dev): suppport loading cached responses here, since running
# generation can be expensive.
full_responses: Sequence[str] = list(
self.model.generate(prompts, max_length=self.max_length)
)
# Model outputs include the prompt, so trim that off and just return the
# generated portion.
responses: Sequence[str] = [
response[len(prompt) :]
for response, prompt in zip(full_responses, prompts)
]

outputs = [{FieldNames.RESPONSE: response} for response in responses]

if self.output_embeddings:
prompt_embeddings = self.embed_and_mean_pool(prompts)
# TODO(lit-dev): embed prompt + response and trim embedding instead?
# Or just embed full_response.
response_embeddings = self.embed_and_mean_pool(responses)

for i in range(len(inputs)):
outputs[i][FieldNames.PROMPT_EMBEDDINGS] = prompt_embeddings[i].numpy()
outputs[i][FieldNames.RESPONSE_EMBEDDINGS] = response_embeddings[
i
].numpy()

return outputs

def output_spec(self) -> lit_types.Spec:
ret = {
FieldNames.RESPONSE: lit_types.GeneratedText(parent=FieldNames.TARGET)
}
if self.output_embeddings:
return ret | {
FieldNames.PROMPT_EMBEDDINGS: lit_types.Embeddings(),
FieldNames.RESPONSE_EMBEDDINGS: lit_types.Embeddings(),
}
return ret


class KerasSalienceModel(_KerasBaseModel):
"""LIT model wrapper for computing salience with Keras on TensorFlow.
This class accepts a loaded model and provides the LIT-required functions plus
additional helper functions to convert and clean tokens and to compute
sequence salience.
This class does not support generation; use the KerasGenerationModel class to
generate the text for which this class will compute salience.
"""

def __init__(self, *args, **kw):
super().__init__(*args, **kw)

score_fn = getattr(self.model, "score", None)

if score_fn is None or not inspect.ismethod(score_fn):
raise TypeError(
"Salience is computed via a .score() API, which is not supported by "
"all GenerativeTask models in KerasNLP. Please provide a model that "
"supports this API."
)

def _pred(self, input_ids, padding_mask, target_masks):
"""Predict a batch of tokenized text."""
# <tf.float32>[batch_size, num_tokens]; ignore the last one in each row.
target_ids = tf.roll(input_ids, shift=-1, axis=1)

##
# Process target masks

# It doesn't make sense to interpret the first token, since it is not ever
# predicted. But we need to ensure that the mask[0] is zero, so it doesn't
# cause problems when 'rolled' to the last position below.
modified_masks = [[0] + list(mask[1:]) for mask in target_masks]
seq_len = target_ids.shape[1]
pad_fn = functools.partial(
lit_utils.pad1d,
min_len=seq_len,
max_len=seq_len,
pad_val=0,
pad_left=False,
)
padded_target_masks = np.stack(
[pad_fn(mask) for mask in modified_masks],
axis=0,
)

padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32)
# Shift masks back so they align with target_ids.
loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1)

embeddings = None

with tf.GradientTape(watch_accessed_variables=True) as tape:

def layer_intercept_fn(x, i):
if i == -1:
nonlocal embeddings, tape
embeddings = x
tape.watch(embeddings)
return x

# <tf.float32>[batch_size, num_tokens]
per_token_loss = self.model.score(
token_ids=input_ids,
padding_mask=padding_mask,
scoring_mode="loss",
layer_intercept_fn=layer_intercept_fn,
target_ids=target_ids,
)
masked_loss = per_token_loss * loss_mask

# <tf.float32>[batch_size, num_tokens, hdim]
grads = tape.gradient(masked_loss, embeddings)
# <tf.float32>[batch_size, num_tokens]
grad_l2 = tf.norm(grads, axis=2)
# <tf.float32>[batch_size, num_tokens]
grad_dot_input = tf.reduce_sum(grads * embeddings, axis=2)

batched_outputs = {
"input_ids": input_ids,
"padding_mask": padding_mask,
# Gradients are already aligned to input tokens.
FieldNames.GRAD_NORM: grad_l2,
FieldNames.GRAD_DOT_INPUT: grad_dot_input,
# Shift token loss to align with (input) tokens.
FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1),
}

return batched_outputs

def _postprocess(self, preds):
"""Post-process single-example preds. Operates on numpy arrays."""
mask = preds.pop("padding_mask").astype(bool)
ids = preds.pop("input_ids")[mask]
preds[FieldNames.TOKENS] = self.ids_to_tokens(ids)
for key in lit_utils.find_spec_keys(
self.output_spec(), lit_types.TokenScores
):
preds[key] = preds[key][mask]
# First token (<bos>) is not actually predicted, so return 0 for loss.
preds[FieldNames.TOKEN_LOSS][0] = 0

return preds

def predict_minibatch(self, inputs):
"""Predict on a single minibatch of examples."""
texts: Sequence[str] = [
ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs
]
preprocessed_texts = self.encode_inputs(
texts, sequence_length=self.max_length
)
sequence_ids = preprocessed_texts["token_ids"]
padding_mask = preprocessed_texts["padding_mask"]

target_masks = [ex.get(FieldNames.TARGET_MASK, []) for ex in inputs]

# Get the predictions.
batched_outputs = self._pred(sequence_ids, padding_mask, target_masks)
# Convert to numpy for post-processing.
detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()}
# Split up batched outputs, then post-process each example.
unbatched_outputs = lit_utils.unbatch_preds(detached_outputs)
return map(self._postprocess, unbatched_outputs)

def input_spec(self):
return super().input_spec() | {
FieldNames.TARGET_MASK: lit_types.TokenScores(align="", required=False),
}

def output_spec(self) -> lit_types.Spec:
return {
FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens.
FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores(
align=FieldNames.TOKENS
),
FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS),
FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS),
}


class KerasTokenizerModel(_KerasBaseModel):
"""LIT model wrapper for tokenizing text with Keras on TensorFlow.
This class accepts a loaded model and provides the LIT-required functions plus
additional helper functions to convert and clean tokens.
"""

def _postprocess(self, preds):
"""Post-process single-example preds. Operates on numpy arrays."""
# Be sure to cast to bool, otherwise this will select intger positions 0, 1
# rather than acting as a boolean mask.
mask = preds.pop("padding_mask").astype(bool)
ids = preds.pop("token_ids")[mask]
preds[FieldNames.TOKENS] = self.ids_to_tokens(ids)
return preds

def predict_minibatch(self, inputs):
"""Tokenize a single minibatch of examples."""
texts: Sequence[str] = [
ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs
]
preprocessed_texts = self.encode_inputs(
texts, sequence_length=self.max_length
)
batched_outputs = {
"token_ids": preprocessed_texts["token_ids"],
"padding_mask": preprocessed_texts["padding_mask"],
}
# Convert to numpy for post-processing.
detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()}
# Split up batched outputs, then post-process each example.
unbatched_outputs = lit_utils.unbatch_preds(detached_outputs)
return map(self._postprocess, unbatched_outputs)

def output_spec(self) -> lit_types.Spec:
return {
FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens.
}


def initialize_model_group_for_salience(
name, *args, **kw
) -> dict[str, lit_model.Model]:
"""Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'."""
generation_model = KerasGenerationModel(*args, **kw)
salience_model = KerasSalienceModel(*args, **kw)
tokenizer_model = KerasTokenizerModel(*args, **kw)
return {
name: generation_model,
f"_{name}_salience": salience_model,
f"_{name}_tokenizer": tokenizer_model,
}

0 comments on commit 1df3ba8

Please sign in to comment.