#### Library Imports

In [7]:
import pandas as pd
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
import weaviate
from langchain_weaviate.vectorstores import WeaviateVectorStore
from weaviate.classes.query import Filter
from pymongo import MongoClient
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.retrievers import ContextualCompressionRetriever
from flashrank import Ranker, RerankRequest
from typing import Optional

import warnings
warnings.filterwarnings("ignore")

weaviate_client = weaviate.connect_to_local()
embeddings = HuggingFaceEmbeddings(model_name="intfloat/e5-large", cache_folder="./embedding_model")
MONGO_URI = "mongodb://root:root@localhost:27017/"
DATABASE_NAME = "incident_db"
COLLECTION_NAME = "incident_collection"

INFO:httpx:HTTP Request: GET http://localhost:8080/v1/.well-known/openid-configuration "HTTP/1.1 404 Not Found"
INFO:httpx:HTTP Request: GET http://localhost:8080/v1/meta "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET https://pypi.org/pypi/weaviate-client/json "HTTP/1.1 200 OK"
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: intfloat/e5-large
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps


In [8]:
from pydantic import root_validator

class CustomReranker(BaseDocumentCompressor):
    """Document compressor using Flashrank interface."""

    client: Ranker
    """Flashrank client to use for compressing documents"""
    top_n: int = 3
    """Number of documents to return."""
    model: Optional[str] = None
    """Model to use for reranking."""

    class Config:
        extra = 'forbid'
        arbitrary_types_allowed = True

    @root_validator(pre=True)
    def validate_environment(cls, values):
        """Validate that api key and python package exists in environment."""
        try:
            from flashrank import Ranker
        except ImportError:
            raise ImportError(
                "Could not import flashrank python package. "
                "Please install it with `pip install flashrank`."
            )

        values["model"] = values.get("model", "ms-marco-MiniLM-L-12-v2")
        values["client"] = Ranker(model_name=values["model"], cache_dir="reranker")
        return values

    def compress_documents(
        self,
        documents,
        query,
        callbacks = None):
        passages = [
            {"id": i, "text": doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)
        ]
        rerank_request = RerankRequest(query=query, passages=passages)
        rerank_response = self.client.rerank(rerank_request)[:self.top_n]
        final_results = []
        for r in rerank_response:
            doc = Document(
                page_content=r["text"],
                metadata={
                    **r['metadata'],
                    "id": r["id"],
                    "relevance_score": r["score"]
                },
            )
            final_results.append(doc)
        return final_results

#### Build Chatbot

In [9]:
compressor = CustomReranker()

def create_retriever(industries):
    filters = None
    if not industries == 'all':
        filters = Filter.any_of([Filter.by_property("industry").equal(industry) for industry in industries])
    db = WeaviateVectorStore(client=weaviate_client, index_name="incident", text_key="text", embedding=embeddings)
    compression_retriever = ContextualCompressionRetriever(
        base_compressor = compressor,
        base_retriever = db.as_retriever(search_type="mmr", search_kwargs={"fetch_k": 20, 'filters': filters})
    )
    return compression_retriever

def get_documents_ids(retrieved_docs):
    if retrieved_docs:
        return [int(doc.metadata['incident_id']) for doc in retrieved_docs]
    else:
        return None

def get_documents_by_ids(ids):
    try:
        client = MongoClient(MONGO_URI)
        db = client[DATABASE_NAME]
        collection = db[COLLECTION_NAME]        
        documents = list(collection.find({"accident_id": {"$in": ids}}))
        return documents
    except Exception as e:
        return []
    finally:
        client.close()

In [10]:
# industries = ['processing of metals']
# filters = Filter.all_of([Filter.by_property("industry").equal(industry) for industry in industries])
# db = WeaviateVectorStore(client=weaviate_client, index_name="incident", text_key="text", embedding=embeddings)
# compression_retriever = ContextualCompressionRetriever(
#     base_compressor = compressor,
#     base_retriever = db.as_retriever(search_type="mmr", search_kwargs={"fetch_k": 20, 'filters': filters})
# )

# compression_retriever.invoke('what are probable causes of a fire in an industrial plant')

In [11]:
from langchain_openai.chat_models import ChatOpenAI
from operator import itemgetter
from langchain.chains.conversation.memory import ConversationSummaryMemory
from langchain.chains import ConversationChain
from langchain_core.runnables import RunnableBranch, RunnablePassthrough, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.prompts import ChatPromptTemplate
import json
from datetime import datetime

llm = ChatOpenAI(
  openai_api_base="https://api.groq.com/openai/v1/",
  model = "llama-3.3-70b-versatile",
  temperature=0.7,
  api_key="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"
)

def format_chat_history(history):
  formatized_chat_history = ""
  for message in history:
    message_type = str(type(message)).split("'")[1].split(".")[-1]
    message_content = message.content.replace("\n", "")
    formatized_chat_history += f"\t{message_type}: {message_content}\n"
  return formatized_chat_history

CONTEXT_TEMPLATE = """
<|start_header_id|>system<|end_header_id|>
Given a discussion history and a follow-up question, rewrite the follow-up question to be fully self-contained and understandable without the context of the previous conversation. Keep it as close as possible to the original meaning but include any relevant details from the history if they add clarity or context. If no additional context is needed, leave the question unchanged.
Discussion history:{chat_history}
<|eot_id|>
<|start_header_id|>user|end_header_id|>
Question: {question}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
Standalone question:"""

CONTEXT_PROMPT = ChatPromptTemplate.from_template(CONTEXT_TEMPLATE)

SYSTEM_TEMPLATE = """
<|start_header_id|>system<|end_header_id|>
You are IncidentNavigator, an AI designed to assist in managing and understanding incidents using a dataset of incident records. Your role is to provide precise, concise, and clear responses based on the context of the documents you receive. If a question falls outside of the information available in the provided context, you should clearly state that you cannot provide an answer but will offer the best response based on what is available.
The documents you process include the following fields:
- accident_id: Unique identifier for each incident.
- event_type: Category of the incident (e.g., fire, collision).
- industry_type: The sector or industry where the incident occurred (e.g., construction, transportation).
- accident_title: A brief, descriptive title for the accident.
- start_date: The date and time the incident began.
- finish_date: The date and time the incident ended or was resolved.
- accident_description: A detailed account of how the accident occurred.
- causes_of_accident: Factors or conditions leading to the incident.
- consequences: Outcomes or impacts of the incident (e.g., injuries, damage).
- emergency_response: Immediate actions taken to manage the incident.
- lesson_learned: Insights or recommendations for future prevention.
- url: Reference link to the document webpage.
When answering questions, follow these guidelines:
- Context Provided: If the context includes information related to these fields, provide a direct and detailed response based on the relevant data.
- Context Missing or Insufficient: If no context or relevant information is provided:
  - State that you cannot provide a definitive answer because the requester does not have sufficient privileges or the information is unavailable.
  - Do not speculate but offer a general response or guidance based on the type of question, when possible.
Context: {context}
You must only output JSON conforming to the schema below: 
"properties": {{
  "answer": {{
    "description": "Detailed answer or response based on the provided context or available data",
    "type": "string",
    "required": true
  }},
  "references": {{
    "description": "References to context documents",
    "type": "array",
    "items": {{
      "type": "object",
      "properties": {{
        "accident_id": {{ "type": "string" }},
        "event_type": {{ "type": "string" }},
        "industry_type": {{ "type": "string" }},
        "accident_title": {{ "type": "string" }},
        "start_date": {{ "type": "string" }},
        "finish_date": {{ "type": "string" }},
        "url": {{ "type": "string" }}
      }}
    }},
    "required": true
  }}
<|eot_id|>
<|start_header_id|>user|end_header_id|>
Question: {question}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
Answer:
"""

SYSTEM_PROMPT = ChatPromptTemplate.from_template(SYSTEM_TEMPLATE)

class CustomJSONEncoder(json.JSONEncoder):
  def default(self, obj):
      if isinstance(obj, datetime):
          return obj.isoformat()
      return super().default(obj)
  
def retrieve(data):
  industries = data['industries']
  query = data['question']
  retriever = create_retriever(industries)
  docs = retriever.invoke(query)
  ids = get_documents_ids(docs)
  retrieved_docs = get_documents_by_ids(ids)
  for document in retrieved_docs:
      document.pop("_id", None)
  data['context'] = "\n\n".join(json.dumps(document, cls=CustomJSONEncoder) for document in retrieved_docs)
  data.pop("industries")
  return data

def get_industry(placeholder = None):
  return ['processing of metals', 'power generation']

In [17]:
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

def get_json_from_string(json_string):
    try:
        json_data = json.loads(json_string)
        return json_data
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        return None

memory = ConversationBufferWindowMemory()

runnable_context = RunnablePassthrough.assign(
    chat_history = RunnableLambda(memory.load_memory_variables) | itemgetter("history") | format_chat_history
) | CONTEXT_PROMPT | llm | {'question': StrOutputParser()}

runnable_system = RunnablePassthrough.assign(
    industries = RunnableLambda(get_industry)
) | retrieve

rag_chain = runnable_context | runnable_system | SYSTEM_PROMPT | llm | StrOutputParser() | get_json_from_string

rag_chain.invoke({"question": "what are ways to prevent a fire in an industrial plant"})

INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET http://localhost:8080/v1/schema/Incident "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET http://localhost:8080/v1/schema/Incident "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"


Error decoding JSON: Expecting value: line 1 column 1 (char 0)


#### Evaluation

In [13]:
# from datasets import load_dataset
# dataset = load_dataset(
#     "explodinggradients/amnesty_qa",
#     "english_v3",
#     trust_remote_code=True
# )

In [14]:
# dataset["eval"].to_pandas()

In [15]:
# from ragas.metrics import LLMContextRecall, Faithfulness, FactualCorrectness, SemanticSimilarity
# from ragas import evaluate