<a href="https://colab.research.google.com/github/algodigger/playground/blob/main/notebook/reranking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Reranking

#### Let's check how reranking model is working in the follwoing usecase
0. Set up everything and initialize [FlagEmddings](https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/reranker) model for reranking
1. Upload pdf document  (Pdf for simplicity)
2. Split to chunks and convert to openAi embeddings
3. Store embeddings in ChromaDB
4. Query the document

In [None]:
%pip install langchain chromadb transformers

####Sanity check:
 Test the ranker algo on simple list of Q&As using the example from the repo

In [11]:
import torch
from torch.nn.functional import softmax
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
model.eval()

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
    logits = model(**inputs, return_dict=True).logits.view(-1, )
    probabilities = softmax(logits, dim=0).tolist()
    print(f"Scores: {logits.float()}")
    print(f"Probabilities: {probabilities}")


Scores: tensor([-5.6085,  5.7623])
Probabilities: [1.1526946764206514e-05, 0.9999884366989136]


**Observations**: probably for a simple ranking we should consider to use smaller reranker to make it more realistic for a production env

Evalutate time using examples generated by ChatGpt

In [15]:
import time
def evaluate_time(pairs, model=model, tokenizer=tokenizer):
  start_time = time.time()
  with torch.no_grad():
      inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
      logits = model(**inputs, return_dict=True).logits.view(-1, )
      probabilities = softmax(logits, dim=0).tolist()

  end_time = time.time()
  inference_time = end_time - start_time
  print(f"Inference time: {inference_time} seconds")
  print(f"Probabilities: {probabilities}")
  print(f"Best pair is: {probabilities.index(max(probabilities))}")

pairs2 = [['what is panda?', 'hi'],
          ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]

pairs4 = [
    ['tell me about cars', 'Cars have wheels.'],
    ['tell me about cars', 'Cars are used for transportation on roads.'],
    ['tell me about cars', 'Cars come in various shapes, sizes, and colors, and are manufactured by different companies around the world.'],
    ['tell me about cars', 'The car industry is constantly evolving with advancements in technology such as electric and self-driving cars.']
]

pairs6 = [
    ['tell me about math', 'Math is the study of numbers and quantities.'],
    ['tell me about math', 'Math involves operations like addition, subtraction, multiplication, and division.'],
    ['tell me about math', 'Math is used to solve problems and understand patterns in the world.'],
    ['tell me about math', 'Math is a fundamental part of science and technology.'],
    ['tell me about math', 'Advanced math includes topics like calculus, algebra, and geometry.'],
    ['tell me about math', 'Math plays a crucial role in fields such as physics, engineering, and computer science.']
]

evaluate_time(pairs2)
evaluate_time(pairs4)
evaluate_time(pairs6)

Inference time: 1.2678768634796143 seconds
Probabilities: [1.1526946764206514e-05, 0.9999884366989136]
Best pair is: 1
Inference time: 2.345053195953369 seconds
Probabilities: [0.9752643704414368, 0.017187338322401047, 0.007121056783944368, 0.0004271531943231821]
Best pair is: 0
Inference time: 2.7637195587158203 seconds
Probabilities: [0.18119004368782043, 0.08866485208272934, 0.16214492917060852, 0.43197494745254517, 0.046347152441740036, 0.08967802673578262]
Best pair is: 3
