# GraphRAG

## Import packages

In [2]:
! pip install nltk numpy pandas unidecode scikit-learn tqdm llm-blender rouge-score xmltodict arxiv biopython
! pip install langchain langchain-core langchain-community langchain_experimental langchain-openai langchain-chroma langchain_mistralai langgraph langchainhub

Collecting tiktoken<1,>=0.7 (from langchain-openai)
  Downloading tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.6 kB)
Downloading tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl (982 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m982.6/982.6 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: tiktoken
  Attempting uninstall: tiktoken
    Found existing installation: tiktoken 0.4.0
    Uninstalling tiktoken-0.4.0:
      Successfully uninstalled tiktoken-0.4.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
llama-index-core 0.10.37.post1 requires networkx>=3.0, but you have networkx 2.6.3 which is incompatible.
readmeai 0.5.99.post5 requires tiktoken<0.5.0,>=0.4.0, but you have tiktoken 0.8.0 which is incompatible.[0m[31m
[0mSuccessfully installed tiktoken-0.8.0


In [1]:
import os
import re
import nltk
import string
import numpy as np
import pandas as pd
from unidecode import unidecode
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from pathlib import Path
import pickle
from rouge_score import rouge_scorer
import json
import llm_blender
from operator import itemgetter
import operator
from dotenv import load_dotenv
from getpass import getpass
from typing import List, Annotated
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
from Bio import Entrez, SeqIO
import torch

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
from langchain.schema import Document
from langchain_community.document_loaders import PDFMinerLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain.embeddings.cache import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.llms import Ollama
from langgraph.graph import START, END, StateGraph
from langchain_core.output_parsers import PydanticOutputParser
from langchain.output_parsers import RetryOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnableParallel
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.retrievers import PubMedRetriever, ArxivRetriever
from langchain_community.tools.tavily_search import TavilySearchResults

  from .autonotebook import tqdm as notebook_tqdm


## Disable warnings

In [2]:
import warnings
warnings.filterwarnings('ignore')

## Setup environment variables

You have to define the following environment variables in the `.env` file, terminal environment, or input field within this Jupyter notebook:
1. MISTRAL_API_KEY
2. OPENAI_API_KEY
3. OPENAI_PROXY
4. TAVILY_API_KEY
5. ENTREZ_EMAIL

## Import packages

In [3]:
env_variables = [
  'MISTRAL_API_KEY',
  'OPENAI_API_KEY',
  'OPENAI_PROXY',
  'TAVILY_API_KEY',
  'ENTREZ_EMAIL',
]

load_dotenv()

for key in env_variables:
  value = os.getenv(key)

  if value is None:
    value = getpass(key)

  os.environ[key] = value

## Setup metrics

### Download NLTK dictionaries

These dictionaries are needed for further text preprocessing.

In [4]:
dict_ids = [
  'punkt_tab',
  'punkt',
  'stopwords',
  'wordnet',
]

for dict_id in dict_ids:
  nltk.download(dict_id, quiet=True)

### Text preprocessing

Define a function for text preprocessing, which is an important step before calculating any metrics. This preprocessing function will help in cleaning the text data, making it ready for further analysis. The preprocessing involves several steps:
1. Lowercasing
2. Stopwords removal
3. Lemmatization
4. Remove accents from characters

In [5]:
lemmatizer = nltk.stem.WordNetLemmatizer()

def preprocess(corpus: str) -> str:
  corpus = corpus.lower()
  stopset = nltk.corpus.stopwords.words('english') + nltk.corpus.stopwords.words('russian') + list(string.punctuation)
  tokens = nltk.word_tokenize(corpus)
  tokens = [t for t in tokens if t not in stopset]
  tokens = [lemmatizer.lemmatize(t) for t in tokens]
  corpus = ' '.join(tokens)
  corpus = unidecode(corpus)
  return corpus

### Embedding Initialization

Here we are initializing the Llama 3 embeddings model. The `OllamaEmbeddings` class is a component of the Ollama library, a set of pre-trained language models. This model is capable of embedding corpora of any length into a 4096-dimensional vector.

The use of `OllamaEmbeddings` requires the installation of a local Ollama server, which can be found at https://ollama.com.

In [6]:
embeddings = OllamaEmbeddings(model='llama3.1')
store = LocalFileStore("./.embeddings_cache")

cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
  embeddings,
  store,
  namespace=embeddings.model,
)

### Average embeddings cosine similarity metric

This function calculates the average cosine similarity between expected answers and LLM predicted answers using their respective embeddings. Cosine similarity is a measure of similarity between two non-zero vectors of an inner product space that measures the cosine of the angle between them:

$$
K(a, b) = \frac{\sum \limits_{i=1}^n a_i b_i}{\sqrt{\sum \limits_{i=1}^n a_i^2} \cdot \sqrt{\sum \limits_{i=1}^n b_i^2}}
$$

In [7]:
def embeddings_cosine_sim_metric(expected_answers: list[str], predicted_answers: list[str]) -> float:
  results = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    expected_embedding = np.array(cached_embeddings.embed_query(expected_answer))
    predicted_embedding = np.array(cached_embeddings.embed_query(predicted_answer))

    sim = cosine_similarity(
      expected_embedding.reshape(1, -1),
      predicted_embedding.reshape(1, -1),
    )[0][0]

    results.append(sim)

  return np.mean(results)

In [8]:
smoothie_f = nltk.translate.bleu_score.SmoothingFunction().method4

def bleu_metric(expected_answers, predicted_answers):
  scores = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    predicted_tokens = nltk.word_tokenize(predicted_answer)
    expected_tokens = [nltk.word_tokenize(expected_answer)]

    score = nltk.translate.bleu_score.sentence_bleu(
      expected_tokens,
      predicted_tokens,
      smoothing_function=smoothie_f,
    )

    scores.append(score)

  return np.mean(scores)

In [9]:
rogue_1_scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)

def rogue_1_metric(expected_answers, predicted_answers):
  scores = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    result = rogue_1_scorer.score(expected_answer, predicted_answer)

    scores.append(result['rouge1'])

  return np.mean(scores)

In [12]:
rogue_l_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def rogue_l_metric(expected_answers, predicted_answers):
  scores = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    result = rogue_l_scorer.score(expected_answer, predicted_answer)

    scores.append(result['rougeL'])

  return np.mean(scores)

## Load documents

In [13]:
docs_dir = Path('./docs')
docs_cache_dir = Path('./.docs_cache')
raw_docs_pkl_path = docs_cache_dir / 'parsed_docs_cache.pkl'

if os.path.exists(raw_docs_pkl_path):
  with open(raw_docs_pkl_path, 'rb') as f:
    docs = pickle.load(f)
else:
  docs = []

  for file in docs_dir.iterdir():
    file_docs = PDFMinerLoader(file, concatenate_pages=False).load()
    for doc in file_docs:
      doc.page_content = unidecode(doc.page_content)
      page = doc.metadata['page']
      doc.metadata['source'] = f'{file.stem} ({page})'
    docs.extend(file_docs)

  with open(raw_docs_pkl_path, 'wb') as f:
    pickle.dump(docs, f)

len(docs)

4438

## Split documents

In [14]:
splitted_docs_pkl_path = docs_cache_dir / 'splitted_docs_cache.pkl'

if os.path.exists(splitted_docs_pkl_path):
  with open(splitted_docs_pkl_path, 'rb') as f:
    splitted_docs = pickle.load(f)
else:
  text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=750,
    chunk_overlap=250,
    length_function=len,
    is_separator_regex=False,
    separators=[
      '.',
      '\uff0e', # Fullwidth full stop
      '\u3002', # Ideographic full stop
      '\n\n',
    ],
  )
  splitted_docs = text_splitter.create_documents([doc.page_content for doc in docs])

  with open(splitted_docs_pkl_path, 'wb') as f:
    pickle.dump(splitted_docs, f)

len(splitted_docs)

35663

## Setup vector store

In [15]:
vector_store = Chroma(
  collection_name='neurorag',
  embedding_function=cached_embeddings,
  persist_directory='./chroma_db'
)
retriever = vector_store.as_retriever()

## Define JSON extractor

In [37]:
def extract_json(response):
  json_pattern = r'\{.*?\}'
  match = re.search(json_pattern, response, re.DOTALL)

  if match:
    return match.group().strip().replace('\\\\', '\\')

  return response

## Build LLM

In [35]:
llm = Ollama(model='llama3.1', temperature=0)

## Build chains

### Route chain

In [None]:
class RouteQuery(BaseModel):
  sources: List[str] = Field(
    description='Given a user question select the retrieval methods you consider the most appropriate for addressing this question. You may also return an empty array if no methods are required.',
  )

route_parser = PydanticOutputParser(pydantic_object=RouteQuery)
route_retry_parser = RetryOutputParser.from_llm(
  parser=route_parser,
  llm=llm,
  max_retries=3,
)

route_template = """
You are an expert at selecting retrieval methods.
Given a user question select the retrieval methods you consider the most appropriate for addressing user question.
You may also return an empty array if no methods are required.

Possible retrieval methods:
1. The "vectorstore" retriever contains documents related to neurobiology and medicine. Use the vectorstore for questions on these topics.
2. The "pubmed" retriever contains biomedical literature and research articles. It is particularly useful for answering detailed questions about medical research, clinical studies, and scientific discoveries.
3. The "arxiv" retriever contains preprints of research papers across various scientific fields, including physics, mathematics, computer science, and biology. Use the arxiv for questions on recent scientific research and theoretical studies in these areas.
4. The "ncbi_protein" retriever contains protein sequence and functional information. Use the NCBI protein DB for questions related to protein sequences, structures, and functions.
5. The "ncbi_gene" retriever contains gene sequence and functional information. Use the NCBI gene DB for questions related to gene sequences, structures, and functions.

{format_instructions}

User question:
{question}
"""
route_prompt = PromptTemplate(
  template=route_template,
  input_variables=['question'],
  partial_variables={'format_instructions': route_parser.get_format_instructions()},
)

question_router = RunnableParallel(
  completion=route_prompt | llm | extract_json, prompt_value=route_prompt
) | RunnableLambda(lambda x: route_retry_parser.parse_with_prompt(**x))
print(question_router.invoke({'question': 'Who will the Bears draft first in the NFL draft?'}))
print(question_router.invoke({'question': 'What are the functions of the oculomotor nerve?'}))

sources=[]
sources=['vectorstore', 'ncbi_protein']


### Grade documents chain

In [19]:
class GradeDocuments(BaseModel):
  binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

docs_grader_parser = PydanticOutputParser(pydantic_object=GradeDocuments)
docs_grader_retry_parser = RetryOutputParser.from_llm(
  parser=docs_grader_parser,
  llm=llm,
  max_retries=3,
)

docs_grader_template = """
You are a grader assessing relevance of a retrieved document to a user question.
If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.

{format_instructions}

User question:
{question}

Retrieved document:
{document}
"""
docs_grader_prompt = PromptTemplate(
  template=docs_grader_template,
  input_variables=['document', 'question'],
  partial_variables={'format_instructions': docs_grader_parser.get_format_instructions()},
)

docs_grader_grader = RunnableParallel(
  completion=docs_grader_prompt | llm | extract_json, prompt_value=docs_grader_prompt
) | RunnableLambda(lambda x: docs_grader_retry_parser.parse_with_prompt(**x))
docs_grader_grader.invoke({'question': 'What is the color of the sky?', 'document': 'The color of the sky is blue'})

GradeDocuments(binary_score='yes')

### Hallucinations chain

In [20]:
class GradeHallucinations(BaseModel):
  binary_score: str = Field(description="Answer is grounded in the facts, 'yes' or 'no'")

hallucination_parser = PydanticOutputParser(pydantic_object=GradeHallucinations)
hallucination_retry_parser = RetryOutputParser.from_llm(
  parser=hallucination_parser,
  llm=llm,
  max_retries=3,
)

hallucination_template = """
You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."

{format_instructions}

Set of facts:
{documents}

LLM generation:
{generation}
"""
hallucination_prompt = PromptTemplate(
  template=hallucination_template,
  input_variables=['question', 'generation'],
  partial_variables={'format_instructions': hallucination_parser.get_format_instructions()},
)

hallucination_grader = RunnableParallel(
  completion=hallucination_prompt | llm | extract_json, prompt_value=hallucination_prompt
) | RunnableLambda(lambda x: hallucination_retry_parser.parse_with_prompt(**x))
print(hallucination_grader.invoke({'documents': ['Sky is blue'], 'generation': 'The color of the sky is blue'}))

binary_score='yes'


### Answer grade chain

In [21]:
class GradeAnswer(BaseModel):
  binary_score: str = Field(description="Answer addresses the question, 'yes' or 'no'")

grade_parser = PydanticOutputParser(pydantic_object=GradeAnswer)
grade_retry_parser = RetryOutputParser.from_llm(
  parser=grade_parser,
  llm=llm,
  max_retries=3,
)

grade_template = """
You are a grader assessing whether an answer addresses / resolves a question. \n
Give a binary score 'yes' or 'no'. 'yes' means that the answer resolves the question.

User question:
{question}

LLM generation:
{generation}

{format_instructions}
"""
grade_prompt = PromptTemplate(
  template=grade_template,
  input_variables=['question', 'generation'],
  partial_variables={'format_instructions': grade_parser.get_format_instructions()},
)

answer_grader = RunnableParallel(
  completion=grade_prompt | llm | extract_json, prompt_value=grade_prompt
) | RunnableLambda(lambda x: grade_retry_parser.parse_with_prompt(**x))
print(answer_grader.invoke({"question": "What is the order of the cranial nerves?", 'generation': 'I do not know.'}))

binary_score='no'


### HyDE chain

In [22]:
hyde_template = """
Please write a scientific paper passage to answer the question

Question: {question}

Passage:
"""
hyde_prompt = ChatPromptTemplate.from_template(hyde_template)
hyde_chain = hyde_prompt | llm | StrOutputParser()

hyde_chain.invoke({"question": 'What is the order of the cranial nerves ?'})

"Here's a scientific paper-style passage answering the question:\n\n**Title:** The Cranial Nerve Plexus: A Review of the Anatomical and Neurological Organization\n\n**Abstract:**\n\nThe cranial nerves, a complex network of 12 pairs of nerves that arise from the brainstem, play a crucial role in regulating various physiological functions. Understanding their organization is essential for appreciating the intricate relationships between the central nervous system and the peripheral nervous system. This review aims to provide an overview of the order of the cranial nerves, highlighting their anatomical and neurological characteristics.\n\n**Introduction:**\n\nThe cranial nerves are a group of nerves that emerge from the brainstem, which includes the midbrain, pons, and medulla oblongata. These nerves are responsible for controlling various functions such as sensation, movement, and autonomic regulation. The order of the cranial nerves is a fundamental concept in neuroanatomy, and their co

### Step-back

In [23]:
step_back_template = """
You are an AI assistant tasked with generating broader, more general queries to improve context retrieval in a RAG system.
Given the original query, generate a step-back query that is more general and can help retrieve relevant background information.

Original query: {question}

Step-back query:
"""
step_back_prompt = ChatPromptTemplate.from_template(step_back_template)
step_back_chain = step_back_prompt | llm | StrOutputParser()

step_back_chain.invoke({"question": 'What is Benedict’s syndrome?'})

'What is the definition of a neurological or medical syndrome?'

### Query Rewriting

In [24]:
rewrite_query_template = """
You are an AI assistant tasked with reformulating user queries to improve retrieval in a RAG system.
Given the original query, rewrite it to be more specific, detailed, and likely to retrieve relevant information.

Original query: {question}

Rewritten query:
"""
rewrite_query_prompt = ChatPromptTemplate.from_template(rewrite_query_template)
rewrite_query_chain = rewrite_query_prompt | llm | StrOutputParser()

rewrite_query_chain.invoke({"question": 'What is the order of the cranial nerves?'})

'Here\'s a rewritten version of the query that\'s more specific, detailed, and likely to retrieve relevant information:\n\n"What is the correct anatomical order of the 12 pairs of cranial nerves, including their names (I-XII), functions, and any notable characteristics or associations with specific brain regions or structures?"\n\nThis revised query adds specificity by:\n\n* Specifying the number of cranial nerves (12)\n* Including their names (I-XII) for clarity\n* Mentioning their functions to provide context\n* Asking about notable characteristics or associations to encourage retrieval of relevant information\n\nBy making these changes, the rewritten query is more likely to retrieve accurate and detailed information from a RAG system.'

### Decomposition

In [None]:
class DecompositionAnswer(BaseModel):
  subqueries: List[str] = Field(description="Given the original query, decompose it into 2-4 simpler sub-queries as json array of strings")

decomposition_parser = PydanticOutputParser(pydantic_object=DecompositionAnswer)
decomposition_retry_parser = RetryOutputParser.from_llm(
  parser=decomposition_parser,
  llm=llm,
  max_retries=3,
)

decomposition_template = """
You are an AI assistant tasked with breaking down complex queries into simpler sub-queries for a RAG system.
Given the original query, decompose it into 2-4 simpler sub-queries that, when answered together, would provide a comprehensive response to the original query.

Original query: {question}

example: What are the impacts of climate change on the environment?

Sub-queries:
1. What are the impacts of climate change on biodiversity?
2. How does climate change affect the oceans?
3. What are the effects of climate change on agriculture?
4. What are the impacts of climate change on human health?

{format_instructions}
"""
decomposition_prompt = PromptTemplate(
  template=decomposition_template,
  input_variables=['question'],
  partial_variables={'format_instructions': decomposition_parser.get_format_instructions()},
)

decomposition_chain = RunnableParallel(
  completion=decomposition_prompt | llm | extract_json, prompt_value=decomposition_prompt
) | RunnableLambda(lambda x: decomposition_retry_parser.parse_with_prompt(**x))
print(decomposition_chain.invoke({"question": "What is Benedict’s syndrome?"}))

subqueries=["What are the symptoms of Benedict's syndrome?", "What are the causes of Benedict's syndrome?", "How is Benedict's syndrome diagnosed?", "What are the treatment options for Benedict's syndrome?"]


: 

### RAG chain

In [25]:
device = (
  'cuda'
  if torch.cuda.is_available()
  else 'mps'
  if torch.backends.mps.is_available()
  else 'cpu'
)
blender = llm_blender.Blender()
blender.loadranker('llm-blender/PairRM', device=device)
blender.loadfuser('llm-blender/gen_fuser_3b', device=device)



In [26]:
rag_template = """
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use five paragraphs maximum and keep the answer verbose and structured.

Question: {question}

Context:

{context}

Answer:
"""
rag_prompt = PromptTemplate.from_template(rag_template)

gpt_llm = ChatOpenAI(model='gpt-4o-mini', temperature=0)
openbio_llm = Ollama(model='taozhiyuai/openbiollm-llama-3:70b_q2_k', temperature=0)
biomistral_llm = Ollama(model='cniongolo/biomistral', temperature=0)

gpt_chain = rag_prompt | gpt_llm | StrOutputParser()
openbio_chain = rag_prompt | openbio_llm | StrOutputParser()
biomistral_chain = rag_prompt | biomistral_llm | StrOutputParser()

def fuse_generations(dict):
  question = dict['question']

  gpt_res = dict['gpt_res']
  openbio_res = dict['openbio_res']
  biomistral_res = dict['biomistral_res']
  answers = [gpt_res, openbio_res, biomistral_res]

  # fuse_generations, ranks = blender.rank_and_fuse(
  #   [question],
  #   [answers],
  #   instructions=['keep the similar length of the output as the candidates.'],
  #   return_scores=False,
  #   batch_size=1,
  #   top_k=5,
  # )
  # return fuse_generations[0]
  ranks = blender.rank([question], [answers], return_scores=False)
  idx = np.argmin(ranks[0])
  return answers[idx]

rag_chain = (
  {
    'gpt_res': gpt_chain,
    'openbio_res': openbio_chain,
    'biomistral_res': biomistral_chain,
    'question': itemgetter('question')
  }
  | RunnableLambda(fuse_generations)
)

final_res = rag_chain.invoke({"context": '', "question": 'What is subunit composition of NMDA receptors and role of each subunit?'})
print(final_res)

APIConnectionError: Connection error.

### Web Search Chain

In [None]:
web_search_tool = TavilySearchResults(k=5)

[{'url': 'https://www.worldhistory.org/RMS_Titanic/',
  'content': 'The RMS Titanic was a White Star Line ocean liner, which sank after hitting an iceberg on its maiden voyage from Southampton to New York on 15 April 1912.Over 1,500 men, women, and children lost their lives.There were 705 survivors. In 1985, the Titanic wreck was found several miles deep on the Atlantic seafloor by Robert D. Ballard.. The largest ship built at the time, Titanic was considered'},
 {'url': 'https://www.youtube.com/watch?v=1PhMWUoPDsk',
  'content': 'On April 15, 1912 the RMS Titanic tragically sunk to the bottom of the sea. 73 years later, National Geographic Explorer-in-Residence Dr. Robert Ballard and'},
 {'url': 'https://www.history.com/topics/early-20th-century-us/titanic',
  'content': 'While it has always been assumed that the ship sank as a result of the gash that caused the bulkhead compartments to flood, various other theories have emerged over the decades, including that the ship’s steel plates

### PubMed Retriever

In [None]:
pub_med_retriever = PubMedRetriever()
pub_med_retriever.invoke('What is the order of the cranial nerves?')

[]

### Arxiv Retriever

In [None]:
arxiv_retriever = ArxivRetriever(load_max_docs=3, get_ful_documents=True)
arxiv_retriever.invoke('What is the order of the cranial nerves?')

[Document(metadata={'Entry ID': 'http://arxiv.org/abs/1912.10601v2', 'Published': datetime.date(2021, 3, 13), 'Title': 'Optimized Cranial Bandeau Remodeling', 'Authors': 'James Drake, Marina Drygala, Ricardo Fukasawa, Jochen Koenemann, Andre Linhares, Thomas Looi, John Phillips, David Qian, Nikoo Saber, Justin Toth, Chris Woodbeck, Jessie Yeung'}, page_content="Craniosynostosis, a condition affecting 1 in 2000 infants, is caused by\npremature fusing of cranial vault sutures, and manifests itself in abnormal\nskull growth patterns. Left untreated, the condition may lead to severe\ndevelopmental impairment. Standard practice is to apply corrective cranial\nbandeau remodeling surgery in the first year of the infant's life. The most\nfrequent type of surgery involves the removal of the so-called fronto-orbital\nbar from the patient's forehead and the cutting of well-placed incisions to\nreshape the skull in order to obtain the desired result. In this paper, we\npropose a precise optimizati

### NCBI Protein DB retriever

In [32]:
db_params = {
  'gene': {
    'rettype': 'xml',
    'retmode': 'xml',
  },
  'protein': {
    'rettype': 'gb',
    'retmode': 'text',
  },
}

class NCBIRetriever(BaseRetriever):
  db: str
  k: int

  def __init__(self, db: str, k: int):
    super().__init__(db=db, k=k)

    self.db = db
    self.k = k

    entrez_email = os.getenv('ENTREZ_EMAIL')
    if entrez_email == None:
      raise ValueError('ENTREZ_EMAIL is not defined')
    Entrez.email = entrez_email

  def _search(self, term):
    handle = Entrez.esearch(db=self.db, term=term, retmax=self.k)
    record = Entrez.read(handle)
    handle.close()
    return record['IdList']

  def _fetch(self, ids):
    rettype = db_params[self.db]["rettype"]
    retmode = db_params[self.db]["retmode"]

    handle = Entrez.efetch(db=self.db, id=ids, rettype=rettype, retmode=retmode)
    if self.db == 'gene':
      records = Entrez.read(handle)
    else:
      records = [SeqIO.read(handle, rettype)]
    handle.close()
    return records

  def _get_gene_document(self, record):
    gene_id = record['Entrezgene_track-info']['Gene-track']['Gene-track_geneid']
    gene_symbol = record['Entrezgene_gene']['Gene-ref']['Gene-ref_locus']
    gene_description = record.get('Entrezgene_summary', 'N/A')
    organism_name = record['Entrezgene_source']['BioSource']['BioSource_org']['Org-ref']['Org-ref_taxname']
    page_content = (
      f'Gene ID: {gene_id}\n'
      f'Gene Symbol: {gene_symbol}\n'
      f'Organism: {organism_name}\n'
      f'Description: {gene_description}'
    )
    source = f'https://www.ncbi.nlm.nih.gov/gene/{gene_id}'
    document = Document(page_content=page_content, metadata={'source': source})
    return document

  def _get_protein_document(self, record):
    molecule_type = record.annotations.get("molecule_type", "N/A")
    organism = record.annotations.get("organism", "N/A")
    comment = record.annotations.get("comment", "N/A")
    page_content = (
      f'Protein ID: {record.id}\n'
      f'Type: {molecule_type}\n'
      f'Name: {record.name}\n'
      f'Organism: {organism}\n'
      f'Description: {record.description}\n'
      f'Comment: {comment}\n'
      f'Sequence: {record.seq}'
    )
    source = f'https://www.ncbi.nlm.nih.gov/protein/{record.id}'
    document = Document(page_content=page_content, metadata={'source': source})
    return document

  def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
    ids = self._search(query)
    records = self._fetch(ids)

    docs = []

    for record in records:
      if self.db == 'gene':
        docs.append(self._get_gene_document(record))
      elif self.db == 'protein':
        docs.append(self._get_protein_document(record))

    return docs

In [25]:
ncbi_protein_retriever = NCBIRetriever(db='protein', k=3)
ncbi_protein_retriever.invoke('ABW05875')

[Document(metadata={'source': 'https://www.ncbi.nlm.nih.gov/protein/ABW05875.1'}, page_content='Protein ID: ABW05875.1\nType: protein\nName: ABW05875\nOrganism: Bromelia pinguin\nDescription: PsbN (chloroplast) [Bromelia pinguin]\nComment: Method: conceptual translation.\nSequence: METATLVAISISGLLVSFTGYALYTAFGQPSQQLRDPFEEHGD')]

In [33]:
ncbi_gene_retriever = NCBIRetriever(db='gene', k=3)
ncbi_gene_retriever.invoke('peng')

[Document(metadata={'source': 'https://www.ncbi.nlm.nih.gov/gene/42862'}, page_content='Gene ID: 42862\nGene Symbol: Tsc1\nOrganism: Drosophila melanogaster\nDescription: Enables GTPase activator activity and protein kinase binding activity. Involved in several processes, including cellular response to chloroquine; negative regulation of developmental growth; and regulation of signal transduction. Located in cytoplasm. Part of TSC1-TSC2 complex. Is expressed in embryonic/larval salivary gland and organism. Used to study autism spectrum disorder; tuberous sclerosis; and tuberous sclerosis 1. Human ortholog(s) of this gene implicated in intellectual disability; lymphangioleiomyomatosis; tuberous sclerosis; and tuberous sclerosis 1. Orthologous to human TSC1 (TSC complex subunit 1). [provided by Alliance of Genome Resources, Dec 2024]'),
 Document(metadata={'source': 'https://www.ncbi.nlm.nih.gov/gene/406949'}, page_content='Gene ID: 406949\nGene Symbol: MIR15B\nOrganism: Homo sapiens\nDe

### NCBI Protein DB chain

In [None]:
class NCBIProteinDBAnswer(BaseModel):
  query: str = Field(description='Given the original query, please find a protein locus for the NCBI protein database.')

ncbi_protein_db_parser = PydanticOutputParser(pydantic_object=NCBIProteinDBAnswer)
ncbi_protein_db_retry_parser = RetryOutputParser.from_llm(
  parser=ncbi_protein_db_parser,
  llm=llm,
  max_retries=3,
)

ncbi_protein_db_template = """
As an expert in bioinformatics and user query optimization for biological databases, your task is to transform user questions into precise and effective queries suitable for the NCBI protein database.
Create a query with only locus of a protein for search within the NCBI protein database.

Original query: {question}

{format_instructions}
"""
ncbi_protein_db_prompt = PromptTemplate(
  template=ncbi_protein_db_template,
  input_variables=['question'],
  partial_variables={'format_instructions': ncbi_protein_db_parser.get_format_instructions()},
)

query_extractor = lambda res: res.query

ncbi_protein_db_chain = RunnableParallel(
  completion=ncbi_protein_db_prompt | llm | extract_json, prompt_value=ncbi_protein_db_prompt
) | RunnableLambda(lambda x: ncbi_protein_db_retry_parser.parse_with_prompt(**x)) | query_extractor | ncbi_protein_retriever
print(ncbi_protein_db_chain.invoke({"question": "Calculate the frequency of each amino acid in the ABW05875 protein sequence"}))

[Document(metadata={}, page_content='Protein ID: 157955043\nType: protein\nName: ABW05875\nOrganism: Bromelia pinguin\nDescription: PsbN (chloroplast) [Bromelia pinguin]\nComment: Method: conceptual translation.\nSequence: METATLVAISISGLLVSFTGYALYTAFGQPSQQLRDPFEEHGD')]


### NCBI Gene DB chain

In [38]:
class NCBIGeneDBAnswer(BaseModel):
  query: str = Field(description='Given the original query, please find a gene locus for the NCBI gene database.')

ncbi_gene_db_parser = PydanticOutputParser(pydantic_object=NCBIGeneDBAnswer)
ncbi_gene_db_retry_parser = RetryOutputParser.from_llm(
  parser=ncbi_gene_db_parser,
  llm=llm,
  max_retries=3,
)

ncbi_gene_db_template = """
As an expert in bioinformatics and user query optimization for biological databases, your task is to transform user questions into precise and effective queries suitable for the NCBI gene database.
Create a query with only locus of a gene for search within the NCBI gene database.

Original query: {question}

{format_instructions}
"""
ncbi_gene_db_prompt = PromptTemplate(
  template=ncbi_gene_db_template,
  input_variables=['question'],
  partial_variables={'format_instructions': ncbi_gene_db_parser.get_format_instructions()},
)

query_extractor = lambda res: res.query

ncbi_gene_db_chain = RunnableParallel(
  completion=ncbi_gene_db_prompt | llm | extract_json, prompt_value=ncbi_gene_db_prompt
) | RunnableLambda(lambda x: ncbi_gene_db_retry_parser.parse_with_prompt(**x)) | query_extractor | ncbi_gene_retriever
print(ncbi_gene_db_chain.invoke({"question": "Calculate the frequency of each amino acid in the peng gene sequence"}))

[Document(metadata={'source': 'https://www.ncbi.nlm.nih.gov/gene/42862'}, page_content='Gene ID: 42862\nGene Symbol: Tsc1\nOrganism: Drosophila melanogaster\nDescription: Enables GTPase activator activity and protein kinase binding activity. Involved in several processes, including cellular response to chloroquine; negative regulation of developmental growth; and regulation of signal transduction. Located in cytoplasm. Part of TSC1-TSC2 complex. Is expressed in embryonic/larval salivary gland and organism. Used to study autism spectrum disorder; tuberous sclerosis; and tuberous sclerosis 1. Human ortholog(s) of this gene implicated in intellectual disability; lymphangioleiomyomatosis; tuberous sclerosis; and tuberous sclerosis 1. Orthologous to human TSC1 (TSC complex subunit 1). [provided by Alliance of Genome Resources, Dec 2024]'), Document(metadata={'source': 'https://www.ncbi.nlm.nih.gov/gene/406949'}, page_content='Gene ID: 406949\nGene Symbol: MIR15B\nOrganism: Homo sapiens\nDes

## Build graph app

In [None]:
class GraphState(TypedDict):
  question: str

  specialized_srcs: List[str]

  step_back_query: str
  rewritten_query: str
  subqueries: List[str]

  generated_docs: List[str]

  documents: Annotated[list, operator.add]

  web_search: str

  generation: str
  generations_num: int

def determine_specialized_srcs(state):
  print('---DETERMINE SPECIALIZED SOURCES---')

  question = state['question']

  try:
    res = question_router.invoke({'question': question})
    srcs = [src.strip().lower() for src in res.sources]
  except:
    srcs = []

  return {'specialized_srcs': srcs}

def route_question(state):
  print('---ROUTE QUESTION---')

  sources = state['specialized_srcs']

  if len(sources) == 0:
    print('---ROUTE QUESTION TO WEB SEARCH---')
    return 'websearch'
  else:
    print(f'---ROUTE QUESTION TO SPECIALIZED SOURCES: {", ".join([source.upper() for source in sources])}---')
    return 'specialized_srcs'

def generate_step_back_query(state):
  print('---GENERATE STEP-BACK QUERY---')

  question = state['question']

  step_back_query = step_back_chain.invoke({'question': question})

  return {'step_back_query': step_back_query}

def generate_rewritten_query(state):
  print('---GENERATE REWRITTEN QUERY---')

  question = state['question']

  rewritten_query = rewrite_query_chain.invoke({'question': question})

  return {'rewritten_query': rewritten_query}

def generate_subqueries(state):
  print('---GENERATE SUBQUERIES---')

  question = state['question']

  try:
    decomposition_answer = decomposition_chain.invoke({'question': question})
    subqueries = decomposition_answer.subqueries
    # Limit to a maximum of four subqueries
    subqueries = subqueries[:4]
  except:
    subqueries = []

  print(f'---FINAL SUBQUERIES NUMBER: {len(subqueries)}---')

  return {'subqueries': subqueries}

def generate_hyde_docs(state):
  print('---GENERATE HYDE DOCUMENTS---')

  question = state['question']
  step_back_query = state['step_back_query']
  rewritten_query = state['rewritten_query']
  subqueries = state['subqueries']

  queries = [question, step_back_query, rewritten_query, *subqueries]
  generated_docs = []

  for query in queries:
    generated_doc = hyde_chain.invoke({'question': query})
    generated_docs.append(generated_doc)

  return {'question': question, 'generated_docs': generated_docs}

def vector_store_retriever_node(state):
  generated_docs = state['generated_docs']
  specialized_srcs = state['specialized_srcs']

  if 'vectorstore' not in specialized_srcs:
    return {'documents': []}

  print('---RETRIEVE FROM VECTOR STORE---')

  documents = []

  for generated_doc in generated_docs:
    documents.extend(retriever.invoke(generated_doc))

  return {'documents': documents}

def pub_med_retriever_node(state):
  specialized_srcs = state['specialized_srcs']

  if 'pubmed' not in specialized_srcs:
    return {'documents': []}

  print('---RETRIEVE FROM PUBMED---')

  question = state['question']
  step_back_query = state['step_back_query']
  rewritten_query = state['rewritten_query']
  subqueries = state['subqueries']

  queries = [question, step_back_query, rewritten_query, *subqueries]
  documents = []

  for query in queries:
    try:
      documents.extend(pub_med_retriever.invoke(query))
    except:
      pass

  return {'documents': documents}

def arxiv_retriever_node(state):
  specialized_srcs = state['specialized_srcs']

  if 'arxiv' not in specialized_srcs:
    return {'documents': []}

  print('---RETRIEVE FROM ARXIV---')

  question = state['question']
  step_back_query = state['step_back_query']
  rewritten_query = state['rewritten_query']
  subqueries = state['subqueries']

  queries = [question, step_back_query, rewritten_query, *subqueries]
  documents = []

  for query in queries:
    try:
      documents.extend(arxiv_retriever.invoke(query))
    except:
      pass

  return {'documents': documents}

def ncbi_protein_db_retriever_node(state):
  specialized_srcs = state['specialized_srcs']

  if 'ncbi_protein' not in specialized_srcs:
    return {'documents': []}

  print('---RETRIEVE FROM NCBI PROTEIN DB---')

  question = state['question']
  step_back_query = state['step_back_query']
  rewritten_query = state['rewritten_query']
  subqueries = state['subqueries']

  queries = [question, step_back_query, rewritten_query, *subqueries]
  documents = []

  for query in queries:
    try:
      documents.extend(ncbi_protein_db_chain.invoke(query))
    except:
      pass

  return {'documents': documents}

def ncbi_gene_db_retriever_node(state):
  specialized_srcs = state['specialized_srcs']

  if 'ncbi_gene' not in specialized_srcs:
    return {'documents': []}

  print('---RETRIEVE FROM NCBI GENE DB---')

  question = state['question']
  step_back_query = state['step_back_query']
  rewritten_query = state['rewritten_query']
  subqueries = state['subqueries']

  queries = [question, step_back_query, rewritten_query, *subqueries]
  documents = []

  for query in queries:
    try:
      documents.extend(ncbi_gene_db_chain.invoke(query))
    except:
      pass

  return {'documents': documents}

def grade_documents(state):
  print('---CHECK DOCUMENT RELEVANCE TO QUESTION---')

  question = state['question']
  documents = state['documents']

  print(f'---INITIAL DOCUMENTS NUMBER: {len(documents)}---')

  filtered_documents = []
  seen_contents = set()
  web_search = 'No'

  for index, document in enumerate(documents):
    print(f'---GRADE DOCUMENT ({index + 1}/{len(documents)})---')

    if document.page_content in seen_contents:
      print('---GRADE: DOCUMENT IS REPEATED---')
      continue
    seen_contents.add(document.page_content)

    try:
      score = docs_grader_grader.invoke({
        'question': question,
        'document': document.page_content,
      })
      grade = score.binary_score
    except:
      grade = 'No'

    if grade.lower() == 'yes':
      print('---GRADE: DOCUMENT RELEVANT---')
      filtered_documents.append(document)
    else:
      print('---GRADE: DOCUMENT NOT RELEVANT---')
      web_search = 'Yes'
      continue

  print(f'---FINAL DOCUMENTS NUMBER: {len(filtered_documents)}---')

  state['documents'].clear()
  return {
    'documents': filtered_documents,
    'web_search': web_search,
  }

def decide_to_generate(state):
  print('---ASSESS GRADED DOCUMENTS---')

  web_search = state['web_search']

  if web_search == 'Yes':
    print('---DECISION: SOME DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---')
    return 'websearch'
  else:
    print('---DECISION: GENERATE---')
    return 'generate'

def web_search(state):
  print('---WEB SEARCH---')

  question = state['question']

  web_results = web_search_tool.invoke({'query': question})
  docs = [Document(page_content=result['content'], metadata={'source': result['url']}) for result in web_results]

  return {'documents': docs}

def generate(state):
  print('---GENERATE---')

  question = state['question']
  documents = state['documents']
  generations_num = state.get('generations_num', 0)

  context = '\n\n'.join(map(lambda doc: doc.page_content, documents))
  generation = rag_chain.invoke({'context': context, 'question': question})

  return {'generation': generation, 'generations_num': generations_num + 1}

def grade_generation(state):
  print('---CHECK HALLUCINATIONS---')

  question = state['question']
  documents = state['documents']
  generation = state['generation']
  generations_num = state['generations_num']

  if generations_num >= 2:
    return 'useful'

  try:
    context = '\n\n'.join(map(lambda doc: doc.page_content, documents))
    score = hallucination_grader.invoke({
      'documents': context,
      'generation': generation,
    })
    grade = score.binary_score
  except:
    grade = 'no'

  if grade == 'yes':
    print('---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---')
    print('---GRADE GENERATION vs QUESTION---')

    try:
      score = answer_grader.invoke({'question': question,'generation': generation})
      grade = score.binary_score
    except:
      grade = 'no'

    if grade == 'yes':
      print('---DECISION: GENERATION ADDRESSES QUESTION---')
      return 'useful'
    else:
      print('---DECISION: GENERATION DOES NOT ADDRESS QUESTION---')
      return 'not useful'
  else:
    print('---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---')
    return 'not supported'

In [None]:
workflow = StateGraph(GraphState)

In [None]:
workflow.add_node('determine_specialized_srcs', determine_specialized_srcs)

workflow.add_node('generate_step_back_query', generate_step_back_query)
workflow.add_node('generate_rewritten_query', generate_rewritten_query)
workflow.add_node('generate_subqueries', generate_subqueries)

workflow.add_node('generate_hyde_docs', generate_hyde_docs)

workflow.add_node('vector_store_retriever', vector_store_retriever_node)
workflow.add_node('pub_med_retriever', pub_med_retriever_node)
workflow.add_node('arxiv_retriever', arxiv_retriever_node)
workflow.add_node('ncbi_protein_db_retriever', ncbi_protein_db_retriever_node)
workflow.add_node('ncbi_gene_db_retriever_node', ncbi_gene_db_retriever_node)

workflow.add_node('websearch', web_search)
workflow.add_node('generate', generate)
workflow.add_node('grade_documents', grade_documents)

<langgraph.graph.state.StateGraph at 0x33ed49a30>

In [None]:
workflow.add_edge(START, 'determine_specialized_srcs')
workflow.add_conditional_edges(
  'determine_specialized_srcs',
  route_question,
  {
    'websearch': 'websearch',
    'specialized_srcs': 'generate_step_back_query',
  },
)

workflow.add_edge('generate_step_back_query', 'generate_rewritten_query')
workflow.add_edge('generate_rewritten_query', 'generate_subqueries')
workflow.add_edge('generate_subqueries', 'generate_hyde_docs')

workflow.add_edge('generate_hyde_docs', 'vector_store_retriever')
workflow.add_edge('generate_hyde_docs', 'pub_med_retriever')
workflow.add_edge('generate_hyde_docs', 'arxiv_retriever')
workflow.add_edge('generate_hyde_docs', 'ncbi_protein_db_retriever')
workflow.add_edge('generate_hyde_docs', 'ncbi_gene_db_retriever_node')

workflow.add_edge('vector_store_retriever', 'grade_documents')
workflow.add_edge('pub_med_retriever', 'grade_documents')
workflow.add_edge('arxiv_retriever', 'grade_documents')
workflow.add_edge('ncbi_protein_db_retriever', 'grade_documents')
workflow.add_edge('ncbi_gene_db_retriever_node', 'grade_documents')

workflow.add_conditional_edges(
  'grade_documents',
  decide_to_generate,
  {
    'websearch': 'websearch',
    'generate': 'generate',
  },
)
workflow.add_edge('websearch', 'generate')
workflow.add_conditional_edges(
  'generate',
  grade_generation,
  {
    'not supported': 'generate',
    'useful': END,
    'not useful': 'websearch',
  },
)

<langgraph.graph.state.StateGraph at 0x33ed49a30>

In [None]:
app = workflow.compile()

In [None]:
app.invoke({'question': 'Count each amino acid in the ABW05875 sequence'})

---DETERMINE SPECIALIZED SOURCES---
---ROUTE QUESTION---
---ROUTE QUESTION TO SPECIALIZED SOURCES: NCBI_PROTEIN---
---GENERATE STEP-BACK QUERY---
---GENERATE REWRITTEN QUERY---
---GENERATE SUBQUERIES---
---FINAL SUBQUERIES NUMBER: 4---
---GENERATE HYDE DOCUMENTS---
---RETRIEVE FROM NCBI PROTEIN DB---
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---INITIAL DOCUMENTS NUMBER: 5---
---GRADE DOCUMENT (1/5)---
---GRADE: DOCUMENT RELEVANT---
---GRADE DOCUMENT (2/5)---
---GRADE: DOCUMENT IS REPEATED---
---GRADE DOCUMENT (3/5)---
---GRADE: DOCUMENT IS REPEATED---
---GRADE DOCUMENT (4/5)---
---GRADE: DOCUMENT IS REPEATED---
---GRADE DOCUMENT (5/5)---
---GRADE: DOCUMENT IS REPEATED---
---FINAL DOCUMENTS NUMBER: 1---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION DOES NOT ADDRESS QUESTION---
---WEB SEARCH---
---GENERATE---
---CHECK HALLUCINA

{'question': 'Count each amino acid in the ABW05875 sequence',
 'specialized_srcs': ['ncbi_protein'],
 'step_back_query': 'To generate a step-back query that improves context retrieval for the original query "Count each amino acid in the ABW05875 sequence", I would ask:\n\n"What is the general process of analyzing protein sequences, and what types of information can be derived from them?"\n\nThis more general query aims to retrieve relevant background information on protein sequence analysis, including topics such as:\n\n* Sequence annotation\n* Amino acid composition\n* Protein structure prediction\n* Bioinformatics tools and techniques\n\nBy asking this step-back query, the RAG system can provide a broader context that may include relevant information about protein sequences in general, which can then be used to improve the retrieval of specific information related to the original query.',
 'rewritten_query': 'Here\'s a rewritten query that is more specific, detailed, and likely to r

## Evaluate RAG

### Load QA dataset

In [None]:
qa_df = pd.read_csv('brainscape.csv')
qa_df

Unnamed: 0,question,answer
0,What are the afferent cranial nerve nuclei?,Trigeminal sensory nucleus- fibres carry gener...
1,What is the order of the cranial nerves ?,1-olfactory\n2-optic\n3-oculomotor\n4-trochlea...
2,What are the efferent cranial nerve nuclei?,Edinger-westphal nucleus\nOculomotor nucleus\n...
3,Which nuclei share the embryo logical origin -...,Oculomotor nucleus Trochlear nucleus Abducens ...
4,Which nuclei share the embryo logical origin- ...,Trigeminal motor nucleus Facial motor nucleus ...
...,...,...
1047,What is the purpose of gephyrin in the glycine...,Involved in anchoring the receptor to a specif...
1048,What is the glycine receptor involved in ?,Reflex response\nCauses reciprocal inhibition ...
1049,What happens in hyperperplexia ?,It’s an exaggerated reflex Often caused by a m...
1050,What is hyperperplexia treated with ?,Benzodiazepine


### Load cached RAGs responses

In [None]:
cache_path = Path('cache.json')

if not os.path.exists(cache_path):
  data = {}
  with open(cache_path, 'w') as file:
    json.dump(data, file)

with open(cache_path, 'r') as f:
  cache = json.load(f)

len(cache.keys())

1043

In [None]:
questions = list(qa_df['question'].tolist())
expected_answers = list(qa_df['answer'].tolist())
predicted_answers = []

for index, question in tqdm(enumerate(questions)):
  if not question in cache:
    cache[question] = app.invoke({'question': question})['generation']

  predicted_answers.append(cache[question])

  with open(cache_path, 'w') as f:
    json.dump(cache, f)

cos_score = embeddings_cosine_sim_metric(expected_answers, predicted_answers)
bleu_score = bleu_metric(expected_answers, predicted_answers)
rogue_1_score = rogue_1_metric(expected_answers, predicted_answers)
rogue_l_score = rogue_l_metric(expected_answers, predicted_answers)

cos_score, bleu_score, rogue_1_score, rogue_l_score

1052it [00:01, 981.70it/s]


KeyboardInterrupt: 