In [2]:

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModel
from collections import defaultdict
import torch
import re
from tqdm import tqdm
import faiss
import numpy as np


raw_datasets = load_dataset("squad")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class Dataset:
    
    def __init__(self) -> None:
        self.ids = []
        self.contexts = []
        self.questions = []
        self.answers = []
        self.spans_input_ids = []
        model_checkpoint = "DeepPavlov/distilrubert-tiny-cased-conversational-v1"
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        
    def load_ds(self, size=100):

        id = 0
        contexts = raw_datasets['train']['context'] + raw_datasets['validation']['context']
        questions = raw_datasets['train']['question'] + raw_datasets['validation']['question']
        answers = raw_datasets['train']['answers'] + raw_datasets['validation']['answers']
        if size:
            treshold = size
            self.size = size
        else:
            treshold = len(contexts)
        for i in tqdm(range(treshold)):
            try:
                self.ids.append(id)
                id += 1
                self.contexts.append(contexts[i])
                self.questions.append(questions[i])
                self.answers.append(answers[i]['text'])
                answer_tokenized = self.tokenizer(answers[i]['text'][0])['input_ids'][1:]
                context_tokenized = self.tokenizer(contexts[i])['input_ids'][1:]
                start = context_tokenized.index(answer_tokenized[0])
                end = context_tokenized.index(answer_tokenized[-1])
                self.spans_input_ids.append({
                    'start':start,
                    'end':end
                })
            except:
                None

class Model:
    def __init__(self, ds, hidden_dim=264):
        self.hidden_dim = hidden_dim
        
        model_checkpoint = "DeepPavlov/distilrubert-tiny-cased-conversational-v1"
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

        self.model = AutoModel.from_pretrained(model_checkpoint)
        self.model_1 = AutoModel.from_pretrained(model_checkpoint)
        self.model_2 = AutoModel.from_pretrained(model_checkpoint)
        self.ds = ds            

    def create_dump(self):
        token_id2cnt = defaultdict(int)
        token_id2j = {}
        j = 0
        
        for context in tqdm(self.ds.contexts):
            input_ids = self.tokenizer(context, truncation=True, max_length=512, return_tensors="pt")
            last_hidden_state = self.model(**input_ids).last_hidden_state[0].detach().numpy()
            for token_num, token_id in enumerate(input_ids['input_ids'][0]):
                if not token_id in self.tokenizer.added_tokens_decoder.keys(): 
                    last_hidden_state_token = last_hidden_state[token_num].reshape((1, self.hidden_dim))
                    
                    token_id = token_id.item()
                    if token_id in token_id2j:
                        H[token_id2j[token_id]] += last_hidden_state_token[0]
                    else:
                        if j == 0:
                            H = last_hidden_state_token
                        else:
                            print(last_hidden_state_token.shape, H.shape)
                            j += 1
                            H = np.vstack((H, last_hidden_state_token))
                        token_id2j[token_id] = j
                            
                    token_id2cnt[token_id] += 1
                
        for token_id, cnt in token_id2cnt.items():
            H[token_id2j[token_id]] = H[token_id2j[token_id]] / cnt
        self.index = faiss.IndexFlatIP(self.hidden_dim)
        self.index.add(H)
        self.H = H

    def predict(self, id, k=100, verbose=False):
        question = self.ds.questions[id]
        answer = self.ds.answers[id]
        context = self.ds.contexts[id]
        if verbose:
            print(f"Q: {question}")
            print(f"C: {context}")

        input_ids = self.tokenizer(question, truncation=True, max_length=512, return_tensors="pt")
        last_hidden_state_1 = self.model_1(**input_ids).last_hidden_state.detach().numpy()[0][0].reshape((1, self.hidden_dim))
        last_hidden_state_2 = self.model_2(**input_ids).last_hidden_state.detach().numpy()[0][0].reshape((1, self.hidden_dim))
        S_1, I_1 = self.index.search(last_hidden_state_1, k)
        S_2, I_2 = self.index.search(last_hidden_state_2, k)
        
        context_ids =  self.tokenizer(context)['input_ids']
        
        answer_candidate2cumscore = {}
        for num_i_1, i_1 in enumerate(I_1[0]):
            for num_i_2, i_2 in enumerate(I_2[0]):
                if i_1 in context_ids and i_2 in context_ids:
                    start_index = context_ids.index(i_1)
                    end_index = context_ids.index(i_2)
                    if start_index <= end_index:
                        answer_candidate_ids = context_ids[start_index:end_index]
                        answer_candidate = self.tokenizer.decode(answer_candidate_ids)
                        answer_candidate2cumscore[answer_candidate] = S_1[0][num_i_1] + S_2[0][num_i_2]
                        
        if answer_candidate2cumscore:
            answer, score = sorted(answer_candidate2cumscore.items(), key=lambda x: -x[1])[0]
            answer = re.sub(r'#', '', answer)
            return answer, score
        else:
            return '',  0.

    def evaluate(self, k=100):
        tp = 0
        for id in tqdm(range(self.size)):
            answer, score = self.predict(id, k)
            if answer == self.ds.answers[id]:
                tp += 1
        return tp / self.size
                
dataset = Dataset()
dataset.load_ds(size=1000)
model = Model(dataset)


100%|██████████| 1000/1000 [00:00<00:00, 1773.81it/s]


In [188]:
model.create_dump()

100%|██████████| 1000/1000 [00:16<00:00, 62.22it/s]


In [189]:
model.predict(0, k=100, verbose=True)

Q: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
C: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.


('', 0.0)

In [95]:
model.evaluate()

100%|██████████| 400/400 [00:58<00:00,  6.88it/s]


0.0

In [30]:
class PhraseModel(torch.nn.Module):
    def __init__(self, hidden_dim, ds):
        super(PhraseModel, self).__init__()

        self.hidden_dim = hidden_dim
        
        model_checkpoint = "DeepPavlov/distilrubert-tiny-cased-conversational-v1"
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

        self.model = AutoModel.from_pretrained(model_checkpoint)
        self.model_1 = AutoModel.from_pretrained(model_checkpoint)
        self.model_2 = AutoModel.from_pretrained(model_checkpoint)
        self.ds = ds
        self.softmax = torch.nn.Softmax()
        

    def forward(self, id):
        context_ids = self.tokenizer(self.ds.contexts[id], truncation=True, max_length=512, return_tensors="pt")
        question_ids = self.tokenizer(self.ds.questions[id], truncation=True, max_length=512, return_tensors="pt")
        last_hidden_state = self.model(**context_ids).last_hidden_state[0]
        q_start = self.model_1(**question_ids).last_hidden_state[0][0]
        q_end = self.model_2(**question_ids).last_hidden_state[0][0]
       
        z_start = torch.matmul(last_hidden_state, q_start.T).reshape(-1)
        P_start = self.softmax(z_start)
        loss_start = P_start[self.ds.spans_input_ids[id]['start']]
        z_end = torch.matmul(last_hidden_state, q_end.T).reshape(-1)
        P_end = self.softmax(z_end)
        loss_end = P_end[self.ds.spans_input_ids[id]['start']]
        return loss_start + loss_end
    
phrasemodel = PhraseModel(hidden_dim=264, ds=dataset)
optimizer = torch.optim.Adam(phrasemodel.parameters(), lr=0.001)



In [31]:
phrasemodel(140)

  return self._call_impl(*args, **kwargs)


tensor(0., grad_fn=<AddBackward0>)

In [238]:
class QueryModel(torch.nn.Module):
    def __init__(self, hidden_dim, H, ds):
        super(QueryModel, self).__init__()

        self.hidden_dim = hidden_dim
        
        model_checkpoint = "DeepPavlov/distilrubert-tiny-cased-conversational-v1"
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

        self.model = AutoModel.from_pretrained(model_checkpoint)
        self.model_1 = AutoModel.from_pretrained(model_checkpoint)
        self.model_2 = AutoModel.from_pretrained(model_checkpoint)
        self.H = torch.Tensor(H)
        self.ds = ds

    def forward(self, id):
        question = self.ds.questions[id]
        context_ids = self.tokenizer(self.ds.contexts[id])['input_ids']
        answer_ids = self.tokenizer(self.ds.answers[id])['input_ids']
        k = 100
        input_ids = self.tokenizer(question, truncation=True, max_length=512, return_tensors="pt")
        last_hidden_state_1 = self.model_1(**input_ids).last_hidden_state[0][0].reshape((1, self.hidden_dim))
        last_hidden_state_2 = self.model_2(**input_ids).last_hidden_state[0][0].reshape((1, self.hidden_dim))
        dot1 = torch.matmul(self.H, last_hidden_state_1.T)
        dot2 = torch.matmul(self.H, last_hidden_state_2.T)
        
        I_1 = np.argsort(-dot1.detach().numpy())[:k]
        I_2 = np.argsort(-dot2.detach().numpy())[:k]
        
        S_1 = dot1[I_1]
        S_2 = dot2[I_2]
        
        answer2cumscore = {}
        answer_correct2cumscore = {}
        for num_i_1, i_1 in enumerate(I_1[0]):
            for num_i_2, i_2 in enumerate(I_2[0]):
                if i_1 in context_ids and i_2 in context_ids:
                    start_index = context_ids.index(i_1)
                    end_index = context_ids.index(i_2)
                    if start_index <= end_index:
                        answer_candidate_ids = context_ids[start_index:end_index]
                        answer_candidate = self.tokenizer.decode(answer_candidate_ids)
                        answer2cumscore[answer_candidate] = torch.Tensor(S_1[0][num_i_1] + S_2[0][num_i_2])
                        if answer_candidate_ids == answer_ids:
                            answer_correct2cumscore[answer_candidate] = torch.Tensor(S_1[0][num_i_1] + S_2[0][num_i_2])
        
        print(answer2cumscore)
        print(answer_correct2cumscore)
        if answer_correct2cumscore:
            scores_numenator = torch.vstack(tuple(answer_correct2cumscore.values()))
        else:
            scores_numenator = torch.Tensor([0.], )
        if answer2cumscore:
            scores_denominator = torch.vstack(tuple(answer2cumscore.values()))
        else:
            scores_denominator = torch.Tensor([0.])
        numenator = torch.sum(torch.exp(scores_numenator))
        denominator = torch.sum(torch.exp(scores_denominator))
        loss = - torch.log(numenator / denominator)
        return loss
    
querymodel = QueryModel(hidden_dim=264, H=model.H, ds=dataset)
optimizer = torch.optim.Adam(querymodel.parameters(), lr=0.001)



In [241]:
loss.requires_grad

False

In [239]:
optimizer.zero_grad()
loss.backward()
optimizer.step()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn