In [1]:
%pip install transformers

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
from sentence_transformers.util import cos_sim

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
ctx_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-multiset-base')
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-multiset-base')


Some weights of the model checkpoint at facebook/dpr-ctx_encoder-multiset-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.weight', 'ctx_encoder.bert_model.pooler.dense.bias']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenize

In [4]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-multiset-base')
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-multiset-base')


Some weights of the model checkpoint at facebook/dpr-question_encoder-multiset-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
from transformers import DPRReader, DPRReaderTokenizer
reader = DPRReader.from_pretrained('facebook/dpr-reader-multiset-base')
reader_tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-multiset-base')


Some weights of the model checkpoint at facebook/dpr-reader-multiset-base were not used when initializing DPRReader: ['span_predictor.encoder.bert_model.pooler.dense.bias', 'span_predictor.encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRReader from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRReader from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRReaderTokenizer'.


In [6]:

with open(f'snowwhite.txt', 'r', encoding='utf8') as f:
    lines = f.readlines()
    text = ''
    for line in lines:
        if len(line.strip()) > 0:
            text += ' ' + line.strip()
    sentences = list(map(lambda x: x.strip(), text.split('.')))
    
contexts = []
for i in range(0, len(sentences)-3):
    context = '.'.join(sentences[i:i+3])
    contexts.append(context)
    

In [7]:
# Assuming ctx_encoder and ctx_tokenizer are already loaded
inputs = ctx_tokenizer(contexts, return_tensors="pt", truncation=True, max_length=256, padding='max_length')
passage_embeddings = ctx_encoder(**inputs).pooler_output



In [8]:
passage_embeddings.shape

torch.Size([143, 768])

In [9]:
questions = [
    "What did the wicked Queen ask the Magic Mirror?",
    "What did the Queen ordered the Huntsman to do to the Snow White?",
    "Where did she go after the Huntsman left her deep into the forest?",
    "What did she do when she entered the cottage?",
    "What happened to the Snow White after she ate the poisoned apple?",
    "What did the Dwarfs do when they found Snow White lying on the ground?",
    "Who came to rescue Snow White?"
]


In [10]:
question_inputs = question_tokenizer(questions, return_tensors="pt", truncation=True, max_length=256, padding='max_length')
question_embedding = question_encoder(**question_inputs).pooler_output
question_embedding.shape

torch.Size([7, 768])

In [11]:
def find_answer(question, contexts):

    reader_inputs = reader_tokenizer(
        question, 
        contexts, 
        return_tensors="pt", 
    )

    reader_output = reader(**reader_inputs)
    start_logits, end_logits = reader_output.start_logits, reader_output.end_logits

    # Find the tokens with the highest `start` and `end` logits
    answer_start = torch.argmax(start_logits)
    answer_end = torch.argmax(end_logits) + 1  # Add 1 to get inclusive range

    # Convert tokens to answer
    answer_tokens = reader_inputs.input_ids[0, answer_start:answer_end]
    answer = reader_tokenizer.decode(answer_tokens)
    return answer


In [12]:
def find_contexts(question_embedding, passage_embeddings, contexts, k=3):
    probs = cos_sim(question_embedding, passage_embeddings)
    # topk = torch.topk(probs, k=k)
    # indices = topk.indices.tolist()[0]
    # return ' \n'.join([contexts[i] for i in indices])
    index = torch.argmax(probs)
    return contexts[index]


In [13]:

for i, xq_vec in enumerate(question_embedding):
    matched_contexts = find_contexts(xq_vec, passage_embeddings, contexts)
    print('matched contexts:', matched_contexts)
    answer = find_answer(questions[i], matched_contexts)
    print(questions[i])
    print('ans:', answer)
    print('---')

matched contexts: The glass answered, oh, queen, of all here the fairest art thou, but the young queen is fairer by far as I trow.Then the wicked woman uttered a curse, and was so wretched, so utterly wretched that she knew not what to do.At first she would not go to the wedding at all, but she had no peace, and had to go to see the young queen
What did the wicked Queen ask the Magic Mirror?
ans: wicked woman uttered a curse
---
matched contexts: She called a huntsman, and said, take the child away into the forest.I will no longer have her in my sight.Kill her, and bring me back her lung and liver as a token
What did the Queen ordered the Huntsman to do to the Snow White?
ans: take the child away into the forest. i will no longer have her in my sight. kill
---
matched contexts: And envy and pride grew higher and higher in her heart like a weed, so that she had no peace day or night.She called a huntsman, and said, take the child away into the forest.I will no longer have her in my sigh