# TP08 - Part 2


In [65]:
import numpy as np

In [66]:
from huggingface_hub import notebook_login

In [67]:
from datasets import load_dataset, load_metric

## Load dataset SQuAD v2

In [68]:
dataset_squadv2 = load_dataset("squad_v2")

Reusing dataset squad_v2 (/root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

In [69]:
dataset_squadv2

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})

## Load dataset **dbpedia**

In [70]:
from beir import util
from beir.datasets.data_loader import GenericDataLoader

dataset_dbpedia = "dbpedia-entity"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset_dbpedia)
data_path = util.download_and_unzip(url, "datasets")
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

  0%|          | 0/4635922 [00:00<?, ?it/s]

### keep only the first 9000 paragraphs (with length >= 50) of the dbpedia dataset

In [71]:
text_validation_dbpedia = [corpus[title]['text'] for title in corpus if len(corpus[title]['text'].split(" ")) >= 50]
text_validation_dbpedia = text_validation_dbpedia[:9000]

In [72]:
len(text_validation_dbpedia) 

9000

### Add the new paragraphs to the SQuAD dataset

In [73]:
for para in text_validation_dbpedia:
  new_item = {'id': '', 'title': '', 'context': para,  'question': '', 'answers': None, 'answer_start': None}
  dataset_squadv2['validation'] = dataset_squadv2['validation'].add_item(new_item)

In [74]:
dataset_squadv2['validation']

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 20873
})

In [75]:
from collections import Counter
print("Total of unique paragraphs: {}".format(len(Counter(dataset_squadv2['validation']['context']))))

Total of unique paragraphs: 10204


## Indexing our corpus

In [76]:
from sentence_transformers import SentenceTransformer, util
# dot product because long answer to retrieve
model = SentenceTransformer('msmarco-distilbert-base-dot-prod-v3')

In [77]:
passage_embedding = model.encode(dataset_squadv2['validation']['context'], device='cuda', show_progress_bar=True)

Batches:   0%|          | 0/653 [00:00<?, ?it/s]

In [78]:
query_embedding = model.encode(dataset_squadv2['validation']['question'], device='cuda', show_progress_bar=True)

Batches:   0%|          | 0/653 [00:00<?, ?it/s]

In [83]:
result_list = []
for question in query_embedding[:1000]:
  result_list.append(np.argsort([util.pytorch_cos_sim(passage, question) for passage in passage_embedding]))

In [84]:
c = Counter(dataset_squadv2['validation']['context'])
counter = [value for value in c.values()]

In [85]:
i = 0
j = 0
c = 0
while(i < len(result_list) and i < 20):
  desc = 0
  limit = counter[c]
  while (desc < limit and i < len(result_list) and i < 20):
    print("valid: [{} - {}] - result: {}".format(j,limit + j, result_list[i][-1]))
    i += 1
    desc += 1
  j += limit
  c += 1

valid: [0 - 9] - result: 18616
valid: [0 - 9] - result: 0
valid: [0 - 9] - result: 18848
valid: [0 - 9] - result: 18848
valid: [0 - 9] - result: 0
valid: [0 - 9] - result: 0
valid: [0 - 9] - result: 15049
valid: [0 - 9] - result: 119
valid: [0 - 9] - result: 13347
valid: [9 - 17] - result: 99
valid: [9 - 17] - result: 102
valid: [9 - 17] - result: 0
valid: [9 - 17] - result: 11
valid: [9 - 17] - result: 16037
valid: [9 - 17] - result: 14668
valid: [9 - 17] - result: 101
valid: [9 - 17] - result: 115
valid: [17 - 21] - result: 19
valid: [17 - 21] - result: 18
valid: [17 - 21] - result: 20


In [86]:
def rrm(result_list):
  sum = 0
  for res in result_list:
    sum +=  (1 / (res[-1] + 1))
  return sum / len(result_list)

In [89]:
print("mean reciprocal rank: {}".format(rrm(result_list)))

mean reciprocal rank: 0.012234932901954727


The score is not good, maybe because there is several identical paragraphs (for example, the eight first questions refer to the same paragraphs)