In [1]:
import os
import json
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import string
import numpy as np
import pandas as pd
from unidecode import unidecode
import transformers
import torch
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from tqdm import tqdm
from dotenv import load_dotenv

from langchain_community.document_loaders import PDFMinerLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain_mistralai.chat_models import ChatMistralAI
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.chains import RetrievalQA
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains.question_answering.stuff_prompt import CHAT_PROMPT as DEFAULT_PROMPT
from langchain.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
os.environ['MISTRAL_API_KEY'] = os.getenv('MISTRAL_API_KEY')

In [3]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/vladimirskvortsov/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

## Setup device

In [4]:
device = (
  'cuda'
  if torch.cuda.is_available()
  else 'mps'
  if torch.backends.mps.is_available()
  else 'cpu'
)

## Setup metric

In [5]:
lemmatizer = nltk.stem.WordNetLemmatizer()

def preprocess(corpus):
  corpus = corpus.lower()
  stopset = nltk.corpus.stopwords.words('english') + nltk.corpus.stopwords.words('russian') + list(string.punctuation)
  tokens = nltk.word_tokenize(corpus)
  tokens = [t for t in tokens if t not in stopset]
  tokens = [lemmatizer.lemmatize(t) for t in tokens]
  corpus = ' '.join(tokens)
  corpus = unidecode(corpus)
  return corpus

In [6]:
embeddings = OllamaEmbeddings(model='llama3')

In [7]:
def embeddings_cosine_sim_metric(expected_answers, predicted_answers):
  results = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    expected_embedding = np.array(embeddings.embed_query(expected_answer))
    predicted_embedding = np.array(embeddings.embed_query(predicted_answer))

    sim = cosine_similarity(
      expected_embedding.reshape(1, -1),
      predicted_embedding.reshape(1, -1),
    )[0][0]

    results.append(sim)

  return np.mean(results)

In [8]:
def bleu_metric(expected_answers, predicted_answers):
  scores = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    predicted_tokens = nltk.word_tokenize(predicted_answer)
    expected_tokens = [nltk.word_tokenize(expected_answer)]

    smoothie = SmoothingFunction().method4
    bleu_score = sentence_bleu(expected_tokens, predicted_tokens, smoothing_function=smoothie)

    scores.append(bleu_score)

  return np.mean(scores)

## Load QA dataset

In [9]:
qa_df = pd.read_csv('../research-neurobiology-qa-dataset/brainscape.csv')
qa_df

Unnamed: 0,question,answer
0,What are the afferent cranial nerve nuclei?,Trigeminal sensory nucleus- fibres carry gener...
1,What is the order of the cranial nerves ?,1-olfactory\n2-optic\n3-oculomotor\n4-trochlea...
2,What are the efferent cranial nerve nuclei?,Edinger-westphal nucleus\nOculomotor nucleus\n...
3,Which nuclei share the embryo logical origin -...,Oculomotor nucleus Trochlear nucleus Abducens ...
4,Which nuclei share the embryo logical origin- ...,Trigeminal motor nucleus Facial motor nucleus ...
...,...,...
1047,What is the purpose of gephyrin in the glycine...,Involved in anchoring the receptor to a specif...
1048,What is the glycine receptor involved in ?,Reflex response\nCauses reciprocal inhibition ...
1049,What happens in hyperperplexia ?,It’s an exaggerated reflex Often caused by a m...
1050,What is hyperperplexia treated with ?,Benzodiazepine


## Load documents

In [10]:
docs_dir = Path('./docs')
docs = []

for file in tqdm(docs_dir.iterdir()):
  if file.is_file() and file.suffix == '.pdf':
    loader = PDFMinerLoader(file)
    docs.extend(loader.load())

text_splitter = RecursiveCharacterTextSplitter(
  chunk_size=700,
  chunk_overlap=0,
  length_function=len,
)

docs = text_splitter.split_documents(docs)

## Setup LLMs

In [11]:
def get_llama2_llm():
  return Ollama(temperature=0, model='llama2')

In [12]:
def get_llama3_llm():
  return Ollama(temperature=0, model='llama3')

In [13]:
def openbiollm_parser(output):
  idx = output.find('Helpful Answer: ')
  if idx != -1:
    return output[idx + len('Helpful answer: '):]
  else:
    return output

def get_openbiollm_8b_llm():
  model = 'aaditya/OpenBioLLM-Llama3-8B'
  model_kwargs = {'torch_dtype': torch.bfloat16}
  pipeline = transformers.pipeline(
    'text-generation',
    model=model,
    model_kwargs=model_kwargs,
    device=device,
  )
  terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids('<|eot_id|>')
  ]
  llm = HuggingFacePipeline.from_model_id(
    model_id=model,
    task='text-generation',
    model_kwargs=model_kwargs,
    pipeline_kwargs={
      'max_new_tokens': 256,
      'eos_token_id': terminators,
      'do_sample': True,
      'temperature': 0.0001,
      'top_p': 0.9,
    },
  )
  return llm | openbiollm_parser

In [14]:
def get_mistral_llm():
  return ChatMistralAI()

## Setup index stores

In [15]:
def get_doc_array_vector_store(docs=[]):
  index = VectorstoreIndexCreator(
    vectorstore_cls=DocArrayInMemorySearch,
    embedding=embeddings,
  ).from_documents(docs)
  return index.vectorstore

In [16]:
def get_chroma_vector_store(docs=[]):
  vector_store = Chroma.from_documents(docs, embeddings)
  return vector_store

## Setup prompt templates

In [17]:

example_prompt_template = """
Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
Question: {question}
"""
few_shot_examples = [
  {
    "question": "Which cranial nerves are motor?",
    "answer": "Oculomotor\nTrochlear \nAbducens\nAccessory\nHypoglossal"
  },
  {
    "question": "ich of the cranial nerves have both sensory and motor control ?",
    "answer": "TrigeminalFacial GlossopharyngealVagus"
  },
  {
    "question": "Which regions of the cross section of the spinal cord have a larger ventral horn ?",
    "answer": "The cervical and lumbar regions have larger ventral horns. The thoracic region has a smaller ventral horn region because it controls the trunk so not many motor neurones are coming out. Thoracic region has a more prominent lateral horn where preganglionic neurones are present"
  },
  {
    "question": "What are the subdivisions of the vertebral column ?",
    "answer": "Cervical = 8\nThoracic= 12\nLumbar=5 \nSacral=5 \nCoccygeal"
  },
]
example_prompt = ChatPromptTemplate.from_messages(
  [
    ("human", example_prompt_template),
    ("ai", "{answer}\n"),
  ],
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
  example_prompt=example_prompt,
  examples=few_shot_examples,
  input_variables=["question"],
)
base_prompt = ChatPromptTemplate.from_template("""
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
You answer in very short sentences and do not include extra information.

{context}

Question: {question}
Helpful Answer:"
""")
final_few_shot_prompt = ChatPromptTemplate.from_messages(
  [
    few_shot_prompt,
    base_prompt
  ]
)

## Setup experiments

In [18]:
llms = [
  ('LLaMA 2', get_llama2_llm()),
  ('LLaMA 3', get_llama3_llm()),
  ('OpenBioLLM Llama3 8B', get_openbiollm_8b_llm()),
  ('Mistral', get_mistral_llm()),
]

vector_stores = [
  ('DocArray', get_doc_array_vector_store),
  ('Chroma', get_chroma_vector_store),
]

prompts = [
  ('Default', DEFAULT_PROMPT),
  ('Few-shot prompting', final_few_shot_prompt),
]

In [19]:
cache_path = Path('cache.json')
with open(cache_path, 'r') as f:
  cache = json.load(f)

cache.keys()

dict_keys(['LLaMA 3_Doc Array In Memory Search_True', 'LLaMA 3_Doc Array In Memory Search_False', 'OpenBioLLM Llama3 8B_Doc Array In Memory Search_True', 'LLaMA 2_Doc Array In Memory Search_False', 'LLaMA 2_Doc Array In Memory Search_True', 'OpenBioLLM Llama3 8B_Doc Array In Memory Search_False'])

In [20]:
df = pd.DataFrame()

sample_df = qa_df.sample(frac=1)
questions = sample_df['question'].tolist()
expected_answers = sample_df['answer'].tolist()

for llm_name, llm in tqdm(llms, desc='LLMs'):
  for vector_store_name, get_vector_store in tqdm(vector_stores, desc='Vector Stores', leave=False):
    for use_docs in tqdm((False, True), desc='Use Docs', leave=False):
      for prompt_name, prompt_template in tqdm(prompts, desc='Prompts', leave=False):
        if use_docs == False and vector_store_name != 'DocArray':
          continue
        vector_store = get_vector_store(docs)
        qa_llm = RetrievalQA.from_chain_type(
          llm=llm,
          chain_type='stuff',
          retriever=vector_store.as_retriever(search_kwargs={"k" : 10}),
          verbose=False,
          chain_type_kwargs = {
            'prompt': prompt_template,
            'document_separator': '<<<<<>>>>>'
          },
        )

        predicted_answers = []

        for index, question in tqdm(enumerate(questions), desc='Questions', leave=False):
          key = f'{llm_name}_{vector_store_name}_{use_docs}'

          if not key in cache:
            cache[key] = {}
          if not question in cache[key]:
            cache[key][question] = qa_llm.invoke(question)['result']

          predicted_answers.append(cache[key][question])

          with open(cache_path, 'w') as f:
            json.dump(cache, f)

        cos_sim = embeddings_cosine_sim_metric(expected_answers, predicted_answers)
        bleu_score = bleu_metric(expected_answers, predicted_answers)

        row = pd.DataFrame({
          'llm': llm_name,
          'vector_store': vector_store_name,
          'use_docs': use_docs,
          'prompt': prompt_name,
          'cos_sim': cos_sim,
          'bleu': bleu_score,
        }, index=[0])
        df = pd.concat([df, row], ignore_index=True)

LLMs:   0%|          | 0/2 [00:00<?, ?it/s]
[A

[A[A

[A[A

[A[A

[A[A
[A
[A

[A[A

[A[A

[A[A

[A[A
[A
[A

[A[A

[A[A

[A[A

[A[A
[A
[A

[A[A

[A[A

[A[A

[A[A
LLMs: 100%|██████████| 2/2 [00:00<00:00, 71.21it/s]


In [21]:
df

Unnamed: 0,llm,vector_store,use_docs,prompt,cos_sim,bleu
0,LLaMA 2,DocArray,False,Default,0,0
1,LLaMA 2,DocArray,False,Few-shot prompting,0,0
2,LLaMA 2,DocArray,True,Default,0,0
3,LLaMA 2,DocArray,True,Few-shot prompting,0,0
4,LLaMA 2,Chroma,False,Default,0,0
5,LLaMA 2,Chroma,False,Few-shot prompting,0,0
6,LLaMA 2,Chroma,True,Default,0,0
7,LLaMA 2,Chroma,True,Few-shot prompting,0,0
8,LLaMA 3,DocArray,False,Default,0,0
9,LLaMA 3,DocArray,False,Few-shot prompting,0,0
