In [1]:
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_core.documents import Document
import pickle
import torch
from transformers import T5Tokenizer
from datasets import Dataset
from torch.utils.data import DataLoader
import numpy as np
import gc

import sys 
sys.path.insert(0, '/home/dzigen/Desktop/ITMO/ВКР/КМУ2024')

from src.colbert_model import ColBERT, ColBertTokenizer
gc.collect()

53

In [2]:
from transformers import AutoModel, AutoTokenizer

In [None]:
e5_model = AutoModel.from_pretrained('intfloat/e5-base-v2')

In [None]:
e5_tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')

In [12]:
# Each input text should start with "query: " or "passage: ".
# For tasks other than retrieval, you can simply use the "query: " prefix.
input_texts = ['query: how much protein should a female eat',
               'query: summit define',
               "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
               "passage: Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments."]

# Tokenize the input texts
batch_dict = e5_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')

outputs = e5_model(**batch_dict)

In [21]:
def average_pool(last_hidden_states, attention_mask):
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    print(last_hidden.shape)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [26]:
import torch.nn.functional as F

In [28]:
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())

torch.Size([4, 75, 768])
[[89.99481201171875, 67.25457763671875], [68.84741973876953, 91.38423156738281]]


In [45]:
embeddings[:2].unsqueeze(1).shape

torch.Size([2, 1, 768])

In [47]:
F.cosine_similarity(embeddings[:2].unsqueeze(1), embeddings[2:], dim=-1)

tensor([[0.8999, 0.6725],
        [0.6885, 0.9138]], grad_fn=<SumBackward1>)

In [None]:
# сделать E5 retriever
# протестить bm25colbert retriever
# протестить e5 retriever

In [2]:
class E5Retriever:
    pass

In [3]:
class BM25ColBertRetriever:
    def __init__(self, bm25_candidates=3, colbert_candidates=2, colbert_reddim=64, docs_bs=4) -> None:
        self.bm25_cands = bm25_candidates
        self.colbert_cands = colbert_candidates
        self.docs_bs = docs_bs

        self.colbert_model = ColBERT(candidates=bm25_candidates, reduced_dim=colbert_reddim)
        self.colbert_tokenizer = ColBertTokenizer
        self.tokenize = lambda x: self.colbert_tokenizer(
            x, max_length=512, truncation=True, 
            padding=True, return_tensors='pt')

    #
    def load_bm25_base(self, pickle_file):
        with open(pickle_file, 'rb') as bm25result_file:
            self.bm25_model = pickle.load(bm25result_file)

    #
    def load_colbert_model(self, weights_path):
        self.colbert_model.load_state_dict(torch.load(weights_path))

    #
    def texts2documents(self, texts):
        return [Document(page_content=txt, metadata={'tokenized': self.tokenize(txt)}) 
                for txt in texts]

    #
    def make_bm25_base(self, texts, save_pickle_file=None):
        documents = self.texts2documents(texts)
        self.bm25_model = BM25Retriever.from_documents(documents, 
                                                     k=self.bm25_cands)

        if save_pickle_file is not None:
            with open(save_pickle_file, 'wb') as bm25result_file:
                pickle.dump(self.bm25_model, bm25result_file)

    #
    def search(self, query, tokenized_query=None):
        bm25_docs, tokenized_docs = self.bm25_retrieve(query)

        if tokenized_query is None:
            tokenized_query = self.tokenize(query)
        colbert_docs, scores = self.colbert_retrieve(
            tokenized_query, bm25_docs, tokenized_docs)

        return colbert_docs, scores

    #
    def bm25_retrieve(self, query):
        relevant_documents = self.bm25_model.get_relevant_documents(query)
        text_docs = np.array([doc.page_content for doc in relevant_documents])
        tokenized_docs = [doc.metadata['tokenized'] for doc in relevant_documents]
        
        docs_dataset = DocDataset(tokenized_docs)
        docs_laoder = DataLoader(docs_dataset, batch_size=self.docs_bs, 
                                 collate_fn=custom_collate, shuffle=False)

        return text_docs, docs_laoder

    #
    def colbert_retrieve(self, tokenized_query, bm25_docs, docs_loader):
        all_scores = torch.tensor([], requires_grad=True)
        for doc_batch in docs_loader:
            print("doc_batch - ", doc_batch['input_ids'].shape)

            scores = self.colbert_model(tokenized_query['input_ids'], tokenized_query['attention_mask'],
                                    doc_batch['input_ids'], doc_batch['attention_mask'])
            all_scores = torch.cat((all_scores, scores),dim=1)

        flat_scores = all_scores.view(-1)
        _, indices = torch.sort(flat_scores, descending=True)
        relevant_ids = indices[:self.colbert_cands]

        print(bm25_docs)
        print(flat_scores)

        return bm25_docs[relevant_ids], flat_scores.take(relevant_ids)
    
class DocDataset(Dataset):
    def __init__(self, docs):
        self._data = docs

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        return self._data[idx]
    
    def __getitems__(self, idxs):
        return [self.__getitem__(idx) for idx in idxs]
        

def custom_collate(data):

    input_ids = torch.cat([item['input_ids'] for item in data], 0)
    attention_mask = torch.cat([item['attention_mask'] for item in data], 0)

    return {
        "input_ids": input_ids, 
        "attention_mask": attention_mask
    }


In [None]:
del retriever
gc.collect()

In [4]:
retriever = BM25ColBertRetriever(colbert_candidates=5, bm25_candidates=5)

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
retriever.make_bm25_base(["France", "Russia", "United Kingdom", "Japan", "USA"])

In [9]:
res

(array(['United Kingdom', 'France', 'USA', 'Russia', 'Japan'], dtype='<U14'),
 tensor([4.0000, 3.6716, 3.6680, 3.6404, 3.6176], grad_fn=<TakeBackward0>))

In [None]:
del res
gc.collect()

In [8]:
res = retriever.search("United Kingdom")

doc_batch -  torch.Size([4, 512])
=='colbert foward'-func
q_ids -  torch.Size([1, 512])
q_masks -  torch.Size([1, 512])
d_ids -  torch.Size([4, 512])
d_masks -  torch.Size([4, 512])

=='q_enc'-function: 
q_ids -  torch.Size([1, 512])
q_masks -  torch.Size([1, 512])

out:
encoder - torch.Size([1, 512, 768])
dim reduce - torch.Size([1, 512, 64])
ecoded queries =  torch.Size([1, 512, 64])
=='d_enc'-function: 
d_ids -  torch.Size([4, 512])
d_masks -  torch.Size([4, 512])

out: 
encoder -  torch.Size([4, 512, 768])
dim reduce -  torch.Size([4, 512, 64])
encoded documents =  torch.Size([4, 512, 64])
=='compute_score'-function: 
q_hidden -  torch.Size([1, 512, 64])
q_mask -  torch.Size([1, 512])
d_hidden -  torch.Size([4, 512, 64])
d_mask -  torch.Size([4, 512])

cos similarity - torch.Size([4, 512, 512])
mask doc tokens -  torch.Size([4, 512, 512])
find max sim -  torch.Size([4, 512])
mask query tokens -  torch.Size([4, 512])
sum scores -  torch.Size([4])
out scores =  torch.Size([1, 4])
doc

In [None]:
res[1].sum().backward()

In [None]:
retriever.colbert_model.zero_grad()

In [None]:
retriever.colbert_model.query_encoder.encoder.layer[11].output.dense.weight.grad