# MMLU (Massive Multitask Language Understanding) RAG Evaluation

The MMLU benchmark evaluates language models across over 50 diverse domains, from basic subjects like history and mathematics to advanced fields such as law and medicine. This comprehensive framework measures the generalization and robustness of language models, making it a crucial tool for advancing natural language processing and developing more versatile AI systems.

Within this notebook, we will be conducting an evaluation of LangChain's RAG models.

https://huggingface.co/datasets/cais/mmlu
https://docs.confident-ai.com/docs/benchmarks-mmlu
https://luv-bansal.medium.com/benchmarking-llms-how-to-evaluate-language-model-performance-b5d061cc8679
https://www.kaggle.com/code/debarshichanda/llm-evaluation-mmlu-style
https://deepgram.com/learn/mmlu-llm-benchmark-guide

## Import packages

In [1]:
! pip install datasets langchain langchain-core langchain-community docarray

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-19.0.0-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py39-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)
Downloading multiprocess-0.70.16-py39-none-any.whl (133 kB)
Downloading pyarrow-19.0.0-cp39-cp39-manylinux_2_28_x86_64.w

In [2]:
import os
import re
from datasets import load_dataset
from tqdm import tqdm
from operator import itemgetter

from langchain_community.llms import Ollama
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import DocArrayInMemorySearch
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI

  from .autonotebook import tqdm as notebook_tqdm


## Define evaluation function

In [3]:
letter_to_number = { 'a': 0, 'b': 1, 'c': 2, 'd': 3 }

def eval_rag(model: object, subset: str) -> float:
  dataset = load_dataset('cais/mmlu', subset)
  test_df = dataset['test'].to_pandas()

  correct_answers_count = 0

  for index, row in tqdm(list(test_df.iterrows()), desc='Questions'):
    question = row['question']
    choices = row['choices']
    correct_answer = row['answer']

    llm_answer = model.invoke({
      'question': question,
      'a': choices[0],
      'b': choices[1],
      'c': choices[2],
      'd': choices[3],
    })

    llm_answer_letter = llm_answer.correct_answer.strip().lower()[0]

    if llm_answer_letter not in letter_to_number:
      continue

    llm_answer_num = letter_to_number[llm_answer_letter]

    if llm_answer_num == correct_answer:
      correct_answers_count += 1

  return correct_answers_count / len(test_df)

## Define LLM

In [4]:
llm = Ollama(model='llama3.3:70b', temperature=0)

# llm = ChatOpenAI(
#   model='gpt-4o',
#   temperature=0,
# )

  llm = Ollama(model='llama3.3:70b', temperature=0)


## Define chain

In [5]:
from typing import List
from pydantic import BaseModel, Field
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.output_parsers import PydanticOutputParser
from langchain.output_parsers import RetryOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnableParallel

def extract_json(response):
  json_pattern = r'\{.*?\}'
  match = re.search(json_pattern, response, re.DOTALL)

  if match:
    return match.group().strip().replace('\\\\', '\\')

  return response

class Schema(BaseModel):
  correct_answer: str = Field(
    description='Given a question and answer options, provide the corresponding letter for the correct answer..',
  )

parser = PydanticOutputParser(pydantic_object=Schema)

template = """Answer the following multiple choice question by giving the most appropriate response in json format. Answer should be one among [A, B, C, D].

{format_instructions}

Question: {question}\n
A) {a}\n
B) {b}\n
C) {c}\n
D) {d}\n"""

prompt = PromptTemplate(
  template=template,
  input_variables=['question', 'a', 'b', 'c', 'd'],
  partial_variables={'format_instructions': parser.get_format_instructions()},
)

chain = prompt | llm | StrOutputParser() | extract_json | parser

## Evaluate the model

Here we take only a subset of all MMLU subjects close to neurobiology.

In [6]:
eval_rag(chain, 'medical_genetics')

Generating test split: 100%|██████████| 100/100 [00:00<00:00, 30081.79 examples/s]
Generating validation split: 100%|██████████| 11/11 [00:00<00:00, 8124.20 examples/s]
Generating dev split: 100%|██████████| 5/5 [00:00<00:00, 3829.02 examples/s]
Questions:   0%|          | 0/100 [00:00<?, ?it/s]


ValueError: Ollama call failed with status code 500. Details: {"error":"model requires more system memory (38.4 GiB) than is available (13.4 GiB)"}