In [None]:
!pip install datasets

In [None]:
!pip install sentence_transformers

In [107]:
import random
import pandas as pd
import numpy as np
import re
from tabulate import tabulate

import nltk
nltk.download('punkt_tab')
from nltk.tokenize import sent_tokenize

from transformers import pipeline
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util, CrossEncoder

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


## Load MS_MARCO Dataset

In [64]:
dataset = load_dataset("ms_marco",  "v1.1", trust_remote_code=True)

In [None]:
train_data = dataset['train'][:100]

In [None]:
print(train_data.keys())

dict_keys(['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'])


In [None]:
queries = train_data['query']
passages = [entry['passage_text'] for entry in train_data['passages']]
answers = train_data['answers']

### Preview Data

In [115]:
pd.set_option('display.max_colwidth', 100)
pd.set_option('display.colheader_justify', 'left')

df = pd.DataFrame({
    'query': queries,
    'passage': passages,
    'answer': answers,
})

print(tabulate(df.head(), headers='keys', tablefmt='fancy_grid'))

╒════╤═══════════════════════════════════════════════════════╤══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════

## Query Preparation

In [88]:
def clean_query(query):
    # Remove special characters and digits, keep only letters and spaces
    cleaned_query = re.sub(r'[^a-zA-Z\s]', '', query)

    cleaned_query = cleaned_query.lower()

    return cleaned_query

## Answer Preparation

### Embedder

In [None]:
model_name = 'all-mpnet-base-v2'
embedder = SentenceTransformer(model_name)

##### Embed all passages

In [60]:
passage_embeddings = embedder.encode(passages, convert_to_tensor=True)

### Function to Extract the 3 most relevant sentences from the passage

In [None]:
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [81]:
def extract_relevant_sentence(query, passage, top_k=3):
    # Split passage into sentences
    sentences = sent_tokenize(passage)

    # Create sentence-query pairs
    sentence_query_pairs = [(query, sentence) for sentence in sentences]

    # Rank sentences using the Cross Encoder
    scores = cross_encoder.predict(sentence_query_pairs)

    # Get indices of sentences with top-k highest relevance (based on cross-encoder scores)
    top_sentence_indices = np.argsort(scores)[::-1][:top_k]

    # Get the most relevant sentences based on the top indices
    relevant_sentences = [sentences[i] for i in top_sentence_indices]

    return " ".join(relevant_sentences)

### Set up Summarizer

In [None]:
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cpu


In [None]:
def summarize_text(text, max_length, min_length):
    summary = summarizer(text, max_length=max_length, min_length=min_length, do_sample=False) # Deterministic not Random Output
    return summary[0]['summary_text']

## Testing the Model

### Query based on the queries that came with the dataset

In [92]:
queries_to_use = random.sample(queries, 5)
print(queries_to_use)

['what is the salary of a person with a biology degree', 'what are monocytes', 'what is slime', 'temperature of neptune in fahrenheit', 'gayla name meaning']


In [93]:
preview = []
queries_embeddings = embedder.encode(queries_to_use, convert_to_tensor=True)
hit = util.semantic_search(queries_embeddings, passage_embeddings, top_k=1)

for i in range(len(queries_to_use)):
    query = clean_query(queries_to_use[i])

    # Get most relevant passage
    result = hit[i][0]
    corpus_id = result['corpus_id']  # Use corpus_id to find the passage
    passage = ' '.join(passages[corpus_id])

    # Extract the 3 most relevant sentences from the passage
    best_sentences = extract_relevant_sentence(query, passage, top_k=3)

    # Summarize the best sentences
    summary = summarizer(best_sentences, max_length=50, min_length=10)

    # Store the preview data for this query
    preview.append({
        "query": query,
        "best_sentences": best_sentences,
        "summary": summary,
        "query_passage_similarity_score": result["score"]
    })

Your max_length is set to 50, but your input_length is only 45. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=22)


In [94]:
# Remove key 'summary_text'
for entry in preview:
    entry['summary'] = entry['summary'][0]['summary_text']

In [116]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.colheader_justify', 'left')

df_preview = pd.DataFrame(preview)

print(tabulate(df_preview, headers='keys', tablefmt='fancy_grid'))

╒════╤══════════════════════════════════════════════════════╤═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╤════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╤══════════════════════════════════╕
│    │ query                                                │ best_sentences                                                                                                

### Input your own query

In [96]:
my_query = clean_query(input("My Question: "))
my_query_embedding = embedder.encode([my_query], convert_to_tensor=True)
my_hits = util.semantic_search(my_query_embedding, passage_embeddings, top_k=1)

my_preview = []

for result in my_hits[0]: # Most relevant passage
    corpus_id = result["corpus_id"]
    passage = " ".join(passages[corpus_id])

    # Extract the 3 most relevant sentences
    best_sentences = extract_relevant_sentence(my_query, passage, top_k=3)

    # Summarize the best sentences
    summary = summarize_text(best_sentences, max_length=50, min_length=10)

    my_preview.append({
        "query": my_query,
        "best_sentences": best_sentences,
        "summary": summary,
        "query_passage_similarity_score": result["score"]
    })

My Question: How to be a rainbow?


In [114]:
df_my_preview = pd.DataFrame(my_preview)

print(tabulate(df_my_preview, headers='keys', tablefmt='fancy_grid'))

╒════╤═════════════════════╤════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╤════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╤══════════════════════════════════╕
│    │ query               │ best_sentences                                                                                                                                                                                                                                             │ summary                                                                                                                                                                            │   query_passage_similarity_score │
╞═══