In [None]:
import os
import pickle
import tiktoken
from code_library import queries, functions
from utils.embeddings_utils import get_embedding

# embedding model parameters
embedding_model = "text-embedding-ada-002" # this is the latest model (V2)
embedding_encoding = "cl100k_base"  # this the tokenizer for text-embedding-ada-002
max_tokens = 8000  # the maximum for text-embedding-ada-002 is 8191

# get encoding
encoding = tiktoken.get_encoding(embedding_encoding)

def generate_embeddings(embeddings_file, queries):
    # generate embeddings for new queries
    embeddings = []
    for query in queries:
        embeddings.append(get_embedding(query, model=embedding_model))

    # save the updated embeddings to the file
    with open(embeddings_file, 'wb') as f:
        pickle.dump(embeddings, f)

    # verify embeddings loaded
    assert(len(embeddings) == len(queries))
    return embeddings

In [None]:
import qdrant_client
from qdrant_client.http import models

client = qdrant_client.QdrantClient(
    host="localhost",
    prefer_grpc=True,
)

In [None]:
def generate_qdrant_collection(collection_name, embeddings, queries, functions):
    client.delete_collection(collection_name=collection_name)
    client.recreate_collection(
        collection_name=collection_name,
        vectors_config={
            'query': models.VectorParams(size=1536, distance=models.Distance.COSINE)
        }
    )

    payload = []
    for i in range(len(functions)):
        payload.append(
            {
                "function": functions[i]
            }
        )

    client.upsert(
        collection_name = collection_name,
        points = models.Batch(
            ids=range(len(queries)),
            vectors={
                "query": embeddings
            },
            payloads=payload
        )
    )

    try:
        collections = client.get_collections()
        print("Qdrant connection established.")
    except Exception as e:
        print("Qdrant not connected.")

    # check collection size
    print(f"Collection size for {collection_name}: {client.count(collection_name=collection_name)}")

In [None]:
# generate collections to qdrant

embeddings_tradesweep = generate_embeddings("./embeddings_tradesweep.pkl", queries)
generate_qdrant_collection("TradeSweep", embeddings_tradesweep, queries, functions)

In [None]:
# search data
import openai
openai.api_key = "<YOUR_API_KEY_HERE>"

def query_qdrant(query, collection_name='prompt-embeddings', vector_name='query', top_k=3):
    # creates embedding vector from user query
    embedded_query = openai.embeddings.create(input=query, model=embedding_model).data[0].embedding

    # get results
    query_results = client.search(
        collection_name=collection_name,
        query_vector=(
            vector_name, embedded_query
        ),
        limit=top_k,
    )

    return query_results

In [None]:
# perform search
query_results = query_qdrant(query="Clean dates to have yyyy-mm-dd format.", top_k=3)

# display results
for i, entry in enumerate(query_results):
    print(f"Function {i}:\n\t{entry.payload['function']}")