<a href="https://colab.research.google.com/github/AnirudhKashyap511/CSE-Labs/blob/master/Question3_Complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## Install packages

In [None]:
!pip install -U sentence-transformers rank_bm25 faiss-gpu datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 3.0 MB/s 
[?25hCollecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[K     |████████████████████████████████| 85.5 MB 142 kB/s 
[?25hCollecting datasets
  Downloading datasets-2.6.1-py3-none-any.whl (441 kB)
[K     |████████████████████████████████| 441 kB 86.3 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 76.0 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 83.3 

## Mount your Google drive in order to save data

In [None]:
from google.colab import drive
from pathlib import Path
import os
drive.mount('/content/drive')





Mounted at /content/drive


In [None]:
!mkdir -p drive/MyDrive/ai_agents/hw3
!mkdir -p drive/MyDrive/ai_agents/hw3/.cache

In [None]:
os.chdir("drive/MyDrive/ai_agents/hw3")

## Download a small corpus of Wikipedia articles and split it into snippets

We will use a corpus used by the SentenceTransformers author. This cell constructs a list, `passages`, containined `(title, snippet)` tuples.

In [None]:
import json
import gzip
import torch

if not torch.cuda.is_available():
    print("Warning: No GPU found. Please add GPU to your notebook")


wikipedia_filepath = 'simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
    util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)


passages = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())
        for paragraph in data['paragraphs']:
            # We encode the passages as [title, text]
            passages.append(dict(title=data['title'], passage=paragraph))

In [None]:
from datasets import Dataset

passages = Dataset.from_list(passages)
type(passages)

datasets.arrow_dataset.Dataset

# Problem 3

## **3.1**: Build a BM25 Search Index

Construct a search index that, given a query span, returns a top-n list of support passages.

In [None]:
from typing import List,Tuple
import datasets
## Here is a base class that you should use for both sparse and dense retrieval

class RetrievalIndex:

  def __init__(self, corpus: datasets.arrow_dataset.Dataset):
    self.corpus = corpus
    self.test_corpus=[]
    
  def __getitem__(self, item):
    return [i for i in self.corpus.select([item])][0]


  def preprocess_text(self):
    test_corpus=[]
    for i in range(self.corpus.num_rows):
      test_corpus.append(self.corpus[i]["title"]+"[SEP]"+self.corpus[i]["passage"])
    self.test_corpus = test_corpus
  @classmethod


  def build_index(cls, corpus, **kwargs):
    """
    Class method that constructs a retrieval index from the corpus
    """
    return cls(corpus, **kwargs)

  
  def lookup(self, query_strs: List[str], topk = 5) -> List[List[Tuple[str, str, float]]]:
    """
    Accepts a list of query strings and returns a list of lists of (title, passage, score) tuples
    """
    raise NotImplementedError()

In [None]:
from rank_bm25 import BM25Okapi
import numpy as np
class BM25RetrievalIndex(RetrievalIndex):
  def __init__(self, corpus):
    """
    tokenize your corpus and initialize your BM25 index. 
    Follow the simple usage shown on the library's Github page: https://github.com/dorianbrown/rank_bm25
    """
    super().__init__(corpus)

    raise NotImplementedError()

  def lookup(self, query_strs, topk = 5) -> List[List[Tuple[str, str, float]]]:
    """
    Retrieve document scores from your BM25 index for each of a list of queries.
    make sure that each list of returned items is sorted by document score.
    """
    raise NotImplementedError()

    



In [None]:
bm25_index = BM25RetrievalIndex.build_index(passages)

In [None]:
bm25_index.lookup(["why do birds fly in a v formation?"])

## **3.2**: Building a Dense Retrieval Index

In [None]:
import faiss 
from sentence_transformers import SentenceTransformer, CrossEncoder, util


class DenseRetrievalIndex(RetrievalIndex):
  def __init__(self, corpus: List[Tuple[str,str]], precomputed_index : str =None):
    """ 
    compute the embeddings for each passage in the wiki corpus, then feed them 
    to the `add_faiss_index` builtin function from HuggingFace's Dataset class
    https://huggingface.co/docs/datasets/v1.2.1/faiss_and_ea.html

    (Optional but recommended) if the filepath argument `precomputed_index` is not None, 
    then this should not compute the embeddings but rather call load_faiss_index on the path

    """
    if precomputed_index == None:
      super().__init__(corpus)
      self.new_dataset = None
      self.encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
      self.preprocess_text()
      self.embeddings = self.encoder.encode(self.test_corpus,show_progress_bar=True)
      passages = []
      for i,paragraph in enumerate(self.corpus):
        passages.append(dict(title=paragraph['title'],passage=paragraph["passage"],embeddings=self.embeddings[i]))
      self.new_dataset = Dataset.from_list(passages)
      self.new_dataset.add_faiss_index(column='embeddings')
#     self.corpus = corpus.map(lambda x:{"inputs": x['title']+"[SEP]"+x['passage']})
# #   self.corpus = corpus.map(lambda x:{"inputs":(x['title'],x['passage'])})
#     embeds = self.encoder.encode(self.corpus["inputs"],convert_to_tensor=True, show_progress_bar=True).cuda()
  
#     self.corpus.add_column("embeddings",embeds)  
    
    
    
    #self.corpus = self.corpus.map(lambda x:{"embeddings":},batched=True)
    # self.corpus.add_faiss_index(column="embeddings")



  def save(self, file):
    self.new_dataset.save_faiss_index("embeddings",file)
    """
    (Optional but recommended) helper that saves the index to a file using `save_faiss_index` 
    """
   # raise NotImplementedError()

  def lookup(self, query_strs, topk=5):
    results = []
    for i in query_strs:
      hits = util.semantic_search(self.encoder.encode(query_strs),self.embeddings,score_function = util.dot_score)
      results.append(hits[0][:topk])
    self.format_output(results[0])
    return results[0]

  def format_output(self,result):
    res = []
    for i in result:
      res.append(self.corpus[i['corpus_id']])
    return res

  @classmethod
  def from_file(cls, corpus, file):
      """
      (Optional but recommended) helper that loads the index from the specified filepath 
      """
      assert os.path.exists(file)
      return cls(corpus=corpus, precomputed_index=file)


In [None]:
dense_index = DenseRetrievalIndex.build_index(passages)
# dense_index.save("msmarco_sbert_final.faiss")

## Uncomment this line and initialize this way if you have already computed and saved the index
#dense_index = DenseRetrievalIndex.from_file(passages, 'msmarco_sbert.faiss')

Batches:   0%|          | 0/15927 [00:00<?, ?it/s]

  0%|          | 0/510 [00:00<?, ?it/s]

In [None]:
dense_index.save("msmarco_sbert_final.faiss")


In [None]:
answers = dense_index.lookup(["why do flocks of birds fly in a v formation?"], topk=5)
print(dense_index.format_output(answers))

[{'title': 'Flock', 'passage': 'Upon leaving beta, Flock has won a number of awards:'}, {'title': 'Bird flu', 'passage': 'Bird flu (also called avian influenza, avian flu, bird influenza, or grippe of the birds), is an illness caused by a virus. The virus, called "influenza A" or "type A", usually lives in birds, but sometimes infects mammals, including humans. It is called influenza when it infects humans.'}, {'title': 'Bird', 'passage': 'Most birds can fly. They do this by pushing through the air with their wings. The curved surfaces of the wings cause air currents (wind) which lift the bird. Flapping keeps the air current moving to create lift and also moves the bird forward.'}, {'title': 'Air force', 'passage': 'Aircraft in an air force sometimes fly in a formation. Formations are when the aircraft fly in a pattern. Air forces is part of the military.'}, {'title': 'Vulnerable species', 'passage': 'Vulnerable species of birds include:'}]


'Bird flu (also called avian influenza, avian flu, bird influenza, or grippe of the birds), is an illness caused by a virus. The virus, called "influenza A" or "type A", usually lives in birds, but sometimes infects mammals, including humans. It is called influenza when it infects humans.'

## **3.3** Using a Reranking Cross Encoder

In [None]:
from sentence_transformers import CrossEncoder
import numpy as np

class RerankingDenseRetrievalIndex(DenseRetrievalIndex):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

  def lookup(self, query_strs, topk=5, initial_topk=50):
    """
    retrieve `initial_topk` candidates as in `DenseRetrievalIndex` class, but 
    then rerank them according to scores of `self.cross_encoder`
    """
    dense_index.lookup(query_strs,topk=initial_topk)
    for query in query_strs:
      dense_op = dense_index.format_output(dense_index.lookup(query,topk=40))
      print(len(dense_op))
      scores = (self.cross_encoder.predict([ (query,i["title"]+" "+i["passage"]) for i in dense_op]))
      print(scores)
      ind = np.argpartition(scores, -topk)[topk:]
      ind = ind[ : : -1]
      return [dense_op[i] for i in ind[::-1]]
  

   


In [None]:
## you should not need to recompute the embeddings or index if you implemented the 
## recommended helper functions
ranking_index = RerankingDenseRetrievalIndex.from_file(passages, 'msmarco_sbert_final.faiss')
ranking_index.lookup(["why do flocks of birds fly in a v formation?"],topk=5)

10
[ -9.955448    -9.862479    -2.3252575   -5.529788   -10.541218
 -10.869417   -10.516525    -0.22224829  -9.624777    -4.953411  ]


[{'title': 'Flying squirrel',
  'passage': 'There are reasons which may explain why gliding has evolved in mammals:'},
 {'title': 'Air force',
  'passage': 'Aircraft in an air force sometimes fly in a formation. Formations are when the aircraft fly in a pattern. Air forces is part of the military.'},
 {'title': 'Collective animal behaviour',
  'passage': '5. Special factors come into play with migrating birds, or birds which gather in huge flocks, such as starlings. Bird behavior has a larger component of learning than fish. In addition to factors listed above is the possibility that migrating bird flocks are good at teaching first-year birds how to do the migration successfully. The specific routes may be genetically programmed or learned to varying degrees. The routes taken on forward and return migration are often different.'},
 {'title': 'Bird',
  'passage': 'If a flock of birds were flying over a field, they would be calling "Fly! Fly!" But a hungry bird, seeing something good to 

In [None]:
ranking_index.lookup(["why is the sky blue?", "why do flocks of birds fly in a v formation?"], topk=5)

10
[  4.652691    -0.60267156   9.079623    -2.9247348    6.4590783
   1.6481416   -5.7430983  -10.83732     -5.2103176   -4.330052  ]


[{'title': 'Blue', 'passage': 'Blue is a color of the Jewish religion.'},
 {'title': 'Blue',
  'passage': 'Blue is one of the colors of the rainbow that people can see. It is one of the seven colors of the rainbow along with red, orange, yellow, green, indigo and violet. Apart from indigo and violet, it has the shortest wavelength of these colors (about 470 nanometers).'},
 {'title': 'Blue',
  'passage': "Blue is the color of the Earth's sky and sea. Earth looks blue when seen from outer space by astronauts."},
 {'title': 'Sky',
  'passage': 'The sky, which is made up of gas molecules, is blue because of the random scattering of sunlight by the molecules. Rayleigh scattering defines the amount of scattering of light rays. Blue light scatters much more than red, which is why the sky appears blue on a clear day. Depending on the time of day, the sky may appear different colors. At dawn or dusk the sky may appear red, orange, or even green and purple depending on how low the sun is and ho

# Problem 4

In [None]:
from transformers import AutoTokenizer, AutoModel

qar_tokenizer = AutoTokenizer.from_pretrained('yjernite/retribert-base-uncased')
qar_model = AutoModel.from_pretrained('yjernite/retribert-base-uncased')



In [None]:
# TODO Write an eval loop that retrieves top-1 documents for each ELI5 dev question
# then feeds them (concatenated to the question) to the qar_model wrapped in a `pipeline`. 

# Problem 5

In [None]:
## TODO feed your 5 questions from the end of HW1 through the pipeline and perform
## qualitative analysis