# GraphRAG

## Import packages

In [80]:
! pip install nltk numpy pandas unidecode scikit-learn tqdm llm-blender
! pip install langchain langchain-core langchain-community langchain_experimental langchain-openai langchain-chroma langchain_mistralai langgraph

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [81]:
import os
import re
import nltk
import string
import numpy as np
import pandas as pd
from unidecode import unidecode
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from pathlib import Path
import pickle
from rouge_score import rouge_scorer
import json
import llm_blender
from operator import itemgetter
from dotenv import load_dotenv
from getpass import getpass
from typing import List
from typing_extensions import TypedDict
from langchain.schema import Document

from langchain_community.document_loaders import PDFMinerLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain.embeddings.cache import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.llms import Ollama
from langgraph.graph import END, StateGraph
from langchain_core.output_parsers import PydanticOutputParser
from langchain.output_parsers import RetryOutputParser
from typing import Literal
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableLambda, RunnableParallel
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

## Disable warnings

In [82]:
import warnings
warnings.filterwarnings('ignore')

## Setup environment variables

You have to define the following environment variables in the `.env` file, terminal environment, or input field within this Jupyter notebook:
1. MISTRAL_API_KEY
2. OPENAI_API_KEY

## Import packages

In [83]:
env_variables = [
  'MISTRAL_API_KEY',
  'OPENAI_API_KEY',
]

load_dotenv()

for key in env_variables:
  value = os.getenv(key)

  if value is None:
    value = getpass(key)

  os.environ[key] = value

## Setup metrics

### Download NLTK dictionaries

These dictionaries are needed for further text preprocessing.

In [84]:
dict_ids = [
  'punkt_tab',
  'punkt',
  'stopwords',
  'wordnet',
]

for dict_id in dict_ids:
  nltk.download(dict_id, quiet=True)

### Text preprocessing

Define a function for text preprocessing, which is an important step before calculating any metrics. This preprocessing function will help in cleaning the text data, making it ready for further analysis. The preprocessing involves several steps:
1. Lowercasing
2. Stopwords removal
3. Lemmatization
4. Remove accents from characters

In [85]:
lemmatizer = nltk.stem.WordNetLemmatizer()

def preprocess(corpus: str) -> str:
  corpus = corpus.lower()
  stopset = nltk.corpus.stopwords.words('english') + nltk.corpus.stopwords.words('russian') + list(string.punctuation)
  tokens = nltk.word_tokenize(corpus)
  tokens = [t for t in tokens if t not in stopset]
  tokens = [lemmatizer.lemmatize(t) for t in tokens]
  corpus = ' '.join(tokens)
  corpus = unidecode(corpus)
  return corpus

### Embedding Initialization

Here we are initializing the Llama 3 embeddings model. The `OllamaEmbeddings` class is a component of the Ollama library, a set of pre-trained language models. This model is capable of embedding corpora of any length into a 4096-dimensional vector.

The use of `OllamaEmbeddings` requires the installation of a local Ollama server, which can be found at https://ollama.com.

In [86]:
embeddings = OllamaEmbeddings(model='llama3.1')
store = LocalFileStore("./.embeddings_cache")

cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
  embeddings,
  store,
  namespace=embeddings.model,
)

### Average embeddings cosine similarity metric

This function calculates the average cosine similarity between expected answers and LLM predicted answers using their respective embeddings. Cosine similarity is a measure of similarity between two non-zero vectors of an inner product space that measures the cosine of the angle between them:

$$
K(a, b) = \frac{\sum \limits_{i=1}^n a_i b_i}{\sqrt{\sum \limits_{i=1}^n a_i^2} \cdot \sqrt{\sum \limits_{i=1}^n b_i^2}}
$$

In [87]:
def embeddings_cosine_sim_metric(expected_answers: list[str], predicted_answers: list[str]) -> float:
  results = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    expected_embedding = np.array(cached_embeddings.embed_query(expected_answer))
    predicted_embedding = np.array(cached_embeddings.embed_query(predicted_answer))

    sim = cosine_similarity(
      expected_embedding.reshape(1, -1),
      predicted_embedding.reshape(1, -1),
    )[0][0]

    results.append(sim)

  return np.mean(results)

In [88]:
smoothie_f = nltk.translate.bleu_score.SmoothingFunction().method4

def bleu_metric(expected_answers, predicted_answers):
  scores = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    predicted_tokens = nltk.word_tokenize(predicted_answer)
    expected_tokens = [nltk.word_tokenize(expected_answer)]

    score = nltk.translate.bleu_score.sentence_bleu(
      expected_tokens,
      predicted_tokens,
      smoothing_function=smoothie_f,
    )

    scores.append(score)

  return np.mean(scores)

In [89]:
rogue_1_scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)

def rogue_1_metric(expected_answers, predicted_answers):
  scores = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    result = rogue_1_scorer.score(expected_answer, predicted_answer)

    scores.append(result['rouge1'])

  return np.mean(scores)

In [90]:
rogue_l_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def rogue_l_metric(expected_answers, predicted_answers):
  scores = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    result = rogue_l_scorer.score(expected_answer, predicted_answer)

    scores.append(result['rougeL'])

  return np.mean(scores)

## Load documents

In [91]:
docs_dir = Path('./docs')
docs_cache_dir = Path('./.docs_cache')
raw_docs_pkl_path = docs_cache_dir / 'parsed_docs_cache.pkl'

docs = None

if os.path.exists(raw_docs_pkl_path):
  with open(raw_docs_pkl_path, 'rb') as f:
    docs = pickle.load(f)
else:
  docs = []
  for file in docs_dir.iterdir():
    docs.extend(PDFMinerLoader(file).load())

  with open(raw_docs_pkl_path, 'wb') as f:
    pickle.dump(docs, f)

## Split documents

In [92]:
splitted_docs_pkl_path = docs_cache_dir / 'splitted_docs_cache.pkl'

if os.path.exists(splitted_docs_pkl_path):
  with open(splitted_docs_pkl_path, 'rb') as f:
    splitted_docs = pickle.load(f)
else:
  text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=750,
    chunk_overlap=250,
    length_function=len,
    is_separator_regex=False,
  )
  splitted_docs = text_splitter.create_documents([doc.page_content for doc in docs])

  with open(splitted_docs_pkl_path, 'wb') as f:
    pickle.dump(splitted_docs, f)

len(splitted_docs)

17443

## Setup vector store

In [93]:
vector_store = Chroma(
  embedding_function=cached_embeddings,
  persist_directory='./chroma_db'
)
retriever = vector_store.as_retriever()

## Define JSON extractor

In [94]:
def extract_json(response):
  json_pattern = r'\{.*?\}'
  match = re.search(json_pattern, response, re.DOTALL)

  if match:
    return match.group().strip()

  return response

## Build LLM

In [95]:
llm = Ollama(model='llama3.1', temperature=0)

## Build chains

### Route chain

In [96]:
class RouteQuery(BaseModel):
  data_source: Literal['vectorstore', 'websearch'] = Field(
    description='Given a user question choose to route it to web search or a vectorstore.',
  )

route_parser = PydanticOutputParser(pydantic_object=RouteQuery)
route_retry_parser = RetryOutputParser.from_llm(
  parser=route_parser,
  llm=llm,
  max_retries=3,
)

route_template = """
You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to neurobiology and medicine.
Use the vectorstore for questions on these topics. For all else, use web-search.

{format_instructions}

{question}
"""
route_prompt = PromptTemplate(
  template=route_template,
  input_variables=['question'],
  partial_variables={'format_instructions': route_parser.get_format_instructions()},
)

question_router = RunnableParallel(
  completion=route_prompt | llm | extract_json, prompt_value=route_prompt
) | RunnableLambda(lambda x: route_retry_parser.parse_with_prompt(**x))
print(question_router.invoke({'question': 'Who will the Bears draft first in the NFL draft?'}))
print(question_router.invoke({'question': 'What is the order of the cranial nerves?'}))

data_source='websearch'
data_source='vectorstore'


### Grade documents chain

In [97]:
class GradeDocuments(BaseModel):
  binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

docs_grade_parser = PydanticOutputParser(pydantic_object=GradeDocuments)
docs_grade_retry_parser = RetryOutputParser.from_llm(
  parser=docs_grade_parser,
  llm=llm,
  max_retries=3,
)

docs_grade_template = """
You are a grader assessing relevance of a retrieved document to a user question.
If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.

{format_instructions}

User question:
{question}

Retrieved document:
{document}
"""
docs_grade_prompt = PromptTemplate(
  template=docs_grade_template,
  input_variables=['document', 'question'],
  partial_variables={'format_instructions': docs_grade_parser.get_format_instructions()},
)

docs_grade_grader = RunnableParallel(
  completion=docs_grade_prompt | llm | extract_json, prompt_value=docs_grade_prompt
) | RunnableLambda(lambda x: docs_grade_retry_parser.parse_with_prompt(**x))
docs_grade_grader.invoke({'question': 'What is the color of the sky?', 'document': 'The color of the sky is blue'})

GradeDocuments(binary_score='yes')

### Hallucinations chain

In [98]:
class GradeHallucinations(BaseModel):
  binary_score: str = Field(description="Answer is grounded in the facts, 'yes' or 'no'")

hallucination_parser = PydanticOutputParser(pydantic_object=GradeHallucinations)
hallucination_retry_parser = RetryOutputParser.from_llm(
  parser=hallucination_parser,
  llm=llm,
  max_retries=3,
)

hallucination_template = """
You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."

{format_instructions}

Set of facts:
{documents}

LLM generation:
{generation}
"""
hallucination_prompt = PromptTemplate(
  template=hallucination_template,
  input_variables=['question', 'generation'],
  partial_variables={'format_instructions': hallucination_parser.get_format_instructions()},
)

hallucination_grader = RunnableParallel(
  completion=hallucination_prompt | llm | extract_json, prompt_value=hallucination_prompt
) | RunnableLambda(lambda x: hallucination_retry_parser.parse_with_prompt(**x))
print(hallucination_grader.invoke({'documents': ['Sky is blue'], 'generation': 'The color of the sky is blue'}))

binary_score='yes'


### Answer grade chain

In [99]:
class GradeAnswer(BaseModel):
  binary_score: str = Field(description="Answer addresses the question, 'yes' or 'no'")

grade_parser = PydanticOutputParser(pydantic_object=GradeAnswer)
grade_retry_parser = RetryOutputParser.from_llm(
  parser=grade_parser,
  llm=llm,
  max_retries=3,
)

grade_template = """
You are a grader assessing whether an answer addresses / resolves a question. \n
Give a binary score 'yes' or 'no'. 'yes' means that the answer resolves the question.

User question:
{question}

LLM generation:
{generation}

{format_instructions}
"""
grade_prompt = PromptTemplate(
  template=grade_template,
  input_variables=['question', 'generation'],
  partial_variables={'format_instructions': grade_parser.get_format_instructions()},
)

answer_grader = RunnableParallel(
  completion=grade_prompt | llm | extract_json, prompt_value=grade_prompt
) | RunnableLambda(lambda x: grade_retry_parser.parse_with_prompt(**x))
print(answer_grader.invoke({"question": "What is the order of the cranial nerves?", 'generation': 'I do not know.'}))

binary_score='no'


### HyDE chain

In [100]:
hyde_template = """
Please write a scientific paper passage to answer the question

Question: {question}

Passage:
"""
hyde_prompt = ChatPromptTemplate.from_template(hyde_template)
hyde_chain = hyde_prompt | llm | StrOutputParser()

hyde_chain.invoke({"question": 'What is the order of the cranial nerves ?'})

"Here's a scientific paper-style passage answering the question:\n\n**Title:** The Cranial Nerve Plexus: A Review of the Anatomical and Neurological Organization\n\n**Abstract:**\n\nThe cranial nerves, a complex network of 12 pairs of nerves that arise from the brainstem, play a crucial role in regulating various physiological functions. Understanding their organization is essential for appreciating the intricate relationships between the central nervous system and the peripheral nervous system. This review aims to provide an overview of the order of the cranial nerves, highlighting their anatomical and neurological characteristics.\n\n**Introduction:**\n\nThe cranial nerves are a group of nerves that emerge from the brainstem, which includes the midbrain, pons, and medulla oblongata. These nerves are responsible for controlling various functions such as sensation, movement, and autonomic regulation. The order of the cranial nerves is a fundamental concept in neuroanatomy, and their co

### Step-back

In [101]:
step_back_template = """
You are an AI assistant tasked with generating broader, more general queries to improve context retrieval in a RAG system.
Given the original query, generate a step-back query that is more general and can help retrieve relevant background information.

Original query: {question}

Step-back query:
"""
step_back_prompt = ChatPromptTemplate.from_template(step_back_template)
step_back_chain = step_back_prompt | llm | StrOutputParser()

step_back_chain.invoke({"question": 'What is Benedict’s syndrome?'})

'What is the definition of a neurological or medical syndrome?'

### Query Rewriting

In [102]:
rewrite_query_template = """
You are an AI assistant tasked with reformulating user queries to improve retrieval in a RAG system.
Given the original query, rewrite it to be more specific, detailed, and likely to retrieve relevant information.

Original query: {question}

Rewritten query:
"""
rewrite_query_prompt = ChatPromptTemplate.from_template(rewrite_query_template)
rewrite_query_chain = rewrite_query_prompt | llm | StrOutputParser()

rewrite_query_chain.invoke({"question": 'What is the order of the cranial nerves?'})

'Here\'s a rewritten version of the query that\'s more specific, detailed, and likely to retrieve relevant information:\n\n"What is the correct anatomical order of the 12 pairs of cranial nerves, including their names (I-XII), functions, and any notable characteristics or associations with specific brain regions or structures?"\n\nThis revised query adds specificity by:\n\n* Specifying the number of cranial nerves (12)\n* Including their names (I-XII) for clarity\n* Mentioning their functions to provide context\n* Asking about notable characteristics or associations to encourage retrieval of relevant information\n\nBy making these changes, the rewritten query is more likely to retrieve accurate and detailed information from a RAG system.'

### Decomposition

In [103]:
class DecompositionAnswer(BaseModel):
  subqueries: List[str] = Field(description="Given the original query, decompose it into 2-4 simpler sub-queries as json array of strings")

decomposition_parser = PydanticOutputParser(pydantic_object=DecompositionAnswer)
decomposition_retry_parser = RetryOutputParser.from_llm(
  parser=decomposition_parser,
  llm=llm,
  max_retries=3,
)

decomposition_template = """
You are an AI assistant tasked with breaking down complex queries into simpler sub-queries for a RAG system.
Given the original query, decompose it into 2-4 simpler sub-queries that, when answered together, would provide a comprehensive response to the original query.

Original query: {question}

example: What are the impacts of climate change on the environment?

Sub-queries:
1. What are the impacts of climate change on biodiversity?
2. How does climate change affect the oceans?
3. What are the effects of climate change on agriculture?
4. What are the impacts of climate change on human health?

{format_instructions}
"""
decomposition_prompt = PromptTemplate(
  template=decomposition_template,
  input_variables=['question'],
  partial_variables={'format_instructions': decomposition_parser.get_format_instructions()},
)

decomposition_chain = RunnableParallel(
  completion=decomposition_prompt | llm | extract_json, prompt_value=decomposition_prompt
) | RunnableLambda(lambda x: decomposition_retry_parser.parse_with_prompt(**x))
print(decomposition_chain.invoke({"question": "What is the order of the cranial nerves?"}))

subqueries=['What are the names of all cranial nerves?', 'Which cranial nerve has the highest number of branches?', 'What is the sequence of cranial nerves in relation to their function (e.g., sensory, motor, mixed)?', 'Are there any notable exceptions or variations in the order of cranial nerves across different species?']


### RAG chain

In [104]:
blender = llm_blender.Blender()
blender.loadranker('llm-blender/PairRM', device='mps')
blender.loadfuser('llm-blender/gen_fuser_3b', device='mps')



Successfully loaded ranker from  /Users/vladimirskvortsov/.cache/huggingface/hub/llm-blender/PairRM


In [105]:
prompt = hub.pull('rlm/rag-prompt')

llama_llm = Ollama(model='llama3.1', temperature=0)
mistral_llm = ChatMistralAI(model='mistral-large-latest', temperature=0)
gpt_llm = ChatOpenAI(model='gpt-4o-mini', temperature=0)

llama_chain = prompt | llama_llm | StrOutputParser()
mistral_chain = prompt | mistral_llm | StrOutputParser()
gpt_chain = prompt | gpt_llm | StrOutputParser()

def fuse_generations(dict):
  question = dict['question']

  llama_res = dict['llama_res']
  mistral_res = dict['mistral_res']
  gpt_res = dict['gpt_res']
  answers = [llama_res, mistral_res, gpt_res]

  fuse_generations, ranks = blender.rank_and_fuse(
    [question],
    [answers],
    instructions=[''],
    return_scores=False,
    batch_size=2,
    top_k=3
  )
  return fuse_generations[0]


rag_chain = (
  {
    'llama_res': llama_chain,
    'mistral_res': mistral_chain,
    'gpt_res': gpt_chain,
    'question': itemgetter('question')
  }
  | RunnableLambda(fuse_generations)
)

rag_chain.invoke({"context": '', "question": 'What is the order of the cranial nerves ?'})

Ranking candidates: 100%|██████████| 1/1 [00:04<00:00,  4.90s/it]
Fusing candidates: 100%|██████████| 1/1 [00:21<00:00, 21.47s/it]


'The order of the cranial nerves is as follows: I. Olfactory II. Optic III. Oculomotor IV. Trochlear V. Trigeminal VI. Abducens VII. Facial VIII. Auditory (or vestibulocochlear) nerve IX. Glossopharyngeal X. Vagus XI. Spinal accessory XII. Hypoglossal.'

### Web search chain

In [106]:
from langchain_community.tools.tavily_search import TavilySearchResults
os.environ['TAVILY_API_KEY'] = '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'
web_search_tool = TavilySearchResults(k=5)

## Build graph app

In [107]:
class GraphState(TypedDict):
  question: str
  step_back_query: str
  rewritten_query: str
  subqueries: List[str]
  generated_docs: str
  documents: List[str]
  web_search: str
  generation: str
  generations_num: int

def route_question(state):
  print('---ROUTE QUESTION---')

  question = state['question']

  source = question_router.invoke({'question': question})

  if source.data_source == 'websearch':
    print('---ROUTE QUESTION TO WEB SEARCH---')
    return 'websearch'
  elif source.data_source == 'vectorstore':
    print('---ROUTE QUESTION TO VECTOR STORE---')
    return 'vectorstore'

def generate_step_back_query(state):
  print('---GENERATE STEP-BACK QUERY---')

  question = state['question']

  step_back_query = step_back_chain.invoke({'question': question})

  return {'step_back_query': step_back_query}

def generate_rewritten_query(state):
  print('---GENERATE REWRITTEN QUERY---')

  question = state['question']

  rewritten_query = rewrite_query_chain.invoke({'question': question})

  return {'rewritten_query': rewritten_query}

def generate_subqueries(state):
  print('---GENERATE SUBQUERIES---')

  question = state['question']

  try:
    decomposition_answer = decomposition_chain.invoke({'question': question})
    subqueries = decomposition_answer.subqueries
    # Limit to a maximum of four subqueries
    subqueries = subqueries[:4]
  except:
    subqueries = []

  print(f'---FINAL SUBQUERIES NUMBER: {len(subqueries)}---')

  return {'subqueries': subqueries}

def generate_hyde_docs(state):
  print('---GENERATE HYDE DOCUMENTS---')

  question = state['question']
  step_back_query = state['step_back_query']
  rewritten_query = state['rewritten_query']
  subqueries = state['subqueries']

  print('question', question)
  print('step_back_query', step_back_query)
  print('rewritten_query', rewritten_query)
  print('subqueries', subqueries)

  queries = [question, step_back_query, rewritten_query, *subqueries]
  generated_docs = []

  for query in queries:
    generated_doc = hyde_chain.invoke({'question': query})
    generated_docs.append(generated_doc)

  return {'question': question, 'generated_docs': generated_docs}

def retrieve(state):
  print('---RETRIEVE---')

  question = state['question']
  generated_docs = state['generated_docs']

  docs = []

  for generated_doc in generated_docs:
    docs.extend(retriever.invoke(generated_doc))

  unique_docs = []
  seen_contents = set()

  for doc in docs:
    if doc.page_content in seen_contents:
      continue

    unique_docs.append(doc)
    seen_contents.add(doc.page_content)


  return {'question': question, 'documents': unique_docs}

def grade_documents(state):
  print('---CHECK DOCUMENT RELEVANCE TO QUESTION---')

  question = state['question']
  documents = state['documents']

  # Score each doc
  filtered_docs = []
  web_search = 'No'
  for index, d in enumerate(documents):
    print(f'---GRADE DOCUMENT ({index + 1}/{len(documents)})---')
    try:
      score = docs_grade_grader.invoke({'question': question, 'document': d.page_content})
      grade = score.binary_score
    except:
      grade = 'No'
    # Document relevant
    if grade.lower() == 'yes':
      print('---GRADE: DOCUMENT RELEVANT---')
      filtered_docs.append(d)
    # Document not relevant
    else:
      print('---GRADE: DOCUMENT NOT RELEVANT---')
      # We do not include the document in filtered_docs
      # We set a flag to indicate that we want to run web search
      web_search = 'Yes'
      continue

  print(f'---FINAL DOCUMENTS NUMBER: {len(filtered_docs)}---')

  return {
    'question': question,
    'documents': filtered_docs,
    'web_search': web_search,
  }

def decide_to_generate(state):
  print('---ASSESS GRADED DOCUMENTS---')

  web_search = state['web_search']

  if web_search == 'Yes':
    # Some documents have been filtered check_relevance
    # We will re-generate a new query
    print('---DECISION: SOME DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---')
    return 'websearch'
  else:
    # We have relevant documents, so generate answer
    print('---DECISION: GENERATE---')
    return 'generate'

def web_search(state):
  print('---WEB SEARCH---')

  question = state['question']
  documents = state.get('documents')

  try:
    docs = web_search_tool.invoke({'query': question})
    web_results = '\n'.join([d['content'] for d in docs])
    web_results = Document(page_content=web_results)
    if documents is not None:
      documents.append(web_results)
    else:
      documents = [web_results]
  except:
    pass

  return {
    'question': question,
    'documents': documents,
  }

def generate(state):
  print('---GENERATE---')

  question = state['question']
  documents = state['documents']
  generations_num = state.get('generations_num', 0)

  # RAG generation
  generation = rag_chain.invoke({'context': documents, 'question': question})
  return {
    'question': question,
    'documents': documents,
    'generation': generation,
    'generations_num': generations_num + 1,
  }

def grade_generation(state):
  print('---CHECK HALLUCINATIONS---')

  question = state['question']
  documents = state['documents']
  generation = state['generation']
  generations_num = state['generations_num']

  if generations_num >= 2:
    return 'useful'

  try:
    score = hallucination_grader.invoke({'documents': documents, 'generation': generation})
    grade = score.binary_score
  except:
    grade = 'no'

  # Check hallucination
  if grade == 'yes':
    print('---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---')
    # Check question-answering
    print('---GRADE GENERATION vs QUESTION---')
    try:
      score = answer_grader.invoke({'question': question,'generation': generation})
      grade = score.binary_score
    except:
      grade = 'no'

    if grade == 'yes':
      print('---DECISION: GENERATION ADDRESSES QUESTION---')
      return 'useful'
    else:
      print('---DECISION: GENERATION DOES NOT ADDRESS QUESTION---')
      return 'not useful'
  else:
    print('---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---')
    return 'not supported'

workflow = StateGraph(GraphState)

workflow.add_node('generate_step_back_query', generate_step_back_query)
workflow.add_node('generate_rewritten_query', generate_rewritten_query)
workflow.add_node('generate_subqueries', generate_subqueries)
workflow.add_node('generate_hyde_docs', generate_hyde_docs)
workflow.add_node('retrieve', retrieve)
workflow.add_node('websearch', web_search)
workflow.add_node('generate', generate)
workflow.add_node('grade_documents', grade_documents)

workflow.set_conditional_entry_point(
  route_question,
  {
    'websearch': 'websearch',
    'vectorstore': 'generate_step_back_query',
  },
)
workflow.add_edge('generate_step_back_query', 'generate_rewritten_query')
workflow.add_edge('generate_rewritten_query', 'generate_subqueries')
workflow.add_edge('generate_subqueries', 'generate_hyde_docs')
workflow.add_edge('generate_hyde_docs', 'retrieve')
workflow.add_edge('retrieve', 'grade_documents')
workflow.add_conditional_edges(
  'grade_documents',
  decide_to_generate,
  {
    'websearch': 'websearch',
    'generate': 'generate',
  },
)
workflow.add_edge('websearch', 'generate')
workflow.add_conditional_edges(
  'generate',
  grade_generation,
  {
    'not supported': 'generate',
    'useful': END,
    'not useful': 'websearch',
  },
)

app = workflow.compile()

In [108]:
app.invoke({"question": 'What is the order of the cranial nerves?'})

---ROUTE QUESTION---
---ROUTE QUESTION TO VECTOR STORE---
---GENERATE STEP-BACK QUERY---
---GENERATE REWRITTEN QUERY---
---GENERATE SUBQUERIES---
---FINAL SUBQUERIES NUMBER: 4---
---GENERATE HYDE DOCUMENTS---
question What is the order of the cranial nerves?
step_back_query The step-back query would be:

"What are the main categories or classifications of the human nervous system?"

This query takes a step back from the original question, which focuses on a specific aspect (the order of cranial nerves), and instead asks for more general information about the broader category (the human nervous system). This can help retrieve relevant background information that might include details about the structure and organization of the nervous system, including the classification of cranial nerves.
rewritten_query Here's a rewritten version of the query that's more specific, detailed, and likely to retrieve relevant information:

"What is the correct anatomical order of the 12 pairs of cranial n

Ranking candidates: 100%|██████████| 1/1 [00:06<00:00,  6.97s/it]
Fusing candidates: 100%|██████████| 1/1 [00:16<00:00, 16.09s/it]


---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---


{'question': 'What is the order of the cranial nerves?',
 'step_back_query': 'The step-back query would be:\n\n"What are the main categories or classifications of the human nervous system?"\n\nThis query takes a step back from the original question, which focuses on a specific aspect (the order of cranial nerves), and instead asks for more general information about the broader category (the human nervous system). This can help retrieve relevant background information that might include details about the structure and organization of the nervous system, including the classification of cranial nerves.',
 'rewritten_query': 'Here\'s a rewritten version of the query that\'s more specific, detailed, and likely to retrieve relevant information:\n\n"What is the correct anatomical order of the 12 pairs of cranial nerves, including their names (I-XII), functions, and any notable characteristics or associations with specific brain regions or structures?"\n\nThis revised query adds specificity by

## Evaluate RAG

### Load QA dataset

In [109]:
qa_df = pd.read_csv('brainscape.csv')
qa_df

Unnamed: 0,question,answer
0,What are the afferent cranial nerve nuclei?,Trigeminal sensory nucleus- fibres carry gener...
1,What is the order of the cranial nerves ?,1-olfactory\n2-optic\n3-oculomotor\n4-trochlea...
2,What are the efferent cranial nerve nuclei?,Edinger-westphal nucleus\nOculomotor nucleus\n...
3,Which nuclei share the embryo logical origin -...,Oculomotor nucleus Trochlear nucleus Abducens ...
4,Which nuclei share the embryo logical origin- ...,Trigeminal motor nucleus Facial motor nucleus ...
...,...,...
1047,What is the purpose of gephyrin in the glycine...,Involved in anchoring the receptor to a specif...
1048,What is the glycine receptor involved in ?,Reflex response\nCauses reciprocal inhibition ...
1049,What happens in hyperperplexia ?,It’s an exaggerated reflex Often caused by a m...
1050,What is hyperperplexia treated with ?,Benzodiazepine


### Load cached RAGs responses

In [110]:
cache_path = Path('cache.json')

if not os.path.exists(cache_path):
  data = {}
  with open(cache_path, 'w') as file:
    json.dump(data, file)

with open(cache_path, 'r') as f:
  cache = json.load(f)

len(cache.keys())

1043

In [111]:
questions = list(qa_df['question'].tolist())
expected_answers = list(qa_df['answer'].tolist())
predicted_answers = []

for index, question in tqdm(enumerate(questions)):
  if not question in cache:
    cache[question] = app.invoke({'question': question})['generation']

  predicted_answers.append(cache[question])

  with open(cache_path, 'w') as f:
    json.dump(cache, f)

cos_score = embeddings_cosine_sim_metric(expected_answers, predicted_answers)
bleu_score = bleu_metric(expected_answers, predicted_answers)
rogue_1_score = rogue_1_metric(expected_answers, predicted_answers)
rogue_l_score = rogue_l_metric(expected_answers, predicted_answers)

cos_score, bleu_score, rogue_1_score, rogue_l_score

1052it [00:01, 743.18it/s]


KeyboardInterrupt: 