Skip to content

An object-oriented wrapper around language models with caching, batching, and more.

Notifications You must be signed in to change notification settings

DaiseyCode/lmwrapper

Repository files navigation

lmwrapper provides a wrapper around OpenAI API and Hugging Face Language models, focusing on being a clean, object-oriented, and user-friendly interface. It has two main goals:

A) Make it easier to use the OpenAI API.

B) Make it easier to reuse your code for other language models with minimal changes.

Some key features currently include local caching of responses, and super simple use of the OpenAI batching API which can save 50% on costs.

lmwrapper is lightweight and can serve as a flexible stand-in for the OpenAI API.

Installation

For usage with just OpenAI models:

pip install lmwrapper

For usage with HuggingFace models as well:

pip install 'lmwrapper[hf]'

For development dependencies:

pip install 'lmwrapper[dev]'

Example usage

Chat

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt, LmChatTurn

lm = get_open_ai_lm(
    model_name=OpenAiModelNames.gpt_4o_mini,
    api_key_secret=None,  # By default, this will read from the OPENAI_API_KEY environment variable.
    # If that isn't set, it will try the file ~/oai_key.txt
    # You need to place the key in one of these places,
    # or pass in a different location. You can get an API
    # key at (https://platform.openai.com/account/api-keys)
)

# Single user utterance
pred = lm.predict("What is 2+2?")
print(pred.completion_text)  # "2 + 2 equals 4."


# Use a LmPrompt to have more control of the parameters
pred = lm.predict(LmPrompt(
    "What is 2+6?",
    max_tokens=10,
    temperature=0, # Set this to 0 for deterministic completions
))
print(pred.completion_text)  # "2 + 6 equals 8."

# Conversation alternating between `user` and `assistant`.
pred = lm.predict(LmPrompt(
    [
        "What is 2+2?",  # user turn
        "4",  # assistant turn
        "What is 5+3?"  # user turn
        "8",  # assistant turn
        "What is 4+4?"  # user turn
        # We use few-shot turns to encourage the answer to be our desired format.
        #   If you don't give example turns you might get something like
        #   "4 + 4 equals 8." instead of just "8" as desired.
    ],
    max_tokens=10,
))
print(pred.completion_text)  # "8"

# If you want things like the system message, you can use LmChatTurn objects
pred = lm.predict(LmPrompt(
    text=[
        LmChatTurn(role="system", content="You always answer like a pirate"),
        LmChatTurn(role="user", content="How does bitcoin work?"),
    ],
    max_tokens=25,
    temperature=0,
))
print(pred.completion_text)
# "Arr, me matey! Bitcoin be a digital currency that be workin' on a technology called blockchain..."

Caching

Add caching = True in the prompt to cache the output to disk. Any subsequent calls with this prompt will return the same value. Note that this might be unexpected behavior if your temperature is non-zero. (You will always sample the same output on reruns). If you want to get multiple samples at a non-zero temperature while still using the cache, you set num_completions > 1 in a LmPrompt.

OpenAI Batching

The OpenAI batching API has a 50% reduced cost when willing to accept a 24-hour turnaround. This makes it good for processing datasets or other non-interactive tasks (which is the main target for lmwrapper currently).

lmwrapper takes care of managing the batch files and other details so that it's as easy as the normal API.

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt
from lmwrapper.batch_config import CompletionWindow

def load_dataset() -> list:
    """Load some toy task"""
    return ["France", "United States", "China"]

def make_prompts(data) -> list[LmPrompt]:
    """Make some toy prompts for our data"""
    return [
        LmPrompt(
            f"What is the capital of {country}? Answer with just the city name.",
            max_tokens=10,
            temperature=0,
            cache=True,
        ) 
        for country in data
    ]

data = load_dataset()
prompts = make_prompts(data)
lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
predictions = lm.predict_many(
    prompts,
    completion_window=CompletionWindow.BATCH_ANY 
    #                 ^ swap out for CompletionWindow.ASAP
    #                   to complete as soon as possible via
    #                   the non-batching API at a higher cost.
) # The batch is submitted here

for ex, pred in zip(data, predictions):  # Will wait for the batch to complete
    print(f"Country: {ex} --- Capital: {pred.completion_text}")
    if ex == "France": assert pred.completion_text == "Paris" 
    # ...

The above code could technically take up to 24hrs to complete. However, OpenAI seems to complete these quicker (for example, these three prompts in ~1 minute or less). In a large batch, you don't have to keep the process running for hours. Thanks to lmwrapper cacheing it will automatically load or pick back up waiting on the existing batch when the script is reran.

The lmwrapper cache lets you also intermix cached and uncached examples.

# ... above code

def load_more_data() -> list:
    """Load some toy task"""
    return ["Mexico", "Canada"]

data = load_data() + load_more_data()
prompts = make_prompts(data)
# If we submit the five prompts, only the two new prompts will be
# submitted to the batch. The already completed prompts will
# be loaded near-instantly from the local cache.
predictions = list(lm.predict_many(
    prompts,
    completion_window=CompletionWindow.BATCH_ANY
))

lmwrapper is designed to automatically manage the batching of thousands or millions of prompts. If needed, it will automatically split up prompts into sub-batches and will manage issues around rate limits.

This feature is mostly designed for the OpenAI cost savings. You could swap out the model for HuggingFace and the same code will still work. However, internally it is like a loop over the prompts. Eventually in lmwrapper we want to do more complex batching if GPU/CPU/accelerator memory is available.

Caveats / Implementation needs

This feature is still somewhat experimental. It likely works in typical usecases, but there are few known things to sort out / TODOs:

  • Retry batch API connection errors
  • Automatically splitting up batches when have >50,000 prompts (limit from OpenAI)
  • Recovering / splitting up batches when hitting your token Batch Queue Limit (see docs on limits)
  • Handle canceled batches during current run (use the web interface to cancel)
  • Handle/recover canceled batches outside of current run
  • Handle if openai batch expires unfinished in 24hrs (though not actually tested or observed this)
  • Automatically splitting up batch when exceeding 100MB prompts limit
  • Handling of failed prompts (like when have too many tokens). Use LmPrediction.has_errors and LmPrediction.error_message to check for an error on a response.
  • Handle when there are duplicate prompts in batch submission
  • Handle when a given prompt has num_completions>1
  • Automatically clean up API files after done (right now end up with a lot of file in storage. There isn't an obvious cost for these batch files, but this might change and it would be better to clean them up.)
  • Test on free-tier accounts. It is not clear what the tiny request limit counts
  • Fancy batching of HF
  • Concurrent batching when in ASAP mode

Please open an issue if you want to discuss one of these or something else.

Note, in the progress bars in PyCharm can be bit cleaner if you enable terminal emulation in your run configuration.

Hugging Face models

Local Causal LM models on Hugging Face models can be used interchangeably with the OpenAI models.

Note: The universe of Huggingface models is diverse and inconsistent. Some (especially the non-completion ones) might require special prompt formatting to work as expected. Some models might not work at all.

from lmwrapper.huggingface_wrapper import get_huggingface_lm
from lmwrapper.structs import LmPrompt

lm = get_huggingface_lm("gpt2")  # Download the smallest 124M parameter model

prediction = lm.predict(LmPrompt(
    "The capital of Germany is Berlin. The capital of France is",
    max_tokens=1,
    temperature=0,
))
print(prediction.completion_text)
assert prediction.completion_text == " Paris"

Additionally, with HuggingFace models lmwrapper provides an interface for accessing the model internal states.

Retries on rate limit

from lmwrapper.openai_wrapper import *

lm = get_open_ai_lm(
    OpenAiModelNames.gpt_3_5_turbo_instruct,
    retry_on_rate_limit=True
)

Other features

Built-in token counting

from lmwrapper.openai_wrapper import *
from lmwrapper.structs import LmPrompt

lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo_instruct)
assert lm.estimate_tokens_in_prompt(
    LmPrompt("My name is Spingldorph", max_tokens=10)) == 7
assert not lm.could_completion_go_over_token_limit(LmPrompt(
    "My name is Spingldorph", max_tokens=1000))

TODOs

If you are interested in one of these particular features or something else please make a Github Issue.

  • Openai completion
  • Openai chat
  • Huggingface interface
  • Huggingface device checking on PyTorch
  • Move cache to be per project
  • Redesign cache away from generic diskcache to make it easier to manage
  • Smart caching when num_completions > 1 (reusing prior completions)
  • OpenAI batching interface (experimental)
  • Anthropic interface
  • Be able to add user metadata to a prompt
  • Use the huggingface chat templates for chat models if available
  • Automatic cache eviction to limit count or disk size (right now have to run a SQL query to delete entries before a certain time or matching your criteria)
  • Multimodal/images in super easy format (like automatically process pil, opencv, etc)
  • sort through usage of quantized models
  • Cost estimation of a prompt before running / log total cost
  • Additional Huggingface runtimes (TensorRT, BetterTransformers, etc)
  • async / streaming (not a top priority for non-interactive research use cases)
  • some lightweight utilities to help with tool use

About

An object-oriented wrapper around language models with caching, batching, and more.

Resources

Stars

Watchers

Forks

Packages

No packages published