## 04.03 Setting up the Milvus Cache

In [1]:
#Setup database & collection
from pymilvus import connections
from pymilvus import db,Collection

from pymilvus import utility

#Names for connections, database and collections
conn_name = "cache_conn"
db_name="cache_db"
collection_name="llm_cache"

#Create a connection to Milvus
connections.add_connection(
    cache_conn={
        "host": "localhost",
        "port": "19530",
        "username" : "username",
        "password" : "password"
    })


#Connect
connections.connect(conn_name)

#Create a DB if not already present
current_dbs=db.list_database(using=conn_name)

if ( db_name not in current_dbs):
    print("Creating database :", db_name)
    resume_db = db.create_database(db_name, using=conn_name) #default db is "default"
else:
    print(db_name, ": Database already exists")

#Switch to the new database
db.using_database(db_name, using=conn_name)

cache_db : Database already exists


In [4]:
#Create a Collection for cache
from pymilvus import CollectionSchema, FieldSchema, DataType, Collection
import json

#Define fields in the cache
#Autogenerated ID field for each entity
cache_id = FieldSchema(
    name="cache_id",
    dtype=DataType.INT64,
    auto_id=True,
    is_primary=True,
    max_length=32)

#Text for the input prompt
prompt_text= FieldSchema(
    name="prompt_text",
    dtype=DataType.VARCHAR,
    max_length=2048)

#Text for the LLM response
response_text= FieldSchema(
    name="response_text",
    dtype=DataType.VARCHAR,
    max_length=2048)

#Embedding for the input prompt
prompt_embedding = FieldSchema(
    name="prompt_embedding",
    dtype=DataType.FLOAT_VECTOR,
    dim=1536 #Define based on embedding used
)

#Define the schema for the cache collection
cache_schema=CollectionSchema(
    fields=[cache_id, prompt_text, response_text, prompt_embedding],
    description="Cache for LLM",
    enable_dynamic_field=True
)

#Create the collection
cache_collection=Collection(
    name=collection_name,
    schema=cache_schema,
    using=conn_name,
    shard_num=2
)

print("Schema : ", cache_collection.schema, "\n")

#Build an index for the prompt embedding field
index_params = {
    "metric_type":"L2",
    "index_type":"IVF_FLAT",
    "params" :{"nlist":1024}
}

cache_collection.create_index(
    field_name="prompt_embedding",
    index_params=index_params
)

#Flush the collection to persist
cache_collection.flush()
#Load the collection in memory
cache_collection.load()

Schema :  {'auto_id': True, 'description': 'Cache for LLM', 'fields': [{'name': 'cache_id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}, {'name': 'prompt_text', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 2048}}, {'name': 'response_text', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 2048}}, {'name': 'prompt_embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 1536}}], 'enable_dynamic_field': True} 



## 04.04. Inference Process with caching

In [11]:
from transformers import AutoTokenizer
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
import os
import time

 

#If you use the free tier, you may hit rate limits with the number of requests

OPENAI_API_KEY="sk-proj-Djbj3MLddNAqs97W1CbiT3BlbkFJzR1ZdzwqNUYvE97Qaiiw"
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

 
embeddings_model = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY, model="text-embedding-ada-002")


#Create an LLM object
#llm= OpenAI(temperature=0., model="text-davinci-003")
llm= OpenAI(temperature=0., model="gpt-3.5-turbo")
#llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model="gpt-3.5-turbo")

#Setup embedding model for creating embeddings
# embeddings_model = OpenAIEmbeddings()
embeddings_model = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY, model="text-embedding-ada-002")
#setup threshold for similarity between vectors
similarity_threshold=0.3

search_params = {
    "metric_type": "L2", 
    "offset": 0, 
    "ignore_growing": False, 
    "params": {"nprobe": 20, "radius":similarity_threshold}
}

#create a function to run the inference loop
def get_response(prompt):
    
    start_time=time.time()
    #create embedding for incoming prompt
    prompt_embed=embeddings_model.embed_query(prompt)
    
    #Check cache if result exists
    cache_results=cache_collection.search(
        data=[prompt_embed],
        anns_field="prompt_embedding",
        param=search_params,
        limit=1, #Look for the top result only
        expr=None,
        output_fields=["prompt_text", "response_text"],
        consistency_level="Strong"
    )
        
    returned_response ="None"
    
    if ( len(cache_results[0]) > 0 ):
        
        #Cache hit
        print(prompt, " :\n Cache hit : ",cache_results[0])
        returned_response = cache_results[0][0].entity.get("response_text")
    
    else:
        #Find answer with LLM
        llm_response=llm(prompt)
        print(prompt, ":\n LLM returned :", llm_response)
        returned_response = llm_response
        
        #save prompt/response to cache
        prompt_text = [prompt]
        prompt_embedding=[prompt_embed]
        response_text = [llm_response]

        insert_data=[prompt_text, response_text, prompt_embedding]
        mr=cache_collection.insert(insert_data)
    
    end_time = time.time()
    print("Time elapsed :",  end_time - start_time, "\n")
    return returned_response
    

In [12]:
#Build up the cache
response=get_response("In which year was Abraham Lincoln born?")
response=get_response("What is distance between the sun and the moon?")
response=get_response("How many years have Lebron James played in the NBA?")
response=get_response("What are the advantages of the python language?")
response=get_response("What is the typical height of an elephant")


NotFoundError: Error code: 404 - {'error': {'message': 'This is a chat model and not supported in the v1/completions endpoint. Did you mean to use v1/chat/completions?', 'type': 'invalid_request_error', 'param': 'model', 'code': None}}

In [13]:
response=get_response("List some advantages of the python language")
response=get_response("How tall is an elephant?")

NotFoundError: Error code: 404 - {'error': {'message': 'This is a chat model and not supported in the v1/completions endpoint. Did you mean to use v1/chat/completions?', 'type': 'invalid_request_error', 'param': 'model', 'code': None}}