In [1]:
#!pip install -U transformers rank_bm25 evaluate unstructured bitsandbytes --quiet

In [1]:
import transformers

transformers.logging.disable_progress_bar()
transformers.logging.set_verbosity_error()

transformers.__version__

'4.38.1'

In [2]:
from datasets import load_dataset
import pandas as pd
import numpy as np
from model import EncoderModel, DecoderModel, BM25Model
from store import VectorStore
from tqdm import tqdm
import torch
from sklearn.metrics import ndcg_score

[nltk_data] Downloading package wordnet to /home/chkei001/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Dataset

In [4]:
ds = load_dataset("squad_v2")
ds

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})

In [5]:
df_val = ds["validation"].to_pandas()[["context", "question", "answers"]]
display(df_val.head(3))

Unnamed: 0,context,question,answers
0,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,"{'text': ['France', 'France', 'France', 'Franc..."
1,The Normans (Norman: Nourmands; French: Norman...,When were the Normans in Normandy?,"{'text': ['10th and 11th centuries', 'in the 1..."
2,The Normans (Norman: Nourmands; French: Norman...,From which countries did the Norse originate?,"{'text': ['Denmark, Iceland and Norway', 'Denm..."


In [6]:
# extract first answer of answer list
extract_answers = lambda answer: "" if len(answer['text']) == 0 else answer['text'][0]
v_extract_answers = np.vectorize(extract_answers)

In [7]:
df_val["answers"] = v_extract_answers(df_val["answers"].values)

print(f"{df_val[df_val['answers'] == ''].shape[0]}/{df_val.shape[0]}")

5945/11873


In [8]:
# sample 400 answerable examples and 100 unanswerable examples 
test_set_answerable = df_val[df_val['answers'] != ''].sample(n=400, random_state=1)
test_set_not_answerable = df_val[df_val['answers'] == ''].sample(n=100, random_state=1)
test_set = pd.concat([test_set_answerable, test_set_not_answerable])
test_set

Unnamed: 0,context,question,answers
6719,According to PolitiFact the top 400 richest Am...,What did the richest 400 Americans have as chi...,grew up in substantial privilege
11420,"The British failures in North America, combine...",How many of the Pitt's planned expeditions wer...,"Two of the expeditions were successful, with F..."
7963,At the same time the Mongols imported Central ...,Who did the Mongols send to Bukhara as adminis...,Han Chinese and Khitans
9256,The other third of the water flows through the...,Where does the Nederrijn change it's name?,Wijk bij Duurstede
6749,"In Marxian analysis, capitalist firms increasi...",What do capitalist firms substitute equipment ...,labor inputs
...,...,...,...
4613,The Very high-speed Backbone Network Service (...,What were select locations connected to?,
257,"When considering computational problems, a pro...",What is a string over a Greek number when cons...,
233,Closely related fields in theoretical computer...,What is the process that asks a more specific ...,
4784,A variety of alternatives to the Y. pestis hav...,In what year was Scott and Duncan's research p...,


In [9]:
# calcualte true binary relevance for ndcg
def true_binary_relevance(result_idxs, original_id):
    return [1 if i == original_id else 0 for i in result_idxs]

In [10]:
import warnings
warnings.filterwarnings('ignore')

retriever_models = [
    "sentence-transformers/all-MiniLM-L6-v2",
    "BAAI/bge-base-en-v1.5",
    "WhereIsAI/UAE-Large-V1",
    "BAAI/bge-m3"
]
causal_models = [
    "google/gemma-7b-it",
    "HuggingFaceH4/zephyr-7b-beta",
    "mistralai/Mistral-7B-Instruct-v0.2",
    #"mistralai/Mixtral-8x7B-Instruct-v0.1"
    "meta-llama/Llama-2-7b-chat-hf"
]

retriever_results = []
causal_lm_results = []

# causal models loop
for causal_id in causal_models:
    causal_lm = DecoderModel(causal_id, device="cuda")
    
    # retriever models loop
    for retriever_id in retriever_models:
        
        # retriever setup loop
        for hybrid in [True, False]:
            # init new vector store with retriever
            db = VectorStore(retriever_id, hybrid)
            # embed documents
            db.add_documents(test_set["context"].values.tolist(), test_set.index.tolist())

            print(f"Retriever: {retriever_id} - Causal LM: {causal_id} - hybrid: {'yes' if hybrid else 'no'}")
            
            with tqdm(total=len(test_set.question.values)) as pbar:
                # loop through dataset
                for document_id, (_, query, correct_answer) in test_set.iterrows():
                    
                    best_contexts = ""
                    best_ndcg = 0
                    
                    # loop distance metrics
                    for distance_metric in ["cosine", "ip", "l2"]:
                        # retrieve documents
                        results = db.search(query)
                        
                        # unpack results
                        idxs = [result["id"] for result in results]
                        scores = [result["score"] for result in results]
                        contexts = [result["document"] for result in results]

                        # retriever results
                        true_relevance = true_binary_relevance(idxs, document_id)
                        ndcg = ndcg_score([true_relevance], [scores])
                        
                        # Only save results for examples for which a context could be found
                        if correct_answer != "":
                            retriever_results.append({
                                "model": retriever_id,
                                "ndcg": ndcg,
                                "metric": distance_metric,
                                "hybrid": "yes" if hybrid else "no"
                            })
                        
                        # caching to give generator best possible context
                        best_ndcg = ndcg if ndcg > best_ndcg else best_ndcg
                        if ndcg > best_ndcg:
                            best_ndcg = ndcg
                            best_contexts = contexts
                    
                    # concatenate list of contexts to one string
                    contexts = "\n\n".join(best_contexts)

                    # generate an answer
                    answer = causal_lm(query, contexts)

                    causal_lm_results.append(
                        {
                            "model": causal_id,
                            "question": query,
                            "answer": answer,
                            "context": contexts,
                            "correct_answer": correct_answer if correct_answer != "" else "Not answerable from the given context."
                        }
                    )
                    pbar.update(1)
                del db
                torch.cuda.empty_cache()
    del causal_lm
    torch.cuda.empty_cache()

Retriever: sentence-transformers/all-MiniLM-L6-v2 - Causal LM: google/gemma-7b-it - hybrid: yes


100%|██████████| 500/500 [12:41<00:00,  1.52s/it]


Retriever: sentence-transformers/all-MiniLM-L6-v2 - Causal LM: google/gemma-7b-it - hybrid: no


100%|██████████| 500/500 [12:21<00:00,  1.48s/it]


Retriever: BAAI/bge-base-en-v1.5 - Causal LM: google/gemma-7b-it - hybrid: yes


100%|██████████| 500/500 [13:46<00:00,  1.65s/it]


Retriever: BAAI/bge-base-en-v1.5 - Causal LM: google/gemma-7b-it - hybrid: no


100%|██████████| 500/500 [13:40<00:00,  1.64s/it]


Retriever: WhereIsAI/UAE-Large-V1 - Causal LM: google/gemma-7b-it - hybrid: yes


100%|██████████| 500/500 [14:21<00:00,  1.72s/it]


Retriever: WhereIsAI/UAE-Large-V1 - Causal LM: google/gemma-7b-it - hybrid: no


100%|██████████| 500/500 [14:09<00:00,  1.70s/it]


Retriever: BAAI/bge-m3 - Causal LM: google/gemma-7b-it - hybrid: yes


100%|██████████| 500/500 [14:13<00:00,  1.71s/it]


Retriever: BAAI/bge-m3 - Causal LM: google/gemma-7b-it - hybrid: no


100%|██████████| 500/500 [14:14<00:00,  1.71s/it]


Retriever: sentence-transformers/all-MiniLM-L6-v2 - Causal LM: HuggingFaceH4/zephyr-7b-beta - hybrid: yes


100%|██████████| 500/500 [28:41<00:00,  3.44s/it]


Retriever: sentence-transformers/all-MiniLM-L6-v2 - Causal LM: HuggingFaceH4/zephyr-7b-beta - hybrid: no


100%|██████████| 500/500 [28:43<00:00,  3.45s/it]


Retriever: BAAI/bge-base-en-v1.5 - Causal LM: HuggingFaceH4/zephyr-7b-beta - hybrid: yes


100%|██████████| 500/500 [29:30<00:00,  3.54s/it]


Retriever: BAAI/bge-base-en-v1.5 - Causal LM: HuggingFaceH4/zephyr-7b-beta - hybrid: no


100%|██████████| 500/500 [29:31<00:00,  3.54s/it]


Retriever: WhereIsAI/UAE-Large-V1 - Causal LM: HuggingFaceH4/zephyr-7b-beta - hybrid: yes


100%|██████████| 500/500 [29:44<00:00,  3.57s/it]


Retriever: WhereIsAI/UAE-Large-V1 - Causal LM: HuggingFaceH4/zephyr-7b-beta - hybrid: no


100%|██████████| 500/500 [29:39<00:00,  3.56s/it]


Retriever: BAAI/bge-m3 - Causal LM: HuggingFaceH4/zephyr-7b-beta - hybrid: yes


100%|██████████| 500/500 [29:45<00:00,  3.57s/it]


Retriever: BAAI/bge-m3 - Causal LM: HuggingFaceH4/zephyr-7b-beta - hybrid: no


100%|██████████| 500/500 [29:41<00:00,  3.56s/it]


Retriever: sentence-transformers/all-MiniLM-L6-v2 - Causal LM: mistralai/Mistral-7B-Instruct-v0.2 - hybrid: yes


100%|██████████| 500/500 [16:20<00:00,  1.96s/it]


Retriever: sentence-transformers/all-MiniLM-L6-v2 - Causal LM: mistralai/Mistral-7B-Instruct-v0.2 - hybrid: no


100%|██████████| 500/500 [16:08<00:00,  1.94s/it]


Retriever: BAAI/bge-base-en-v1.5 - Causal LM: mistralai/Mistral-7B-Instruct-v0.2 - hybrid: yes


100%|██████████| 500/500 [16:45<00:00,  2.01s/it]


Retriever: BAAI/bge-base-en-v1.5 - Causal LM: mistralai/Mistral-7B-Instruct-v0.2 - hybrid: no


100%|██████████| 500/500 [16:41<00:00,  2.00s/it]


Retriever: WhereIsAI/UAE-Large-V1 - Causal LM: mistralai/Mistral-7B-Instruct-v0.2 - hybrid: yes


100%|██████████| 500/500 [16:52<00:00,  2.02s/it]


Retriever: WhereIsAI/UAE-Large-V1 - Causal LM: mistralai/Mistral-7B-Instruct-v0.2 - hybrid: no


100%|██████████| 500/500 [16:52<00:00,  2.03s/it]


Retriever: BAAI/bge-m3 - Causal LM: mistralai/Mistral-7B-Instruct-v0.2 - hybrid: yes


100%|██████████| 500/500 [16:56<00:00,  2.03s/it]


Retriever: BAAI/bge-m3 - Causal LM: mistralai/Mistral-7B-Instruct-v0.2 - hybrid: no


100%|██████████| 500/500 [16:46<00:00,  2.01s/it]


Retriever: sentence-transformers/all-MiniLM-L6-v2 - Causal LM: meta-llama/Llama-2-7b-chat-hf - hybrid: yes


100%|██████████| 500/500 [34:24<00:00,  4.13s/it]


Retriever: sentence-transformers/all-MiniLM-L6-v2 - Causal LM: meta-llama/Llama-2-7b-chat-hf - hybrid: no


100%|██████████| 500/500 [34:11<00:00,  4.10s/it]


Retriever: BAAI/bge-base-en-v1.5 - Causal LM: meta-llama/Llama-2-7b-chat-hf - hybrid: yes


100%|██████████| 500/500 [35:21<00:00,  4.24s/it]


Retriever: BAAI/bge-base-en-v1.5 - Causal LM: meta-llama/Llama-2-7b-chat-hf - hybrid: no


100%|██████████| 500/500 [34:24<00:00,  4.13s/it]


Retriever: WhereIsAI/UAE-Large-V1 - Causal LM: meta-llama/Llama-2-7b-chat-hf - hybrid: yes


100%|██████████| 500/500 [34:09<00:00,  4.10s/it]


Retriever: WhereIsAI/UAE-Large-V1 - Causal LM: meta-llama/Llama-2-7b-chat-hf - hybrid: no


100%|██████████| 500/500 [35:33<00:00,  4.27s/it]


Retriever: BAAI/bge-m3 - Causal LM: meta-llama/Llama-2-7b-chat-hf - hybrid: yes


100%|██████████| 500/500 [36:34<00:00,  4.39s/it]


Retriever: BAAI/bge-m3 - Causal LM: meta-llama/Llama-2-7b-chat-hf - hybrid: no


100%|██████████| 500/500 [36:33<00:00,  4.39s/it]


In [11]:
pd.DataFrame(causal_lm_results).to_csv("causal_lm_results_v2.csv")
pd.DataFrame(retriever_results).to_csv("retriever_results_v2.csv")

# Evaluation

In [1]:
import pandas as pd
import numpy as np

causal_lm_results = pd.read_csv("causal_lm_results_v2.csv", index_col=0)
retriever_results = pd.read_csv("retriever_results_v2.csv", index_col=0)

## Retriever

In [2]:
retriever_results.groupby(['model', 'metric', 'hybrid']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,ndcg
model,metric,hybrid,Unnamed: 3_level_1
BAAI/bge-base-en-v1.5,cosine,no,0.855504
BAAI/bge-base-en-v1.5,cosine,yes,0.869621
BAAI/bge-base-en-v1.5,ip,no,0.855504
BAAI/bge-base-en-v1.5,ip,yes,0.869621
BAAI/bge-base-en-v1.5,l2,no,0.855504
BAAI/bge-base-en-v1.5,l2,yes,0.869621
BAAI/bge-m3,cosine,no,0.795663
BAAI/bge-m3,cosine,yes,0.859185
BAAI/bge-m3,ip,no,0.795663
BAAI/bge-m3,ip,yes,0.859185


# Decoder

In [3]:
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=False)

In [4]:
causal_lm_results

Unnamed: 0,model,question,answer,context,correct_answer
0,google/gemma-7b-it,What did the richest 400 Americans have as chi...,<pad><pad><pad><eos>,,grew up in substantial privilege
1,google/gemma-7b-it,How many of the Pitt's planned expeditions wer...,I do know. The provided text does not contain ...,,"Two of the expeditions were successful, with F..."
2,google/gemma-7b-it,Who did the Mongols send to Bukhara as adminis...,"Sure, here is the answer to the question:\n\nT...",,Han Chinese and Khitans
3,google/gemma-7b-it,Where does the Nederrijn change it's name?,I do know. The text does not provide informati...,,Wijk bij Duurstede
4,google/gemma-7b-it,What do capitalist firms substitute equipment ...,"Sure, here is the answer to the question:\n\nI...",,labor inputs
...,...,...,...,...,...
15995,meta-llama/Llama-2-7b-chat-hf,What were select locations connected to?,I'm happy to help! Based on the information pr...,,Not answerable from the given context.
15996,meta-llama/Llama-2-7b-chat-hf,What is a string over a Greek number when cons...,"A string over an Greek letter, in the context ...",,Not answerable from the given context.
15997,meta-llama/Llama-2-7b-chat-hf,What is the process that asks a more specific ...,"Sure, I'd be happy to help! Can you please pro...",,Not answerable from the given context.
15998,meta-llama/Llama-2-7b-chat-hf,In what year was Scott and Duncan's research p...,I'm not sure when Scott and Duncans' research ...,,Not answerable from the given context.


In [5]:
grouped_results = causal_lm_results.groupby(["model"])

results = []

for name, values in grouped_results:
    group_result = {
        "bleu": [],
        "rouge_1_precision": [],
        "rouge_1_recall": [],
        "rouge_1_fmeasure": [],
        "rouge_L_precision": [],
        "rouge_L_recall": [],
        "rouge_L_fmeasure": [],
    }
    for answer, correct_answer in zip(values["answer"], values["correct_answer"]):
        bleu = sentence_bleu(
            references=correct_answer,
            hypothesis=answer
        )
    
        scores = scorer.score(correct_answer, answer)
        precision, recall, fmeasure = scores["rouge1"]
        precision_L, recall_L, fmeasure_L = scores["rougeL"]
        
        group_result["bleu"].append(bleu)
        group_result["rouge_1_precision"].append(precision)
        group_result["rouge_1_recall"].append(recall)
        group_result["rouge_1_fmeasure"].append(fmeasure)
        group_result["rouge_L_precision"].append(precision_L)
        group_result["rouge_L_recall"].append(recall_L)
        group_result["rouge_L_fmeasure"].append(fmeasure_L)
        
    results.append(
        {
            "model": name,
            "bleu": np.mean(group_result["bleu"]),
            "rouge_1_precision": np.mean(group_result["rouge_1_precision"]),
            "rouge_1_recall": np.mean(group_result["rouge_1_recall"]),
            "rouge_1_fmeasure": np.mean(group_result["rouge_1_fmeasure"]),
            "rouge_L_precision": np.mean(group_result["rouge_L_precision"]),
            "rouge_L_recall": np.mean(group_result["rouge_L_recall"]),
            "rouge_L_fmeasure": np.mean(group_result["rouge_L_fmeasure"])
        }
    )
    
pd.DataFrame(results)

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


Unnamed: 0,model,bleu,rouge_1_precision,rouge_1_recall,rouge_1_fmeasure,rouge_L_precision,rouge_L_recall,rouge_L_fmeasure
0,"(HuggingFaceH4/zephyr-7b-beta,)",7.353941e-232,0.024778,0.36664,0.044254,0.022605,0.348476,0.040514
1,"(google/gemma-7b-it,)",8.771579e-232,0.027905,0.183684,0.044423,0.025373,0.173073,0.040591
2,"(meta-llama/Llama-2-7b-chat-hf,)",7.341089999999999e-232,0.02076,0.289247,0.036523,0.019008,0.272986,0.033484
3,"(mistralai/Mistral-7B-Instruct-v0.2,)",8.404847e-232,0.040925,0.238148,0.064073,0.032702,0.204071,0.051511
