# Using Caching to handle high-traffic LLM Apps without dropping requests

🚀 **Do this in 1 line of code [reliableGPT](https://github.com/BerriAI/reliableGPT)**

```reliableGPT(openai.ChatCompletion.create, caching=True)```

## Environment Set-up

In [None]:
!pip install openai waitress chromadb

In [26]:
import openai
openai.api_key = "YOUR_OPENAI_KEY"

## Create our Server

We'll probably be making OpenAI calls from our backend server. Let's do that now!

### Wrap OpenAI

We'll wrap around the main openai chatcompletions endpoint, this will ensure we're able to cache our OpenAI responses, and reply with them later.

## Key components:
### - `chromadb`: We're using this as our in-memory cache.
### - `wrapper_fn`: We're wrapping the OpenAI chat endpoint with this function.

In [56]:
from threading import active_count # get the number of active threads being used
import chromadb
from uuid import uuid4
import traceback

# In-memory Cache!
chroma_client = chromadb.Client()
cache = chroma_client.create_collection(name="cache")

# Our Wrapper Function!!
def wrapper_fn(fn, max_threads=1):
  def wrapped_fn(*args, **kwargs):
    try:
      thread_utilization = active_count()/max_threads
      if thread_utilization > 0.85: # if thread utilization is > 85% (high load!!), let's respond with cached requests
        if "messages" in kwargs:
          try:
            if cache.count() > 0:
              input_prompt = "\n".join(message["content"]
                                        for message in kwargs["messages"])
              result = cache.query(
                query_texts=[input_prompt],
                n_results=1
              )
              return result
            else:
              pass
          except:
            pass
      result = fn(*args, **kwargs)
      if "messages" in kwargs:
        input_prompt = "\n".join(message["content"]
                                    for message in kwargs["messages"])
        cache.add(
          documents=[input_prompt],
          metadatas=[{"result": result.choices[0].message.content}],
          ids=[str(uuid4())],
        )
      return result.choices[0].message.content
    except:
      traceback.print_exc()
  return wrapped_fn

## Key components:
### - `openai.ChatCompletion.create`: We're wrapping this endpoint.

In [None]:
# Wrapping the OpenAI endpoint!
original_fn = openai.ChatCompletion.create # let's keep the original reference in case we need it
openai.ChatCompletion.create = wrapper_fn(openai.ChatCompletion.create)

# Test OpenAI Call
chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
print(chat_completion)

### Create our Flask Server

Let's make our flask server - we'll also use the waitress package to handle queuing, threading, etc. for our server.

In [61]:
# Creating our Flask Server!!
import threading
from flask import Flask, request
from waitress import serve

app = Flask(__name__)

@app.route("/test_func")
def test_fn():
  result = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
  return result

# Method to run our flask app (need to do this, since we're running this in a Jupyter Notebook)
def run_app():
  from waitress import serve
  serve(app, host='0.0.0.0', port=3000, threads=1)

# Run our flask app
# Starting a Python thread that will run the flask application
flask_thread = threading.Thread(target=run_app)
flask_thread.start()

# 🚀 Test with traffic!

Let's send 100+ simultaneous calls to openai and see what happens.

In [None]:
# Making 100+ calls at once!
import concurrent.futures

# Function to call test_fn()
def call_test_fn():
  with app.test_client() as client:
      response = client.get('/test_func')
      return response.data

# Call the function to test the Flask endpoint
with concurrent.futures.ThreadPoolExecutor(max_workers=800) as executor:
  # Submit the requests and gather the future objects
  futures = [executor.submit(call_test_fn) for _ in range(100)]

  # Wait for all futures to complete
  concurrent.futures.wait(futures)

  # Retrieve the results
  results = [future.result() for future in futures]

for result in results:
  print(result)