In [None]:
!pip install langchain sentence-transformers langchain-community chromadb torch

In [None]:
!pip install anvil-uplink

In [3]:
import json
from langchain_text_splitters import RecursiveJsonSplitter

In [4]:
doc_path = '/content/summarized.json'
json_content = json.load(open(doc_path))

In [5]:
splitter = RecursiveJsonSplitter(max_chunk_size=1000)
docs = splitter.create_documents(texts=json_content)

In [None]:
from langchain.embeddings import HuggingFaceEmbeddings
embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

In [7]:
from langchain.vectorstores import Chroma

db = Chroma.from_documents(docs, embeddings_model)

In [None]:
from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from pydantic import Field, PrivateAttr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM

class HuggingFaceLLM(LLM):
    model_name: str = Field(..., description="The name of the Hugging Face model to use")
    _tokenizer: Any = PrivateAttr()
    _model: Any = PrivateAttr()

    def __init__(self, model_name: str, **kwargs):
        hugging_face_access_token = "<hf_access_token>"
        super().__init__(model_name=model_name, **kwargs)
        self._tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            use_auth_token=hugging_face_access_token
        )
        self._model = AutoModelForSeq2SeqLM.from_pretrained(
            self.model_name,
            use_auth_token=hugging_face_access_token
        )

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        inputs = self._tokenizer(prompt, return_tensors="pt")
        outputs = self._model.generate(**inputs, max_new_tokens=100)
        return self._tokenizer.decode(outputs[0], skip_special_tokens=True)

    @property
    def _llm_type(self) -> str:
        return "huggingface"

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        return {"model_name": self.model_name}

In [None]:
llm = HuggingFaceLLM(model_name="google/flan-t5-large")

In [None]:
from langchain_core.prompts import ChatPromptTemplate

inference_template = ChatPromptTemplate.from_template("""
You are a helpful assistant that Answers questions ONLY using the context provided in ONE WORD.
Think STEP by STEP, use only the RELEVANT context. Do NOT Explain.
You will be tipped $1000 if the answer is good.

Context: {context} \n\n
Question: {input} \n\n

Answer:""")

comparison_template = ChatPromptTemplate.from_template("""
You are a helpful assistant that accepts context along with an comparison question.
The answer to the question should be a SINGLE word, YES OR NO.
ONLY Use the context given to generate the answer. NO EXPLANATION.
You will be tipped $1000 if the answer is good.

Context: {context} \n\n
Question: {input} \n\n

Answer:""")

temporal_template = ChatPromptTemplate.from_template("""
You are a helpful assistant that accepts context along with an temporal question.
The answer to the question should be a SINGLE word, YES OR NO.
ONLY use the context given to generate the answer. NO EXPLANATION,
You will be tipped $1000 if the answer is good.

Context: {context} \n\n
Question: {input} \n\n

Answer:""")

In [11]:
from langchain.chains.combine_documents import create_stuff_documents_chain

doc_chain_inf = create_stuff_documents_chain(llm, inference_template)
doc_chain_cmp = create_stuff_documents_chain(llm, comparison_template)
doc_chain_tmp = create_stuff_documents_chain(llm, temporal_template)

In [13]:
from langchain.chains import create_retrieval_chain
retriever = db.as_retriever(
    search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.22}
)

retriever_chain_inf = create_retrieval_chain(retriever, doc_chain_inf)
retriever_chain_cmp = create_retrieval_chain(retriever, doc_chain_cmp)
retriever_chain_tmp = create_retrieval_chain(retriever, doc_chain_tmp)

In [14]:
def classify_query(query):
    # Convert query to lowercase for case-insensitive matching
    query = query.casefold()

    #remove non alphabet characters
    query = ''.join(e for e in query if e.isalnum() or e.isspace())

    # Rule for comparison-based queries
    comparison_keywords = ["does", "do", "are"]
    if query.startswith(tuple(comparison_keywords)):
        return "comparison_query"

    # Rule for temporal-based queries
    temporal_keywords = ["between", "after", "has", "was", "is", "before", "considering", "regarding", "according"]
    if query.startswith(tuple(temporal_keywords)):
        return "temporal_query"

    # Rule for inference-based queries
    inference_keywords = ["which", "who", "what", "considering", "in"]
    if query.startswith(tuple(inference_keywords)):
        return "inference_query"

    confusing = ['did']
    if query.startswith(tuple(confusing)):
        if any(keyword in query for keyword in ["disagree with", "identically", 'both', 'while', 'also', 'similarly', "as the", "in the same manner that", "similar", "than", "align with", "does the same", "compared", "contrast", "unlike", "contrast"]):
            return "comparison_query"
        else:
            return "temporal_query"

    # Default case if no rules match
    return "Unclassified"

In [15]:
import ast

In [21]:
def get_answer(ques):
  answer_dict = dict()

  answer_dict['query'] = ques
  answer_dict['question_type'] = classify_query(ques)

  if answer_dict['question_type'] == "comparison_query":
    response = retriever_chain_cmp.invoke({"input": ques})
  elif answer_dict['question_type'] == "temporal_query":
    response = retriever_chain_tmp.invoke({"input": ques})
  else:
    response = retriever_chain_inf.invoke({"input": ques})
  print(response)

  answer_dict['answer'] = response['answer']
  answer_dict['evidence_list'] = [ast.literal_eval(doc.page_content) for doc in response['context']]

  if not answer_dict['evidence_list']:
    answer_dict['question_type'] = "null_query"
    answer_dict['answer'] = "Insufficient Information"

  return answer_dict

In [None]:
import anvil.server

anvil.server.connect("server_CT53XFYTBGTI33ZZGDINQSNT-BCVKL5DLWOPNBPI2")

In [None]:
@anvil.server.callable
def get_final_result(query):
  temp = get_answer(query)
  with open('answer.json', 'w') as file:
    file.write(json.dumps(temp, indent=4))
  return temp['answer'], temp['question_type'], json.dumps(temp, indent=4)

In [None]:
anvil.server.wait_forever()