In [2]:
import ollama
import chromadb
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import glob
import sqlparse
import re
import json

client = chromadb.Client()
collection = client.create_collection(name="docs")

In [3]:
client.delete_collection(name="docs")
collection = client.create_collection(name="docs")

columns = "DSPD" # ADM Check [Handling and Storage.Op.8177]

main(columns)

NameError: name 'main' is not defined

#### `main`

In [None]:
def main(columns):

    EMBEDDING_MODEL = "mxbai-embed-large:latest"
    LLM_MODEL = "gemma2:2b"

    documents = read_in_excel_files()

    processed_documents = preprocess_queries(documents)

    PROMPT = f"Is there a Report containing Columns '{', '.join(columns.split())}'? If so, provide Report Name"

    document_embeddings = document_embedding(processed_documents, EMBEDDING_MODEL)

    prompt_embeddings = ollama.embeddings(model=EMBEDDING_MODEL, prompt=columns)

    retrieval_results = retrieve(prompt_embeddings)

    similarity = calculate_similarity(document_embeddings, prompt_embeddings, retrieval_results)

    generate_answer(LLM_MODEL, retrieval_results, PROMPT, similarity)

##### `read_in_excel_files`

In [None]:
def read_in_excel_files():

    PATH = "../Query Processing/1_migrated_excel_queries/Queries_3.xlsx"

    excel_data = glob.glob(PATH)

    dataframes = [pd.read_excel(data, engine="openpyxl") for data in excel_data]
    query_df = pd.concat(dataframes, ignore_index=True)

    return query_df

##### `preprocess_queries`

In [50]:
def preprocess_queries(query_df):

    query_df = query_df.dropna(subset=["SQL"]).fillna("")

    SQL_EXPRESSIONS = [
        "SELECT",
        "DISTINCT",
        "FROM",
        "TRIM",
        "BOTH",
        "SUBSTR",
        "COUNT",
        "IN",
        "AS",
        "ORDER BY",
        "GROUP BY",
        "AND",
        "WHERE",
        "CONCAT",
        "CASE",
        "SUM",
        "WHEN",
        "TO_DATE",
        "TO_CHAR",
        "CAST",
        "BETWEEN",
        "TRUNC",
        "OVER",
        "PARTITION BY",
        "TIMESTAMP",
        "THEN",
        "ELSE",
        "NULL",
        "END",
        "INNER JOIN",
        "OUTER JOIN",
        "SQL",
        "CURRENT_DATE",
        "YYYY",
        "MMMM",
        "SSSS",
        "DAYS",
        "CURRENT_YEAR",
        "FIRST_VALUE",
        "AT",
    ]

    query_df["SQL"] = (query_df["SQL"]
                       .str.replace('ê', 'e').str.replace('é', 'e').str.replace('è', 'e').str.replace('à', 'a').str.replace('ç', 'c')
                       .str.replace('ô', 'o').str.replace('û', 'u').str.replace('ù', 'u').str.replace('î', 'i').str.replace('ï', 'i')
                       .str.replace('â', 'a').str.replace('ä', 'a').str.replace('ö', 'o').str.replace('ü', 'u').str.replace('ÿ', 'y')
                       .str.replace('ñ', 'n').str.replace('É', 'E').str.replace('È', 'E').str.replace('À', 'A').str.replace('Ç', 'C')
                       .str.replace('Ô', 'O').str.replace('Û', 'U').str.replace('Ù', 'U').str.replace('Î', 'I').str.replace('Ï', 'I')
                       .str.replace('Â', 'A').str.replace('Ä', 'A').str.replace('Ö', 'O').str.replace('Ü', 'U').str.replace('Ÿ', 'Y')
                       .str.replace('Ñ', 'N')
    )

    with_pattern = r'WITH\s+"\w+"\s+AS\s*\(.*?\)\s*SELECT'
    pattern_2 = r'"[^"]*"\.'
    pattern_3 = r'\w+\.'

    for index, query in enumerate(query_df["SQL"]):
        formatted_query = query.upper()
        formatted_query = sqlparse.format(formatted_query, reindent=True, keyword_case='upper', strip_comments=True).strip()
        formatted_query = formatted_query.replace("'", '"').replace("\n", " ")
        formatted_query = re.sub(with_pattern, 'SELECT', formatted_query, flags=re.DOTALL)
        formatted_query = re.sub(pattern_2, '', formatted_query)
        formatted_query = re.sub(pattern_3, '', formatted_query)
        formatted_query = " ".join(formatted_query.split())
        formatted_query = re.sub(r'"(\w+)" AS "\1"', r'"\1"', formatted_query)
        formatted_query = re.sub(r'("\w+")\s+"\w+"', r'\1', formatted_query)
        formatted_query = re.sub(r'[",\'\-"\)\(=<>\:/\*\d]', " ", formatted_query)
        formatted_query = re.sub(r'\b[a-zA-Z]\b', " ", formatted_query)

        for expr in SQL_EXPRESSIONS:
            formatted_query = formatted_query.replace(expr, " ")


        query_df.at[index, 'SQL'] = formatted_query

    query_df["SQL"] = query_df['SQL'].apply(lambda x: f"Columns: {x}")
    query_df["Report Name"] = query_df['Report Name'].apply(lambda x: f"Report Name: {x}")
    query_df = query_df[["Report Name", "SQL"]]

    documents = []
    for index, row in query_df.iterrows():
        documents.append(f"{' '.join(row['Report Name'].split())}, {' '.join(row['SQL'].split())}")

    return documents

##### `document_embedding`

In [8]:
def document_embedding(documents, EMBEDDING_MODEL):
    document_embeddings = []

    for i, d in enumerate(documents):
        response = ollama.embed(model=EMBEDDING_MODEL, input=d)
        embeddings = response["embeddings"]
        document_embeddings.append(embeddings)
        collection.add(ids=[str(i)], embeddings=embeddings, documents=[d])

    return document_embeddings,

##### `retrieve`

In [9]:
def retrieve(embedded_prompt):
    retrieval_results = collection.query(query_embeddings=embedded_prompt["embedding"], n_results=1)
    return retrieval_results

##### `calculate_similarity`

In [10]:
def calculate_similarity(document_embeddings, embedded_prompt, retrieval_results):
    document_vector = np.array(document_embeddings[0][int(retrieval_results["ids"][0][0])]).reshape(1, -1)
    prompt_vector = np.array(embedded_prompt["embedding"]).reshape(1, -1)
    similarity = cosine_similarity(document_vector, prompt_vector).flatten()[0]

    return similarity

##### `generate_answer`

In [11]:
def generate_answer(LLM_MODEL, retrieval_results, PROMPT, similarity):
    output = ollama.generate(
        model=LLM_MODEL,
        prompt=f"Using this data: {retrieval_results["documents"][0][0]}. Respond to this prompt: {PROMPT}",
    )

    print(output["response"])
    print(f"Similarity: {round(similarity * 100, 2)}%")