<a href="https://colab.research.google.com/github/Aggregate-Intellect/xir/blob/main/wikiQA_fineturning_retriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install beir
!pip install tensorflow-text
!pip install farm-haystack
!pip install --upgrade pip
!pip install git+https://github.com/deepset-ai/haystack.git

In [99]:
from typing import List
import requests
import pandas as pd
from haystack import Document
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import RAGenerator, DensePassageRetriever
from haystack.utils import fetch_archive_from_http, print_answers, print_documents

from haystack import Document

from haystack.pipelines import ExtractiveQAPipeline
from haystack.pipelines import DocumentSearchPipeline

In [100]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

# Load wikiQA dataset

In [101]:
from datasets import load_dataset

wikiqa_data = load_dataset(
    'wiki_qa',
    split='train',
)
next(iter(wikiqa_data))



{'answer': 'A partly submerged glacier cave on Perito Moreno Glacier .',
 'document_title': 'Glacier cave',
 'label': 0,
 'question': 'how are glacier caves formed?',
 'question_id': 'Q1'}

In [102]:
import json
import pandas as pd

pd_wikiqa_data = pd.DataFrame(wikiqa_data)
pd_wikiqa_data

Unnamed: 0,question_id,question,document_title,answer,label
0,Q1,how are glacier caves formed?,Glacier cave,A partly submerged glacier cave on Perito Moreno Glacier .,0
1,Q1,how are glacier caves formed?,Glacier cave,The ice facade is approximately 60 m high,0
2,Q1,how are glacier caves formed?,Glacier cave,Ice formations in the Titlis glacier cave,0
3,Q1,how are glacier caves formed?,Glacier cave,A glacier cave is a cave formed within the ice of a glacier .,1
4,Q1,how are glacier caves formed?,Glacier cave,"Glacier caves are often called ice caves , but this term is properly used to...",0
...,...,...,...,...,...
20355,Q3043,what is section eight housing,Section 8 (housing),A tenant who leaves a subsidized project will lose access to the project-bas...,0
20356,Q3043,what is section eight housing,Section 8 (housing),The United States Department of Housing and Urban Development and United Sta...,0
20357,Q3044,what is the main type of restaurant,Category:Types of restaurants,Restaurants categorized by type and information about these different types.,0
20358,Q3046,what is us dollar worth based on,History of the United States dollar,U.S. Federal Reserve notes in the mid-1990s,0


# Preprocessing and document store

In [104]:
pd_data = pd.DataFrame(wikiqa_data)

# make sure all sentences ending with .
# if last character is not . ? !, add . at the end
for index in range(0,pd_data.shape[0]):
  last_char = pd_data['answer'][index][-1]
  if not (last_char =='.' or last_char =='!' or last_char =='?'):
    pd_data['answer'][index] = pd_data['answer'][index] + '.'

pd_data.to_csv("wikiQA_dataset.csv", index=False)

processed_wikiqa = pd.DataFrame(columns = ["question_id", "question", "document_title", "document_content", "answer"])
processed_all_wikiqa = pd.DataFrame(columns = ["question_id", "document_title", "document_content"])

for question in pd_data['question_id'].unique().tolist():
    pd_question = pd_data.loc[pd_data['question_id'] == question]   
    content = ' '.join(pd_question["answer"].tolist())
    pd_answer = pd_question[pd_question['label'] == 1]
    processed_all_wikiqa = processed_all_wikiqa.append({"question_id": question,
                                                        "document_title": pd_question["document_title"].iloc[0],
                                                        "document_content" : content}, ignore_index=True)
    if not pd_answer.empty:
        answer = ' '.join(pd_answer['answer'].tolist())
        new_row = {"question_id": pd_question["question_id"].iloc[0], 
                  "question" : pd_question["question"].iloc[0],
                  "document_title" : pd_question["document_title"].iloc[0],
                  "document_content" : content,
                  "answer": answer}
        processed_wikiqa = processed_wikiqa.append(new_row, ignore_index=True)
processed_wikiqa.to_csv("processed_wikiqa.csv", index=False)

docs = []
for index, row in processed_all_wikiqa.iterrows():
    # create haystack document object with text content and doc metadata
    doc = Document(
        content = row['document_content'],
        meta = {
            "document_title": row['document_title']
        }
    )
    docs.append(doc)

document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", return_embedding=True)

# Delete existing documents in documents store
document_store.delete_documents()

document_store.write_documents(docs)
document_store.get_document_count()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


Writing Documents:   0%|          | 0/2118 [00:00<?, ?it/s]

1994

# Fine turning on retriver model

In [105]:
data = []
for question_id in pd_wikiqa_data['question_id'].unique().tolist():
    pd_question = pd_wikiqa_data.loc[pd_wikiqa_data['question_id'] == question_id]   
    pd_neg = pd_question[pd_question['label'] == 0]
    neg_content = ' '.join(pd_neg["answer"].tolist())
    pd_answer = pd_question[pd_question['label'] == 1.0]
    if not pd_answer.empty:
        answer = ' '.join(pd_answer['answer'].tolist())
        new_row = {"question" : pd_question["question"].iloc[0],
                  "neg_doc": neg_content,
                  "pos_doc": answer,
                  "score": 1.0}
        data.append(new_row)

In [106]:
fine_turned_retriever = EmbeddingRetriever(
            document_store=document_store,
            embedding_model="flax-sentence-embeddings/all_datasets_v3_mpnet-base",
            model_format="sentence_transformers"
)
fine_turned_retriever.train(training_data = data, learning_rate = 2e-5, n_epochs = 1, num_warmup_steps = None, batch_size = 1)

INFO - haystack.modeling.utils -  Using devices: CUDA:0
INFO - haystack.modeling.utils -  Number of GPUs: 1
INFO - haystack.nodes.retriever.dense -  Init retriever using embeddings of model flax-sentence-embeddings/all_datasets_v3_mpnet-base
INFO - haystack.nodes.retriever._embedding_encoder -  GPL training/adapting SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
) with 873 examples


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/873 [00:00<?, ?it/s]

In [107]:
from typing import Union
fine_turned_retriever.save("/content/my retriever") 

# Validation

Load test dataset

In [108]:
vali_wikiqa_data = load_dataset(
    'wiki_qa',
    split='test',
)
next(iter(vali_wikiqa_data))



{'answer': 'African immigration to the United States refers to immigrants to the United States who are or were nationals of Africa .',
 'document_title': 'African immigration to the United States',
 'label': 0,
 'question': 'HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US',
 'question_id': 'Q0'}

In [109]:
pd_test_data = pd.DataFrame(vali_wikiqa_data)
pd_test_data

Unnamed: 0,question_id,question,document_title,answer,label
0,Q0,HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US,African immigration to the United States,African immigration to the United States refers to immigrants to the United ...,0
1,Q0,HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US,African immigration to the United States,The term African in the scope of this article refers to geographical or nati...,0
2,Q0,HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US,African immigration to the United States,"From the Immigration and Nationality Act of 1965 to 2007, an estimated total...",0
3,Q0,HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US,African immigration to the United States,African immigrants in the United States come from almost all regions in Afri...,0
4,Q0,HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US,African immigration to the United States,"They include people from different national, linguistic, ethnic, racial, cul...",0
...,...,...,...,...,...
6160,Q3045,what is an open mare?,Mare,"The word can also be used for other female equine animals, particularly mule...",0
6161,Q3045,what is an open mare?,Mare,A broodmare is a mare used for breeding.,0
6162,Q3045,what is an open mare?,Mare,A horse's female parent is known as its dam.,0
6163,Q3045,what is an open mare?,Mare,"An adult male horse is called a stallion , or, if castrated , a gelding .",0


In [110]:
pd_test_data = pd.DataFrame(vali_wikiqa_data)

# make sure all sentences ending with .
# if last character is not . ? !, add . at the end
for index in range(0,pd_test_data.shape[0]):
  last_char = pd_test_data['answer'][index][-1]
  if not (last_char =='.' or last_char =='!' or last_char =='?'):
    pd_test_data['answer'][index] = pd_test_data['answer'][index] + '.'

pd_test_data.to_csv("wikiQA_dataset.csv", index=False)

processed_wikiqa = pd.DataFrame(columns = ["question_id", "question", "document_title", "document_content", "answer"])
processed_all_wikiqa = pd.DataFrame(columns = ["question_id", "document_title", "document_content"])

for question_id in pd_test_data['question_id'].unique().tolist():
    # print(question_id)
    pd_question = pd_test_data.loc[pd_test_data['question_id'] == question_id]   
    # print(pd_question)
    content = ' '.join(pd_question["answer"].tolist())
    pd_answer = pd_question[pd_question['label'] == 1]
    # print(pd_question["document_title"].iloc[0])
    processed_all_wikiqa = processed_all_wikiqa.append({"question_id": question_id,
                                                        "document_title": pd_question["document_title"].iloc[0],
                                                        "document_content" : content}, ignore_index=True)
    if not pd_answer.empty:
        answer = ' '.join(pd_answer['answer'].tolist())
        new_row = {"question_id": pd_question["question_id"].iloc[0], 
                  "question" : pd_question["question"].iloc[0],
                  "document_title" : pd_question["document_title"].iloc[0],
                  "document_content" : content,
                  "answer": answer}
        processed_wikiqa = processed_wikiqa.append(new_row, ignore_index=True)

processed_wikiqa

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


Unnamed: 0,question_id,question,document_title,document_content,answer
0,Q0,HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US,African immigration to the United States,African immigration to the United States refers to immigrants to the United ...,"As such, African immigrants are to be distinguished from African American pe..."
1,Q4,how a water pump works,Pump,"A small, electrically powered pump. A large, electrically driven pump (elect...","Pumps operate by some mechanism (typically reciprocating or rotary ), and co..."
2,Q20,how old was sue lyon when she made lolita,Lolita (1962 film),Lolita is a 1962 comedy-drama film by Stanley Kubrick based on the classic n...,"The actress who played Lolita, Sue Lyon , was fourteen at the time of filming."
3,Q33,how are antibodies used in,antibody,Each antibody binds to a specific antigen ; an interaction similar to a lock...,"An antibody (Ab), also known as an immunoglobulin (Ig), is a large Y-shaped ..."
4,Q59,HOW MUCH IS CENTAVOS IN MEXICO,Mexican peso,The peso ( sign : $; code : MXN) is the currency of Mexico . Modern peso and...,"The peso is subdivided into 100 centavos, represented by "" ¢ ""."
...,...,...,...,...,...
238,Q2990,what is the main component of vaccines,vaccine,Jonas Salk in 1955 holds two bottles of a culture used to grow polio vaccine...,The agent stimulates the body's immune system to recognize the agent as fore...
239,Q2994,what is preciosa crystal?,Preciosa (corporation),Preciosa is the luxury brand name for the range of precision-cut lead crysta...,Preciosa is the luxury brand name for the range of precision-cut lead crysta...
240,Q3004,who are all of the jonas brothers,Jonas Brothers,"The Jonas Brothers are an American pop rock band. Formed in 2005, they have ...","Formed in 2005, they have gained popularity from the Disney Channel children..."
241,Q3008,who is mary matalin married to,Mary Matalin,"This is about the political professional. For the actress, see Marlee Matlin...",She is married to Democratic political consultant James Carville .


In [112]:
docs = []
for index, row in processed_all_wikiqa.iterrows():
    # create haystack document object with text content and doc metadata
    doc = Document(
        content = row['document_content'],
        meta = {
            "document_title": row['document_title']
        }
    )
    docs.append(doc)

document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", return_embedding=True)

# Delete existing documents in documents store
document_store.delete_documents()

document_store.write_documents(docs)
document_store.get_document_count()

Writing Documents:   0%|          | 0/633 [00:00<?, ?it/s]

619

In [113]:
retriever = EmbeddingRetriever(
   document_store=document_store,
   embedding_model="flax-sentence-embeddings/all_datasets_v3_mpnet-base",
   model_format="sentence_transformers"
)

document_store.update_embeddings(retriever, batch_size=128)

IR_pipeline = DocumentSearchPipeline(retriever = retriever)

INFO - haystack.modeling.utils -  Using devices: CUDA:0
INFO - haystack.modeling.utils -  Number of GPUs: 1
INFO - haystack.nodes.retriever.dense -  Init retriever using embeddings of model flax-sentence-embeddings/all_datasets_v3_mpnet-base
INFO - haystack.document_stores.faiss -  Updating embeddings for 619 docs...


Updating Embedding:   0%|          | 0/619 [00:00<?, ? docs/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

In [121]:
processed_wikiqa

Unnamed: 0,question_id,question,document_title,document_content,answer
0,Q0,HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US,African immigration to the United States,African immigration to the United States refers to immigrants to the United ...,"As such, African immigrants are to be distinguished from African American pe..."
1,Q4,how a water pump works,Pump,"A small, electrically powered pump. A large, electrically driven pump (elect...","Pumps operate by some mechanism (typically reciprocating or rotary ), and co..."
2,Q20,how old was sue lyon when she made lolita,Lolita (1962 film),Lolita is a 1962 comedy-drama film by Stanley Kubrick based on the classic n...,"The actress who played Lolita, Sue Lyon , was fourteen at the time of filming."
3,Q33,how are antibodies used in,antibody,Each antibody binds to a specific antigen ; an interaction similar to a lock...,"An antibody (Ab), also known as an immunoglobulin (Ig), is a large Y-shaped ..."
4,Q59,HOW MUCH IS CENTAVOS IN MEXICO,Mexican peso,The peso ( sign : $; code : MXN) is the currency of Mexico . Modern peso and...,"The peso is subdivided into 100 centavos, represented by "" ¢ ""."
...,...,...,...,...,...
238,Q2990,what is the main component of vaccines,vaccine,Jonas Salk in 1955 holds two bottles of a culture used to grow polio vaccine...,The agent stimulates the body's immune system to recognize the agent as fore...
239,Q2994,what is preciosa crystal?,Preciosa (corporation),Preciosa is the luxury brand name for the range of precision-cut lead crysta...,Preciosa is the luxury brand name for the range of precision-cut lead crysta...
240,Q3004,who are all of the jonas brothers,Jonas Brothers,"The Jonas Brothers are an American pop rock band. Formed in 2005, they have ...","Formed in 2005, they have gained popularity from the Disney Channel children..."
241,Q3008,who is mary matalin married to,Mary Matalin,"This is about the political professional. For the actress, see Marlee Matlin...",She is married to Democratic political consultant James Carville .


# Acccuracy of original model

In [None]:
from haystack.document_stores import InMemoryDocumentStore

number_correctly_retrieved_passage = 0
number_correctly_retrieved_sentence = 0
number_documents = 244


for index, row in processed_wikiqa.iterrows():
    query = row['question']
    result = IR_pipeline.run(
             query=query, 
             params={"Retriever": {"top_k": 1}}
             )
    retrieved_document = result["documents"][0].content
    target_document = processed_wikiqa['document_content'][index]
    
    if retrieved_document == target_document:
        number_correctly_retrieved_passage += 1
        
        sentence_document_store = InMemoryDocumentStore()
        sentence_document_store.delete_documents()
        docs = []
        sentence_list = nltk.tokenize.sent_tokenize(retrieved_document)
        for sentence in sentence_list:
            doc = Document(content=sentence,
                           meta={"article_title": processed_wikiqa['document_title']})
            docs.append(doc)
            sentence_document_store.write_documents(docs)
        
        original_sentence_retriever = EmbeddingRetriever(
            document_store=sentence_document_store,
            embedding_model="flax-sentence-embeddings/all_datasets_v3_mpnet-base",
            model_format="sentence_transformers"
        )

        sentence_document_store.update_embeddings(original_sentence_retriever)

        sentence_pipeline = DocumentSearchPipeline(retriever = original_sentence_retriever)
        
        sentence_result = sentence_pipeline.run(query=query, 
                                                params={"Retriever": {"top_k": 1}})
        if sentence_result['documents'] != []:
            retrieved_sentence = sentence_result["documents"][0].content
        target_sentence = processed_wikiqa['answer'][index]
        if str(retrieved_sentence) in " ".join(str(target_sentence).split()):
            number_correctly_retrieved_sentence += 1
            print("Found")
        else:
            print("Not Found")
    if index == number_documents-1:
        break

In [123]:
passage_accuracy = number_correctly_retrieved_passage / number_documents
sentence_accuracy = number_correctly_retrieved_sentence/ number_documents
print(passage_accuracy)
print(sentence_accuracy)

0.9877049180327869
0.6598360655737705


# Accuracy of fine-turned model

In [None]:
from haystack.document_stores import InMemoryDocumentStore

number_correctly_retrieved_passage = 0
number_correctly_retrieved_sentence = 0
number_documents = 244


for index, row in processed_wikiqa.iterrows():
    query = row['question']
    result = IR_pipeline.run(
             query=query, 
             params={"Retriever": {"top_k": 1}}
             )
    # print(result)
    retrieved_document = result["documents"][0].content
    target_document = processed_wikiqa['document_content'][index]
    
    if retrieved_document == target_document:
        number_correctly_retrieved_passage += 1
        
        sentence_document_store = InMemoryDocumentStore()
        sentence_document_store.delete_documents()
        docs = []
        sentence_list = nltk.tokenize.sent_tokenize(retrieved_document)
        # print(sentence_list)
        for sentence in sentence_list:
            doc = Document(content=sentence,
                           meta={"article_title": processed_wikiqa['document_title']})
            docs.append(doc)
            sentence_document_store.write_documents(docs)
        
        sentence_retriever = EmbeddingRetriever(
            document_store=sentence_document_store,
            embedding_model="/content/my retriever",
            model_format="sentence_transformers"
        )

        sentence_document_store.update_embeddings(sentence_retriever)

        sentence_pipeline = DocumentSearchPipeline(retriever = sentence_retriever)
        
        sentence_result = sentence_pipeline.run(query=query, 
                                                params={"Retriever": {"top_k": 1}})
        if sentence_result['documents'] != []:
            retrieved_sentence = sentence_result["documents"][0].content
        target_sentence = processed_wikiqa['answer'][index]
        if str(retrieved_sentence) in " ".join(str(target_sentence).split()):
            number_correctly_retrieved_sentence += 1
            print("Found")
        else:
            print("Not Found")
    if index == number_documents-1:
        break

In [119]:
passage_accuracy = number_correctly_retrieved_passage / number_documents
sentence_accuracy = number_correctly_retrieved_sentence/ number_documents
print(passage_accuracy)
print(sentence_accuracy)

0.9877049180327869
0.7622950819672131
