## Imports

In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM
    )
import torch
import numpy as np
import os
from tqdm import tqdm
import time
import json
from typing import List, Dict, Any, Literal,Annotated
# import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## Common Utils

In [2]:
def get_stored_result(path:str, type_:Literal["json","csv"]):
    # Check if the path is valid
    if not os.path.exists(path):
        print("The path is not valid or the file does not exist")
        return None 
    if type_ == "json":
        with open(path, 'r') as f:
            data = json.load(f)
        return data
    elif type_ == "csv":
        with open(path, 'r') as f:
            data = json.load(f)
        return data
    else:
        print("Invalid type")
        return None 
    
def save_result(result, path:str, type_:Literal["json","csv"], indent:int=None):
    if type_ == "json":
        with open(path, 'w') as f:
            if indent is not None:
                json.dump(result, f, indent=indent)
            else:
                json.dump(result, f)
    elif type_ == "csv":
        #  check if the result is a dataframe
        if isinstance(result, pd.DataFrame):
            result.to_csv(path, index=False)
        else:
            print("The result is not a dataframe")
            raise ValueError("The result is not a dataframe")
    else:
        print("Invalid type")
        raise ValueError("Invalid type")


## Splade Model

In [3]:
class SparseModel:
  def __init__(self,max_length=512):
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil").to(self.device)
    self.max_length = max_length
    self.sparse_tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil",truncation=True, padding='max_length', max_length=self.max_length)
    

  def decode_sparse_dict(self, sparse_dict,trim=None):
    a = np.zeros((30522))
    a[sparse_dict['indices']] = sparse_dict['values']
    if trim is not None:
      a[a.argsort()[:-trim]] = 0
    return a

  def decode_sparse_dicts(self, sparse_dicts,trim=None):
    res = []
    for _ in sparse_dicts:
      res.append(self.decode_sparse_dict(_,trim).tolist())
    return res

  def formalize(self, sparse_dict):
    idx2token = {idx: token for token, idx in self.sparse_tokenizer.get_vocab().items()}
    sparse_dict_tokens = {
        idx2token[idx]: weight for idx, weight in zip(sparse_dict['indices'], sparse_dict['values'])
    }
    sparse_dict_tokens = {
        k: v for k, v in sorted(
            sparse_dict_tokens.items(),
            key=lambda item: item[1],
            reverse=True
        )
    }
    return sparse_dict_tokens

  # This function is used to encode a list of texts into sparse vectors. Process all texts at once
  # Use this function if you have a GPU. Padding is used to make all texts have the same length
  # Faster in GPU , Slower in CPU
  def encode_texts(self, texts:List[str]):
    input_ids = self.sparse_tokenizer(texts, return_tensors='pt', padding='max_length' if len(texts)>1 else False).to(self.device)
    input_ids = {k: v[:self.max_length] for k, v in input_ids.items()}

    with torch.no_grad():
      logits = self.model(**input_ids).logits

    sparse_vecs = torch.max(
        torch.log(
            1+torch.relu(logits)
        )*input_ids['attention_mask'].unsqueeze(-1),
    dim=1)[0].cpu()

    sparse_dicts = []
    for sparse_vec in sparse_vecs:
      indices = sparse_vec.nonzero().squeeze().tolist()
      values = sparse_vec[indices].tolist()
      sparse_dict = {'indices': indices, 'values': values}
      sparse_dicts.append(sparse_dict)

    return sparse_dicts
  
  # This function is used to encode a single text into a sparse vector
  def encode_text(self, text:str):
    return self.encode_texts([text])[0]

  # This function is used to encode a list of texts into sparse vectors (It iterates over one text at a time)
  # Faster in CPU , Slower in GPU. Maintain batch_size=1
  def encode_text_list(self, texts:list, batch_size:int=1):

    sparse_dicts = []

    for i in tqdm(range(0, len(texts), batch_size)):
      batch = texts[i:i+batch_size]
      sparse_dicts += self.encode_texts(batch)

    return sparse_dicts

## Spalde Model Object

In [4]:
splade_model = SparseModel()



## Data Read From JSON

In [5]:
sample_json_data = get_stored_result("./sample-data/data1.json", "json")

In [6]:
## Elpased Time Calulation for processing a text at a time
encoded_result = {}
start = time.time()
for idx, metadata in sample_json_data.items():
    encoded_result[idx] = splade_model.encode_text(metadata['content'])
    encoded_result[idx]['metadata'] = metadata
end = time.time()
print(f"Elapsed Time for processing a text at a time: {end-start}")

Elapsed Time for processing a text at a time: 11.123236179351807


In [7]:
## Elpased Time Calulation for processing a list of texts
encoded_result = {}
start = time.time()
texts = [metadata['content'] for metadata in sample_json_data.values()]
encoded_result = splade_model.encode_text_list(texts)
end = time.time()
print(f"Elapsed Time for processing a list of texts: {end-start}")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 50/50 [00:07<00:00,  6.62it/s]

Elapsed Time for processing a list of texts: 7.633612155914307





In [8]:
# ## Elpased Time Calulation for processing all texts at once
# encoded_result = {}
# start = time.time()
# texts = [metadata['content'] for metadata in sample_json_data.values()]
# results = splade_model.encode_texts(texts)
# end = time.time()
# print(f"Elapsed Time for processing all texts at once: {end-start}")

In [9]:
save_result(encoded_result, "./sample-data/encoded_data1.json", "json")

# Using Qdrant

In [5]:
from qdrant_client import QdrantClient, models
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_core.documents import Document

In [6]:
QCLIENT = QdrantClient(path="./.qdrant-vectors")

In [7]:
class QdrantHandler:
    def __init__(self, qclient, **kwargs):
        self.qclient = qclient
        self.retreivers = {}
        self.sparse_enocder = kwargs.get("sparse_enocder", None)
    
    def init_collections(self, collection_names:List[str],reset_before_init:bool=False):
        for collection_name in collection_names:
            if reset_before_init:
                self.qclient.delete_collection(collection_name)
            try:
                if not self.qclient.collection_exists(collection_name):
                    self.qclient.create_collection(
                        collection_name,
                        vectors_config={},
                        sparse_vectors_config={
                            "sparse_vector": models.SparseVectorParams(
                                index=models.SparseIndexParams(
                                    on_disk=False,
                                )
                            )
                        },
                    )
            except Exception as e:
                print(f"Collection {collection_name} already exists. Reusing the collection")

            print(f"Collection {collection_name} is created")
            retriever = QdrantSparseVectorRetriever(
                client=self.qclient,
                collection_name=collection_name,
                sparse_vector_name="sparse_vector",
                sparse_encoder=self.sparse_enocder
            )

            self.retreivers[collection_name] = retriever

    def add_documents(self, collection_name:str, data:List[Dict[str, Any]], pk:str="id",content_key:str="content", batch_size:int=10):
        
        document_list = [ 
            Document(
                metadata = doc,
                page_content = doc[content_key]
            ) for doc in data
        ]

        id_list = [doc[pk] for doc in data]

        for i in tqdm(range(0, len(document_list), batch_size)):
            batch = document_list[i:i+batch_size]
            batch_ids = id_list[i:i+batch_size]
            self.retreivers[collection_name].add_documents(documents=batch, ids=batch_ids)
    
    def retrieve(self, collection_name:str, query:str, top_k:int=10):
        return self.retreivers[collection_name].invoke(query)

In [8]:
QDRANT_HANDLER = QdrantHandler(QCLIENT, sparse_enocder=lambda x: tuple(splade_model.encode_text(x).values()))

In [9]:
QDRANT_HANDLER.init_collections(["sample_collection"], reset_before_init=False)

Collection sample_collection is created


  retriever = QdrantSparseVectorRetriever(


In [10]:
# QDRANT_HANDLER.add_documents("sample_collection", list(sample_json_data.values()))

In [84]:
sample_result = QDRANT_HANDLER.retrieve("sample_collection", "What is nerural network?")

In [87]:
sample_result[0].page_content

'A neural network consists of layers of interconnected nodes. Each node in a layer processes an input, applies a weight and an activation function, and passes the result to the next layer. The final output layer generates predictions based on the input data.'

# LLM Utils

In [11]:
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, SystemMessage,ToolMessage
from langchain.prompts import PromptTemplate
from typing import List, Any

In [12]:
from together import Together
from dotenv import load_dotenv
import os

load_dotenv()

True

In [13]:

def prompt_generator(query: str, contexts: List[str], history: List[Any]):
    # Define the system prompt for concise question answering
    system_prompt = """
    You are an assistant for question-answering tasks. 
    Use the following pieces of retrieved context to answer the question. 
    If you don't know the answer, just say that you don't know. 
    Use three sentences maximum and keep the answer concise.
    """

    # Define a template with placeholders for system prompt, query, context, and history
    template = """
    System Instruction:
    {system_prompt}

    Given the following context:
    {contexts}

    And the previous conversation history:
    {history}

    Answer the following query:
    {query}
    """

    # Combine the inputs into a prompt template using Langchain's PromptTemplate class
    prompt = PromptTemplate(
        input_variables=["system_prompt", "contexts", "history", "query"],
        template=template
    )

    history_text = "\n".join([f"{'Human' if isinstance(item, HumanMessage) else 'System'}: {item.content}" for item in history]) 

    # Format the prompt with the actual values
    formatted_prompt = prompt.format(
        system_prompt=system_prompt.strip(),
        contexts="\n".join(contexts),
        history=history_text,
        query=query
    )
    
    return formatted_prompt


In [14]:
class LLMEngine:
    def __init__(self, model_path:str="meta-llama/Meta-Llama-3-8B-Instruct-Turbo"):
        self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
        self.model_path = model_path
    
    def invoke(self, query:str, contexts:List[str], history:List[Any]):
        prompt=prompt_generator(
                query=query,
                contexts=contexts,
                history=history
                )
        response = self.client.chat.completions.create(
            model=self.model_path,
            messages=[{"role": "user", "content": prompt}],
        )

        response_content = response.choices[0].message.content

        return response_content

In [15]:
LLM_ENGINE = LLMEngine()

# Langgraph Agent

In [16]:
import langgraph
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.sqlite import SqliteSaver
from IPython.display import Image, display
from langgraph.checkpoint.memory import MemorySaver

In [17]:
class AgentState(TypedDict):
    cur_state: str
    messages: Annotated[list, add_messages]
    contexts: List[Document]

In [50]:
class AgentConfig(TypedDict):
    qdrant_handler: QdrantHandler = QDRANT_HANDLER
    sample_collection: str = "sample_collection"

class AgentFunctions:
    @staticmethod
    def welcome(state:AgentState)->AgentState:
        msg = SystemMessage("Welcome to the system. How can I help you?")
        state['cur_state'] = "welcome"
        print(f"Current State: {state['cur_state']}")
        state['messages'].append(msg)
        return state

    @staticmethod
    def human_input(state:AgentState)->AgentState:
        state['cur_state'] = "human_input"
        print(f"Current State: {state['cur_state']}")
        return state
    
    @staticmethod
    def retreive_documents(state:AgentState)->AgentState:
        query = state['messages'][-1].content
        retrieved_nodes = AgentConfig.qdrant_handler.retrieve(AgentConfig.sample_collection, query)
        state['contexts'] = [node.page_content for node in retrieved_nodes] 
        state['cur_state'] = "retrieval"
        print(f"Current State: {state['cur_state']}")
        return state

    @staticmethod
    def invoke_llm(state:AgentState)->AgentState:
        query = state['messages'][-1].content
        history = state['messages'][-4:-1]
        contexts = state['contexts']
        response = LLM_ENGINE.invoke(query, contexts, history)
        msg = SystemMessage(response)
        state['messages'].append(msg)
        state['cur_state'] = "invoke_llm"
        print(f"Current State: {state['cur_state']}")
        return state
    
    @staticmethod
    def check_end_of_conversation(state:AgentState)->bool:
        state['cur_state'] = "check_end_of_conversation"
        print(f"Current State: {state['cur_state']}")
        if state['messages'][-1].content == "/quit":
            return "quit"
        return "continue"
    
    @staticmethod
    def exit_conversation(state:AgentState)->AgentState:
        state['cur_state'] = "exit_conversation"
        print(f"Current State: {state['cur_state']}")
        msg = SystemMessage("Goodbye!")
        state['messages'].append(msg)
        return state

In [51]:
class MyAgent:
    def __init__(self,thread_id):
        self.config = None
        self.app = None
        self.build(thread_id)

    def draw_workflow(self):
        display(Image(self.app.get_graph(xray=True).draw_mermaid_png()))
    
    def build(self, thread_id):
        workflow = StateGraph(AgentState)
        workflow.add_node("welcome", AgentFunctions.welcome)
        workflow.add_node("human_input", AgentFunctions.human_input)
        workflow.add_node("retreive_documents", AgentFunctions.retreive_documents)
        workflow.add_node("invoke_llm", AgentFunctions.invoke_llm)
        workflow.add_node("exit_conversation", AgentFunctions.exit_conversation)

        
        workflow.add_edge("welcome", "human_input")
        # workflow.add_conditional_edges(
        #     "human_input", 
        #     AgentFunctions.check_end_of_conversation,
        #     {
        #         "quit": "exit_conversation",
        #         "continue": "retreive_documents"
        #     }
        # )

        workflow.add_edge("human_input", "retreive_documents")
        workflow.add_edge("retreive_documents", "invoke_llm")
        workflow.add_edge("invoke_llm", "exit_conversation")
        workflow.add_edge("exit_conversation", langgraph.graph.END)

        workflow.set_entry_point("welcome")
        
        memory = MemorySaver()

        app = workflow.compile(
            checkpointer=memory,
            interrupt_before=["human_input"],
        )

        
        # Assign the compiled app before the connection closes
        self.app = app
        self.config = {"configurable": {"thread_id": str(thread_id)}}

    def get_recent_state_snap(self):
        return self.app.get_state(config=self.config).values.copy()

    def get_last_message(self):
        snap = self.get_recent_state_snap()
        return snap["messages"][-1]
    
    def continue_flow(self, state):
        self.app.invoke(state, config=self.config)
        return self.get_recent_state_snap()
    
    def resume_with_user_input(self, user_input:str):
        snap = self.get_recent_state_snap()
        snap["messages"].append(HumanMessage(user_input))
        self.app.update_state(self.config, snap)
        return self.continue_flow(snap)

## Running the agent

In [52]:
sample_agent = MyAgent(thread_id=50)

In [54]:
# sample_agent.draw_workflow()

In [55]:
sample_snap0 =  sample_agent.continue_flow({
    "cur_state": "start",
    "messages": [],
    "contexts": [],
})

Current State: welcome


In [56]:
sample_snap0['messages']

[SystemMessage(content='Welcome to the system. How can I help you?', additional_kwargs={}, response_metadata={}, id='f28fe5d8-04b0-4782-b81f-6c694142b3aa')]

In [36]:
print(sample_agent.get_last_message().content) 

Welcome to the system. How can I help you?


In [29]:
sample_agent.get_recent_state_snap()

{'cur_state': 'welcome',
 'messages': [SystemMessage(content='Welcome to the system. How can I help you?', additional_kwargs={}, response_metadata={}, id='a0cd0610-d334-4a55-84b7-3aef04da3de2')],
 'contexts': []}

In [57]:
sample_snap1 = sample_agent.resume_with_user_input("Hi! I want to know about neural networks")

Current State: welcome
