# Documenation for backend/inference.py

In [None]:
import os
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_mistralai import ChatMistralAI
from .document_loading import (
	load_documents_from_directory, 
	load_or_create_faiss_vector_store,
	get_hybrid_retriever
)
from .prompts import prompt
from .citations import get_answer_with_source

# Import and load environment variables
from dotenv import load_dotenv
load_dotenv(override=True)

_Explanation_:

__Imports__
1. ```os```: Standard library module for operating system interactions.

2. ```create_retrieval_chain```: Function to create a retrieval chain using Langchain.

3. ```create_stuff_documents_chain```: Function to combine multiple documents into a single output.

4. ```ChatMistralAI```: Class for interacting with the Mistral AI model.

5. Document Loading Functions:

- ```load_documents_from_directory```: Loads PDF documents from a directory and splits them into chunks.
- ```load_or_create_faiss_vector_store```: Loads or creates a FAISS vector store.
- ```get_hybrid_retriever```: Creates a hybrid retriever combining BM25 and vector search.
  
6. ```prompt```: Presumably contains predefined prompts for interacting with the AI model.

7. ```get_answer_with_source```: Returns answers with their source references.

8. ```load_dotenv```: Loads environment variables from a .env file.

In [None]:
# Load documents and the embeddings from the FAISS vector store
if os.getenv("CORPUS_SOURCE") == "":
    document_path = "data/default/textbook"
    persist_directory = "data/default/faiss_indexes"
document_path = "data/default/textbook"
persist_directory = "data/default/faiss_indexes"

top_k = 15 # number of relevant documents to be returned
documents = load_documents_from_directory(document_path)
faiss_store = load_or_create_faiss_vector_store(documents, "pdf_collection", persist_directory)
retriever = get_hybrid_retriever(documents = documents, vector_store = faiss_store, k = top_k)

_Explanation:_

1. Check Environment Variable:

- Load documents and the embeddings from the FAISS vector store_
```
if os.getenv("CORPUS_SOURCE") == "":
    document_path = "data/default/textbook"
    persist_directory = "data/default/faiss_indexes"
```

- The code checks if the environment variable ```CORPUS_SOURCE``` is empty. 
- If it is, it sets the ```document_path``` and ```persist_directory``` to default values:
```document_path```: ```"data/default/textbook"``` - The path where the text documents are located.
```persist_directory```: ```"data/default/faiss_indexes"``` - The directory where the FAISS vector store will be saved or loaded from.


```
document_path = "data/default/textbook"
persist_directory = "data/default/faiss_indexes"
```


2. Set Default Paths:
- Regardless of the environment variable, the code reassigns document_path and persist_directory to default values, which may seem redundant if the previous condition is met. This might be intentional for clarity or ensuring consistent behavior.


```top_k = 15 # number of relevant documents to be returned```

3. Define ```top_k``` :
- This variable specifies the number of relevant documents to be returned by the retriever. In this case, it is set to 15.

```documents = load_documents_from_directory(document_path)```


4. Load Documents:
- The ```load_documents_from_directory()``` function is called with the specified ```document_path```. This function loads and splits the documents from the directory into manageable chunks, returning a list of document chunks.

```faiss_store = load_or_create_faiss_vector_store(documents, "pdf_collection", persist_directory)```

5. Load or Create FAISS Vector Store:
- The ```load_or_create_faiss_vector_store()``` function is called with the loaded documents, a collection name ```("pdf_collection")```, and the ```persist_directory```.
- This function either loads an existing FAISS vector store or creates a new one, returning the vector store object.

```retriever = get_hybrid_retriever(documents=documents, vector_store=faiss_store, k=top_k)```

6. Create Hybrid Retriever:
- The ```get_hybrid_retriever()``` function is called with the loaded documents, the FAISS vector store, and the ```top_k``` value. This function returns an ```EnsembleRetriever``` that combines the BM25 and vector retrieval methods, enabling efficient document searching.

_Process Overview_:
- This code snippet sets up the environment for document retrieval by loading documents from a directory, creating or loading a FAISS vector store for embeddings, and configuring a hybrid retriever to return a specified number of relevant documents. The use of environment variables allows for flexibility in specifying the document source.

In [None]:
# Get Mistral API Key from the environment variables
api_key = os.getenv("MISTRAL_API_KEY")
if not api_key:
	raise ValueError("MISTRAL_API_KEY not found in .env")

def load_llm_api(model_name):
	"""
	Load and configure the Mistral AI LLM.
	Returns:
		ChatMistralAI: Configured LLM instance.
	"""
	return ChatMistralAI(
		model=model_name,
		mistral_api_key=api_key,
		temperature=0.2,
		max_tokens=256,
		top_p=0.4,
	)
MODEL_NAME = "open-mistral-7b"
llm = load_llm_api(MODEL_NAME)

_Explanation_:

_Get Mistral API Key from the environment variables_

```
api_key = os.getenv("MISTRAL_API_KEY")
if not api_key:
    raise ValueError("MISTRAL_API_KEY not found in .env")
```

1. Retrieve API Key:
The code attempts to retrieve the Mistral API key using os.getenv("MISTRAL_API_KEY"). If the key is not found (i.e., it is None), a ValueError is raised with a message indicating that the API key is not present in the .env file.

```
def load_llm_api(model_name):
    """
    Load and configure the Mistral AI LLM.
    Returns:
        ChatMistralAI: Configured LLM instance.
    """
```
2. Define Function ```load_llm_api``` :
This function takes a single parameter, ```model_name```, which specifies the name of the Mistral model to be loaded. It is responsible for creating and returning a configured instance of the ```ChatMistralAI``` class.

```
return ChatMistralAI(
        model=model_name,
        mistral_api_key=api_key,
        temperature=0.2,
        max_tokens=256,
        top_p=0.4,
    )
```
3. Configure LLM:
Inside the function, an instance of ```ChatMistralAI``` is created with the following parameters:
- ```model```: The name of the model passed as an argument.
- ```mistral_api_key```: The API key retrieved earlier.
- ```temperature```: A parameter controlling the randomness of the output (set to 0.2 for relatively deterministic responses).
- ```max_tokens```: The maximum number of tokens to generate in the response (set to 256).
- ```top_p```: A parameter for nucleus sampling (set to 0.4), determining the diversity of the output.

```
MODEL_NAME = "open-mistral-7b"
llm = load_llm_api(MODEL_NAME)
```
4. Load LLM Instance:
The model name ```("open-mistral-7b")``` is assigned to the variable ```MODEL_NAME```. The ```load_llm_api()``` function is then called with ```MODEL_NAME``` to create and configure the LLM instance, which is stored in the variable ```llm```.

_Process Overview:_

- This code snippet sets up the configuration for using the Mistral AI language model by retrieving the necessary API key from the environment variables, defining a function to load and configure the model, and creating an instance of the LLM for subsequent use. The parameters provided for the model ensure controlled and relevant output generation.

- This documentation provides a detailed overview of the code logic, the purpose of each section, and the configuration process for the Mistral AI language model.

In [None]:
def chat_completion(question):
  """
  Generate a response to a given question using the RAG (Retrieval-Augmented Generation) chain,
  streaming parts of the response as they are generated.

  Args:
    question (str): The user question to be answered.

  Yields:
    str: The generated response in chunks.
  """
  print(f"Running prompt: {question}")
  question_answer_chain = create_stuff_documents_chain(llm, prompt)
  rag_chain = create_retrieval_chain(retriever, question_answer_chain)

  # Stream response from LLM
  full_response = {"answer": "", "context": []}
  for chunk in rag_chain.stream({"input": question}):
    if "answer" in chunk:
      full_response["answer"] += chunk["answer"]
      yield (chunk["answer"], MODEL_NAME)
    if "context" in chunk:
      full_response["context"].extend(chunk["context"])

  # After streaming is complete, use the full response to extract citations
  final_answer = get_answer_with_source(full_response)
  # Yield any remaining part of the answer with citations
  remaining_answer = final_answer[len(full_response["answer"]):]
  if remaining_answer:
    yield (remaining_answer, MODEL_NAME)

_Explanation:_

```chat_completion```
- This function generates a response to a user-provided question using a retrieval-augmented generation (RAG) chain. It streams parts of the response as they are generated, allowing for a more interactive experience.

Parameters
```question (str)```:

- The user question that needs to be answered.

  
Yields:
- ```str```: The generated response in chunks, allowing the user to receive information incrementally.

Function Logic
1. Print Question:
- The function starts by printing the question being processed to the console for logging purposes:

```print(f"Running prompt: {question}")```

2. Create Chains:

It creates two chains using Langchain:
- ```question_answer_chain```: This chain is created using ````create_stuff_documents_chain(llm, prompt)```, which combines the LLM and a predefined prompt to answer questions based on document content.
- ```rag_chain```: This is created by combining the retriever and the ```question-answer chain``` using ```create_retrieval_chain(retriever, question_answer_chain)```, allowing the function to use both retrieval and generation methods.

3. Stream Response:

- An initial dictionary, full_response, is created to store the answer and context:

```full_response = {"answer": "", "context": []}```


- The function then enters a loop that streams the response from the RAG chain:

```for chunk in rag_chain.stream({"input": question}):```

- During this streaming, it checks for ```answer``` and ```context``` keys in each ```chunk```:
   - If an ```answer``` is found, it appends it to the ```full_response["answer"]``` and yields the chunk of the answer along with the model name:

```yield (chunk["answer"], MODEL_NAME)```

- If context is found, it extends the full_response["context"] with the new context data.

4. Extract Citations:

- After the streaming is complete, the function uses ```get_answer_with_source(full_response)``` to extract citations and construct the final answer. This ensures that the response includes source references for the provided information.

5. Yield Remaining Answer:

- The function calculates any remaining part of the answer that was not yielded during the streaming process:

```remaining_answer = final_answer[len(full_response["answer"]):]```

- If there is any remaining answer, it yields this part along with the model name.

Process Overview: 
- The ```chat_completion``` function implements a retrieval-augmented generation mechanism to provide responses to user questions in an interactive manner. It streams answers in chunks and enriches the response with context and citations, enhancing the reliability and traceability of the information generated.

- This documentation provides a detailed overview of the function, its parameters, and its logic, ensuring clarity in understanding how the ```chat_completion``` function operates.