# Usage
This notebook contains the code to evaluate Llama 3 8B instruct on MMLU, Winogrande, and Belebele in South African languages Afrikaans, Zulu, and Xhosa as well as in English.

We recommend using at least a single A100 GPU for all experiments.

# Installation

In [0]:
# Install Python packages
!pip install transformers torch accelerate
!pip install -i https://pypi.org/simple/ bitsandbytes
!pip install flash-attn --no-build-isolation
!pip install pandas tqdm seaborn matplotlib

In [0]:
# dbutils.library.restartPython()  TODO: Uncomment this if using Azure Databricks

In [0]:
# Set up HuggingFace Token

import os
import json

with open('../../config.json', 'r') as fp:
    config = json.load(fp)
    hf_token = config['huggingface_read_token']

os.environ['HF_TOKEN'] = hf_token

In [0]:
# Download and load Llama 3 8B Instruct (with some optimizations)
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_name = model_id[model_id.rfind('/')+1:]

quant_config = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Speed up inference
model.generation_config.cache_implementation = "static"
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Inference

In [0]:
# Define inference function that accepts a row in an OpenAI Batch API-formatted JSONL and produces a response
def infer(jsonl_row):
    messages = jsonl_row['body']['messages']
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(
        input_ids,
        max_new_tokens=jsonl_row['body']['max_tokens'],
        eos_token_id=terminators,
        do_sample=True,
        temperature=jsonl_row['body']['temperature'],
        top_p=jsonl_row['body']['top_p'],
    )

    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return response

test_row = {"custom_id": "<|MODEL|>-on-en-mmlu-clinical_knowledge-0-answer-A", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "<|MODEL|>", "messages": [{"role": "user", "content": "The following are multiple choice questions (with answers) about clinical knowledge.\n\nQuestion 1: The energy for all forms of muscle contraction is provided by:\nA. ATP.\nB. ADP.\nC. phosphocreatine.\nD. oxidative phosphorylation.\nAnswer: A\n\nQuestion 2: What is the difference between a male and a female catheter?\nA. Male and female catheters are different colours.\nB. Male catheters are longer than female catheters.\nC. Male catheters are bigger than female catheters.\nD. Female catheters are longer than male catheters.\nAnswer: B\n\nQuestion 3: In the assessment of the hand function which of the following is true?\nA. Abduction of the thumb is supplied by spinal root T2\nB. Opposition of the thumb by opponens policis is supplied by spinal root T1\nC. Finger adduction is supplied by the median nerve\nD. Finger abduction is mediated by the palmar interossei\nAnswer: B\n\nQuestion 4: How many attempts should you make to cannulate a patient before passing the job on to a senior colleague, according to the medical knowledge of 2020?\nA. 4\nB. 3\nC. 2\nD. 1\nAnswer: C\n\nQuestion 5: Glycolysis is the name given to the pathway involving the conversion of:\nA. glycogen to glucose-1-phosphate.\nB. glycogen or glucose to fructose.\nC. glycogen or glucose to pyruvate or lactate.\nD. glycogen or glucose to pyruvate or acetyl CoA.\nAnswer: C\n\nNow, given the following question and answer choices, output only the letter corresponding to the correct answer. Do not add any explanation.\n\nQuestion: What size of cannula would you use in a patient who needed a rapid blood transfusion (as of 2020 medical knowledge)?\nA. 18 gauge.\nB. 20 gauge.\nC. 22 gauge.\nD. 24 gauge.\nAnswer:\n"}], "max_tokens": 512, "temperature": 0.7, "top_p": 0.9}}
infer(test_row)

In [0]:
import pandas as pd
from tqdm import tqdm
import json

# Infer on all prompts
generations_map = {}
with open(f'../../data/evaluation_batches/gpt_style_batch_evaluation_template.jsonl', 'r') as fp:
    all_prompts_jsonl = pd.read_json(fp, lines=True)

for index, row in tqdm(all_prompts_jsonl.iterrows(), total=all_prompts_jsonl.shape[0]):
    generations_map[row['custom_id'].replace('<|MODEL|>', model_name)] = infer({'custom_id': row['custom_id'], 'body': row['body']})

# Save generations so they never have to be run again
with open(f'../../results/out_of_the_box_performance/generations_{model_name}.json', 'w') as fp:
    json.dump(generations_map, fp, indent=2)

# Evaluation

In [0]:
# Define response-to-correctness functions

def check_mc_answer(custom_id, generation):
    parsed_gen = generation.strip().replace('(', ''). replace(')', '').upper()
    return len(parsed_gen) > 0 and parsed_gen[0] == custom_id[-1]  # answer is stored in last number of custom_id

def check_winogrande_answer(custom_id, generation):
    correct_number = custom_id[-1]  # answer is stored in the last character of the custom_id
    incorrect_number = str(3 - int(correct_number))  # maps 1 to 2 and 2 to 1
    correct = correct_number in generation and incorrect_number not in generation
    return correct

## MMLU-Clinical-ZA

In [0]:
import re
import seaborn as sns
import matplotlib.pyplot as plt
import json
import pandas as pd

with open(f'../../results/out_of_the_box_performance/generations_{model_name}.json', 'r') as fp:
    generations_map = json.load(fp)

# Get and display MMLU performance

sections = [
    'clinical_knowledge',
    'college_medicine',
]

mmlu_langs = [
    'en',
    'af',
    'zu',
    'xh',
]

matrix = pd.DataFrame(
    data=0.0,
    index=[model_name],
    columns=mmlu_langs
)

for lang in mmlu_langs:
    total_score = 0
    q_cnt = 0

    for section in sections:

        # Construct the pattern
        pattern = re.compile(rf".*-on-{lang}-mmlu-{section}.*")

        # Filter keys
        matching_generations = [(c_id, gen) for c_id, gen in generations_map.items() if pattern.match(c_id)]
        print(len(matching_generations))

        for (c_id, gen) in matching_generations:
            if check_mc_answer(c_id, gen):
                total_score += 1
            q_cnt += 1

    final_score = total_score / q_cnt
    matrix.at[model_name, lang] = round(final_score*100, 1)

# Create the heatmap
plt.figure(figsize=(12, 8), dpi=100)  # Increase the figure size and resolution for HD
ax = sns.heatmap(matrix, annot=matrix, cmap="Greens", cbar=False, annot_kws={"size": 16}, fmt='.1f')

# Rotate the labels on the y-axis (left) to be horizontal
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=16)  # Increase y-axis label size
ax.set_xticklabels(ax.get_xticklabels(), fontsize=16)  # Increase x-axis label size

# Display the heatmap
plt.tight_layout()
plt.show()
matrix.to_csv(f'../../results/out_of_the_box_performance/mmlu_{model_name}.csv')


## Winogrande-ZA

In [0]:
# Get and display Winogrande performance

winogrande_langs = [
    'en',
    'af',
    'zu',
    'xh',
]

matrix = pd.DataFrame(
    data=0.0,
    index=[model_name],
    columns=winogrande_langs
)

for lang in winogrande_langs:
    total_score = 0
    q_cnt = 0

    # Construct the pattern
    pattern = re.compile(rf".*-on-{lang}-winogrande.*")

    # Filter keys
    matching_generations = [(c_id, gen) for c_id, gen in generations_map.items() if pattern.match(c_id)]
    print(len(matching_generations))

    for (c_id, gen) in matching_generations:
        if check_winogrande_answer(c_id, gen):
            total_score += 1
        q_cnt += 1

    final_score = total_score / q_cnt
    matrix.at[model_name, lang] = round(final_score*100, 1)

# Create the heatmap
plt.figure(figsize=(12, 8), dpi=100)  # Increase the figure size and resolution for HD
ax = sns.heatmap(matrix, annot=matrix, cmap="Greens", cbar=False, annot_kws={"size": 16}, fmt='.1f')

# Rotate the labels on the y-axis (left) to be horizontal
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=16)  # Increase y-axis label size
ax.set_xticklabels(ax.get_xticklabels(), fontsize=16)  # Increase x-axis label size

# Display the heatmap
plt.tight_layout()
plt.show()
matrix.to_csv(f'../../results/out_of_the_box_performance/winogrande_{model_name}.csv')


## Belebele-ZA

In [0]:
# Get and display Belebele performance

belebele_langs = [
    'en',
    'af',
    'zu',
    'xh',
]

matrix = pd.DataFrame(
    data=0.0,
    index=[model_name],
    columns=belebele_langs
)

for lang in belebele_langs:
    total_score = 0
    q_cnt = 0

    # Construct the pattern
    pattern = re.compile(rf".*-on-{lang}-belebele.*")

    # Filter keys
    matching_generations = [(c_id, gen) for c_id, gen in generations_map.items() if pattern.match(c_id)]
    print(len(matching_generations))

    for (c_id, gen) in matching_generations:
        if check_mc_answer(c_id, gen):
            total_score += 1
        q_cnt += 1

    final_score = total_score / q_cnt
    matrix.at[model_name, lang] = round(final_score*100, 1)

# Create the heatmap
plt.figure(figsize=(12, 8), dpi=100)  # Increase the figure size and resolution for HD
ax = sns.heatmap(matrix, annot=matrix, cmap="Greens", cbar=False, annot_kws={"size": 16}, fmt='.1f')

# Rotate the labels on the y-axis (left) to be horizontal
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=16)  # Increase y-axis label size
ax.set_xticklabels(ax.get_xticklabels(), fontsize=16)  # Increase x-axis label size

# Display the heatmap
plt.tight_layout()
plt.show()
matrix.to_csv(f'../../results/out_of_the_box_performance/belebele_{model_name}.csv')
