<a href="https://colab.research.google.com/github/Diangelion/thesis-tutor/blob/main/thesis_tutor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# THESIS TUTOR AI AGENT

## Instalation and Libraries

In [None]:
!pip install -qU langchain==0.3.24 \
                 langchain-core==0.3.55 \
                 langchain-community==0.3.22 \
                 langchain-huggingface==0.1.2 \
                 langchain-chroma==0.2.3 \
                 hf_xet pymupdf pymupdf4llm \
                 firebase-admin sympy pylatexenc \
                 langchain-experimental langgraph \
                 langgraph-checkpoint-sqlite \
                 langsmith transformers tavily-python \
                 presidio-analyzer presidio-anonymizer

In [None]:
#=======================
# Libraries
#=======================
import os
import shutil
import pymupdf
import firebase_admin

from google.colab import userdata
from tqdm import tqdm
from pathlib import Path
from pymupdf4llm import to_markdown
from firebase_admin import credentials, firestore
from sympy import sympify, latex
from pylatexenc.latex2text import LatexNodes2Text
from typing import List
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
from tavily import TavilyClient

from langchain import hub
from langchain.schema import Document
from langchain.agents import AgentExecutor, create_react_agent, AgentOutputParser
from langchain.text_splitter import MarkdownTextSplitter
from langchain_chroma import Chroma
from langchain_core.runnables import (
    ConfigurableFieldSpec,
    RunnablePassthrough,
    RunnableLambda
)

from langchain_huggingface.llms import HuggingFacePipeline
from langchain_huggingface import (
    HuggingFaceEmbeddings,
    HuggingFaceEndpoint,
    ChatHuggingFace
)

from langchain_core.agents import AgentFinish
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import tool
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    AIMessage
)
from langchain_core.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate
)

from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForCausalLM
)

from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine

### Configurations

In [None]:
#=======================
# Configuration
#=======================
CONFIG = {
    # User
    "user_session_id": "testing",
    # Misc
    "base_source_dir": "/content/drive/MyDrive/Skripsi/",
    "persist_chroma_db": "/content/drive/MyDrive/LLM/thesis-tutor/chroma_db",
    "chunk_size": 1000,
    "chunk_overlap": 400,
    "RESET_DB": False,
    # Firebase
    "firebase_key_file": "/content/drive/MyDrive/LLM/thesis-tutor/firebase-admin.json",
    "firebase_collection": "chat_history",
    "firebase_user_id": "id-1",
    # LLM
    "llm_repo_id": "Qwen/Qwen3-1.7B",
    # Embedding
    "embedding_model_name": "sentence-transformers/all-mpnet-base-v2",
    "embedding_model_kwargs": {"device": "cpu"},
    "embedding_encode_kwargs": {"normalize_embeddings": False},
    # Safety
    "safety_model_classifier": "protectai/deberta-v3-base-prompt-injection-v2",
    "safety_max_history_length": 10,
    "safety_pii_entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"],
    "safety_risk_keywords": ["password", "credit card", "social security"],
    "safety_threshold": 0.85
}

## Preparations

In [None]:
#=======================
# ChromaDB Checking
#=======================
def chroma_db_exists(persist_dir: str):
    return os.path.exists(persist_dir) and os.path.isdir(persist_dir)

In [None]:
#=======================
# Source Gathering
#=======================
def get_doc_sources() -> list:
  files_path = []

  # Papers
  papers_dir = CONFIG["base_source_dir"] + "Papers/"
  for file_name in os.listdir(papers_dir):
    if os.path.isfile(os.path.join(papers_dir, file_name)):
      files_path.append(papers_dir + file_name)

  # Books
  books_dir = CONFIG["base_source_dir"] + "Books/"
  for file_name in os.listdir(books_dir):
    if os.path.isfile(os.path.join(books_dir, file_name)):
      files_path.append(books_dir + file_name)

  return files_path

In [None]:
#=======================
# Document Loading
#=======================
def get_doc_loaded(doc_sources: list = []) -> list:
  all_docs = []

  if len(doc_sources) > 0:
    # Process all PDFs
    for file_path in tqdm(doc_sources, desc="Processing PDFs"):
      path = Path(file_path)
      with pymupdf.open(path) as doc:
        # Extract full text as Markdown (preserves math/tables across pages)
        # Process PER PAGE for metadata
        page_md = to_markdown(doc, page_chunks=True, show_progress=False)
        for page in page_md:
          all_docs.append(Document(
              page_content=page["text"],
              metadata={
                  "title": path.name,
                  "page": page["metadata"]["page"],
              }
          ))

    # Now you have:
    # - All pages from all PDFs in `all_docs`
    # - Each entry has source filename + exact page number

  return all_docs

In [None]:
#=======================
# Document Chunking
#=======================
def get_doc_chunked(doc_loaded: list = []) -> list:
  chunks = []

  if len(doc_loaded) > 0:
    # Split with markdown-aware rules
    splitter = MarkdownTextSplitter(
        chunk_size=CONFIG["chunk_size"],
        chunk_overlap=CONFIG["chunk_overlap"]
    )
    chunks = splitter.split_documents(doc_loaded)

    print(f"Metadata example: {chunks[0].metadata}")
    # Should output: {"source": "paper1.pdf", "page": 1, "total_pages": 10}

    print(f"Chunk example: {chunks[0].page_content}")
    # Should show structured text like "## Introduction\n..."

    # Now you have chunks like:
    # [
    #   Document(
    #       page_content="## Introduction\n...",
    #       metadata={"source": "paper1.pdf", "page": 1}
    #   ),
    #   Document(
    #       page_content="| Tool | Accuracy |\n|------|----------|...",
    #       metadata={"source": "paper2.pdf", "page": 3}
    #   )
    # ]

  return chunks

In [None]:
#=======================
# Document Embedding
#=======================
def get_retriever(doc_chunked: list = []):
  embeddings = HuggingFaceEmbeddings(
      model_name=CONFIG["embedding_model_name"],
      model_kwargs=CONFIG["embedding_model_kwargs"],
      encode_kwargs=CONFIG["embedding_encode_kwargs"],
      show_progress=True
  )

  if len(doc_chunked) > 0:
    vector_store = Chroma.from_documents(
        documents=doc_chunked,
        embedding=embeddings,
        persist_directory=CONFIG["persist_chroma_db"]
    )
    print(f"Vector store saved to {CONFIG['persist_chroma_db']}")
  else:
    vector_store = Chroma(
        persist_directory=CONFIG["persist_chroma_db"],
        embedding_function=embeddings
    )
    print(f"Vector store loaded from {CONFIG['persist_chroma_db']}")

  return vector_store.as_retriever(search_kwargs={"k": 10})

In [None]:
# Get the retriever
def prepare_retriever():
  if chroma_db_exists(CONFIG["persist_chroma_db"]):
    return get_retriever()

  doc_sources = get_doc_sources()
  doc_loaded = get_doc_loaded(doc_sources)
  doc_chunked = get_doc_chunked(doc_loaded)
  return get_retriever(doc_chunked)

## Firebase Utils

In [None]:
#=======================
# Firebase Utils
#=======================

# Load conversation history from Firestore
def load_chat_history(
    db: "firestore.Client",
    collection: str,
    user_id: str
) -> List[BaseMessage]:
  """Load chat history from Firestore subcollection"""
  messages = []

  # Reference to the history subcollection
  history_ref = db.collection(collection).document(user_id).collection("history")

  # Query all documents in the subcollection, ordered by timestamp
  docs = history_ref.order_by("timestamp").stream()

  for doc in docs:
    data = doc.to_dict()
    if "human" in data:
      messages.append(HumanMessage(content=data["human"]))
    if "ai" in data:
      messages.append(AIMessage(content=data["ai"]))

  return messages

# Save conversation to Firestore
def save_chat_history(
    db: "firestore.Client",
    collection: str, user_id: str,
    human_msg: str, ai_msg: str
):
  doc_ref = db.collection(collection).document(user_id)
  message_ref = doc_ref.collection("history").document()

  message_ref.set({
    "human": human_msg,
    "ai": ai_msg,
    "timestamp": firestore.SERVER_TIMESTAMP
  })

## CORE

In [None]:
#=======================
# LLM
#=======================
def create_llm():
  # os.environ['HUGGINGFACEHUB_API_KEY'] = userdata.get('HUGGINGFACEHUB_API_TOKEN')
  # llm = HuggingFaceEndpoint(
  #     repo_id=CONFIG["llm_repo_id"],
  #     task="text-generation",
  #     huggingfacehub_api_token=os.environ['HUGGINGFACEHUB_API_KEY'],
  # )
  # ===============================================================================
  # hf = HuggingFacePipeline.from_model_id(
  #     model_id=CONFIG["llm_repo_id"],
  #     task="text-generation"
  # )
  # ===============================================================================
  tokenizer = AutoTokenizer.from_pretrained(CONFIG["llm_repo_id"])
  model = AutoModelForCausalLM.from_pretrained(
      CONFIG["llm_repo_id"],
      torch_dtype="auto",
      device_map="auto"
  )
  pipe = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=32768
  )
  hf = HuggingFacePipeline(pipeline=pipe)
  # ===============================================================================
  chat = ChatHuggingFace(llm=hf)
  return chat

In [None]:
#=======================
# RAG Chain
#=======================
def rag_pipeline(retriever):
  # Prompt template
  system_template = """You are a thesis tutor in Artificial Intelligence or Machine Learning.
  Use the context to answer questions.
  IMPORTANT!
  1. Keep the context metadata with 'title' and 'page' fields along with your answer.
  2. Count the sources used.
  3. Number the sources used, e.g. the first context metadata become number 1, the second is number 2, and so on.
  3. If you don't know the answer, just say that you don't know.

  Context with sources:
  {context}

  Therefore, structure responses as:
  ---------------------------------------------------------------------
  Answer:
  YOUR ANSWER HERE

  Source:
  source_number. title, page
  source_number. title, page
  etc.
  ---------------------------------------------------------------------"""

  # Create prompt
  prompt = ChatPromptTemplate.from_messages([
      ("system", system_template),
      ("human", "{input}")
  ])

  # Create LLM
  llm = create_llm()

  def format_docs_with_metadata(docs):
    formatted = []
    for i, doc in enumerate(docs):
      metadata = doc.metadata
      title = metadata.get("title", "Unknown title")
      page = metadata.get("page", "N/A")
      formatted.append(
          f"Document {i+1} (Title: {title}, Page: {page}):\n{doc.page_content}"
      )
    return "\n\n".join(formatted)

  # Create chain using LCEL
  rag_chain = (
      {
          "context": lambda x: format_docs_with_metadata(retriever.invoke(x["input"])),
          "input": lambda x: x["input"]
      }
      | prompt
      | llm
      | StrOutputParser()
  )

  return rag_chain

## Tools

In [None]:
#=======================
# Equations
#=======================
# Process LaTeX equations from research papers.
# Example input: '$\frac{d}{dx}f(x)$'
@tool
def display_equation(latex_equation: str) -> str:
  """Convert a LaTeX equation into plain‑text and symbolic form."""
  try:
      # Convert LaTeX to plain text for calculation
      plain_eq = LatexNodes2Text().latex_to_text(latex_equation).strip('$')
      expr = sympify(plain_eq)
      return latex(expr)
  except Exception as e:
      return f"Could not process equation: {str(e)}"

#=======================
# Code Generator
#=======================
# Generate code snippets
@tool
def generate_code(task: str) -> str:
  """Generate Python codes if needed based on specific task.
  e.g. 1DCNN code, XGBoost code, etc."""
  client = TavilyClient(api_key="YOUR_TAVILY_API_KEY")
  results = client.search(
      query=f"Python code examples in {task}",
      max_results=5
  )

  prompt = """You are a coder expert in Artificial Intelligence or Machine Learning.
  YOU MUST write Python code examples for {task} using the context.
  IMPORTANT!
  1. Keep the context sources like 'title' and 'url' fields along with your answer.
  2. Count the sources used.
  3. Number the sources used, e.g. the first context metadata become number 1, the second is number 2, and so on.
  3. If you don't know the answer, just say that you don't know.
  Context:
  {context}"""

  llm = create_llm()
  generate_code_chain = (
      {
          "context": lambda x: x["context"],
          "task": lambda x: x["task"]
      }
      | prompt
      | llm
      | StrOutputParser()
  )
  return generate_code_chain.invoke({ "task": task, "context": results.to_string() })

#=======================
# Format Bibliography
#=======================
# Generate code snippets
class BibEntry(TypedDict):
  source_number: int
  paper_title:   str
  page_number:   int
@tool
def make_bibliography(entries: List[BibEntry]) -> str:
    """
    Make bibliography based on multiple paper title and page number of the paper.
    The paper title and page number is from the thesis_expert tool.
    The source number is only for the numbering of the bibliography.

    Args:
      entries: a list of dicts, each with keys:
        - source_number (int)
        - paper_title (str)
        - page_number (int)

    Returns:
      a multi-line bibliography string.
    """
    lines = []
    for item in entries:
        num   = item["source_number"]
        title = item["paper_title"]
        page  = item["page_number"]
        lines.append(f"{num}. {title}, page: {page}")
    return "\n".join(lines)

#=======================
# RAG
#=======================
@tool
def thesis_expert(user_question: str) -> str:
  """
  Need to answer the question based on research paper.
  """
  return rag_chain.invoke({"input": user_question})

## History

In [None]:
# #=======================
# # Firebase History
# #=======================
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
  """In memory implementation of chat message history."""
  messages: list[BaseMessage] = Field(default_factory=list)

  def add_messages(self, messages: list[BaseMessage]) -> None:
    """Add a list of messages to the store"""
    self.messages.extend(messages)

  def clear(self) -> None:
    self.messages = []

In [None]:
# Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results.
store = {}

def get_session_history(
    session_id: str,
    firestore_client: "firestore.Client",
    firebase_collection: str,
    firebase_user_id: str
) -> BaseChatMessageHistory:
  print(f"""Getting session history for:
  Session ID: {session_id}
  Firestore Client: {firestore_client}
  Firebase Collection: {firebase_collection}
  Firebase User ID: {firebase_user_id}
  =============================================================
  """)
  if session_id not in store:
      store[session_id] = InMemoryHistory()

      # Load existing history from Firestore
      firestore_messages = load_chat_history(
          firestore_client,
          firebase_collection,
          firebase_user_id
      )

      # Add loaded messages to the in-memory store
      store[session_id].add_messages(firestore_messages)
  return store[session_id]

## Output Parser

In [None]:
class StrictThesisOutputParser(AgentOutputParser):
  def parse(self, text: str) -> AgentFinish:
    if "Final Answer:" not in text:
      raise ValueError(f"Missing 'Final Answer:' in response:\\n{text}")
    final_answer = text.split("Final Answer:")[-1]
    return AgentFinish(
        return_values={"output": final_answer},
        log=text
    )

## Agent

In [None]:
#=======================
# Agent
#=======================
def create_agent(rag_chain, retriever, db):
  # Utilities needed
  tools = [display_equation, generate_code, make_bibliography, thesis_expert]
  llm = create_llm()

  agent_prompt = """You are a Thesis Tutor Expert in AI/ML. Follow STRICT steps:
  Step 1: You MUST call thesis_expert on the user's question and get the core answer.
  Step 2: Use any of the other tools except thesis_expert to support or beautify that answer.

  Tools available:
  {tools}

  Follow this ReACT format:
  Question: {input}
  Thought: I will first gather the core answer using thesis_expert, KEEP the source and page metadata.
  Action: thesis_expert
  Action Input: {input}
  Observation: <result from thesis_expert>

  Thought: now I have the thesis-based answer. Do I need to generate code/equations/bibliography?
  Action: the action to take, should be one of [{tool_names}], but EXCLUDING the thesis_expert
  Action Input: think what you need to beautify
  Observation: <result>
  ...(you can repeat Thought/Action/Action Input/Observation as needed)

  Thought: I now have all the information needed to answer. I need to return the final answer that ALWAYS starts with 'Final Answer:' and NEVER use JSON/markdown/URLs.
  Final Answer: <comprehensive answer with any citations or references>

  Begin!
  Question: {input}
  Thought: {agent_scratchpad}"""
  agent_prompt = PromptTemplate.from_template(agent_prompt)

  # Create the agent and executor
  agent = create_react_agent(
      llm=llm,
      tools=tools,
      prompt=agent_prompt,
      stop_sequence=True,
      output_parser=StrictThesisOutputParser()
  )
  agent_executor = AgentExecutor.from_agent_and_tools(
      agent=agent,
      tools=tools,
      handle_parsing_errors=True,
      verbose=True
  )

  # Add memory usage into the agent executor
  agent_with_memory = RunnableWithMessageHistory(
      agent_executor,
      get_session_history=get_session_history,
      input_messages_key="input",
      history_messages_key="history",
      output_messages_key="output",
      history_factory_config=[
          ConfigurableFieldSpec(
              id="session_id",
              annotation=str,
              name="Session ID",
              description="Unique identifier for the user session.",
              default="",
              is_shared=True,
          ),
          ConfigurableFieldSpec(
              id="firestore_client",
              annotation="firestore.Client",
              name="Firestore Client",
              description="Firestore client to connect the class with Firebase Firestore.",
              default=None,
              is_shared=True,
          ),
          ConfigurableFieldSpec(
              id="firebase_collection",
              annotation=str,
              name="Firebase Collection",
              description="Collection name in Firebase to load the chat from.",
              default="",
              is_shared=True,
          ),
          ConfigurableFieldSpec(
              id="firebase_user_id",
              annotation=str,
              name="Firebase User ID ",
              description="Used to select data from specific ID in a Firebase collection.",
              default="",
              is_shared=True,
          )
      ]
  )

  return agent_with_memory

## Guardrails & Safety

In [None]:
#=======================
# Safety Utilities
#=======================
# Remove PII from user input
def sanitize_input(
    analyzer,
    anonymizer,
    text: str
) -> str:
  """
  Sanitize user input by removing PII.

  Args:
    analyzer: The Presidio analyzer engine
    anonymizer: The Presidio anonymizer engine
    text: The input text to sanitize

  Returns:
    Sanitized text with PII removed
  """
  if not text or not text.strip():
    return ""

  # Analyze to find entities
  results = analyzer.analyze(
      text=text,
      language="en",
      entities=CONFIG["safety_pii_entities"]
  )

  # Anonymize the identified entities
  anonymized_result = anonymizer.anonymize(
      text=text,  # Make sure to pass the text parameter correctly
      analyzer_results=results  # Pass the analyzer results
  )

  return anonymized_result.text

# Fixed is_safe_input function
def is_safe_input(
    classifier,
    text: str
) -> bool:
  """
  Check if input is safe based on classifier results.

  Args:
    classifier: The classification pipeline
    text: The input text to check

  Returns:
    Boolean indicating if the input is safe
  """
  if not text or not text.strip():
    return True

  try:
    result = classifier(text)[0]
    print(f"Safety check score: {result['score']}")
    return result['score'] >= CONFIG["safety_threshold"]  # Return TRUE if it's safe
  except Exception as e:
    print(f"Safety check error: {e}")
    return False  # Fail closed - if error, assume unsafe

## Main

In [None]:
def main():
  #=======================
  # Reset DB
  #=======================
  if CONFIG["RESET_DB"]:
    shutil.rmtree(CONFIG["persist_chroma_db"], ignore_errors=True)

  #=======================
  # Firebase Setup
  #=======================
  cred = credentials.Certificate(CONFIG["firebase_key_file"])
  # Check if Firebase app is already initialized
  try:
    firebase_admin.get_app()
  except ValueError:
    # Firebase isn't initialized, so initialize it
    firebase_admin.initialize_app(cred)
  db = firestore.client()

  #=======================
  # Preparation
  #=======================
  retriever = prepare_retriever()
  rag_chain = rag_pipeline(retriever)
  agent_with_memory = create_agent(rag_chain, retriever, db)

  #=======================
  # Guardrails & Safety
  #=======================
  analyzer = AnalyzerEngine()
  anonymizer = AnonymizerEngine()
  safe_classifier = pipeline("text-classification", model=CONFIG["safety_model_classifier"])

  print(
      "Welcome to Thesis Tutor!\n"
      "I could provide you informations about:\n"
      "\t1. Explainable AI: SHAP.\n"
      "\t2. XGBoost.\n"
      "\t3. 1DCNN.\n"
      "\t4. Usage of XAI and both model in finance.\n"
      "Type 'exit' to end."
  )

  # user_input = """Tell me the usage of XGBoost in credit scoring.
  # Also, tell me its strength instead of other models like CNN or Logistic Regression!"""

  # #=======================
  # # Conversation
  # #=======================
  # try:
  #   # Input sanitization
  #   clean_input = sanitize_input(analyzer, anonymizer, user_input)

  #   # Content moderation
  #   if not is_safe_input(safe_classifier, clean_input):
  #     return print("I cannot respond to that request")

  #   # Context monitoring
  #   if any(kw in clean_input for kw in CONFIG["safety_risk_keywords"]):
  #     return print("Security alert: Sensitive topic detected")

  #   # Invoke with context and history
  #   response = agent_with_memory.invoke(
  #       { "input": clean_input },
  #       config={
  #           "configurable": {
  #               "session_id": CONFIG["user_session_id"],
  #               "firestore_client": db,
  #               "firebase_collection": CONFIG["firebase_collection"],
  #               "firebase_user_id": CONFIG["firebase_user_id"]
  #           }
  #       }
  #   )
  #   response = response["output"]

  #   # Output Validation
  #   if not is_safe_input(safe_classifier, response):
  #     response = "I cannot provide that information"

  #   return print(f"AI Response:\n{response}")

  while True:
    user_input = input("Your question: ").strip()
    if user_input.lower() in ['exit']:
      break

    try:
      # Input sanitization
      clean_input = sanitize_input(analyzer, anonymizer, user_input)

      # Content moderation
      if not is_safe_input(safe_classifier, clean_input):
        return print("I cannot respond to that request")

      # Context monitoring
      if any(kw in clean_input for kw in CONFIG["safety_risk_keywords"]):
        return print("Security alert: Sensitive topic detected")

      # Invoke with context and history
      response = agent_with_memory.invoke(
          { "input": clean_input },
          config={
              "configurable": {
                  "session_id": CONFIG["user_session_id"],
                  "firestore_client": db,
                  "firebase_collection": CONFIG["firebase_collection"],
                  "firebase_user_id": CONFIG["firebase_user_id"]
              }
          }
      )
      ai_output = response["output"]

      # Output Validation
      if not is_safe_input(safe_classifier, ai_output):
        response = "I cannot provide that information"

      # Add chat history into the firebase
      try:
        save_chat_history(
            db,
            CONFIG["firebase_collection"],
            CONFIG["firebase_user_id"],
            user_input,
            ai_output
        )
      except Exception as e:
        print(f"Error saving history: {e}")

      print(f"AI Response:\n{response}")
    except Exception as e:
      print(f"\n:rotating_light: Error: {e}")

  print("\nThank you for using the thesis tutor!")

In [None]:
if __name__ == "__main__":
  main()