## Setting up

In [None]:
# Uncomment once and run the file to get access to all the necessary libraries
%pip install -r requirements.txt
# %pip install --upgrade openai langchain
# !apt-get install -y poppler-utils
# !apt-get install -y tesseract-ocr
# !apt-get install -y libmagic1

In [None]:
# Latest update to the langchain package causing issue here
!pip uninstall httpx -y

In [None]:
# Reinstall to this version to enable langchain to work properly
!pip install httpx==0.27.2

In [None]:
!pip show httpx

In [None]:
import os
import bs4
import ast
import fitz
import uuid
import json
import base64
import markdown
import pickle
from PIL import Image
from pprint import pprint
from PyPDF2 import PdfMerger
from base64 import b64decode
from datetime import datetime
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from IPython.display import Image, display
from unstructured.partition.pdf import partition_pdf
from unstructured.documents.elements import Text, Image
from typing_extensions import Annotated, TypedDict, Sequence, List

from langchain import hub
from langchain_core.prompts import MessagesPlaceholder
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.text_splitter  import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate

from langchain_community.document_loaders import(
  PyPDFLoader,
  Docx2txtLoader,
  UnstructuredPDFLoader,
  WebBaseLoader,
  UnstructuredMarkdownLoader,
  UnstructuredWordDocumentLoader,
  TextLoader,
  UnstructuredPDFLoader
)

from chromadb.config import Settings
from chromadb.api.types import Embedding
from langchain.schema import BaseMessage
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains import create_retrieval_chain
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_history_aware_retriever
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.vectorstores import Chroma, InMemoryVectorStore
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

from IPython.display import Image, display
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver

import gradio as gr

load_dotenv("template.env")

# SQL Graph

In [None]:
# Let's connect to the database again
db = SQLDatabase.from_uri("sqlite:///brain_tumor_mri.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT DISTINCT label FROM mri_data")

In [None]:
# Tracing via Langsmith
trace = os.getenv("LANGCHAIN_TRACING_V2")
langsmith = os.getenv("LANGCHAIN_API_KEY")
openai_api_key = os.getenv("OPENAI_API_KEY")

# Build a GPT model
gpt = ChatOpenAI(
    model = "gpt-4-turbo",
    temperature=0,
    openai_api_key = os.getenv("OPENAI_API_KEY")
)
# Incase we do any embeddings
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-large",
    openai_api_key = os.getenv("OPENAI_API_KEY")
)

# Little test
response = gpt.invoke("Why Abhi bang the table?")
print(response.content)

In [None]:
 # LangGraph create our workflow!
class SqlState(TypedDict):
    question: str
    query: str
    result: str
    answer: str

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

# Make sure we only have one message
assert len(query_prompt_template.messages) == 1

# query_prompt_template.messages[0].pretty_print()

# Create our personalised pydantic model
class QueryOutput(TypedDict):
    """Generated SQL query."""
    query: Annotated[str, ..., "Syntactically valid SQL query."] #This serves as an hint to what kind of query is acceptable!

def write_query(state: SqlState):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 5,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = gpt.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

# print(write_query({"question": "How many rows are there?"}))

def execute_query(state: SqlState):
    """Execute SQL query"""
    execute_query_tool = QuerySQLDataBaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

# print(execute_query(write_query({"question": "How many Employees are there?"})))

def generate_answer(state: SqlState):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = gpt.invoke(prompt)
    return {"answer": response.content}

graph_builder = StateGraph(SqlState).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [None]:
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
# Save your money don't need to run this
# graph.invoke({"question":"How many no tumor patients are there?"})['answer']

In [None]:
# Function to process the question and get the answer
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])

def answer_question():
  question = {'question': input("What is your SQL related question? ")}
  config = {"configurable": {"thread_id": "1"}}

  # Run the graph
  for step in graph.stream(question, config, stream_mode="updates"):
      # print(step)
      pass  # We can pass this to make it neater

  # Retrieve the final state
  try:
    user_approval = input("Do you want to execute query? (yes/no): ")
  except Exception:
    user_approval = "no"

  if user_approval.lower() == "yes":
    for state in graph.stream(None, config, stream_mode="values"):
        pass

    final_output = state.get("answer")
    print(final_output)

  else:
    for state in graph.stream(None, config, stream_mode="values"):
        pass

    generated_query = state.get("query")
    return print(f"Operation cancelled by user. Here is the query: {generated_query}")

In [None]:
# Save your money don't need to run this
# answer_question()

In [None]:
display(Image(graph.get_graph().draw_mermaid_png()))

# RAG system
## So far we have only handled PDF with only words what if we want to deal with PDF files that are multimedia?

## Preprocessing

In [None]:
# Merge the PDF together
pdf_files = ["cancer1.pdf", "cancer2.pdf"]
merger = PdfMerger()

for pdf in pdf_files:
    merger.append(pdf)

output_path = "merged.pdf"
merger.write(output_path)
merger.close()

In [None]:
# Run this to get the our chunks
# We will skip this step due to the long waiting time
'''
chunks = partition_pdf(
    filename='merged.pdf',
    infer_table_structure=True,            # extract tables
    strategy="hi_res",                     # mandatory to infer tables

    extract_image_block_types=["Image"],   # Add 'Table' to list to extract image of tables
    image_output_dir_path=output_path,   # if None, images and tables will saved in base64

    extract_image_block_to_payload=True,   # if true, will extract base64 for API usage

    chunking_strategy="by_title",          # or 'basic'
    max_characters=10000,                  # defaults to 500
    combine_text_under_n_chars=2000,       # defaults to 0
    new_after_n_chars=6000,
)
'''

In [None]:
output_file = "chunks.pkl"

'''
# For saving the chunks to a file
output_file = "chunks.pkl"
with open(output_file, "wb") as file:  # "wb" for write binary
    pickle.dump(chunks, file)

print(f"Chunks saved to {output_file}")
'''

# Load the chunks back
with open(output_file, "rb") as file:  # "rb" for read binary
    chunks = pickle.load(file)

print(f"Chunks loaded successfully. Type of first element: {type(chunks[0])}") # Should see only composite elements

## Exploration

In [None]:
# Overview
# pprint(chunks)

# Chunk contents
# pprint(chunks[7].to_dict())

In [None]:
print(len(chunks))

In [None]:
# We get 2 types of elements from the partition_pdf function
print(set([str(type(i)) for i in chunks]))

In [None]:
pprint(chunks[0].to_dict())

In [None]:
chunks[7].to_dict()

In [None]:
# pprint(chunks[3])
pprint(chunks[3].metadata.orig_elements)

print("-------------------------")

# pprint(chunks[3].to_dict())

In [None]:
# Take a sneak peak at the text data
elements = chunks[3].metadata.orig_elements
# print(elements)

chunk_texts = [e for e in elements if 'NarrativeText' in str(type(e))]
# print(chunk_texts)

for text in chunk_texts:
  # pprint(text.metadata)
  pprint(str(text))
  break

In [None]:
# Separate tables from texts
tables = []
texts = []

for chunk in chunks:
    if "Table" in str(type(chunk)):
        tables.append(chunk)

    if "CompositeElement" in str(type((chunk))):
        texts.append(chunk)

print(len(tables))
print(tables[0:1])
print(tables[0])
print("--------------")
print(len(texts))
print(texts[0:1])

In [None]:
# Get the images from the CompositeElement objects
def get_images_base64(chunks):
  images_b64 = []
  for chunk in chunks:
    if "CompositeElement" in str(type(chunk)):
      chunk_els = chunk.metadata.orig_elements
      for el in chunk_els:
        if "Image" in str(type(el)):
            images_b64.append(el.metadata.image_base64)
  return images_b64

images = get_images_base64(chunks)
pprint(images)

In [None]:
# Decode the base64 string to binary
# Display the image

def display_base64_image(base64_code):
    image_data = base64.b64decode(base64_code)
    display(Image(data=image_data))

display_base64_image(images[0])

## Summarisation

In [None]:
pprint(tables[0].to_dict())
print("--------------")
pprint(texts[0].to_dict())

In [None]:
# Text summariser
prompt_text = """
You are an assistant tasked with summarizing tables and text.
Give a concise summary of the table or text.

Respond only with the summary, no additionnal comment.
Do not start your message by saying "Here is a summary" or anything like that.
Just give the summary as it is.

Table or text chunk: {element}
"""


# Summary chain


# Summarize text, concurrency limits the number of task running at the same time


In [None]:
# Check text


In [None]:
# Check first index


In [None]:
# Summarize tables
tables_html = [table.metadata.text_as_html for table in tables]
table_summaries = summarize_chain.batch(tables_html, {"max_concurrency": 3})

In [None]:
# Check table 


In [None]:
# Image summariser
prompt_template = """
  Describe the image in detail. For context,
  the image is part of a research paper explaining the meidcal treatment
  of brain cancer. Be specific about graphs such as bar plots if any
  or diagrams showing the human anatomy.
"""

messages = [
    (
        "user",
        [
            {"type": "text", "text": prompt_template},
            {
                "type": "image_url",
                "image_url": {"url": "data:image/jpeg;base64,{image}"},
            },
        ],
    )
]

prompt = ChatPromptTemplate.from_messages(messages)

chain = prompt | gpt | StrOutputParser()


image_summaries = chain.batch(images)
pprint(image_summaries[0])

In [None]:
# Set up persistence directory for Chroma
persist_directory = "/content/db"

# The vectorstore to use to index the chunks
vectorstore = Chroma(
    collection_name="multi_modal_rag",
    embedding_function=embeddings,
    client_settings=Settings(persist_directory=persist_directory)
)

# The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"

# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

In [None]:
# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
# print(doc_ids[0:2])

summary_texts = [
    Document(page_content=summary, metadata={id_key: doc_ids[i]}) for i, summary in enumerate(text_summaries)
]

retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

In [None]:
results = retriever.invoke("How to treat meningioma?")
results[0].to_dict()

In [None]:
# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=summary, metadata={id_key: table_ids[i]}) for i, summary in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

# Add image summaries
img_ids = [str(uuid.uuid4()) for _ in images]
summary_img = [
    Document(page_content=summary, metadata={id_key: img_ids[i]}) for i, summary in enumerate(image_summaries)
]
retriever.vectorstore.add_documents(summary_img)
retriever.docstore.mset(list(zip(img_ids, images)))

In [None]:
# Retrieve


In [None]:
# Display the images!

def display_base64_image(base64_code):
  image_data = base64.b64decode(base64_code)
  display(Image(data=image_data))

for i in range(0,4):
  display_base64_image(docs[i])

## Chaining

In [None]:
# Creating the image and text answer


# Testing


# Testing


In [None]:
def build_prompt(kwargs):
  docs_by_type = kwargs["context"]
  user_question = kwargs["question"]

  context_text = ""
  if len(docs_by_type["texts"]) > 0:
      for text_element in docs_by_type["texts"]:
          context_text += text_element.text

  # construct prompt with context (including images)
  prompt_template = f"""
  Answer the question to the best of your abilities based only on the following
  context, which can include text, tables, and the image below.
  Context: {context_text}
  Question: {user_question}
  """

  prompt_content = [{"type": "text", "text": prompt_template}]

  if len(docs_by_type["images"]) > 0:
      for image in docs_by_type["images"]:
          prompt_content.append(
              {
                  "type": "image_url",
                  "image_url": {"url": f"data:image/jpeg;base64,{image}"},
              }
          )

  return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content),])

In [None]:
# First chain 


In [None]:
print(chain.invoke("How to treat meningioma?"))

In [None]:
# Chain with sources/ images and tables


In [None]:
response = chain_with_sources.invoke("Can you show a picture of the brain")

pprint(response)

In [None]:
# response = chain_with_sources.invoke("Show me a picture of the brain")
# response = chain_with_sources.invoke("Show me a picture of craniotomy")
# response = chain_with_sources.invoke("How to treat meningioma")

if response['context']['images'] is not None:
  for image in response['context']['images']:
      display_base64_image(image)
else:
  print("Response:", response['response'])


print("\n\nContext:")
for text in response['context']['texts']:
    print(text.text)
    print("Page number: ", text.metadata.page_number)
    print("\n" + "-"*50 + "\n")
for image in response['context']['images']:
    display_base64_image(image)

In [None]:
if response['context']['images'] is not None:
  print("Response:", response['response'])
  display_base64_image(response['context']['images'][0])
else:
  print("Response:", response['response'])

In [None]:
source = chain_with_sources.invoke("Show me a picture of the brain")
print(response['response'])

display_base64_image(source['context']['images'][0])

## Rag langgraph

In [None]:
from typing import TypedDict, List, Dict, Any
from langgraph.graph.state import StateGraph, START

class RagState(TypedDict):
    """Represents the state of our graph."""
    question: str
    response: str
    images: List[str]

# Generation node
    # """Generate a response using the chain_with_sources."""


def showcase_answers(state: RagState):
    """Showcase all answers, including the response and images."""
    # print("Generated Response:")
    # print(state["response"])

    if state["images"]:
        for i, image in enumerate(state["images"]):
            print(f"Image {i + 1}:")
            display_base64_image(image)
    else:
        print("No diagrams available.")

    return state

# Graph building


result = rag.invoke({"question": "Can you show a picture of the brain?"}, config)

In [None]:
display(Image(rag.get_graph(xray=True).draw_mermaid_png()))

# Parent Graph


## New SQL graph

In [None]:
# LangGraph create our workflow!
class SqlState(TypedDict):
    question: str
    query: str
    result: str
    answer: str

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

# Make sure we only have one message
assert len(query_prompt_template.messages) == 1

# query_prompt_template.messages[0].pretty_print()

# Create our personalised pydantic model
class QueryOutput(TypedDict):
    """Generated SQL query."""
    query: Annotated[str, ..., "Syntactically valid SQL query."]

def write_query(state: SqlState):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 5,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = gpt.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

# print(write_query({"question": "How many rows are there?"}))

def execute_query(state: SqlState):
    """Execute SQL query"""
    execute_query_tool = QuerySQLDataBaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

# print(execute_query(write_query({"question": "How many Employees are there?"})))

def generate_answer(state: SqlState):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = gpt.invoke(prompt)
    return {"answer": response.content}

sql_builder = StateGraph(SqlState).add_sequence(
  [write_query, execute_query, generate_answer]
)

memory = MemorySaver()
sql_builder.add_edge(START, "write_query")
sql = sql_builder.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "1"}}

print(sql.invoke({"question": "How many patients are there?"}, config))

In [None]:
display(Image(sql.get_graph(xray=True).draw_mermaid_png()))

## New RAG Graph

In [None]:
class RagState(TypedDict):
    """Represents the state of our graph."""
    question: str
    answer: str
    images: List[str]

def generation(state: RagState):
    """Generate a response using the chain_with_sources."""
    question = state["question"]

    response = chain_with_sources.invoke(question)

    text = response.get('response', "No response generated")

    images = response.get('context', {}).get('images', [])
    images = images if images else "No diagrams"

    return {
        "answer": text,
        "images": images,
    }

rag_graph = StateGraph(RagState)
rag_graph.add_sequence([generation])
rag_graph.add_edge(START, "generation")

memory = MemorySaver()
rag = rag_graph.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "1"}}

result = rag.invoke({"question": "Can you show a picture of the brain?"}, config)
result['images']

## Final Graph

In [None]:
class ParentGraph(TypedDict):
  """Represents the state of our graph."""
  question: str
  question_type: str
  answer: str
  images: List[str]

class QueryOutput(TypedDict):
    """Generated question type."""
    question_type: Annotated[str, ..., "Syntactically valid question type."]


In [None]:
# Define the parent graph
parent_graph = StateGraph(ParentGraph)

parent_graph.add_node("classify_question", classify_question)
parent_graph.add_node("sql_subgraph", sql_subgraph)
parent_graph.add_node("rag_subgraph", rag_subgraph)

# Add conditional edges based on the routing function
parent_graph.add_conditional_edges("classify_question", route_based_on_question_type)

parent_graph.add_edge(START, "classify_question")

# memory = MemorySaver()
# parent = parent_graph.compile(checkpointer=memory)
# config = {"configurable": {"thread_id": "1"}}

parent = parent_graph.compile()

result = parent.invoke({"question": "How many people are there?"})
# result = parent.invoke({"question": "Show me a picture of the brain"})

# print("Final Answer:", result["answer"])
print("Final Answer:", result)

In [None]:
display(Image(parent.get_graph(xray=True).draw_mermaid_png(), width=1000))

## Gradio UI

In [None]:
import tempfile
import base64
import os

def process_question(user_input):
  result = parent.invoke({"question": user_input})
  answer_text = result["answer"]
  images = result.get("images")

  temp_image_paths = []

  if (
    images
    and len(images) > 0 # Cannot be empty
    and images[0].startswith('/9j/4AAQ') # Must be base64
    and images[0] != 'No diagrams' # Or we can use this to solve everything
  ):
    for idx, img_data in enumerate(images):
        img_data = img_data.split(',')[1] if ',' in img_data else img_data
        img_bytes = base64.b64decode(img_data)
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
        temp_file.write(img_bytes)
        temp_file.close()
        temp_image_paths.append(temp_file.name)

  return answer_text, temp_image_paths

# process_question("Show me a picture of the brain")
# process_question("How many patients are there")
# process_question("Can you give me a breakdown of the number of people with the different type of cancers?")
# process_question("How many patients are there with no tumors detected?")

In [None]:
# Create a Gradio interface with a custom layout

banner_url = "https://i.imgur.com/jn2wz20.png"

with gr.Blocks() as iface:
    # Add the banner image
    gr.Markdown(f"""
    <div style="text-align: center; margin-bottom: 20px;">
        <img src="{banner_url}" alt="Doctor Banner" style="max-width: 100%; height: auto;">
    </div>
    """)

    # Add the title and description
    gr.Markdown("""
    ## Doctor's Question & Answering System
    Enter a question, and the system will classify and process it accordingly. If applicable, images will also be displayed.
    """)

    # Add the input and outputs
    with gr.Row():
        question_input = gr.Textbox(
            label="Ask Your Question", lines=2, placeholder="Enter your question here..."
        )
    with gr.Row():
        answer_output = gr.Textbox(label="Answer", lines=8)
        images_output = gr.Gallery(label="Images", show_label=True)

    # Add the submit button
    submit_btn = gr.Button("Submit")
    submit_btn.click(
        fn=process_question,
        inputs=question_input,
        outputs=[answer_output, images_output]
    )

# Launch the Gradio app
if __name__ == "__main__":
    iface.launch(debug=True)