<a href="https://colab.research.google.com/github/Swastik3025/Wikiret/blob/main/Wikiret.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##IMPORTING DEPENDENCIES

In [1]:
!pip install -U sentence-transformers rank_bm25

Collecting sentence-transformers
  Downloading sentence_transformers-3.1.1-py3-none-any.whl.metadata (10 kB)
Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading sentence_transformers-3.1.1-py3-none-any.whl (245 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m245.3/245.3 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25, sentence-transformers
Successfully installed rank_bm25-0.2.2 sentence-transformers-3.1.1


In [2]:
import json
import os
import torch
import pandas as pd
import gzip
from sentence_transformers import SentenceTransformer,CrossEncoder,util

  from tqdm.autonotebook import tqdm, trange


In [3]:
if not torch.cuda.is_available():
  print("Warning: No GPU found. Please add GPU to your notebook")
else:
  print("GPU is connected")

GPU is connected


##STARTUP CODE EXAMPLE

In [4]:
Query="How many people live in Shanghai?"
docs=["Around 25 million people live in Shanghai","Shanghai is one of the world's major centres for finance and economics."]

# Loading the pretrained model
model=SentenceTransformer('sentence-transformers/multi-qa-MINILM-L6-cos-v1')

# Encoding query and documents
Query_embed=model.encode(Query)
docs_embed=model.encode(docs)

# Computing Cosine similarity/Dot Score between query and all document embeddings
scores=util.dot_score(Query_embed,docs_embed)[0].cpu().tolist()

# Combine docs and scores
doc_score_pairs=list(zip(docs,scores))

# Sort by decreasing scores
doc_score_pairs=sorted(doc_score_pairs,key=lambda x: x[1],reverse=True)

# Output sentences and scores
for doc,score in doc_score_pairs:
  print(doc,score)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/11.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Around 25 million people live in Shanghai 0.941273033618927
Shanghai is one of the world's major centres for finance and economics. 0.5358985066413879


##DATASET

In [5]:
# Using simple english wikipedia as dataset.

wiki_filepath="simplewiki-2020-11-01.jsonl.gz"

if not os.path.exists(wiki_filepath):
  util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz',wiki_filepath)

passages=[]
with gzip.open(wiki_filepath,'rt',encoding='utf8') as fIn:
  for line in fIn:
    data=json.loads(line.strip())
    passages.append(data['paragraphs'][0])

print("Passages:",len(passages))

  0%|          | 0.00/50.2M [00:00<?, ?B/s]

Passages: 169597


##1.Lexical Search Model/Baseline Model

In [6]:
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction import _stop_words
import string
from tqdm.autonotebook import tqdm
import numpy as np

# lower case our text and removing stop-words
def bm25_tokenizer(text):
  tokenized_doc=[]
  for token in text.lower().split():
    token=token.strip(string.punctuation)
    if len(token)>0 and token not in _stop_words.ENGLISH_STOP_WORDS:
      tokenized_doc.append(token)
  return tokenized_doc

tokenized_corpus=[]
for passage in tqdm(passages):
  tokenized_corpus.append(bm25_tokenizer(passage))

bm25=BM25Okapi(tokenized_corpus)

  0%|          | 0/169597 [00:00<?, ?it/s]

##2.Semantic Search with SBERT Bi-Encoder

In [7]:
# We will use the Bi-Encoder to encode all passage

bi_encod=SentenceTransformer('multi-qa-MINILM-L6-cos-v1')
bi_encod.max_seq_length=256  # Truncate long passages to 256 tokens
top_k=128      # Number of passages we want to retrieve with bi-encoder

corpus_embeddings=bi_encod.encode(passages,convert_to_tensor=True)
corpus_embeddings

tensor([[-0.1050, -0.0670,  0.0066,  ..., -0.0062,  0.0119, -0.0517],
        [ 0.0713,  0.0376, -0.0308,  ..., -0.0017,  0.0601, -0.1008],
        [ 0.0027,  0.0527,  0.0639,  ...,  0.0340,  0.0052, -0.0556],
        ...,
        [ 0.0400, -0.0056, -0.0687,  ...,  0.0122,  0.0623, -0.0073],
        [ 0.0438, -0.0652,  0.1016,  ...,  0.0288,  0.0213, -0.0313],
        [-0.0807, -0.0066, -0.0762,  ..., -0.0126,  0.0771, -0.0163]],
       device='cuda:0')

In [8]:
corpus_embeddings.shape

torch.Size([169597, 384])

In [9]:
# We will convert torch tensor to pandas dataframe
embed=corpus_embeddings.detach().cpu().numpy()
pd_corpus=pd.DataFrame(embed)
pd_corpus

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,374,375,376,377,378,379,380,381,382,383
0,-0.105023,-0.066984,0.006559,-0.075239,-0.029778,0.032903,0.041868,0.096896,-0.032055,-0.030800,...,-0.081437,-0.041778,0.028318,-0.013225,-0.037180,-0.013523,-0.030022,-0.006164,0.011879,-0.051670
1,0.071267,0.037586,-0.030801,0.017431,0.017352,0.041926,0.016158,-0.008603,-0.015108,0.037280,...,0.042315,-0.069265,-0.020141,-0.001224,0.027026,0.064461,-0.022421,-0.001684,0.060108,-0.100817
2,0.002654,0.052741,0.063910,-0.026432,-0.042960,-0.117778,-0.041497,0.009521,0.081352,0.016946,...,-0.006478,0.076592,0.013565,-0.053890,0.091820,0.013788,0.014588,0.034018,0.005163,-0.055596
3,-0.032643,-0.005084,-0.104261,0.063471,-0.043725,0.012432,0.067527,0.021174,0.057384,-0.026403,...,-0.063869,-0.011750,-0.088418,-0.036881,0.011142,-0.038457,-0.000024,-0.072465,0.022039,0.046768
4,-0.023334,0.070958,0.068794,0.040282,-0.067307,0.007153,-0.046180,-0.107382,-0.000101,-0.074365,...,0.017998,0.160744,-0.020466,0.004663,-0.022195,0.108305,-0.021253,-0.067848,0.011429,0.050955
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
169592,0.031391,-0.020579,-0.022487,-0.080696,0.056671,0.085363,0.043364,0.053081,-0.042765,-0.015982,...,0.018944,-0.031684,-0.036291,0.002544,-0.015884,-0.019092,0.015174,0.016902,0.060815,0.037487
169593,0.078687,0.023156,0.025629,-0.007406,-0.045560,0.038647,0.006824,0.000921,-0.044711,0.000034,...,0.034964,0.049574,0.009712,0.017496,-0.072706,-0.052105,0.018976,0.042470,-0.118604,0.025879
169594,0.039964,-0.005592,-0.068719,-0.069237,0.041660,0.002297,0.076915,0.068725,0.037130,0.031586,...,0.050264,-0.025924,-0.000071,-0.027144,0.033330,0.013986,-0.020076,0.012242,0.062253,-0.007267
169595,0.043826,-0.065224,0.101616,-0.063292,-0.027100,0.122430,0.054029,0.028393,-0.051804,0.040918,...,0.019195,0.005024,0.027869,0.024391,-0.094816,-0.058863,-0.063887,0.028803,0.021326,-0.031320


In [10]:
# Saving in csv file
pd_corpus.to_csv("corpus.csv",index_label=False)
saved_corpus=pd.read_csv("corpus.csv")
saved_corpus

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,374,375,376,377,378,379,380,381,382,383
0,-0.105023,-0.066984,0.006559,-0.075239,-0.029778,0.032903,0.041868,0.096896,-0.032055,-0.030800,...,-0.081437,-0.041778,0.028318,-0.013225,-0.037180,-0.013523,-0.030022,-0.006164,0.011879,-0.051670
1,0.071267,0.037586,-0.030801,0.017431,0.017352,0.041926,0.016158,-0.008603,-0.015108,0.037280,...,0.042315,-0.069265,-0.020141,-0.001224,0.027026,0.064461,-0.022421,-0.001684,0.060108,-0.100817
2,0.002654,0.052741,0.063910,-0.026432,-0.042960,-0.117778,-0.041497,0.009521,0.081352,0.016946,...,-0.006478,0.076592,0.013565,-0.053890,0.091820,0.013788,0.014588,0.034018,0.005163,-0.055596
3,-0.032643,-0.005084,-0.104261,0.063471,-0.043725,0.012432,0.067527,0.021174,0.057384,-0.026403,...,-0.063869,-0.011750,-0.088418,-0.036881,0.011142,-0.038457,-0.000024,-0.072465,0.022039,0.046768
4,-0.023334,0.070958,0.068794,0.040282,-0.067307,0.007153,-0.046180,-0.107382,-0.000101,-0.074365,...,0.017998,0.160744,-0.020466,0.004663,-0.022195,0.108305,-0.021253,-0.067848,0.011429,0.050955
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
169592,0.031391,-0.020579,-0.022487,-0.080696,0.056671,0.085363,0.043364,0.053081,-0.042765,-0.015982,...,0.018944,-0.031684,-0.036291,0.002544,-0.015884,-0.019092,0.015174,0.016902,0.060815,0.037487
169593,0.078687,0.023156,0.025629,-0.007406,-0.045560,0.038647,0.006824,0.000921,-0.044711,0.000034,...,0.034964,0.049574,0.009712,0.017496,-0.072706,-0.052105,0.018976,0.042470,-0.118604,0.025879
169594,0.039964,-0.005592,-0.068719,-0.069237,0.041660,0.002297,0.076915,0.068725,0.037130,0.031586,...,0.050264,-0.025924,-0.000071,-0.027144,0.033330,0.013986,-0.020076,0.012242,0.062253,-0.007267
169595,0.043826,-0.065224,0.101616,-0.063292,-0.027100,0.122430,0.054029,0.028393,-0.051804,0.040918,...,0.019195,0.005024,0.027869,0.024391,-0.094816,-0.058863,-0.063887,0.028803,0.021326,-0.031320


In [11]:
# Query Answering
def search(query):
  print("Input question:",query)

  # Lexical Search
  bm25_scores=bm25.get_scores(bm25_tokenizer(query))
  top_n=np.argpartition(bm25_scores,-5)[-5:]
  bm25_preds=[{'corpus_id':idx,'score':bm25_scores[idx]} for idx in top_n]
  bm25_preds=sorted(bm25_preds,key=lambda x:x['score'],reverse=True)

  print("Top-3 lexical search predictions")
  for pred in bm25_preds[0:3]:
    print("\t{:.3f}\t{}".format(pred['score'],passages[pred['corpus_id']].replace("\n"," ")))

  # Semantic Search with Bi-encoder
  query_embedding=bi_encod.encode(query,convert_to_tensor=True)
  query_embedding_gpu=query_embedding.cuda()  ## speed optimisation

  # Function to perform cosine similarity search between a list of query embeddings and list of corpus embeddings
  preds=util.semantic_search(query_embedding_gpu,corpus_embeddings,top_k=top_k)
  preds=preds[0]   # get the prediction for 1st query since upto 100 queries are processes in parallel

  print("\n-------------------\n")
  print("Top-3 Bi-encoder search predictions")
  preds=sorted(preds,key=lambda x:x['score'],reverse=True)
  for pred in preds[0:3]:
    print("\t{:.3f}\t{}".format(pred['score'],passages[pred['corpus_id']].replace("\n"," ")))
  return preds[:120],query_embedding,query

In [12]:
my_preds,query_embedding,query=search(query="Which is the southernmost country in Europe?")

Input question: Which is the southernmost country in Europe?
Top-3 lexical search predictions
	18.206	Cádiz is a province of southern Spain, in the southwestern part of the autonomous community of Andalusia. It is the southernmost part of mainland Spain, as well as the southernmost part of continental Europe.
	16.304	The Valle di Muggio is located in Ticino in Switzerland and is the southernmost valley of the country.
	14.502	Spain is a country in Europe.

-------------------

Top-3 Bi-encoder search predictions
	0.649	Southern Europe is a region of the European continent. The officialise définition of southern europ is iberic peninsula, italian peninsula and balkanic peninsula. Spain, Portugal, Italy and Greece and over all the Mediterranean countries of the European continent as parts of Southern Europe.
	0.607	Luxembourg (Dutch and German: "Luxemburg", Luxembourgish: Lëtzebuerg, Walloon: "Lussimbork") is the southernmost province of Belgium and Wallonia. With , it is the largest pro

##DATA VISUALISATION

In [13]:
temp=[]
temp=[[key for key in my_preds[0].keys()],*[list(idx.values()) for idx in my_preds]]
new_temp=temp[1:]
new_temp[:5]

[[40019, 0.6488983631134033],
 [53603, 0.6073328256607056],
 [74853, 0.5905330777168274],
 [40016, 0.5872663855552673],
 [68183, 0.5511994957923889]]

In [14]:
corpus_nums=[int(item[0]) for item in new_temp]
corpus_nums

[40019,
 53603,
 74853,
 40016,
 68183,
 14717,
 79056,
 148983,
 26751,
 26715,
 26719,
 14718,
 82611,
 79072,
 116804,
 105550,
 31113,
 100586,
 39994,
 39986,
 59774,
 68256,
 158493,
 59538,
 42495,
 152266,
 101227,
 6297,
 149531,
 92190,
 43605,
 67900,
 22300,
 70467,
 82253,
 24756,
 41476,
 44321,
 48027,
 126161,
 17948,
 138712,
 53237,
 25433,
 82292,
 43511,
 59536,
 31424,
 101225,
 6476,
 125469,
 69659,
 164610,
 42128,
 87522,
 82610,
 138372,
 59591,
 14727,
 20431,
 63252,
 40775,
 59537,
 105531,
 84701,
 59962,
 59255,
 59175,
 124630,
 101221,
 165900,
 79393,
 14742,
 43077,
 56359,
 9462,
 152820,
 26087,
 46690,
 59532,
 82289,
 68200,
 135709,
 37320,
 116809,
 125627,
 160754,
 56159,
 83816,
 56223,
 148133,
 168567,
 15723,
 21204,
 83462,
 56502,
 60883,
 167149,
 67427,
 130631,
 30742,
 15042,
 6321,
 148189,
 106457,
 61707,
 139938,
 131463,
 9641,
 140498,
 100600,
 6620,
 161338,
 56435,
 57790,
 25851,
 40017,
 22011,
 57785,
 57827]

In [15]:
select_corpus_rows=saved_corpus.loc[corpus_nums]
select_corpus_rows

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,374,375,376,377,378,379,380,381,382,383
40019,0.071083,0.019608,-0.024786,-0.040884,0.015264,0.077679,-0.053445,-0.026805,0.055678,-0.093409,...,0.072429,-0.025151,0.027862,-0.041608,0.040765,0.043570,-0.013912,0.106731,-0.064934,0.041150
53603,0.114832,0.050414,0.014039,-0.055407,0.019493,0.042618,-0.061437,0.004949,0.021172,-0.036829,...,0.002637,0.025607,-0.055966,-0.066068,-0.009035,-0.035878,-0.032458,0.039817,-0.003605,0.071169
74853,0.049773,0.118941,-0.076711,-0.023432,-0.057506,0.040908,-0.050808,0.004607,-0.023578,-0.032519,...,-0.011541,0.005898,-0.060594,-0.105599,0.025999,0.074264,-0.000057,0.069740,-0.094150,0.086590
40016,0.043966,0.025232,0.002572,-0.007988,0.046726,0.029939,-0.017251,-0.081140,0.047377,-0.065888,...,-0.031731,-0.042814,-0.000533,-0.009554,0.053028,0.055127,0.040038,0.023960,-0.078571,0.089749
68183,0.070129,0.021692,-0.042192,-0.067544,0.010155,0.061137,-0.017072,-0.085666,0.028424,-0.036392,...,0.034788,-0.024006,0.048197,-0.012877,0.027108,0.017455,-0.049313,0.091547,-0.009843,0.056781
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25851,0.020001,-0.049642,-0.073445,-0.097129,0.055834,-0.041047,-0.092540,-0.041759,0.018013,-0.024255,...,-0.013406,-0.016732,-0.037393,0.037056,-0.023069,-0.007736,0.019467,0.058769,-0.008285,0.041688
40017,0.097543,0.075366,0.005641,0.018544,0.020005,-0.030726,-0.021766,-0.017920,-0.009740,-0.091999,...,0.008052,0.028593,0.023252,0.030047,0.038598,0.019085,-0.053103,0.043428,-0.032566,-0.001432
22011,-0.019035,0.059168,-0.023327,-0.002730,-0.039367,0.079316,-0.012495,-0.023793,0.042873,-0.050651,...,0.012487,-0.044493,0.011198,0.082887,-0.007053,0.077061,0.001842,0.029862,-0.004057,-0.021167
57785,0.008028,0.076490,-0.017677,-0.010264,0.009262,-0.033786,-0.027153,-0.025191,0.002912,0.004834,...,0.002548,0.037848,0.020643,0.022682,0.060635,0.051334,-0.081913,0.043740,-0.026263,-0.007285


In [16]:
df_text=pd.DataFrame({'text':passages})
df_text_nums=df_text.loc[corpus_nums]
print(df_text_nums)

                                                    text
40019  Southern Europe is a region of the European co...
53603  Luxembourg (Dutch and German: "Luxemburg", Lux...
74853  Carinthia is the southernmost State of Austria...
40016  Northern Europe is the northern part of the Eu...
68183  Southeast Europe or Southeastern Europe is a r...
...                                                  ...
25851  It is the closest of the farther non-spherical...
40017  Western Europe is a geographic region of Europ...
22011  Europe is a Swedish rock band. The band was st...
57785  Venizel is a commune. It is found in the regio...
57827  Vénérolles is a commune. It is found in the re...

[120 rows x 1 columns]


In [17]:
embed_query=query_embedding.detach().cpu().numpy()

pd_query=pd.DataFrame(embed_query).T

print(type(pd_query))
print(pd_query.dtypes)
print(pd_query.shape)
print("Column headers from list(df.columns):",list(pd_query.columns))

<class 'pandas.core.frame.DataFrame'>
0      float32
1      float32
2      float32
3      float32
4      float32
        ...   
379    float32
380    float32
381    float32
382    float32
383    float32
Length: 384, dtype: object
(1, 384)
Column headers from list(df.columns): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 16

In [19]:
print(type(select_corpus_rows))
print(select_corpus_rows.dtypes)
print(select_corpus_rows.shape)
print("Column header from list(df.columns):",list(select_corpus_rows.columns))
select_corpus_rows.columns=np.arange(384)
print("Column header from list(df.columns):",list(select_corpus_rows.columns))

<class 'pandas.core.frame.DataFrame'>
0      float64
1      float64
2      float64
3      float64
4      float64
        ...   
379    float64
380    float64
381    float64
382    float64
383    float64
Length: 384, dtype: object
(120, 384)
Column header from list(df.columns): ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118'

In [20]:
df_con=pd.concat([pd_query,select_corpus_rows],ignore_index=False,axis=0)
df_con

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,374,375,376,377,378,379,380,381,382,383
0,0.080532,0.091979,-0.047747,-0.036801,-0.012664,0.013002,-0.046419,0.006748,0.018803,-0.075588,...,-0.001845,-0.006229,-0.014597,-0.023434,0.060017,0.050776,0.022065,0.078658,-0.021233,0.054354
40019,0.071083,0.019608,-0.024786,-0.040884,0.015264,0.077679,-0.053445,-0.026805,0.055678,-0.093409,...,0.072429,-0.025151,0.027862,-0.041608,0.040765,0.043570,-0.013912,0.106731,-0.064934,0.041150
53603,0.114832,0.050414,0.014039,-0.055407,0.019493,0.042618,-0.061437,0.004949,0.021172,-0.036829,...,0.002637,0.025607,-0.055966,-0.066068,-0.009035,-0.035878,-0.032458,0.039817,-0.003605,0.071169
74853,0.049773,0.118941,-0.076711,-0.023432,-0.057506,0.040908,-0.050808,0.004607,-0.023578,-0.032519,...,-0.011541,0.005898,-0.060594,-0.105599,0.025999,0.074264,-0.000057,0.069740,-0.094150,0.086590
40016,0.043966,0.025232,0.002572,-0.007988,0.046726,0.029939,-0.017251,-0.081140,0.047377,-0.065888,...,-0.031731,-0.042814,-0.000533,-0.009554,0.053028,0.055127,0.040038,0.023960,-0.078571,0.089749
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25851,0.020001,-0.049642,-0.073445,-0.097129,0.055834,-0.041047,-0.092540,-0.041759,0.018013,-0.024255,...,-0.013406,-0.016732,-0.037393,0.037056,-0.023069,-0.007736,0.019467,0.058769,-0.008285,0.041688
40017,0.097543,0.075366,0.005641,0.018544,0.020005,-0.030726,-0.021766,-0.017920,-0.009740,-0.091999,...,0.008052,0.028593,0.023252,0.030047,0.038598,0.019085,-0.053103,0.043428,-0.032566,-0.001432
22011,-0.019035,0.059168,-0.023327,-0.002730,-0.039367,0.079316,-0.012495,-0.023793,0.042873,-0.050651,...,0.012487,-0.044493,0.011198,0.082887,-0.007053,0.077061,0.001842,0.029862,-0.004057,-0.021167
57785,0.008028,0.076490,-0.017677,-0.010264,0.009262,-0.033786,-0.027153,-0.025191,0.002912,0.004834,...,0.002548,0.037848,0.020643,0.022682,0.060635,0.051334,-0.081913,0.043740,-0.026263,-0.007285


In [21]:
df_con['text']=df_text_nums
df_con.at[0,'text']=query
df_con

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,375,376,377,378,379,380,381,382,383,text
0,0.080532,0.091979,-0.047747,-0.036801,-0.012664,0.013002,-0.046419,0.006748,0.018803,-0.075588,...,-0.006229,-0.014597,-0.023434,0.060017,0.050776,0.022065,0.078658,-0.021233,0.054354,Which is the southernmost country in Europe?
40019,0.071083,0.019608,-0.024786,-0.040884,0.015264,0.077679,-0.053445,-0.026805,0.055678,-0.093409,...,-0.025151,0.027862,-0.041608,0.040765,0.043570,-0.013912,0.106731,-0.064934,0.041150,Southern Europe is a region of the European co...
53603,0.114832,0.050414,0.014039,-0.055407,0.019493,0.042618,-0.061437,0.004949,0.021172,-0.036829,...,0.025607,-0.055966,-0.066068,-0.009035,-0.035878,-0.032458,0.039817,-0.003605,0.071169,"Luxembourg (Dutch and German: ""Luxemburg"", Lux..."
74853,0.049773,0.118941,-0.076711,-0.023432,-0.057506,0.040908,-0.050808,0.004607,-0.023578,-0.032519,...,0.005898,-0.060594,-0.105599,0.025999,0.074264,-0.000057,0.069740,-0.094150,0.086590,Carinthia is the southernmost State of Austria...
40016,0.043966,0.025232,0.002572,-0.007988,0.046726,0.029939,-0.017251,-0.081140,0.047377,-0.065888,...,-0.042814,-0.000533,-0.009554,0.053028,0.055127,0.040038,0.023960,-0.078571,0.089749,Northern Europe is the northern part of the Eu...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25851,0.020001,-0.049642,-0.073445,-0.097129,0.055834,-0.041047,-0.092540,-0.041759,0.018013,-0.024255,...,-0.016732,-0.037393,0.037056,-0.023069,-0.007736,0.019467,0.058769,-0.008285,0.041688,It is the closest of the farther non-spherical...
40017,0.097543,0.075366,0.005641,0.018544,0.020005,-0.030726,-0.021766,-0.017920,-0.009740,-0.091999,...,0.028593,0.023252,0.030047,0.038598,0.019085,-0.053103,0.043428,-0.032566,-0.001432,Western Europe is a geographic region of Europ...
22011,-0.019035,0.059168,-0.023327,-0.002730,-0.039367,0.079316,-0.012495,-0.023793,0.042873,-0.050651,...,-0.044493,0.011198,0.082887,-0.007053,0.077061,0.001842,0.029862,-0.004057,-0.021167,Europe is a Swedish rock band. The band was st...
57785,0.008028,0.076490,-0.017677,-0.010264,0.009262,-0.033786,-0.027153,-0.025191,0.002912,0.004834,...,0.037848,0.020643,0.022682,0.060635,0.051334,-0.081913,0.043740,-0.026263,-0.007285,Venizel is a commune. It is found in the regio...


In [23]:
df_con.to_csv('df_con.csv',index_label=False)
print("Data saved as csv file")

Data saved as csv file


In [24]:
df_con_reload=pd.read_csv('df_con.csv',index_col=None)
df_con_reload

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,375,376,377,378,379,380,381,382,383,text
0,0.080532,0.091979,-0.047747,-0.036801,-0.012664,0.013002,-0.046419,0.006748,0.018803,-0.075588,...,-0.006229,-0.014597,-0.023434,0.060017,0.050776,0.022065,0.078658,-0.021233,0.054354,Which is the southernmost country in Europe?
40019,0.071083,0.019608,-0.024786,-0.040884,0.015264,0.077679,-0.053445,-0.026805,0.055678,-0.093409,...,-0.025151,0.027862,-0.041608,0.040765,0.043570,-0.013912,0.106731,-0.064934,0.041150,Southern Europe is a region of the European co...
53603,0.114832,0.050414,0.014039,-0.055407,0.019493,0.042618,-0.061437,0.004949,0.021172,-0.036829,...,0.025607,-0.055966,-0.066068,-0.009035,-0.035878,-0.032458,0.039817,-0.003605,0.071169,"Luxembourg (Dutch and German: ""Luxemburg"", Lux..."
74853,0.049773,0.118941,-0.076711,-0.023432,-0.057506,0.040908,-0.050808,0.004607,-0.023578,-0.032519,...,0.005898,-0.060594,-0.105599,0.025999,0.074264,-0.000057,0.069740,-0.094150,0.086590,Carinthia is the southernmost State of Austria...
40016,0.043966,0.025232,0.002572,-0.007988,0.046726,0.029939,-0.017251,-0.081140,0.047377,-0.065888,...,-0.042814,-0.000533,-0.009554,0.053028,0.055127,0.040038,0.023960,-0.078571,0.089749,Northern Europe is the northern part of the Eu...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25851,0.020001,-0.049642,-0.073445,-0.097129,0.055834,-0.041047,-0.092540,-0.041759,0.018013,-0.024255,...,-0.016732,-0.037393,0.037056,-0.023069,-0.007736,0.019467,0.058769,-0.008285,0.041688,It is the closest of the farther non-spherical...
40017,0.097543,0.075366,0.005641,0.018544,0.020005,-0.030726,-0.021766,-0.017920,-0.009740,-0.091999,...,0.028593,0.023252,0.030047,0.038598,0.019085,-0.053103,0.043428,-0.032566,-0.001432,Western Europe is a geographic region of Europ...
22011,-0.019035,0.059168,-0.023327,-0.002730,-0.039367,0.079316,-0.012495,-0.023793,0.042873,-0.050651,...,-0.044493,0.011198,0.082887,-0.007053,0.077061,0.001842,0.029862,-0.004057,-0.021167,Europe is a Swedish rock band. The band was st...
57785,0.008028,0.076490,-0.017677,-0.010264,0.009262,-0.033786,-0.027153,-0.025191,0.002912,0.004834,...,0.037848,0.020643,0.022682,0.060635,0.051334,-0.081913,0.043740,-0.026263,-0.007285,Venizel is a commune. It is found in the regio...


In [25]:
!pip install umap-learn

Collecting umap-learn
  Downloading umap_learn-0.5.6-py3-none-any.whl.metadata (21 kB)
Collecting pynndescent>=0.5 (from umap-learn)
  Downloading pynndescent-0.5.13-py3-none-any.whl.metadata (6.8 kB)
Downloading umap_learn-0.5.6-py3-none-any.whl (85 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/85.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.7/85.7 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pynndescent-0.5.13-py3-none-any.whl (56 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pynndescent, umap-learn
Successfully installed pynndescent-0.5.13 umap-learn-0.5.6


In [27]:
import umap

# Reduce data from 384 dimensions to 3 dimensions
reduct_data=umap.UMAP(n_neighbors=5,min_dist=0.01,n_components=3)

In [28]:
df_con=pd.read_csv('df_con.csv',index_col=None)
print(df_con)

df_con_384=df_con.iloc[:,:384]
df_con_text=df_con.iloc[:,384:385]

print((df_con_text))
text_ls=df_con_text['text'].values.tolist()

              0         1         2         3         4         5         6  \
0      0.080532  0.091979 -0.047747 -0.036801 -0.012664  0.013002 -0.046419   
40019  0.071083  0.019608 -0.024786 -0.040884  0.015264  0.077679 -0.053445   
53603  0.114832  0.050414  0.014039 -0.055407  0.019493  0.042618 -0.061437   
74853  0.049773  0.118941 -0.076711 -0.023432 -0.057506  0.040908 -0.050808   
40016  0.043966  0.025232  0.002572 -0.007988  0.046726  0.029939 -0.017251   
...         ...       ...       ...       ...       ...       ...       ...   
25851  0.020001 -0.049642 -0.073445 -0.097129  0.055834 -0.041047 -0.092540   
40017  0.097543  0.075366  0.005641  0.018544  0.020005 -0.030726 -0.021766   
22011 -0.019035  0.059168 -0.023327 -0.002730 -0.039367  0.079316 -0.012495   
57785  0.008028  0.076490 -0.017677 -0.010264  0.009262 -0.033786 -0.027153   
57827  0.061323  0.031976 -0.048341 -0.041653  0.039515  0.061644 -0.025313   

              7         8         9  ...       375 

In [30]:
embedding=reduct_data.fit_transform(df_con_384)
print(embedding.shape)
print(type(embedding))

(121, 3)
<class 'numpy.ndarray'>


In [31]:
# Pandas DataFrame
df_embedding=pd.DataFrame(embedding)
df_embedding

Unnamed: 0,0,1,2
0,8.319954,7.160865,4.723001
1,7.363319,7.772455,4.667607
2,8.688131,6.901364,4.459262
3,9.231537,7.462975,5.494837
4,7.266113,7.939365,4.334959
...,...,...,...
116,8.623796,7.692831,4.519753
117,6.791823,7.780655,4.692522
118,7.563158,7.308461,4.856507
119,9.447074,2.162612,5.017587


In [36]:
# adding column containing text
df_embedding['text']=text_ls
df_embedding

Unnamed: 0,0,1,2,text
0,8.319954,7.160865,4.723001,Which is the southernmost country in Europe?
1,7.363319,7.772455,4.667607,Southern Europe is a region of the European co...
2,8.688131,6.901364,4.459262,"Luxembourg (Dutch and German: ""Luxemburg"", Lux..."
3,9.231537,7.462975,5.494837,Carinthia is the southernmost State of Austria...
4,7.266113,7.939365,4.334959,Northern Europe is the northern part of the Eu...
...,...,...,...,...
116,8.623796,7.692831,4.519753,It is the closest of the farther non-spherical...
117,6.791823,7.780655,4.692522,Western Europe is a geographic region of Europ...
118,7.563158,7.308461,4.856507,Europe is a Swedish rock band. The band was st...
119,9.447074,2.162612,5.017587,Venizel is a commune. It is found in the regio...


In [38]:
# renaming the columns
df_embedding.columns=['x','y','z','text']
df_embedding

Unnamed: 0,x,y,z,text
0,8.319954,7.160865,4.723001,Which is the southernmost country in Europe?
1,7.363319,7.772455,4.667607,Southern Europe is a region of the European co...
2,8.688131,6.901364,4.459262,"Luxembourg (Dutch and German: ""Luxemburg"", Lux..."
3,9.231537,7.462975,5.494837,Carinthia is the southernmost State of Austria...
4,7.266113,7.939365,4.334959,Northern Europe is the northern part of the Eu...
...,...,...,...,...
116,8.623796,7.692831,4.519753,It is the closest of the farther non-spherical...
117,6.791823,7.780655,4.692522,Western Europe is a geographic region of Europ...
118,7.563158,7.308461,4.856507,Europe is a Swedish rock band. The band was st...
119,9.447074,2.162612,5.017587,Venizel is a commune. It is found in the regio...


In [39]:
# Data Visualisation in 3D
!pip install plotly



In [44]:
import plotly.graph_objects as graph
data_1=graph.Scatter3d(x=df_embedding['x'],y=df_embedding['y'],z=df_embedding['z'],mode='markers',marker_color=df_embedding['x'],text=df_embedding['text'],name='all close embeddings')

# focus on query embedding
df_0=df_embedding[0:1]

data_2=graph.Scatter3d(x=df_0.x,y=df_0.y,z=df_0.z,marker=dict(size=200,color='yellow'),opacity=0.4,text=df_0.text.str[0:50],name='Query Sphere')

# focus on correct answer embedding
df_4=df_embedding[4:5]

data_3=graph.Scatter3d(x=df_4.x,y=df_4.y,z=df_4.z,marker=dict(size=100,color='blue'),opacity=0.4,text=df_4.text.str[0:50],name='Answer Sphere')

data=[data_1,data_2,data_3]
fig=graph.Figure(data)
fig.update_layout(title='Q+A SBERT Bi-Encoder',height=1000)
fig.show()

##3.SBERT Cross-Encoder

In [45]:
# SBERT bi-encoder will retrieve 100 documents
# Using SBERT cross encoder will rerank result based on query relevance

cross_encod=CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

def comp_search(query):
  print("Input question:", query)

  # Lexical Search
  bm25_scores=bm25.get_scores(bm25_tokenizer(query))
  top_n=np.argpartition(bm25_scores,-5)[-5:]
  bm25_preds=[{'corpus_id':idx,'score':bm25_scores[idx]} for idx in top_n]
  bm25_preds=sorted(bm25_preds,key=lambda x:x['score'],reverse=True)

  print("Top-3 lexical search predictions")
  for pred in bm25_preds[0:3]:
    print("\t{:.3f}\t{}".format(pred['score'],passages[pred['corpus_id']].replace("\n"," ")))

  # Semantic Search with Bi-encoder
  query_embedding=bi_encod.encode(query,convert_to_tensor=True)
  query_embedding_gpu=query_embedding.cuda()  ## speed optimisation

  # Function to perform cosine similarity search between a list of query embeddings and list of corpus embeddings
  preds=util.semantic_search(query_embedding_gpu,corpus_embeddings,top_k=top_k)
  preds=preds[0]

  # Reranking of limited set from bi-encoder SBERT
  cross_input=[[query,passages[pred['corpus_id']]] for pred in preds]
  cross_scores=cross_encod.predict(cross_input)

  # Sorting results by the cross encoder scores
  for idx in range(len(cross_scores)):
    preds[idx]['cross-score']=cross_scores[idx]

  # Output of top3 predictions from SBERT bi-encoder
  print("\n----------------------------------------------------------------\n")
  print("Top-3 Bi-encoder search predictions")
  preds=sorted(preds,key=lambda x:x['score'],reverse=True)
  for pred in preds[0:3]:
    print("\t{:.3f}\t{}".format(pred['score'],passages[pred['corpus_id']].replace("\n"," ")))

  #Output of top3 predictions from SBERT cross-encoder
  print("\n----------------------------------------------------------------\n")
  print("Top-3 Cross-encoder reranker search predictions")
  preds=sorted(preds,key=lambda x:x['cross-score'],reverse=True)
  for pred in preds[0:3]:
    print("\t{:.3f}\t{}".format(pred['cross-score'],passages[pred['corpus_id']].replace("\n"," ")))

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]


`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



In [46]:
comp_search(query="Which is the southernmost country in Europe?")

Input question: Which is the southernmost country in Europe?
Top-3 lexical search predictions
	18.206	Cádiz is a province of southern Spain, in the southwestern part of the autonomous community of Andalusia. It is the southernmost part of mainland Spain, as well as the southernmost part of continental Europe.
	16.304	The Valle di Muggio is located in Ticino in Switzerland and is the southernmost valley of the country.
	14.502	Spain is a country in Europe.

----------------------------------------------------------------

Top-3 Bi-encoder search predictions
	0.649	Southern Europe is a region of the European continent. The officialise définition of southern europ is iberic peninsula, italian peninsula and balkanic peninsula. Spain, Portugal, Italy and Greece and over all the Mediterranean countries of the European continent as parts of Southern Europe.
	0.607	Luxembourg (Dutch and German: "Luxemburg", Luxembourgish: Lëtzebuerg, Walloon: "Lussimbork") is the southernmost province of Belgi

In [47]:
comp_search(query="What is the capital of USA?")

Input question: What is the capital of USA?
Top-3 lexical search predictions
	10.602	Usa grew up around Usa Shrine which was established in the 8th century.
	9.881	Miss California USA is a pageant in California, the winner of which is nominated to that year´s Miss USA pageant.
	8.991	Miss Teen USA is an American pageant. It is like the Miss USA pageant, except it is for teenagers (people between the ages of 13 and 19). It was founded on August 30, 1983. The curren Miss Teen USA is Kaliegh Garris, of Connecticut. She was the second Connecticuter to win the pageant.

----------------------------------------------------------------

Top-3 Bi-encoder search predictions
	0.596	Cities in the United States:
	0.574	In the United States:
	0.571	United States of America;

----------------------------------------------------------------

Top-3 Cross-encoder reranker search predictions
	7.298	Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is 

In [48]:
comp_search(query="Name best band of Pakistan?")

Input question: Name best band of Pakistan?
Top-3 lexical search predictions
	10.667	The Bravery was an American rock band. They were from New York City. The band is best known for their song "An Honest Mistake".
	10.267	Blind Melon are an American alternative rock band from Los Angeles, California. The band is best known for their 1993 song "No Rain".
	9.896	Frankie Goes to Hollywood were a British dance band from Liverpool, England. The band is probably best known for their first single "Relax". This BBC banned the song.

----------------------------------------------------------------

Top-3 Bi-encoder search predictions
	0.565	Pakistani rock is a form of rock music from Pakistan. It mixes up Pakistani classical music with American rock. It has its own vibrant elements, with slightly different sounds and tunes. Pakistani rock is mostly sung in Urdu, however large numbers of songs are also sung in languages such as Punjabi, Sindhi, Pashto and Seraiki. Many new music bands are also no

In [49]:
comp_search(query="Which is the toughest Exam in the world?")

Input question: Which is the toughest Exam in the world?
Top-3 lexical search predictions
	15.224	ÖSS Turkish Öğrenci Seçme Sınavı is the exam for entering universities in Turkey. Every year in June, the ÖSS exam is done. OSS is a test exam which has many questions from different subjects such as Maths, Geography, and History.
	13.613	The National League East is a division in the MLB. All of the teams are on the east coast of the U.S. In 2007, 2008, and 2009, the Phillies won the division; they also won the 2008 World Series. The NL East is known as one of the toughest divisions in the MLB.
	13.162	A midterm exam is a test given around the middle of a term in schools.

----------------------------------------------------------------

Top-3 Bi-encoder search predictions
	0.469	An examination (exam) is a test. Many things may be examined, but the word is most often used for an assessment of a person. It measures a test-taker's knowledge, skill, aptitude, physical fitness, or ability or s

In [50]:
comp_search("Strongest empires of all time")

Input question: Strongest empires of all time
Top-3 lexical search predictions
	13.875	The Spanish Empire also known as "Spanish Monarchy" was one of the largest empires in history and became one of the first global empires in world history
	12.765	The Hellenistic period in Ancient Greece (323–146 BC) was the time period between the death of Alexander the Great when the generals of Alexander created their own empires and the Roman conquest of mainland Greece.
	12.633	Age of Empires III is a 2005 computer game made by the company Ensemble Studios. It is published by Microsoft. It is the third game of the Age of Empires games, and has better graphics than the ones before it. It is a real-time strategy game. The plot is from 1500 to 1860. An expansion pack, "Age of Empires III: The WarChiefs", was released for the game on October 17, 2006. The second expansion, "Age of Empires III: The Asian Dynasties", was released on October 23, 2007. "Age of Empires III: Definitive Edition was announce

In [51]:
comp_search("What are the features of Capybaras?")

Input question: What are the features of Capybaras?
Top-3 lexical search predictions
	15.539	Caviidae are a group of rodents that live in South America. Some Caviidae are guinea pigs and capybaras.
	8.622	Selenography is the study of the physical features of the Moon.
	8.562	Freemium is a business model, and a marketing term, used for digital products. It means that the base features of a product can be used without payment, but that more advanced features require payment.

----------------------------------------------------------------

Top-3 Bi-encoder search predictions
	0.563	Capybara ("Hydrochoerus hydrochaeris") is a semi-aquatic rodent of South America. It weighs about a hundred pounds, and is about two feet tall at the shoulder. The capybara is the world's largest rodent.
	0.448	The clypeus is one of the hard parts that makes up the face of an insect. The clypeus is often well-defined by grooves along its horizontal and vertical margins, and is most commonly rectangular in ove