Skip to content

Commit

Permalink
LIT: Refactor instrumented Keras LMs to support TF and Torch
Browse files Browse the repository at this point in the history
Summary of changes:

* Wherever possible, use `keras.ops` for Tensor operations. These functions are designed to create and operate over Tensors from any of the Keras 3 backends.
* Identified a common API surface for salience predictions that supports all existing (and likely any future) Keras 3 backends. Functions implementing this API shoudl be named `_pred_{framework}` and return a tuple of the GradNorm and GradDotInput salience scores.
* Refactors `KerasSalienceModel._pred()` to perform common operations for data before calling out to backend-specific prediction functions.
* Extracts TensorFlow code in `KerasSalienceModel` to a new `_pred_tf()` function.
* Implements a `KerasSalienceModel._pred_torch()` function based on the HF implementation in lit_nlp/examples/models/pretrained_lms.py
* Provides a stub for `KerasSalienceModel._pred_jax()` with a detailed comment outlining the JAX idiosyncrasies that we need to adapt to in order to support this backend.

PiperOrigin-RevId: 622966761
  • Loading branch information
RyanMullins authored and LIT team committed Apr 8, 2024
1 parent 82abec6 commit 5ee7064
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 197 deletions.
7 changes: 7 additions & 0 deletions lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ def maybe_copy_np(arr):
# If this is not a view of another array.
if arr.base is None:
return arr
# Tensorflow provides a bridge to share memory between tensorflow and numpy
# arrays. This looks like a view into an array but the base is a
# tensorflow_wrapper not an array, so the view heuristics below don't work. We
# can check for this case by checking is arr.base has the ndim attribute.
# https://github.com/tensorflow/tensorflow/blob/6ed79e8429730c33dc894175da7a1849a8e3e57f/tensorflow/python/lib/core/ndarray_tensor_bridge.cc#L90
if not hasattr(arr.base, 'ndim'):
return np.copy(arr)
# Heuristic to check if we should 'detach' this array from the parent blob.
# We want to know if this array is a view that might leak memory.
# The simplest check is if arr.base is larger than arr, but we don't want to
Expand Down
185 changes: 98 additions & 87 deletions lit_nlp/examples/lm_salience_demo.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
r"""Demo for sequence salience with a left-to-right language model.
To use with Gemma models, install the latest versions of Keras and KerasNLP:
To use with the Gemma, Llama, or Mistral models, install the latest versions of
Keras, KerasNLP, and/or HuggingFace Transformers:
pip install keras>=3.0.5 keras-nlp>=0.8.0
pip install keras>=3.1.0 keras-nlp>=0.9.0 transformers>=4.38.0
To run with the default configuration (Gemma on TensorFlow via Keras):
To run:
blaze run -c opt examples:lm_salience_demo -- \
--models=gemma_instruct_2b_en:gemma_instruct_2b_en \
--port=8890 --alsologtostderr
We strongly recommend a GPU or other accelerator to run this demo, although for
testing, the smaller GPT-2 models run well on CPU. To use tensorflow weights of
GPT2, set the flag values as below:
--models=gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz
--hf_framework=tensorflow
We also support pytorch weights for GPT-2 model, simply set the flag values:
--models=gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2-pt.tar.gz
--hf_framework=pytorch
A few more examples of the flag setup for other supported models (GPU required):
Llama2: --hf_framework=pytorch
--models=llama2:meta-llama/Llama-2-7b-hf
Mistral: --hf_framework=pytorch
--models=mistral:mistralai/Mistral-7B-v0.1
By default this include a small set of sample prompts, but you can load your
own examples using the --datasets flag or through the "Configure" menu in the
UI.
MODELS:
We strongly recommend a GPU or other accelerator to run this server with LLMs.
The table below shows the model names and presets for common models. Use these
to parameterize the --models flag with comma-separated `{model}:{preset}`
strings, and remember the number of models loaded will be limited by the memory
available on your accelerator.
| Model | dl_framework | dl_backend=tensorflow Preset | dl_backend=torch Preset |
| ------- | ------------ | ---------------------------- | ------------------------------------ |
| Gemma | kerasnlp | gemma_1.1_instruct_7b_en | gemma_1.1_instruct_7b_en |
| Gemma | transformers | Unavailable | google/gemma-1.1-7b-it |
| Llama 2 | kerasnlp | llama2_instruct_7b_en | llama2_instruct_7b_en |
| Llama 2 | transformers | Unavailable | meta-llama/Llama-2-7b-hf |
| Mistral | kerasnlp | mistral_instruct_7b_en | mistral_instruct_7b_en |
| Mistral | transformers | Unavailable | mistralai/Mistral-7B-Instruct-v0.2 |
Additional model presets can be found at the following locations, though
compatibility with the LIT model wrappers is not guaranteed:
* KerasNLP: https://keras.io/api/keras_nlp/models/
* HuggingFace Transformers: https://huggingface.co/models
DATASETS:
By default this includes a small set of sample prompts. You can load your own
examples using the --datasets flag or through the "Configure" menu in the UI.
"""

from collections.abc import Sequence
Expand All @@ -37,31 +47,15 @@
import sys
from typing import Optional

# TODO(b/327281789): remove once keras 3 is the default.
# Temporary; need to set this before importing keras_nlp
os.environ["FORCE_KERAS_3"] = "True"

# pylint: disable=g-import-not-at-top
from absl import app
from absl import flags
from absl import logging
import keras
from keras_nlp import models as keras_models
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import layout
from lit_nlp.examples.datasets import lm as lm_data
from lit_nlp.examples.models import instrumented_keras_lms as lit_keras
from lit_nlp.examples.models import pretrained_lms
from lit_nlp.lib import file_cache

# pytype: disable=import-error
try:
import torch
except (ModuleNotFoundError, ImportError):
logging.warning("PyTorch is not available.")
# pytype: enable=import-error

# NOTE: additional flags defined in server_flags.py

FLAGS = flags.FLAGS
Expand All @@ -70,16 +64,14 @@

_MODELS = flags.DEFINE_list(
"models",
[
"gemma_instruct_2b_en:gemma_instruct_2b_en",
"gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz",
],
"Models to load, as <name>:<path>. Currently supports Gemma (Keras NLP) and"
"HuggingFace models. For HuggingFace models, GPT2, Llama, Mistral have been"
"verified to work with this demo. Thereotically, supported decoder models"
"in `transformers.AutoModelForCausalLM` should work, but adjustments might"
"be needed on their tokenizers (e.g. need to define custom pad_token when"
"eos_token is not available to use as pad_token).",
["gemma_instruct_2b_en:gemma_instruct_2b_en"],
"Models to load, as <name>:<path>. Path can be a URL, a local file path, or"
" the name of a preset for the configured Deep Learning framework (either"
" KerasNLP or HuggingFace Transformers; see --dl_framework for more). This"
" demo is tested with Gemma, GPT2, Llama, and Mistral on all supported"
" --dl_framework values. Other models should work, but adjustments might be"
" needed on their tokenizers (e.g., to define custom pad_token"
" when eos_token is not available to use as pad_token).",
)

_DATASETS = flags.DEFINE_list(
Expand All @@ -99,19 +91,31 @@
),
)

_HF_FRAMEWORK = flags.DEFINE_enum(
"hf_framework",
_DL_BACKEND = flags.DEFINE_enum(
"dl_backend",
"tensorflow",
["tensorflow", "pytorch"],
"Deep learning framework for the HuggingFace model.",
["jax", "torch", "tensorflow"],
"The deep learning backend framework that the model runs on. All models"
" loaded by this server will use the same backend, incompatibilities will"
" result in errors.",
)

_DL_FRAMEWORK = flags.DEFINE_enum(
"dl_framework",
"kerasnlp",
["kerasnlp", "transformers"],
"The deep learning framework that loads and runs the model on the backend."
" This server will attempt to load all models specified by the --models"
" flag with the configured framework, incompatibilities will result in"
" errors.",
)

_PRECISION = flags.DEFINE_enum(
"precision",
"bfloat16",
["bfloat16", "float32"],
"Floating point precision for the HuggingFace (PyTorch) and Keras models,"
"only `bfloat16` and `float32` are supported for now.",
"Floating point precision for the models, only `bfloat16` and `float32` are"
" supported for now.",
)


Expand Down Expand Up @@ -201,20 +205,28 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
raise app.UsageError("Too many command-line arguments.")

# Set Keras backend and floating-point precision.
os.environ["KERAS_BACKEND"] = "tensorflow"
if hasattr(keras, "config") and hasattr(keras.config, "set_floatx"):
if _DL_FRAMEWORK.value == "kerasnlp":
# NOTE: Keras and KerasNLP require that certain environment variables are
# set before they are imported.
# TODO(b/327281789): Remove FORCE_KERAS_3 once Keras 3 is the default.
os.environ["FORCE_KERAS_3"] = "True"
os.environ["KERAS_BACKEND"] = _DL_BACKEND.value

# NOTE: Imported here and not at the top of the file to avoid
# initialization issues with the environment variables above. We should also
# import keras before any other Keras-related modules (e.g., KerasNLP or the
# LIT wrappers) to limit the potenital for improperly configured backends.
import keras # pylint: disable=g-import-not-at-top

keras.config.set_floatx(_PRECISION.value)
else:
# TODO(b/327281789): remove once we can guarantee Keras 3.
logging.warn(
"keras.config.set_floatx() not available; using default precision."
)
elif _DL_BACKEND.value == "torch":
# NOTE: Keras sets precision for all backends with set_floatx(), but for
# HuggingFace Transformers with PyTorch we need to set it explicitly.
import torch # pylint: disable=g-import-not-at-top # pytype: disable=import-error

if _HF_FRAMEWORK.value == "pytorch":
if _PRECISION.value == "bfloat16":
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.float32)
torch.set_default_dtype(
torch.bfloat16 if _PRECISION.value == "bfloat16" else torch.float32
)

plaintextPrompts = functools.partial( # pylint: disable=invalid-name
lm_data.PlaintextSents, field_name="prompt"
Expand Down Expand Up @@ -259,38 +271,37 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# Load models, according to the --models flag.
models = {}
for model_string in _MODELS.value:
# Only split on the first ':', because path may be a URL
# containing 'https://'
# Only split on the first ':' as path may be a URL containing 'https://'
model_name, path = model_string.split(":", 1)
logging.info("Loading model '%s' from '%s'", model_name, path)
if model_name.startswith("gemma"):
path = file_cache.cached_path(
path,
extract_compressed_file=path.endswith(".tar.gz"),
copy_directories=True,
)

path = file_cache.cached_path(
path,
extract_compressed_file=path.endswith(".tar.gz"),
copy_directories=True,
)

if _DL_FRAMEWORK.value == "keras":
# pylint: disable=g-import-not-at-top
from keras_nlp import models as keras_models
from lit_nlp.examples.models import instrumented_keras_lms as lit_keras
# pylint: enable=g-import-not-at-top
# Load the weights once for the underlying Keras model.
gemma_keras_model = keras_models.GemmaCausalLM.from_preset(path) # pytype: disable=module-attr
models = models | lit_keras.initialize_model_group_for_salience(
model_name, gemma_keras_model, max_length=512, batch_size=4
model = keras_models.CausalLM.from_preset(path)
models |= lit_keras.initialize_model_group_for_salience(
model_name, model, max_length=512, batch_size=4
)
# Disable embeddings from the generation model.
# TODO(lit-dev): re-enable embeddings if we can figure out why UMAP was
# crashing? Maybe need n > 2 examples.
models[model_name].output_embeddings = False
else:
# NOTE: (Style Deviation) Imported here to limit uncessary imports.
from lit_nlp.examples.models import pretrained_lms # pylint: disable=g-import-not-at-top
# Assuming a valid decoder model name supported by
# `transformers.AutoModelForCausalLM` is provided to "path".
models[model_name] = pretrained_lms.HFGenerativeModel(
path, framework=_HF_FRAMEWORK.value, max_new_tokens=512
)
# Salience wrapper, using same underlying Keras models so as not to
# load the weights twice.
models[f"_{model_name}_salience"] = (
pretrained_lms.HFSalienceModel.from_loaded(models[model_name])
)
models[f"_{model_name}_tokenizer"] = (
pretrained_lms.HFTokenizerModel.from_loaded(models[model_name])
models |= pretrained_lms.initialize_model_group_for_salience(
model_name, path, framework=_DL_BACKEND.value, max_new_tokens=512
)

for name in datasets:
Expand Down

0 comments on commit 5ee7064

Please sign in to comment.