Skip to content

Commit

Permalink
Support different float precision for LM salience.
Browse files Browse the repository at this point in the history
Also set watch_accessed_variables=False, because we don't need it.

PiperOrigin-RevId: 607091706
  • Loading branch information
iftenney authored and LIT team committed Feb 14, 2024
1 parent 1df3ba8 commit b6ab352
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
10 changes: 10 additions & 0 deletions lit_nlp/examples/lm_salience_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from collections.abc import Sequence
import functools
import os
import sys
from typing import Optional

from absl import app
from absl import flags
from absl import logging
import keras
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import layout
Expand Down Expand Up @@ -37,6 +39,10 @@
),
)

_KERAS_FLOATX = flags.DEFINE_string(
"keras_floatx", "bfloat16", "Floating-point type for Keras models."
)

# Custom frontend layout; see api/layout.py
modules = layout.LitModuleName
LM_LAYOUT = layout.LitCanonicalLayout(
Expand Down Expand Up @@ -109,6 +115,10 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

# Set Keras backend and floating-point precision.
os.environ["KERAS_BACKEND"] = "tensorflow"
keras.config.set_floatx(_KERAS_FLOATX.value)

plaintextPrompts = functools.partial( # pylint: disable=invalid-name
lm_data.PlaintextSents, field_name="prompt"
)
Expand Down
30 changes: 15 additions & 15 deletions lit_nlp/examples/models/instrumented_keras_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def __init__(
self.model.preprocessor.tokenizer.id_to_token
)

# map ids: <tf.int64>[batch_size, num_tokens]
# to embs: <tf.float32>[batch_size, num_tokens, emb_dim]
# map ids: <tf.int>[batch_size, num_tokens]
# to embs: <tf.float>[batch_size, num_tokens, emb_dim]
self.embedder = self.model.backbone.token_embedding

@classmethod
Expand Down Expand Up @@ -114,7 +114,7 @@ def embed_texts(self, texts: Sequence[str]):
processed_inputs = self.encode_inputs(
texts, sequence_length=self.max_length
)
# <tf.float32>[batch_size, num_tokens, emb_dim]
# <tf.float>[batch_size, num_tokens, emb_dim]
embs = self.embedder(processed_inputs["token_ids"])
# <tf.bool>[batch_size, num_tokens]
mask = processed_inputs["padding_mask"]
Expand All @@ -123,13 +123,13 @@ def embed_texts(self, texts: Sequence[str]):
def embed_and_mean_pool(self, texts: Sequence[str]):
"""Return a single vector for each text."""
embs, mask = self.embed_texts(texts)
# <tf.float32>[batch_size, num_tokens, 1]
mask = tf.expand_dims(tf.cast(mask, dtype=tf.float32), axis=2)
# <tf.float32>[batch_size, 1, emb_dim]
# <tf.float>[batch_size, num_tokens, 1]
mask = tf.expand_dims(tf.cast(mask, dtype=embs.dtype), axis=2)
# <tf.float>[batch_size, 1, emb_dim]
pooled_embs = tf.reduce_sum(
mask * embs, axis=1, keepdims=True
) / tf.reduce_sum(mask, axis=1, keepdims=True)
# <tf.float32>[batch_size, emb_dim]
# <tf.float>[batch_size, emb_dim]
return tf.squeeze(pooled_embs, axis=1)

def predict_minibatch(
Expand Down Expand Up @@ -203,7 +203,7 @@ def __init__(self, *args, **kw):

def _pred(self, input_ids, padding_mask, target_masks):
"""Predict a batch of tokenized text."""
# <tf.float32>[batch_size, num_tokens]; ignore the last one in each row.
# <tf.int>[batch_size, num_tokens]; ignore the last one in each row.
target_ids = tf.roll(input_ids, shift=-1, axis=1)

##
Expand All @@ -226,13 +226,13 @@ def _pred(self, input_ids, padding_mask, target_masks):
axis=0,
)

padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32)
padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool)
# Shift masks back so they align with target_ids.
loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1)

embeddings = None

with tf.GradientTape(watch_accessed_variables=True) as tape:
with tf.GradientTape(watch_accessed_variables=False) as tape:

def layer_intercept_fn(x, i):
if i == -1:
Expand All @@ -241,21 +241,21 @@ def layer_intercept_fn(x, i):
tape.watch(embeddings)
return x

# <tf.float32>[batch_size, num_tokens]
# <tf.float>[batch_size, num_tokens]
per_token_loss = self.model.score(
token_ids=input_ids,
padding_mask=padding_mask,
scoring_mode="loss",
layer_intercept_fn=layer_intercept_fn,
target_ids=target_ids,
)
masked_loss = per_token_loss * loss_mask
masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype)

# <tf.float32>[batch_size, num_tokens, hdim]
# <tf.float>[batch_size, num_tokens, hdim]
grads = tape.gradient(masked_loss, embeddings)
# <tf.float32>[batch_size, num_tokens]
# <tf.float>[batch_size, num_tokens]
grad_l2 = tf.norm(grads, axis=2)
# <tf.float32>[batch_size, num_tokens]
# <tf.float>[batch_size, num_tokens]
grad_dot_input = tf.reduce_sum(grads * embeddings, axis=2)

batched_outputs = {
Expand Down
18 changes: 9 additions & 9 deletions lit_nlp/examples/models/pretrained_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def predict_minibatch(self, inputs):
responses = self.tokenizer.batch_decode(
outputs[:, -self.max_new_tokens :], skip_special_tokens=True
)
# Input embeddings: <tf.float32>[batch_size, num_tokens, emb_dim]
# Input embeddings: <tf.float>[batch_size, num_tokens, emb_dim]
embeddings = self.model.transformer.wte(outputs)
batched_outputs = {
"embs": embeddings,
Expand Down Expand Up @@ -532,7 +532,7 @@ def _pred(self, encoded_inputs, target_masks):
"""
input_ids = encoded_inputs["input_ids"]

# <tf.float32>[batch_size, num_tokens]; ignore the last one in each row.
# <tf.int32>[batch_size, num_tokens]; ignore the last one in each row.
target_ids = tf.roll(encoded_inputs["input_ids"], shift=-1, axis=1)
##
# Process target masks
Expand All @@ -554,11 +554,11 @@ def _pred(self, encoded_inputs, target_masks):
axis=0,
)

padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32)
padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool)
# Shift masks back so they align with target_ids.
loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1)

with tf.GradientTape(watch_accessed_variables=True) as tape:
with tf.GradientTape(watch_accessed_variables=False) as tape:
# We need to run the embedding layer ourselves so we can trace it.
# See here for how the model normally does this:
# http://google3/third_party/py/transformers/models/gpt2/modeling_tf_gpt2.py;l=450;rcl=578656271
Expand All @@ -574,18 +574,18 @@ def _pred(self, encoded_inputs, target_masks):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction="none"
)
# <tf.float32>[batch_size, num_tokens]
# <tf.float>[batch_size, num_tokens]
per_token_loss = loss_fn(target_ids, out.logits)
masked_loss = per_token_loss * loss_mask
masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype)

grads = tape.gradient(
masked_loss, embs
) # <tf.float32>[batch_size, num_tokens, hdim]
) # <tf.float>[batch_size, num_tokens, hdim]

grad_l2 = tf.norm(grads, axis=2) # <tf.float32>[batch_size, num_tokens]
grad_l2 = tf.norm(grads, axis=2) # <tf.float>[batch_size, num_tokens]
grad_dot_input = tf.reduce_sum(
grads * embs, axis=2
) # <tf.float32>[batch_size, num_tokens]
) # <tf.float>[batch_size, num_tokens]

batched_outputs = {
"input_ids": encoded_inputs["input_ids"],
Expand Down

0 comments on commit b6ab352

Please sign in to comment.