# MMLU (Massive Multitask Language Understanding) RAG Evaluation

The MMLU benchmark evaluates language models across over 50 diverse domains, from basic subjects like history and mathematics to advanced fields such as law and medicine. This comprehensive framework measures the generalization and robustness of language models, making it a crucial tool for advancing natural language processing and developing more versatile AI systems.

Within this notebook, we will be conducting an evaluation of LangChain's RAG models.

https://huggingface.co/datasets/cais/mmlu
https://docs.confident-ai.com/docs/benchmarks-mmlu
https://luv-bansal.medium.com/benchmarking-llms-how-to-evaluate-language-model-performance-b5d061cc8679
https://www.kaggle.com/code/debarshichanda/llm-evaluation-mmlu-style
https://deepgram.com/learn/mmlu-llm-benchmark-guide

## Import packages

In [7]:
! pip install datasets langchain langchain-core langchain-community docarray



In [8]:
from datasets import load_dataset
from tqdm import tqdm
import re

from langchain_community.llms import Ollama
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import DocArrayInMemorySearch
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.output_parsers import StrOutputParser
from operator import itemgetter

## Define evaluation function

In [9]:
letter_to_number = { 'a': 0, 'b': 1, 'c': 2, 'd': 3 }

def eval_rag(model: object, subsets: list) -> float:
  accuracies = {}

  for subset in tqdm(subsets, desc='Subsets'):
    dataset = load_dataset('cais/mmlu', subset)
    test_df = dataset['test'].to_pandas()

    correct_answers_count = 0

    for index, row in tqdm(list(test_df.iterrows()), desc='Questions'):
      question = row['question']
      choices = row['choices']
      correct_answer = row['answer']

      llm_answer = model.invoke({
        'question': question,
        'a': choices[0],
        'b': choices[1],
        'c': choices[2],
        'd': choices[3],
      })

      if not llm_answer in letter_to_number:
        continue

      llm_answer_num = letter_to_number[llm_answer]

      if llm_answer_num == correct_answer:
        correct_answers_count += 1

    accuracies[subset] = correct_answers_count / len(test_df)

  return sum(accuracies.values()) / len(accuracies.values())

## Define prompt

In [10]:
from langchain.prompts import ChatPromptTemplate

template = """Answer the following multiple choice question by giving the most appropriate response. Answer should be one among [A, B, C, D].
Tell only the correct letter and nothing else. Use the following pieces of context to answer the question at the end.

{context}

Question: {question}\n
A) {a}\n
B) {b}\n
C) {c}\n
D) {d}\n


Answer:"""

prompt = ChatPromptTemplate.from_template(template)

## Build test model

In [11]:
llm = Ollama(model='llama3', temperature=0)
embeddings = OllamaEmbeddings(model='llama3')
index = VectorstoreIndexCreator(
  vectorstore_cls=DocArrayInMemorySearch,
  embedding=embeddings,
).from_documents([])
vector_store = index.vectorstore

def format_docs(docs):
  return '\n\n'.join(doc.page_content for doc in docs)

def test_answer_parser(result):
  result = result.lower().strip()

  if re.match(r'^[abcd](?:$|\))', result):
    return result[0]

  return None

qa_chain = (
  {
    'context': itemgetter('question') | vector_store.as_retriever() | format_docs,
    'question': itemgetter('question'),
    'a': itemgetter('a'),
    'b': itemgetter('b'),
    'c': itemgetter('c'),
    'd': itemgetter('d'),
  }
  | prompt
  | llm
  | StrOutputParser()
  | test_answer_parser
)

## Evaluate the model

Here we take only a subset of all MMLU subjects close to neurobiology.

In [12]:
eval_subsets = [
  'anatomy',
  'college_biology',
  'high_school_biology',
  'college_medicine',
  'professional_medicine',
  'medical_genetics',
  'professional_psychology',
  'high_school_psychology',
]

eval_rag(qa_chain, eval_subsets)

Questions: 100%|██████████| 135/135 [00:46<00:00,  2.91it/s]
Questions: 100%|██████████| 144/144 [00:55<00:00,  2.61it/s]
Questions: 100%|██████████| 310/310 [02:06<00:00,  2.45it/s]
Questions: 100%|██████████| 173/173 [01:44<00:00,  1.65it/s]
Questions: 100%|██████████| 272/272 [03:52<00:00,  1.17it/s]
Questions: 100%|██████████| 100/100 [00:33<00:00,  2.98it/s]
Questions: 100%|██████████| 612/612 [03:57<00:00,  2.58it/s]
Questions: 100%|██████████| 545/545 [03:17<00:00,  2.76it/s]
Subsets: 100%|██████████| 8/8 [17:42<00:00, 132.87s/it]


0.6462949556814584