In [2]:
# ==========================
# Cell 1 - Install packages
# ==========================
!pip install -q langgraph langchain langchain-community chromadb sentence-transformers google-generativeai pypdf gTTS tiktoken

# (Colab users may need to restart runtime after large installs; usually not necessary.)

In [15]:
# ==========================
# Cell 2 - Imports & Config
# ==========================
import os
import re
import json
import time
import logging
from typing import List, Dict, Any
from IPython.display import Audio, display
from google.colab import files

# LangGraph & LangChain
from langgraph.graph import StateGraph, START
from langgraph.graph import MessagesState
from langchain_core.messages import HumanMessage, AIMessage

# LangChain doc loaders / embeddings / vectorstores
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

# Generative model & TTS
import google.generativeai as genai
from gtts import gTTS

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("drug-chatbot")

# === Configuration ===
# Replace with your Gemini API key OR set env var GEMINI_API_KEY
GENIE_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyCUKNZaIJ4xTwHm-hq66qm7jYg3xfC2Y18")
genai.configure(api_key=GENIE_API_KEY)

# Vector DB persistence directory (optional). If you want ephemeral (Colab), set to None
VECTORDB_PERSIST_DIR = "chroma_persist"   # choose None for in-memory
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 800
CHUNK_OVERLAP = 120
RETRIEVER_K = 4

# Create directories if needed
if VECTORDB_PERSIST_DIR:
    os.makedirs(VECTORDB_PERSIST_DIR, exist_ok=True)

logger.info("Config set. Ready.")

In [16]:
# =================================
# Cell 3 - Helpers: PDF upload/load
# =================================
def upload_pdfs() -> List[str]:
    """Open Colab filepicker and return list of uploaded filenames."""
    uploaded = files.upload()
    filenames = list(uploaded.keys())
    logger.info("Uploaded files: %s", filenames)
    return filenames

def load_pdfs_as_docs(pdf_paths: List[str]):
    """Load PDF pages with PyPDFLoader, attach metadata (source filename & page)."""
    docs = []
    for path in pdf_paths:
        loader = PyPDFLoader(path)
        loaded = loader.load()  # returns one Document per page (with .page_content and metadata['page'])
        for d in loaded:
            if d.metadata is None:
                d.metadata = {}
            d.metadata["source_file"] = os.path.basename(path)
            docs.append(d)
    logger.info("Loaded %d pages from %d PDFs.", len(docs), len(pdf_paths))
    return docs

In [17]:
# ===========================================
# Cell 4 - Build / update vector index (Chroma)
# ===========================================
def build_vector_index(documents, persist_dir: str = VECTORDB_PERSIST_DIR):
    """Chunk docs, embed, and save to Chroma vectorstore."""
    logger.info("Splitting into chunks (chunk_size=%d overlap=%d)...", CHUNK_SIZE, CHUNK_OVERLAP)
    splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
    chunks = splitter.split_documents(documents)
    logger.info("Total chunks created: %d", len(chunks))

    embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)

    logger.info("Creating/Loading Chroma vectorstore (persist_dir=%s)...", persist_dir)
    if persist_dir:
        vectordb = Chroma.from_documents(chunks, embedding=embeddings, persist_directory=persist_dir)
        vectordb.persist()
    else:
        vectordb = Chroma.from_documents(chunks, embedding=embeddings)
    retriever = vectordb.as_retriever(search_kwargs={"k": RETRIEVER_K})
    logger.info("Vectorstore ready.")
    return vectordb, retriever

# If you already have PDFs loaded and want to build index:
# pdf_files = upload_pdfs(); docs = load_pdfs_as_docs(pdf_files); vectordb, retriever = build_vector_index(docs)

In [18]:
# ===========================================
# Cell 5 - Section extraction helper (best-effort)
# ===========================================
SECTION_HEADERS = [
    "INDICATIONS", "INDICATION", "DOSAGE", "DOSAGE AND ADMINISTRATION", "ADMINISTRATION",
    "CONTRAINDICATIONS", "WARNINGS", "WARNINGS AND PRECAUTIONS", "PRECAUTIONS",
    "ADVERSE REACTIONS", "ADVERSE EVENTS", "TOXICOLOGY",
    "DRUG INTERACTIONS", "INTERACTIONS", "USE IN SPECIFIC POPULATIONS", "PREGNANCY", "PEDIATRIC", "GERIATRIC"
]

header_re = re.compile(r"^\s*([A-Z][A-Z0-9\s&/-]{3,})\s*$", re.M)

def extract_sections(text: str) -> Dict[str, str]:
    """
    Best-effort split of a chunk into likely PI headings.
    Returns dict mapping header->text block.
    """
    sections = {}
    # naive approach: find heading lines and split
    matches = list(header_re.finditer(text))
    if not matches:
        # fallback: try to find known header words anywhere
        for h in SECTION_HEADERS:
            pat = re.compile(rf"({h}[:\s\n])", re.I)
            m = pat.search(text)
            if m:
                # take a window after header
                start = m.start()
                end = start + 800  # slice short window
                sections[h] = text[start:end].strip().replace("\n", " ")
        return sections

    # build sections from headings
    spans = []
    for i, m in enumerate(matches):
        start = m.start()
        end = matches[i+1].start() if i+1 < len(matches) else len(text)
        heading = m.group(1).strip()
        block = text[start:end].strip().replace("\n", " ")
        sections[heading] = block
    return sections

In [21]:
# ===========================================
# Cell 6 - LangGraph state & nodes
# ===========================================
from typing import TypedDict

# Define state type (inherits messages state for history convenience)
class BotState(MessagesState):
    query: str
    retrieved: List[Dict[str, Any]]
    model_response_raw: str
    answer: str
    citations: List[str]
    tts_file: str

# Node: accept query & add HumanMessage to history
def node_accept_query(state: BotState):
    q = state.get("query", "").strip()
    if not q:
        raise ValueError("No query provided (state['query'] is empty).")
    # append human message to messages history
    human = HumanMessage(content=q)
    return {"messages": [human]}

# Node: retrieval
def node_retrieve(state: BotState):
    q = state["query"]
    hits = retriever.get_relevant_documents(q)
    retrieved = []
    for i, d in enumerate(hits):
        source = d.metadata.get("source_file", "unknown.pdf")
        page = d.metadata.get("page", "N/A")
        snippet = d.page_content.strip().replace("\n", " ")
        label = f"[Source {i+1}: {source} | Page {page}]"
        # attempt to extract sections from the snippet for improved citation later
        sections = extract_sections(d.page_content)
        retrieved.append({
            "label": label,
            "text": snippet,
            "source": source,
            "page": page,
            "sections": sections
        })
    return {"retrieved": retrieved}

# Node: generate (call Gemini with JSON-enforcement)
def node_generate(state: BotState):
    q = state["query"]
    retrieved = state.get("retrieved", [])
    # Build context string (label + snippet). Keep snippets reasonably sized.
    context_pieces = []
    for r in retrieved:
        # cap snippet length
        text_snip = (r["text"][:1200] + "...") if len(r["text"]) > 1200 else r["text"]
        context_pieces.append(f"{r['label']}\n{text_snip}")
    context = "\n\n".join(context_pieces)

    system_instructions = (
        "You are a drug information assistant for caregivers. Use ONLY the provided context for factual claims. "
        "Be concise and user-friendly. If the question requires clinical judgement or is ambiguous, recommend consulting a healthcare professional. "
        "Return STRICT JSON with keys: 'answer' (string) and 'citations' (list of citation strings). "
        "Each citation must reference labels from the context (e.g., '[Source 1: file.pdf | Page 3]')."
    )

    prompt = f"""{system_instructions}

Context:
{context}

Conversation history (most recent messages):
{state.get('messages', [])}

User question:
{q}

Return only valid JSON. Example:
{{"answer":"...","citations":["[Source 1: file.pdf | Page 3]"]}}
"""

    # Call Gemini synchronously
    logger.info("Calling Gemini for question: %s", q[:120])
    resp = genai.GenerativeModel("gemini-2.5-flash").generate_content(prompt)
    raw_text = resp.text
    # try to extract JSON
    m = re.search(r"\{.*\}", raw_text, re.S)
    if m:
        try:
            parsed = json.loads(m.group())
            answer = parsed.get("answer", "").strip()
            citations = parsed.get("citations", [])
        except Exception as e:
            logger.warning("JSON parse failed: %s", e)
            answer = raw_text.strip()
            citations = []
    else:
        answer = raw_text.strip()
        citations = []

    # append assistant message to history
    ai_msg = AIMessage(content=answer)
    return {
        "model_response_raw": raw_text,
        "answer": answer,
        "citations": citations,
        "messages": [ai_msg]
    }

# Node: produce TTS
def node_tts(state: BotState):
    answer = state.get("answer", "")
    if not answer:
        return {}
    fname = f"drug_answer_{int(time.time())}.mp3"
    try:
        tts = gTTS(text=answer, lang="en")
        tts.save(fname)
        logger.info("Saved TTS to %s", fname)
        return {"tts_file": fname}
    except Exception as e:
        logger.warning("TTS failed: %s", e)
        return {}

In [22]:
# ===========================================
# Cell 7 - Build and compile the LangGraph
# ===========================================
# Create a builder-style graph similar to earlier examples
builder = StateGraph(BotState)

# register nodes (names must be unique)
builder.add_node("accept_query", node_accept_query)
builder.add_node("retrieve", node_retrieve)
builder.add_node("generate", node_generate)
builder.add_node("tts", node_tts)

# Wire edges: START -> accept_query -> retrieve -> generate -> tts -> END
builder.add_edge(START, "accept_query")
builder.add_edge("accept_query", "retrieve")
builder.add_edge("retrieve", "generate")
builder.add_edge("generate", "tts")

graph = builder.compile()
logger.info("LangGraph compiled.")

In [23]:
# ===========================================
# Cell 8 - Persistent session state & API
# ===========================================
# persistent_state preserves messages across queries in this notebook session
persistent_state = {
    "messages": [],        # stores HumanMessage/AIMessage as objects (MessagesState-aware)
    "query": "",
    "retrieved": [],
    "model_response_raw": "",
    "answer": "",
    "citations": [],
    "tts_file": ""
}

def ask_drug_bot(query: str, keep_history: bool = True) -> Dict[str, Any]:
    """Invoke the LangGraph with a query. Returns the resulting state."""
    if keep_history:
        init_state = persistent_state.copy()
    else:
        init_state = {"messages": [], "query": "", "retrieved": [], "model_response_raw": "", "answer": "", "citations": [], "tts_file": ""}

    init_state["query"] = query
    result = graph.invoke(init_state)
    # merge back into persistent_state for next round
    persistent_state.update(result)
    # Print nicely
    print("\n--- Bot Answer ---")
    print(result.get("answer", ""))
    if result.get("citations"):
        print("\n--- Citations ---")
        for c in result["citations"]:
            print("-", c)
    # Play TTS if available
    tts_path = result.get("tts_file")
    if tts_path and os.path.exists(tts_path):
        display(Audio(tts_path, autoplay=True))
    return result

# Example:
# ask_drug_bot("Summarize the main indications and primary safety warnings in the uploaded PDFs.")

In [24]:
# ===========================================
# Cell 9 - Admin helpers (upload additional PDFs / rebuild index)
# ===========================================
def add_more_pdfs_and_reindex():
    new_files = upload_pdfs()
    docs = load_pdfs_as_docs(new_files)
    # if vectordb exists persist dir in memory, load existing docs from that directory vs re-create entirely
    global vectordb, retriever
    try:
        # load any existing docs in the vectorstore by re-building with all docs
        # For simplicity here we will re-build vectorstore with newly loaded docs + existing ones if we kept the 'all_docs' list
        # If you kept `all_docs` earlier, append new docs and rebuild.
        all_docs.extend(docs)
    except NameError:
        # if all_docs not defined previously, create it
        globals()["all_docs"] = docs.copy()
    vectordb, retriever = build_vector_index(all_docs, persist_dir=VECTORDB_PERSIST_DIR)
    logger.info("Rebuilt vector index with new files.")

# Usage:
# add_more_pdfs_and_reindex()

In [25]:
# ===========================================
# Cell 10 - Full example flow (run after building index)
# ===========================================
# 1) Upload PDFs and build vector index (run once)
pdf_files = upload_pdfs()
all_docs = load_pdfs_as_docs(pdf_files)
vectordb, retriever = build_vector_index(all_docs, persist_dir=VECTORDB_PERSIST_DIR)

# 2) Interact with the bot
ask_drug_bot("What is the recommended adult dosage and administration for YONDELIS (trabectedin)?")
ask_drug_bot("What are the contraindications and main warnings for this drug?")
ask_drug_bot("Is there guidance for pregnant patients?")

Saving YONDELIS.pdf to YONDELIS (1).pdf

--- Bot Answer ---
The recommended dosage for YONDELIS (trabectedin) is 1.5 mg/m2. It is administered as an intravenous infusion over 24 hours through a central venous line every 21 days (3 weeks), continuing until disease progression or unacceptable toxicity.

--- Citations ---
- [Source 3: YONDELIS.pdf | Page 1]
- [Source 4: YONDELIS (1).pdf | Page 1]



--- Bot Answer ---



--- Bot Answer ---
Yes, there is guidance for pregnant patients. YONDELIS can potentially harm an unborn baby, and pregnant women should be advised of this risk. Patients should not become pregnant during treatment with YONDELIS. If a female becomes pregnant or suspects pregnancy while on treatment, she should contact her healthcare provider.

--- Citations ---
- [Source 1: YONDELIS.pdf | Page 17]
- [Source 2: YONDELIS (1).pdf | Page 17]
- [Source 3: YONDELIS.pdf | Page 18]
- [Source 4: YONDELIS (1).pdf | Page 18]


{'messages': [HumanMessage(content='What is the recommended adult dosage and administration for YONDELIS (trabectedin)?', additional_kwargs={}, response_metadata={}, id='1be2d528-ecb9-4f14-b8b2-e7c3aa6d842f'),
  AIMessage(content='The recommended dosage for YONDELIS (trabectedin) is 1.5 mg/m2. It is administered as an intravenous infusion over 24 hours through a central venous line every 21 days (3 weeks), continuing until disease progression or unacceptable toxicity.', additional_kwargs={}, response_metadata={}, id='f3895e25-9a7f-41ab-bcbe-550042bb84c5'),
  HumanMessage(content='Is there guidance for pregnant patients?', additional_kwargs={}, response_metadata={}, id='92d69132-433a-477a-9a39-22a99fdab3ba'),
  AIMessage(content='Yes, there is guidance for pregnant patients. YONDELIS can potentially harm an unborn baby, and pregnant women should be advised of this risk. Patients should not become pregnant during treatment with YONDELIS. If a female becomes pregnant or suspects pregnancy