In [23]:
import nltk
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer
from transformers import BartForConditionalGeneration, BartTokenizer
import torch

from  chunking import Chunking

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cuda:1"
# Download NLTK data
# I had to export as an env var where the data were downloaded : export NLTK_DATA=/home/hay4hi/nltk_data
nltk.set_proxy('http://rb-proxy-de.bosch.com:8080')
nltk.download('punkt')
nltk.download('punkt_tab')

# Load retriever models and tokenizers
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to(device)
context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

# Load generator model and tokenizer
generator = BartForConditionalGeneration.from_pretrained('facebook/bart-large').to(device)
generator_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

[nltk_data] Error loading punkt: <urlopen error Tunnel connection
[nltk_data]     failed: 407 Proxy Authentication Required>
[nltk_data] Error loading punkt_tab: <urlopen error Tunnel connection
[nltk_data]     failed: 407 Proxy Authentication Required>
Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint 

In [22]:
# Example large document
large_document = """
Paris is the capital of France. It is known for its art, fashion, and culture. The Eiffel Tower is one of the most famous landmarks in Paris.
Berlin is the capital of Germany. It has a rich history and is known for its museums and historical sites.
Madrid is the capital of Spain. It is famous for its vibrant nightlife and cultural heritage.
"""
chunker = Chunking()

# Step 1: Divide the text into chunks (e.g., sentences)
#chunker.chunking_into_sentences(large_document)
chunker.chunking_sliding_window(large_document, window_size=2, stride=1)
chunks = chunker.chunks
print(chunks)

# Step 2: Encode the chunks using the context encoder
chunk_embeddings = [context_encoder(**context_tokenizer(chunk, return_tensors='pt').to(device)).pooler_output for chunk in chunks]
print(len((chunk_embeddings[0].cpu().detach().numpy()[0])))

for chunk in chunks:
    print(context_tokenizer(chunk, return_tensors='pt'))

['\nParis is the capital of France. It is known for its art, fashion, and culture.', 'It is known for its art, fashion, and culture. The Eiffel Tower is one of the most famous landmarks in Paris.', 'The Eiffel Tower is one of the most famous landmarks in Paris. Berlin is the capital of Germany.', 'Berlin is the capital of Germany. It has a rich history and is known for its museums and historical sites.', 'It has a rich history and is known for its museums and historical sites. Madrid is the capital of Spain.', 'Madrid is the capital of Spain. It is famous for its vibrant nightlife and cultural heritage.']
768
{'input_ids': tensor([[ 101, 3000, 2003, 1996, 3007, 1997, 2605, 1012, 2009, 2003, 2124, 2005,
         2049, 2396, 1010, 4827, 1010, 1998, 3226, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
{'input_ids': tensor([[  101,  2009,

In [18]:
# Input query
query = "What is the capital of France?"

# Step 3: Encode the query using the question encoder
query_embedding = question_encoder(**question_tokenizer(query, return_tensors='pt').to(device)).pooler_output

In [21]:
# Step 4: Retrieve the most relevant chunk (simplified)
similarities = [torch.cosine_similarity(query_embedding, chunk_embedding) for chunk_embedding in chunk_embeddings]
retrieved_chunk = chunks[torch.argmax(torch.tensor(similarities))]
print(similarities)
print(chunk_embeddings[torch.argmax(torch.tensor(similarities))])
print(query_embedding)
print("Retrieved chunk:", retrieved_chunk)

[tensor([0.7176], device='cuda:1', grad_fn=<SumBackward1>), tensor([0.6171], device='cuda:1', grad_fn=<SumBackward1>), tensor([0.6185], device='cuda:1', grad_fn=<SumBackward1>), tensor([0.5604], device='cuda:1', grad_fn=<SumBackward1>), tensor([0.6354], device='cuda:1', grad_fn=<SumBackward1>), tensor([0.6012], device='cuda:1', grad_fn=<SumBackward1>)]
tensor([[ 3.7717e-01, -3.6294e-01, -1.4239e-01, -8.1584e-02, -8.7463e-02,
          2.4833e-01,  2.7685e-01,  9.0941e-01, -6.2721e-01, -6.8360e-01,
         -6.7071e-01, -8.6989e-02, -3.4711e-02,  6.0646e-01,  1.6714e-01,
          7.1914e-02,  6.4191e-01,  2.1475e-01,  1.4365e-01, -5.6612e-02,
         -6.9641e-01,  2.2993e-01,  1.6472e-01, -3.0849e-01,  3.7029e-01,
         -4.5719e-01, -5.9384e-01,  2.6634e-02,  1.7385e-01,  3.4582e-01,
         -9.9865e-02, -2.4540e-01, -5.0825e-01, -1.2908e-02,  1.1949e-01,
         -1.4771e-01,  1.4425e-01, -1.6174e-01,  1.9517e-01, -4.3521e-01,
          4.1186e-02, -4.7788e-01,  5.6292e-01,  9.70

In [20]:
input_ids = generator_tokenizer(query + " " + retrieved_chunk, return_tensors='pt').input_ids.to(device)
output_ids = generator.generate(input_ids)
response = generator_tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_ids)

print(response)

tensor([[    2,     0,  2264,    16,     5,   812,     9,  1470,   116,  1437,
         46303, 42199, 42593, 32826,     6,  1470,  1437, 46303, 36440,     2]],
       device='cuda:1')
What is the capital of France? ___________________________________________________________________________________Paris, France ________________________________________________________________________
