In [30]:
from uuid import UUID
from pathlib import Path
import tiktoken
import os
import logging

from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore
from elasticsearch import Elasticsearch

from redbox.models import Settings
from redbox.models.settings import ElasticLocalSettings
from redbox.storage import ElasticsearchStorageHandler

from core_api.src.callbacks import LoggerCallbackHandler

from dotenv import find_dotenv, load_dotenv

ROOT = Path().resolve().parent

_ = load_dotenv(find_dotenv(ROOT / '.env'))

logging.basicConfig(level=logging.INFO)
log = logging.getLogger()

env = Settings(
    _env_file=(ROOT / '.env'),
    minio_host="localhost", 
    object_store="minio",
    elastic=ElasticLocalSettings(host="localhost"),
)

embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder="../models/")

es = Elasticsearch(
    hosts=[
        {
            "host": "localhost",
            "port": env.elastic.port,
            "scheme": env.elastic.scheme,
        }
    ],
    basic_auth=(env.elastic.user, env.elastic.password),
)

if env.elastic.subscription_level == "basic":
    strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
    strategy = ApproxRetrievalStrategy(hybrid=True)

vector_store = ElasticsearchStore(
    es_connection=es,
    index_name="redbox-data-chunk",
    embedding=embedding_model,
    strategy=strategy,
    vector_query_field="embedding",
)

# See core_api.src.dependecies for details on this hack
os.environ["AZURE_API_VERSION"] = env.openai_api_version

logger_callback = LoggerCallbackHandler(logger=log)

llm = ChatLiteLLM(
    model=env.azure_openai_model,
    streaming=True,
    azure_key=env.azure_openai_api_key,
    api_base=env.azure_openai_endpoint,
    max_tokens=1_024,
    callbacks=[logger_callback]
)

storage_handler = ElasticsearchStorageHandler(es_client=es, root_index=env.elastic_root_index)

tokeniser = tiktoken.get_encoding("cl100k_base")

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps


In [21]:
log.info("hi")

INFO:root:hi


In [32]:
log

<RootLogger root (INFO)>

In [31]:
llm.invoke("hi")

INFO:httpx:HTTP Request: POST https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/chat/completions?api-version=2024-02-01 "HTTP/1.1 200 OK"


AIMessage(content='Hello! How can I assist you today?', id='run-af1e6245-01ef-4ccc-9b5d-f9f1c5191ace-0')

In [29]:
_ = llm.invoke("hi", config={"callbacks": [
    logger_callback
]})

INFO:httpx:HTTP Request: POST https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/chat/completions?api-version=2024-02-01 "HTTP/1.1 200 OK"


# RAG scratch

In [3]:
from core_api.src.retriever import ParameterisedElasticsearchRetriever, get_all_chunks_query
from langchain_core.runnables import ConfigurableField

def get_parameterised_retriever(
    env, 
    es
):
    """Creates an Elasticsearch retriever runnable.

    Runnable takes input of a dict keyed to question, file_uuids and user_uuid.

    Runnable returns a list of Chunks.
    """
    default_params = {
        "size": env.ai.rag_k,
        "num_candidates": env.ai.rag_num_candidates,
        "match_boost": 1,
        "knn_boost": 1,
        "similarity_threshold": 0,
    }
    return ParameterisedElasticsearchRetriever(
        es_client=es,
        index_name=f"{env.elastic_root_index}-chunk",
        params=default_params,
        embedding_model=embedding_model,
        content_field="text",
    ).configurable_fields(
        params=ConfigurableField(
            id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever."
        )
    )

retriever = get_parameterised_retriever(env, es)

In [46]:
retriever.invoke(
    input={
        "question": "KAN",
        "file_uuids": [
            "50a8a0a0-63de-435e-a51a-31cdbae24de2", # KAN paper
            "db17bb46-8b44-489f-8336-d69455210619" # MAMBA paper
        ],
        "user_uuid": "5c37bf4c-002c-458d-9e68-03042f76a5b1"
    }
)

INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search [status:200 duration:0.009s]


[Document(page_content='the “Human-constructed KAN shape” in Table 2.\n\n(2) KANs without pruning. We fix the KAN shape to width 5 and depths are swept over {2,3,4,5,6}.\n\n(3) KAN with pruning. We use the sparsification (λ = 10−2 or 10−3) and the pruning technique\n\nfrom Section 2.5.1 to obtain a smaller KAN from a fixed-shape KAN from (2).\n\n(4) MLPs with fixed width 5, depths swept\n\nin {2, 3, 4, 5, 6}, and activations chosen from\n\n{Tanh, ReLU, SiLU}.', metadata={'_index': 'redbox-data-chunk', '_id': '8ab01636-63fa-422a-8b3c-14d064887162', '_score': 9.445025, '_ignored': ['text.keyword'], '_source': {'uuid': '8ab01636-63fa-422a-8b3c-14d064887162', 'created_datetime': '2024-06-26T07:21:20.962466', 'creator_user_uuid': '5c37bf4c-002c-458d-9e68-03042f76a5b1', 'parent_file_uuid': '50a8a0a0-63de-435e-a51a-31cdbae24de2', 'index': 148, 'metadata': {'parent_doc_uuid': '50a8a0a0-63de-435e-a51a-31cdbae24de2', 'languages': ['eng'], 'link_texts': None, 'link_urls': None, 'links': None, 'pa

In [53]:
from langchain_core.runnables import (
    Runnable,
    RunnableLambda,
    RunnablePassthrough,
    chain,
)
from langchain.schema import StrOutputParser
from operator import itemgetter
from redbox.models import ChatRoute
from redbox.models.chain import ChainInput

from core_api.src.format import format_documents
from core_api.src.runnables import make_chat_prompt_from_messages_runnable


def build_retrieval_chain(
    llm,
    retriever,
    tokeniser,
    env,
) -> Runnable:
    return (
        RunnablePassthrough.assign(documents=retriever)
        | RunnablePassthrough.assign(
            formatted_documents=(RunnablePassthrough() | itemgetter("documents") | format_documents)
        )
        | {
            "response": make_chat_prompt_from_messages_runnable(
                system_prompt=env.ai.retrieval_system_prompt,
                question_prompt=env.ai.retrieval_question_prompt,
                input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
                tokeniser=tokeniser,
            )
            | llm
            | StrOutputParser(),
            "source_documents": itemgetter("documents"),
            "route_name": RunnableLambda(lambda _: ChatRoute.search.value),
        }
    )

rag = build_retrieval_chain(llm, retriever, tokeniser, env)

params = ChainInput(
    question="Give the full citation.",
    file_uuids=[
        "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
        "1a9d18a7-9499-47b6-abcc-4e82370028ee" # MAMBA paper
    ],
    user_uuid="5c37bf4c-002c-458d-9e68-03042f76a5b1",
    chat_history=[
        {"text": "What is the fastest attention that the authors are aware of?", "role": "user"},
        {"text": "The fastest implementation of attention, according to the authors, is **FlashAttention-2 (Dao 2024)** with a causal mask. It's stated that this version of FlashAttention-2 is approximately **1.7× faster** than the version without a causal mask because roughly half of the attention entries are computed.", "role": "ai"},
    ],
)

rag.invoke(params.model_dump())

INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search [status:200 duration:0.022s]
INFO:httpx:HTTP Request: POST https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/chat/completions?api-version=2024-02-01 "HTTP/1.1 200 OK"


{'response': "I'm sorry, but the document excerpts provided do not contain a full citation for **FlashAttention-2 (Dao 2024)**.",
 'source_documents': [Document(page_content='provide better estimates of the full softmax kernel (rather than just the exp-transformed numerator).', metadata={'_index': 'redbox-data-chunk', '_id': 'adaca12e-c530-4a7d-9281-972c0ad784ef', '_score': 7.465159, '_source': {'uuid': 'adaca12e-c530-4a7d-9281-972c0ad784ef', 'created_datetime': '2024-06-26T07:21:32.070537', 'creator_user_uuid': '5c37bf4c-002c-458d-9e68-03042f76a5b1', 'parent_file_uuid': 'db17bb46-8b44-489f-8336-d69455210619', 'index': 302, 'metadata': {'parent_doc_uuid': 'db17bb46-8b44-489f-8336-d69455210619', 'languages': ['eng'], 'link_texts': None, 'link_urls': None, 'links': None, 'page_number': 27}, 'embedding': [-0.048336658626794815, -0.07596364617347717, 0.006124784238636494, 0.0008513265638612211, 0.100751131772995, -0.005539730191230774, -0.0620366595685482, 0.04322299361228943, 0.0507054962

In [51]:
from langchain_core.runnables import (
    Runnable,
    RunnableLambda,
    RunnablePassthrough,
    chain,
)
from langchain.schema import StrOutputParser
from operator import itemgetter
from redbox.models import ChatRoute
from redbox.models.chain import ChainInput

from core_api.src.format import format_documents
from core_api.src.runnables import make_chat_prompt_from_messages_runnable

CONDENSE_SYSTEM_PROMPT = (
    "Given the following conversation and a follow up question, generate a follow "
    "up question to be a standalone question. "
    "You are only allowed to generate one question in response. "
    "Include sources from the chat history in the standalone question created, "
    "when they are available. "
    "If you don't know the answer, just say that you don't know, "
    "don't try to make up an answer. \n"
)

CONDENSE_QUESTION_PROMPT= "{question}\n=========\n Standalone question: "


def build_condense_retrieval_chain(
    llm,
    retriever,
    tokeniser,
    env,
) -> Runnable:
    
    def route(input_dict: dict):
        if len(input_dict["chat_history"]) > 0:
            return RunnablePassthrough.assign(
                question=make_chat_prompt_from_messages_runnable(
                    system_prompt=env.ai.condense_system_prompt,
                    question_prompt=env.ai.condense_question_prompt,
                    input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
                    tokeniser=tokeniser,
                )
                | llm
                | StrOutputParser()
            )
        else:
            return RunnablePassthrough()

    return (
        RunnableLambda(route)
        | RunnablePassthrough.assign(documents=retriever)
        | RunnablePassthrough.assign(
            formatted_documents=(RunnablePassthrough() | itemgetter("documents") | format_documents)
        )
        | {
            "response": make_chat_prompt_from_messages_runnable(
                system_prompt=env.ai.retrieval_system_prompt,
                question_prompt=env.ai.retrieval_question_prompt,
                input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
                tokeniser=tokeniser,
            )
            | llm
            | StrOutputParser(),
            "source_documents": itemgetter("documents"),
            "route_name": RunnableLambda(lambda _: ChatRoute.search.value),
        }
    )

# crag = make_chat_prompt_from_messages_runnable(
#     system_prompt=CONDENSE_SYSTEM_PROMPT,
#     question_prompt=CONDENSE_QUESTION_PROMPT,
#     input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
#     tokeniser=tokeniser,
# ) | llm

crag = build_condense_retrieval_chain(llm, retriever, tokeniser, env)

params = ChainInput(
    # question="Give the full citation.",
    question="What is the fastest attention that the authors are aware of?",
    file_uuids=[
        "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
        "1a9d18a7-9499-47b6-abcc-4e82370028ee" # MAMBA paper
    ],
    user_uuid="5c37bf4c-002c-458d-9e68-03042f76a5b1",
    chat_history=[
        # {"text": "What is the fastest attention that the authors are aware of?", "role": "user"},
        # {"text": "The fastest implementation of attention, according to the authors, is **FlashAttention-2 (Dao 2024)** with a causal mask. It's stated that this version of FlashAttention-2 is approximately **1.7× faster** than the version without a causal mask because roughly half of the attention entries are computed.", "role": "ai"},
    ],
)

crag.invoke(params.model_dump())

INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search [status:200 duration:0.029s]
INFO:httpx:HTTP Request: POST https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/chat/completions?api-version=2024-02-01 "HTTP/1.1 200 OK"


{'response': 'The fastest attention implementation the authors are aware of is **FlashAttention-2 (Dao 2024)** with causal mask. This implementation is about 1.7× faster than without a causal mask because approximately only half of the attention entries are computed.',
 'source_documents': [Document(page_content='For attention, we compare against the fastest implementation that we are aware of (FlashAttention-2 (Dao 2024)), with causal mask. Note that FlashAttention-2 with causal mask is about 1.7× faster than without causal mask, since approximately only half of the attention entries are computed.', metadata={'_index': 'redbox-data-chunk', '_id': 'a0b15ad6-d1e1-4ca9-82b6-aab56b080732', '_score': 22.066563, '_ignored': ['text.keyword'], '_source': {'metadata': {'parent_file_uuid': '1a9d18a7-9499-47b6-abcc-4e82370028ee', 'creator_user_uuid': '5c37bf4c-002c-458d-9e68-03042f76a5b1', 'index': 647, 'page_number': 36, 'languages': ['eng'], 'link_texts': None, 'link_urls': None, 'links': None