In [2]:
import warnings
warnings.filterwarnings("ignore")

import os 
from pathlib import Path 
from langchain_ollama import OllamaEmbeddings
import pandas as pd
import matplotlib.pyplot as plt 

import faiss
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_ollama import ChatOllama
from langchain_ollama.llms import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import torch
import re
from typing import List
from langchain_core.runnables import chain

from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain

import gradio as gr
from functools import partial
from rich.console import Console
from rich.style import Style
from rich.theme import Theme

console = Console()
base_style = Style(color="#76B900", bold=True)
pprint = partial(console.print, style=base_style)

## Load Dataset from CSV

In [3]:
data_dir = Path.cwd().parent / 'data'
file_name = 'prelim_data_cleaned.csv'
data = pd.read_csv(data_dir / file_name)
data.head()

Unnamed: 0,Question,Answer,question_len,answer_len
0,What is radiation therapy?,Radiation therapy (also called radiotherapy) i...,4,48
1,How is radiation therapy given?,Radiation therapy can be external beam or inte...,5,65
2,Who gets radiation therapy?,Many people with cancer need treatment with ra...,4,34
3,What does radiation therapy do to cancer cells?,"Given in high doses, radiation kills or slows ...",8,87
4,How long does radiation therapy take to work?,Radiation therapy does not kill cancer cells r...,8,35


## Define Retrieval Utilities

In [4]:
def init_embedder(embedder_config, device):
    embedder_params = embedder_config['params']
    if embedder_config['backend'] == "HF":
        embedder = HuggingFaceEmbeddings(
            model_name=embedder_params['model_name'],
            model_kwargs={'device': device},
            encode_kwargs=embedder_params.get('encode_kwargs', {})
        )
        print(f"Embedder Initialized with {embedder_params['model_name']}")
    elif embedder_config['backend'] == "OLLAMA":
        embedder = OllamaEmbeddings(model=embedder_params['model_name'])
        print(f"Embedder Initialized with {embedder_params['model_name']}")
    else:
        raise NotImplementedError("Embedder backend not supported")
    return embedder


def init_vectorstore(embedder):
    # Create the FAISS index for storing embeddings
    embedding_size = len(embedder.embed_query("hello world"))  # Example to get embedding size
    index = faiss.IndexFlatL2(embedding_size)
    
    vector_store = FAISS(
        embedding_function=embedder,
        index=index,
        docstore=InMemoryDocstore(),
        index_to_docstore_id={}
    )
    return vector_store

def preprocess_text(text):
    """
    Preprocess the input text for embedding by normalizing and cleaning.
    :param text: input string to preprocess
    :return: cleaned and preprocessed string
    """
    text = text.lower()  # Convert to lowercase
    text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    text = re.sub(r'\s+', ' ', text).strip()  # Remove extra whitespace
    return text

def create_documents_from_questions(questions, question_ids):
    for index in range(len(questions)):
        questions[index] = preprocess_text(questions[index])
    
    documents = []
    for text, doc_id in zip(questions, question_ids):
        _ = Document(
            page_content=text,
            metadata={"id" : doc_id}
        )
        documents.append(_)        
    return documents

## Create Retrieval Chain

In [5]:
# Define Configs and Settings For Retrieval
SCORE_THRESHOLD = 0.5
TOPK=5

EMBEDDER_CONFIG = {
    "backend" : "HF", 
    "params" : {
        "model_name" : "sentence-transformers/all-mpnet-base-v2",
        "encode_kwargs" : {'normalize_embeddings': False}
    }
}

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [6]:
# Intialize Embedder and VectorStore
embedder = init_embedder(embedder_config=EMBEDDER_CONFIG, device=device)
vectorstore = init_vectorstore(embedder)

Embedder Initialized with sentence-transformers/all-mpnet-base-v2


In [7]:
# Add Documents to VectorStore
documents = create_documents_from_questions(data.Question.tolist(), data.index.tolist())
vectorstore.add_documents(documents=documents, ids=data.index.tolist())
print(f"{len(documents)} Documents added to Vector Store")

99 Documents added to Vector Store


In [8]:
# Define Custom Functions as Runnables to use in retrieval chain
@chain
def retriever(query: dict) -> List[Document]:
    """Custom Retriever Logic to filter based on SIM THRESHOLD"""
    query = preprocess_text(query.get('query',''))
    docs, scores = zip(*vectorstore.similarity_search_with_relevance_scores(query, k=TOPK))
    result = []
    for doc, score in zip(docs, scores):
        if score > SCORE_THRESHOLD:
            doc.metadata["score"] = score
            result.append(doc)
    return result

@chain
def format_retrieved_docs(documents: List[Document]) -> str:
    """Context Formatter"""
    docs = []
    for doc in documents:
        score = doc.metadata.get('score')
        doc_index = doc.metadata.get('id')
        answer = data.iloc[doc_index].Answer
        docs.append((doc.page_content, answer, score))

    context_text = "\n".join([f"Q: {ctx[0]}\nA: {ctx[1]}" for ctx in docs])
    return context_text

In [9]:
ret_chain = retriever | format_retrieved_docs

## Create a Chain to fill User's InfoBase

In [10]:
from langchain.output_parsers import PydanticOutputParser
from langchain.schema.runnable.passthrough import RunnableAssign
from langchain_core.output_parsers import StrOutputParser


def RExtract(pydantic_class, llm, prompt):
    '''
    Runnable Extraction module
    Returns a knowledge dictionary populated by slot-filling extraction
    '''
    parser = PydanticOutputParser(pydantic_object=pydantic_class)
    instruct_merge = 
    ({'format_instructions' : lambda x: parser.get_format_instructions()})
    def preparse(string):
        if '{' not in string: string = '{' + string
        if '}' not in string: string = string + '}'
        string = (string
            .replace("\\_", "_")
            .replace("\n", " ")
            .replace("\]", "]")
            .replace("\[", "[")
        )
        # print(string)  ## Good for diagnostics
        return string
    return instruct_merge | prompt | llm | preparse | parser

class KnowledgeBase(BaseModel):
    name: str = Field(default="unknown", description="Name of the user, first name and last name")
    age : str = Field(default="unknown", description="Age of the user")
    gender: str = Field(default="unknown", description="Gender of the user")
    query: str = Field(default="unknown", description="Detailed User's Query to answer. It should be framed in Question format.")
    summary: str = Field('unknown', description="Running summary of conversation. Update this with new input")
    response: str = Field('unknown', description="An ideal response to the user based on their new message")

parser_prompt = ChatPromptTemplate.from_template(
    "You are chatting with a user. The user just responded ('input'). Please update the knowledge base."
    "Record your response in the 'response' tag to continue the conversation."
    "Do not hallucinate any details, and make sure the knowledge base is not redundant."
    "Update the entries frequently to adapt to the conversation flow."
    "\n{format_instructions}"
    "\n\nOLD KNOWLEDGE BASE: {know_base}"
    "\n\nNEW MESSAGE: {input}"
    "\n\nNEW KNOWLEDGE BASE:"
)


llm = OllamaLLM(model='llama3.1:70b', temperature=0.4)
instruct_llm = llm #| StrOutputParser
extractor = RExtract(KnowledgeBase, instruct_llm, parser_prompt)
info_update = RunnableAssign({'know_base' : extractor})

## Initialize the knowledge base and see what you get
state = {'know_base' : KnowledgeBase()}
state['input'] = "What is radiation therapy and how can it be used to treat cancer?"
state = info_update.invoke(state)
print(state)

{'know_base': KnowledgeBase(name='unknown', age='unknown', gender='unknown', query='What is radiation therapy and how can it be used to treat cancer?', summary='User asked about radiation therapy and its use in treating cancer.', response='Radiation therapy, also known as radiotherapy, is a type of cancer treatment that uses high-energy rays to kill or shrink cancer cells. It can be used alone or in combination with other treatments, such as surgery or chemotherapy. Radiation therapy works by damaging the DNA of cancer cells, which prevents them from growing and dividing.'), 'input': 'What is radiation therapy and how can it be used to treat cancer?'}


In [11]:
from operator import itemgetter
itemgetter(state.get('know_base'))

operator.itemgetter(KnowledgeBase(name='unknown', age='unknown', gender='unknown', query='What is radiation therapy and how can it be used to treat cancer?', summary='User asked about radiation therapy and its use in treating cancer.', response='Radiation therapy, also known as radiotherapy, is a type of cancer treatment that uses high-energy rays to kill or shrink cancer cells. It can be used alone or in combination with other treatments, such as surgery or chemotherapy. Radiation therapy works by damaging the DNA of cancer cells, which prevents them from growing and dividing.'))

In [12]:
@chain
def extract_query(state):
    return state.get('know_base').query

internal_chain = (
    info_update
    | RunnableAssign({"query" : extract_query})
    | RunnableAssign({'context' : ret_chain})
)

In [13]:
state = {'know_base' : KnowledgeBase()}
state['input'] = "I am Ayush Agarwal. Who am I?"
state = internal_chain.invoke(state)

In [14]:
state['output'] = ""
state.keys()

dict_keys(['know_base', 'input', 'query', 'context', 'output'])

In [15]:
external_prompt = ChatPromptTemplate.from_messages([
    ("system", (
        "You are a chatbot for NU Medicine, and you are providing users cancer treatment assistance."
        " Please chat with them! Stay concise and clear!"
        " Your running knowledge base is: {know_base}."
        " This is for you only; Do not mention it!"
        " \nUsing that, we retrieved the following: {context}\n"
        "\nHere is the query that user want to get answered: {query}\n"
        "\nKeep asking follow up questions, until user provides basic info like name, age and gender\n"
        "\nModerate your tone according to user's age and gender\n"
        "If they provide info and the retrieval fails, ask to confirm their name, age, gender and query"
        " Do not ask them any other personal info."
    )),
    ("assistant", "{output}"),
    ("user", "{input}"),
])

external_chain = external_prompt | OllamaLLM(model='llama3.1:70b', temperature=0.4)

In [16]:
external_chain.invoke(state)

"Hello Ayush! You've just told me your name, but I don't have any information about you beyond that. Could you please provide more context or clarify what you mean by 'who am I'? Are you looking for general information about yourself or is there something specific you're trying to understand?"

In [17]:
state = {'know_base' : KnowledgeBase()}

def chat_gen(message, history=[], return_buffer=True):

    ## Pulling in, updating, and printing the state
    global state
    state['input'] = message
    state['history'] = history
    state['output'] = "" if not history else history[-1][1]

    ## Generating the new state from the internal chain
    state = internal_chain.invoke(state)
    #print("State after chain run:")
    #pprint({k:v for k,v in state.items() if k != "history"})
    
    ## Streaming the results
    buffer = ""
    for token in external_chain.stream(state):
        buffer += token
        yield buffer if return_buffer else token

def queue_fake_streaming_gradio(chat_stream, history = [], max_questions=8):

    ## Mimic of the gradio initialization routine, where a set of starter messages can be printed off
    for human_msg, agent_msg in history:
        if human_msg: print("\n[ Human ]:", human_msg)
        if agent_msg: print("\n[ Agent ]:", agent_msg)

    ## Mimic of the gradio loop with an initial message from the agent.
    for _ in range(max_questions):
        message = input("\n[ Human ]: ")
        print("\n[ Agent ]: ")
        history_entry = [message, ""]
        for token in chat_stream(message, history, return_buffer=False):
            print(token, end='')
            history_entry[1] += token
        history += [history_entry]
        print("\n")

## history is of format [[User response 0, Bot response 0], ...]
chat_history = [[None, "Hello! I'm your NU Medicine Agent! How can I help you?"]]

## Simulating the queueing of a streaming gradio interface, using python input
# queue_fake_streaming_gradio(
#     chat_stream = chat_gen,
#     history = chat_history
# )

In [18]:
state = {'know_base' : KnowledgeBase()}

chatbot = gr.Chatbot(value=[[None, "Hello! I'm your NU Medicine Agent! How can I help you?"]])
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue().launch(debug=True, share=True)

Running on local URL:  http://127.0.0.1:7860


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Running on public URL: https://8a67ac0cb0de98e8cf.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://8a67ac0cb0de98e8cf.gradio.live
