In [105]:
from langchain_community.chat_models import ChatOllama
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.runnables import RunnableLambda
from bs4 import BeautifulSoup
import re

In [99]:
model_name = 'dunzhang/stella_en_400M_v5'
model_kwargs = {'device': 'cuda', "trust_remote_code": True}

embedding_model = HuggingFaceEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
)

Some weights of the model checkpoint at dunzhang/stella_en_400M_v5 were not used when initializing NewModel: ['new.pooler.dense.bias', 'new.pooler.dense.weight']
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [100]:
vector_store = Chroma(
    collection_name="faq",
    persist_directory="./db",
    embedding_function=embedding_model,
    collection_metadata={"hnsw:space": "cosine"}
)

In [11]:
vector_store.reset_collection()

In [10]:
from langchain_text_splitters import CharacterTextSplitter

text_splitter = CharacterTextSplitter(
    separator="\n",
    chunk_size=4000,
    chunk_overlap=100
)

In [None]:
from loaders.HTMLDirectory import HTMLDirectoryLoader

def faq_html_parser(html):
    soup = BeautifulSoup(html)
    question = soup.find(id="kb_article_question")
    answer = soup.find(id="kb_article_text")

    if not question or not answer:
        return None
    
    qa = f"{question.text.strip()}\n{answer.text.strip()}"
    removed_repeating_newlines = re.sub(r'\n{2,}', '\n', qa)

    return removed_repeating_newlines

faq_html_loader = HTMLDirectoryLoader("../web-scraper/faq-archive", faq_html_parser)
faq_documents = list(faq_html_loader.lazy_load())
faq_split_documents = text_splitter.split_documents(faq_documents)
vector_store.add_documents(faq_split_documents)

In [None]:
from loaders.JSONFile import JSONFileLoader

def json_parser(d):
    return {
        "page_content": d["extracted"],
        "metadata" : {"source": d["url"]}
    }

json_file_loader = JSONFileLoader("../web-scraper/data/urls.json", json_parser)
json_documents = json_file_loader.lazy_load()
json_split_documents = text_splitter.split_documents(json_documents)
vector_store.add_documents(json_split_documents)

In [101]:
llm = ChatOllama(model="gemma2", temperature=0)

In [128]:
from langchain.chains import create_history_aware_retriever
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor

from langchain.globals import set_verbose
from langchain.callbacks.tracers import ConsoleCallbackHandler
set_verbose(True)


retriever = vector_store.as_retriever(
    search_kwargs={'k': 2}
)

combined_retriever = EnsembleRetriever(retrievers=[retriever, ])

compressor = LLMChainExtractor.from_llm(llm)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=combined_retriever
)

contextualize_q_system_prompt = (
    "Given a chat history and the latest user question "
    "which might reference context in the chat history, "
    "formulate a standalone question which can be understood "
    "without the chat history. Do NOT answer the question. "
    "just reformulate it if needed and otherwise return it as is. "
    "if there is no chat history, return the input as is. "
    "if the input is a greeting, return the input as is. "
)

contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)

history_aware_retriever = create_history_aware_retriever(
    llm, combined_retriever, contextualize_q_prompt
)

system_prompt = (
    "Your name is Hoku. You are an assistant for answering questions about UH Manoa."
    "Answer the question given ONLY the provided context.\n"
    "If the answer DOES NOT appear in the context, say 'I'm sorry I don't know the answer to that'.\n"
    "Use three sentences maximum and keep the answer concise and speak nicely.\n"
    "DO NOT mention the context, users do not see it."
    "if the user greets you, greet them back nicely"
)

qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "context:{context}\n\nquestion: {input}"),
    ]
)

question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

sources_examples = [
    {"input": "Hi Hoku!", "output": "no"},
    {"input": "How are you?", "output": "no"},
    {"input": "What is duo mobile used for?", "output": "yes"},
    {"input": "what specs should i have for a mac laptop?", "output": "yes"},
    {"input": "Thank you!", "output": "no"},
]

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}"),
        ("ai", "{output}"),
    ]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=sources_examples,
    input_variables=["input"]
)

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "Your job is to classify a user input as needing sources 'yes' or not needing sources 'no'."),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)

def requires_source(inp: dict):
    chain = final_prompt | llm
    return "yes" in chain.invoke(inp).content.lower()


def add_sources_to_response_if_needed(inp: dict) -> dict:
    if not requires_source({"input" : inp["input"]}):
        return inp
    
    sources_text = "\n".join(list(set(doc.metadata["source"] for doc in inp['context'])))
    inp["answer"] = f"{inp['answer'].strip()}\n\nFor more information, check out these links\n{sources_text}"
    return inp

conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

conversational_rag_chain_with_sources = conversational_rag_chain | add_sources_to_response_if_needed

In [132]:
store = {}

In [None]:
while True:
    user_input = input()
    print(user_input)
    
    if not user_input:
        break
    
    answer = conversational_rag_chain_with_sources.invoke(
        {"input": user_input},
        config={
            "configurable": {"session_id": "1"},
            # 'callbacks': [ConsoleCallbackHandler()]
        },
    )

    print(answer["answer"])
    print()