# Using a hot (in-memory) + hosted cache to prevent dropped customer requests in prod (LLM Apps).

# Create a class to manage your Caching logic

In [None]:
from typing import Any
import threading
from threading import active_count
import requests
import traceback
from flask import Flask, request
import json
from fuzzywuzzy import process

class reliableCache:
    def __init__(self, query_arg=None, customer_instance_arg=None, user_email=None, similarity_threshold=0.65, max_threads=100, verbose=False) -> None:
        self.max_threads = max_threads
        self.verbose = verbose
        self.query_arg = query_arg
        self.customer_instance_arg = customer_instance_arg
        self.user_email = user_email
        self.threshold = similarity_threshold
        self.cache_wrapper_threads = {}
        self.hot_cache = {}
        pass

    def print_verbose(self, print_statement):
        if self.verbose:
            print("Cached Request: " + str(print_statement))

    def add_cache(self, user_email, instance_id, input_prompt, response):
        try:
            self.print_verbose(f"result being stored in cache: {response}")
            url = "YOUR_HOSTED_CACHE_ENDPOINT/add_cache"
            querystring = {
                "customer_id": "temp5@xyz.com",
                "instance_id": instance_id,
                "user_email": user_email,
                "input_prompt": input_prompt,
                "response": json.dumps({"response": response})
            }
            response = requests.post(url, params=querystring)
        except:
            pass

    def try_cache_request(self, user_email, instance_id, query=None):
        try:
            url = "YOUR_HOSTED_CACHE_ENDPOINT/get_cache"
            querystring = {
                "customer_id": "temp5@xyz.com",
                "instance_id": instance_id,
                "user_email": user_email,
                "input_prompt": query,
                "threshold": self.threshold
            }
            response = requests.get(url, params=querystring)
            self.print_verbose(f"response: {response.text}")
            extracted_result = response.json()["response"]
            print(f"extracted_result: {extracted_result} \n\n original response: {response.json()}")
            return extracted_result
        except:
            pass
        self.print_verbose(f"cache miss!")
        return None

    def cache_wrapper(self, func):
        def wrapper(*args, **kwargs):
            query = request.args.get("query") # the customer question
            instance_id = request.args.get(self.customer_instance_arg) # the unique instance to put that customer query/response in
            try:
                if (self.user_email, instance_id) in self.hot_cache:
                    choices = self.hot_cache[(self.user_email, instance_id)]
                    most_similar_query = process.extractOne(query, choices)
                    if most_similar_query[1] > 70:
                        result = self.hot_cache[(self.user_email, instance_id, most_similar_query[0])]
                        return result
                else:
                    result = func(*args, **kwargs)
                    # add response to cache
                    self.add_cache(self.user_email, instance_id=instance_id, input_prompt=query, response=result)
            except Exception as e:
                cache_result = self.try_cache_request(user_email=self.user_email, instance_id=instance_id, query=query)
                if cache_result:
                    print("cache hit!")
                    self.hot_cache[(self.user_email, instance_id, query)] = cache_result
                    if (self.user_email, instance_id) not in self.hot_cache:
                        self.hot_cache[(self.user_email, instance_id)] = []
                    self.hot_cache[(self.user_email, instance_id)].append(query)
                    return cache_result
                else:
                    print("Cache miss!")
                    raise e
            self.print_verbose(f"final result: {result}")
            return result
        return wrapper

    def get_wrapper_thread_utilization(self):
        self.print_verbose(f"cache wrapper thread values: {self.cache_wrapper_threads.values()}")
        active_cache_threads = 0
        for value in self.cache_wrapper_threads.values():
            if value == True:
                active_cache_threads += 1
        # active_cache_threads = sum(self.cache_wrapper_threads.values())
        self.print_verbose(f"active_cache_threads: {active_cache_threads}")
        return active_cache_threads / self.max_threads

# Wrap your query endpoint with it

In our caching class, we designed our cache to be a decorator (cache_wrapper). This wraps the `berri_query` endpoint, and in case we return errors, this will catch that and return a cached response instead. We also check each request against a local / hot cache, to reduce dropped requests (useful during high-traffic scenarios).

In [None]:
@app.route("/berri_query")
@cache.cache_wrapper
def berri_query():
  print('Request receieved: ', request)
  # your execution logic
  pass