In [1]:
import warnings
warnings.filterwarnings("ignore")

import os 
from pathlib import Path 
from langchain_ollama import OllamaEmbeddings
import pandas as pd
import matplotlib.pyplot as plt 
import numpy as np 

import faiss
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_ollama import ChatOllama
from langchain_ollama.llms import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import torch
import re
from typing import List
from langchain_core.runnables import chain
from langchain_core.runnables import RunnablePassthrough
import json
from langchain.output_parsers import PydanticOutputParser
from langchain.schema.runnable.passthrough import RunnableAssign
from langchain_core.output_parsers import StrOutputParser

import gradio as gr
from functools import partial
from rich.console import Console
from rich.style import Style
from rich.theme import Theme

console = Console()
base_style = Style(color="#76B900", bold=True)
pprint = partial(console.print, style=base_style)

In [2]:
data_dir = Path.cwd().parent / 'data/data_files'
file_name = 'capstone_final_data_v1.csv'
data = pd.read_csv(data_dir / file_name)
data.head()

Unnamed: 0,Question,Answer
0,What is radiation therapy?,Radiation therapy (also called radiotherapy) i...
1,How is radiation therapy given?,Radiation therapy can be external beam or inte...
2,Who gets radiation therapy?,Many people with cancer need treatment with ra...
3,What does radiation therapy do to cancer cells?,"Given in high doses, radiation kills or slows ..."
4,How long does radiation therapy take to work?,Radiation therapy does not kill cancer cells r...


In [3]:
def init_embedder(embedder_config, device):
    embedder_params = embedder_config['params']
    if embedder_config['backend'] == "HF":
        embedder = HuggingFaceEmbeddings(
            model_name=embedder_params['model_name'],
            model_kwargs={'device': device},
            encode_kwargs=embedder_params.get('encode_kwargs', {})
        )
        print(f"Embedder Initialized with {embedder_params['model_name']}")
    elif embedder_config['backend'] == "OLLAMA":
        embedder = OllamaEmbeddings(model=embedder_params['model_name'])
        print(f"Embedder Initialized with {embedder_params['model_name']}")
    else:
        raise NotImplementedError("Embedder backend not supported")
    return embedder


def init_vectorstore(embedder):
    # Create the FAISS index for storing embeddings
    embedding_size = len(embedder.embed_query("hello world"))  # Example to get embedding size
    index = faiss.IndexFlatL2(embedding_size)
    
    vector_store = FAISS(
        embedding_function=embedder,
        index=index,
        docstore=InMemoryDocstore(),
        index_to_docstore_id={}
    )
    return vector_store

def preprocess_text(text):
    """
    Preprocess the input text for embedding by normalizing and cleaning.
    :param text: input string to preprocess
    :return: cleaned and preprocessed string
    """
    text = text.lower()  # Convert to lowercase
    text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    text = re.sub(r'\s+', ' ', text).strip()  # Remove extra whitespace
    return text

def create_documents_from_questions(questions, question_ids):
    for index in range(len(questions)):
        questions[index] = preprocess_text(questions[index])
    
    documents = []
    for text, doc_id in zip(questions, question_ids):
        _ = Document(
            page_content=text,
            metadata={"id" : doc_id}
        )
        documents.append(_)        
    return documents

In [4]:
# Define Configs and Settings For Retrieval
SCORE_THRESHOLD = 0.5
TOPK=3

EMBEDDER_CONFIG = {
    "backend" : "HF", 
    "params" : {
        "model_name" : "sentence-transformers/all-mpnet-base-v2",
        "encode_kwargs" : {'normalize_embeddings': False}
    }
}

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(device)

# Intialize Embedder and VectorStore
embedder = init_embedder(embedder_config=EMBEDDER_CONFIG, device=device)
vectorstore = init_vectorstore(embedder)


# Add Documents to VectorStore
documents = create_documents_from_questions(data.Question.tolist(), data.index.tolist())
vectorstore.add_documents(documents=documents, ids=data.index.tolist())
print(f"{len(documents)} Documents added to Vector Store")

cuda
Embedder Initialized with sentence-transformers/all-mpnet-base-v2
325 Documents added to Vector Store


In [31]:
# Mean Similarity Method. We can use other kind of averaging
def calculate_mean_similarity(sims):
    if len(sims) == 0:
        return -1
    return np.mean(np.array(sims))

def calculate_max_similarity(sims):
    if len(sims) == 0:
        return -1
    return np.max(np.array(sims))

# Define Custom Functions as Runnables to use in retrieval chain
@chain
def retriever(query: dict) -> List[Document]:
    """Custom Retriever Logic to filter based on SIM THRESHOLD"""
    query = preprocess_text(query.get('query',''))
    docs, scores = zip(*vectorstore.similarity_search_with_relevance_scores(query, k=TOPK))
    result = []
    for doc, score in zip(docs, scores):
        if score > SCORE_THRESHOLD:
            doc.metadata["score"] = score
            result.append(doc)
    return result

@chain
def format_retrieved_docs(documents: List[Document]) -> str:
    """Context Formatter"""
    docs = []
    for doc in documents:
        score = doc.metadata.get('score')
        doc_index = doc.metadata.get('id')
        answer = data.iloc[doc_index].Answer
        docs.append((doc.page_content, answer, score))

    context_text = "\n".join([f"Q: {ctx[0]}\nA: {ctx[1]}" for ctx in docs])
    #mean_sim = calculate_mean_similarity([ctx[2] for ctx in docs])
    max_sim = calculate_max_similarity([ctx[2] for ctx in docs])
    return {"context" : context_text, "retrieval_score" : max_sim}


def RExtract(pydantic_class, llm, prompt):
    '''
    Runnable Extraction module
    Returns a knowledge dictionary populated by slot-filling extraction
    '''
    parser = PydanticOutputParser(pydantic_object=pydantic_class)
    instruct_merge = RunnableAssign({'format_instructions' : lambda x: parser.get_format_instructions()})
    def preparse(string):
        if '{' not in string: string = '{' + string
        if '}' not in string: string = string + '}'
        string = (string
            .replace("\\_", "_")
            .replace("\n", " ")
            .replace("\]", "]")
            .replace("\[", "[")
        )
        # print(string)  ## Good for diagnostics
        return string
    return instruct_merge | prompt | llm | preparse | parser

In [42]:
class KnowledgeBase(BaseModel):
    name: str = Field(
        default="unknown", description="Name of the user, first name and last name"
    )
    age: str = Field(default="25", description="Age of the user")
    gender: str = Field(default="male", description="Gender of the user")
    education_level : str = Field(default='12th grade', description="Highest education level of the user.")
    disease_site : str = Field(default='unknown', description="Description of location of Cancer. If specified as general, then it is general cancer")
    query: str = Field(
        default="unknown",
        description="Frame user input into a question format using disease_site and current summary",
    )
    summary: str = Field(
        "unknown",
        description="Running summary of conversation under 500 words. Create a summary of the conversion using previous chatbot response and user inputs",
    )

knowledge_prompt = ChatPromptTemplate.from_template(
    "You are chatting with a user. The user just responded ('input'). Please update the knowledge base."
    "Record your response in the 'response' tag to continue the conversation."
    "Do not hallucinate any details, and make sure the knowledge base is not redundant."
    "Update the entries frequently to adapt to the conversation flow."
    "\n{format_instructions}"
    "\n\nOLD KNOWLEDGE BASE: {know_base}"
    "\n\nOLD CHATBOT RESPONSE : {response}"
    "\n\nNEW MESSAGE: {input}"
    "\n\nNEW KNOWLEDGE BASE:"
)

conversation_prompt = ChatPromptTemplate.from_messages([
    ("system", (
        "You are a chatbot for NU Medicine, and you are providing users cancer treatment assistance."
        " Please chat with them! Stay clear and provide detailed answers to Questions as much as possible!"
        " Your running knowledge base is: {know_base}."
        " Your running summary of the conversation is : {summary}. Use this to have a fluent conversation.\n"
        " This is for you only; Do not mention it!"
        " \nUsing that, we retrieved the following: {context}\n"
        "\nHere is the query that user want to get answered: {query}\n"
        "\nMake sure you provide elaborated and detailed answers using the retrieved context\n"
        " Do not ask them any other personal info."
        "\nAdjust tone, complexity, and empathy in responses based on the user's inferred gender, age, and education level for personalized and context-sensitive communication.\n"
        "\ngender:{gender} age:{age} education_level:{education_level}\n"
    )),
    ("assistant", "{output}"),
    ("user", "{input}"),
])

In [33]:
llm = OllamaLLM(model='llama3.1:70b', temperature=0.4)
instruct_llm = llm 
extractor = RExtract(KnowledgeBase, instruct_llm, knowledge_prompt)
info_update = RunnableAssign({'know_base' : extractor})

@chain
def extract_query(state):
    return state.get('know_base').query

@chain
def merge_outputs(ret_results):
    return {**ret_results[0], **ret_results[1]}

ret_chain = retriever | format_retrieved_docs
internal_chain = (
    info_update
    | RunnableAssign({"query" : extract_query})
    | (lambda x: (x, ret_chain.invoke(x)))
    | merge_outputs
)

In [43]:
response_str = "Welcome to NU Medicine. We assure you are at the right place."
state = {'know_base' : KnowledgeBase(), 'response' : response_str}
#state['input'] = "What is radiation therapy?"
state['input'] = "What is the treatment?"
state = internal_chain.invoke(state)
state

{'know_base': KnowledgeBase(name='unknown', age='25', gender='male', education_level='12th grade', disease_site='unknown', query='What is the treatment for my condition?', summary='The user has reached out to NU Medicine seeking information on treatment options. They have not specified a particular disease or condition.'),
 'response': 'Welcome to NU Medicine. We assure you are at the right place.',
 'input': 'What is the treatment?',
 'query': 'What is the treatment for my condition?',
 'context': '',
 'retrieval_score': -1}

In [35]:
#response_str = data['Answer'].iloc[0]
state['response'] = "Radiation therapy is useful for treating lung cancer. Do you want to know more ?"
state['input'] = "Yes"
state = internal_chain.invoke(state)

In [36]:
state['know_base']

KnowledgeBase(name='unknown', age='25', gender='male', education_level='12th grade', disease_site='lung', query='What are the details of radiation therapy for lung cancer?', summary='The user has come to NU Medicine seeking information about their condition. They have asked about the treatment options available and specifically wanted to know more about radiation therapy.')

In [37]:
state.keys()

dict_keys(['know_base', 'response', 'input', 'query', 'context', 'retrieval_score'])

In [38]:
state['age'] = state['know_base'].age
state['gender'] = state['know_base'].gender
state['education_level'] = state['know_base'].education_level
state['summary'] = state['know_base'].summary

In [25]:
conversation_prompt.input_variables

['age',
 'context',
 'education_level',
 'gender',
 'input',
 'know_base',
 'output',
 'query',
 'summary']

In [39]:
external_chain = (
                conversation_prompt 
                | OllamaLLM(model='llama3.1:70b', temperature=0.4)
                )

In [41]:
state['output'] = ""
external_chain.invoke(state)

"Hello! I'm here to help you understand more about radiation therapy for lung cancer. \n\nRadiation therapy is a treatment option that uses high-energy rays or particles to kill cancer cells. In the case of lung cancer, it can be used as a primary treatment, or in combination with other treatments like chemotherapy or surgery.\n\nThere are two main types of radiation therapy: external beam therapy (EBT) and brachytherapy. EBT is the most common type and uses a machine outside the body to direct x-rays or gamma rays at the tumor. Proton therapy is a form of EBT that uses charged atoms instead of x-rays.\n\nOn the other hand, brachytherapy involves placing small radioactive sources directly inside or near the tumor. This allows for high doses of radiation to be delivered directly to the cancer cells while minimizing damage to surrounding healthy tissue.\n\nWhen it comes to lung cancer specifically, radiation therapy can be used in different ways depending on the stage and location of the