## Imports

In [1]:
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from torch.utils.data import DataLoader
from collections import defaultdict
from scipy import spatial
import random
import re
from scipy import spatial
stopwords = ["ourselves", "hers", "between", "yourself", "but", "again", 
            "there", "about", "once", "during", "out", "very", "having", 
            "with", "they", "own", "an", "be", "some", "for", "do", "its", 
            "yours", "such", "into", "of", "most", "itself", "other", "off", 
            "is", "s", "am", "or", "who", "as", "from", "him", "each", "the", 
            "themselves", "until", "below", "are", "we", "these", "your", "his", 
            "through", "don", "nor", "me", "were", "her", "more", "himself", "this", 
            "down", "should", "our", "their", "while", "above", "both", "up", "to", 
            "ours", "had", "she", "all", "no", "when", "at", "any", "before", "them", 
            "same", "and", "been", "have", "in", "will", "on", "does", "yourselves", 
            "then", "that", "because", "what", "over", "why", "so", "can", "did", "not", 
            "now", "under", "he", "you", "herself", "has", "just", "where", "too", "only", 
            "myself", "which", "those", "i", "after", "few", "whom", "t", "being", "if", 
            "theirs", "my", "against", "a", "by", "doing", "it", "how", "further", "was", 
            "here", "than"]
def def_value():
    return []

2022-11-15 18:06:33.923088: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Initialize model

In [9]:
BERT_model = SentenceTransformer('all-MiniLM-L6-v2')

## SimANS Implementation

In [50]:
def get_batches(BERT_model,docs,relevance_dict,batch_size):
    embeddings = BERT_model.encode(docs)
    input_batches = []
    for id in relevance_dict.keys():
        if id == 5:
            break
        if not relevance_dict[id] == []:
            cosine_sims = [1 - spatial.distance.cosine(embedding, embeddings[id]) for embedding in embeddings]
            positive = random.choice(relevance_dict[id])
            negatives = []
            sim_ranked_indexes = [index for _, index in sorted(zip(cosine_sims, range(len(cosine_sims))))]
            positive_index = sim_ranked_indexes.index(positive)
            index = 1
            while len(negatives) < batch_size-1:
                if positive_index-index > 0 and positive_index-index not in relevance_dict[id]:
                    if len(negatives) < batch_size-1:
                        negatives.append(positive_index-index)
                if positive_index+index < len(docs) and positive_index+index not in relevance_dict[id]:
                    if len(negatives) < batch_size-1:
                        negatives.append(positive_index+index)
                
                index += 1
            
            batch = [sim_ranked_indexes[neg] for neg in negatives]
            labels = [0.0 for _ in range(batch_size-1)]
            batch.append(positive)
            labels.append(1.0)

            zipped = list(zip(batch, labels))
            random.shuffle(zipped)
            batch, labels = zip(*zipped)

            input_batch = create_input_batch(batch,labels,docs,id)
            input_batches.append(input_batch)

    return [pair for batch in input_batches for pair in batch]
        

In [51]:
def create_input_batch(batches,batch_labels,docs,id):
    input_batch = []
    for i,doc in enumerate(batches):
        input_batch.append(InputExample(texts=[docs[id],docs[doc]],label=batch_labels[i]))
    return input_batch
        

## Data Pre-Processing

In [5]:

def prepare_data(filename):

    with open(filename) as f:
        lines = f.readlines()

    docs = []
    relevances_list = []
    for index, line in enumerate(lines):
        previous_line = index-1
        if lines[previous_line][:2] == ".W":
            sentence = ""
            while not previous_line+1 == len(lines) and not lines[previous_line+1][:1] == ".":
                previous_line += 1
                sentence += " " + lines[previous_line]
            docs.append(sentence)
        elif lines[previous_line][:2] == ".X":
            while not previous_line+1 == len(lines) and not lines[previous_line+1][0] == ".":
                previous_line += 1
                relevances_list.append(lines[previous_line])

    relevance_dict = defaultdict(def_value)
    for relevance in relevances_list:
        metadata = relevance.replace("\n","").split("\t")
        if not metadata[0] == metadata[2]:
            relevance_dict[int(metadata[2])-1].append(int(metadata[0])-1)

    return docs, relevance_dict

### Prepare Data

In [6]:
docs, relevance_dict = prepare_data("CISI.ALL")

## Train Model

In [54]:
def fine_tune_BERT(BERT_model,docs,relevance_dict,batch_size,epochs,calc_negatives_per_epoch):
    train_loss = losses.CosineSimilarityLoss(BERT_model)
    for epoch in range(epochs):
        print("epoch", epoch+1)
        if epoch % calc_negatives_per_epoch == 0:
            print("calculating Batches")
            input_batches = get_batches(BERT_model,docs,relevance_dict,batch_size)
        print("Training Model")
        train_dataloader = DataLoader(input_batches, shuffle=False, batch_size=batch_size)
        BERT_model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100)


In [55]:
fine_tune_BERT(BERT_model,docs,relevance_dict,batch_size=20,epochs=1,calc_negatives_per_epoch=2)

epoch 1
calculating Batches
Training Model



[A
[A
[A
[A
[A
Iteration: 100%|██████████| 5/5 [00:22<00:00,  4.53s/it]
Epoch: 100%|██████████| 1/1 [00:22<00:00, 22.68s/it]


## Retriever

In [7]:
def retrieve_top_k(BERT_model,query,corpus_embedded,k):
    query_embedding = BERT_model.encode(query)
    cosine_sims = [1 - spatial.distance.cosine(embedding, query_embedding) for embedding in corpus_embedded]
    sim_ranked_indexes = [index for _, index in sorted(zip(cosine_sims, range(len(cosine_sims))),reverse=True)]
    return sim_ranked_indexes[:k]

### Example

In [None]:
corpus_embedded = BERT_model.encode(docs)

In [None]:
success_rate = 0
for query in range(len(docs)):
    if not len(relevance_dict[query]) == 0:
        top_k = retrieve_top_k(BERT_model,docs[query],corpus_embedded,100)
        retreived = 0
        for positive in relevance_dict[query]:
            if positive in top_k:
                retreived += 1
        success_rate += retreived/len(relevance_dict[query])
print("success_rate", success_rate/len(docs))


success_rate 0.31303500346829805


## ColBERT

In [18]:
def preprocess_text(text,is_query,max_query_lenght=0):
    clean_text = re.sub(r'[^\w]', ' ', text)
    clean_text_list = clean_text.split()
    clean_text_list_no_stopwords = [token for token in clean_text_list if token not in stopwords]
    if is_query:
        return clean_text_list_no_stopwords[:max_query_lenght]
    return clean_text_list_no_stopwords

In [25]:
def get_corpus_token_embeddings(model,docs):
    corpus_token_embeddings = []
    for doc in docs:
        doc_embeddings = get_text_token_embedding(model,doc)
        corpus_token_embeddings.append(doc_embeddings)
    return corpus_token_embeddings

In [26]:
def get_text_token_embedding(model,text,is_query=False,max_query_lenght=0):
    doc_tokens = preprocess_text(text,is_query,max_query_lenght)
    doc_embeddings = [model.encode(token) for token in doc_tokens]
    return doc_embeddings

In [27]:
def ColBERT_similarity(query_embeddings,doc_embeddings):
    doc_tokens_tree = spatial.KDTree(doc_embeddings)
    total_sim = 0
    for query_embedding in query_embeddings:
        closest_embedding_idx = doc_tokens_tree.query(query_embedding)[1]
        max_cosine_sim = 1 - spatial.distance.cosine(doc_embeddings[closest_embedding_idx], query_embedding)
        total_sim += max_cosine_sim

    return total_sim

In [46]:
def re_rank(query,top_k_token_embeddings,top_k_idx,max_query_tokens):
    query_embeddings = get_text_token_embedding(BERT_model,query,True,max_query_tokens)
    similarities = []
    for token_embeddings in top_k_token_embeddings:
        similarities.append(ColBERT_similarity(query_embeddings,token_embeddings))
    sim_ranked_indexes = [index for _, index in sorted(zip(similarities, top_k_idx),reverse=True)]
    return sim_ranked_indexes
    

In [19]:
corpus_doc_embeddings = get_corpus_token_embeddings(BERT_model,docs)

## Test

In [None]:
for query_index in range(len(corpus_doc_embeddings)):

    false = 0
    false_count = 0
    true = 0
    true_count = 0

    query_embeddings = get_text_token_embedding(BERT_model,docs[query_index],True,250)

    for doc_idx in range(len(corpus_doc_embeddings)):
        doc_embeddings = corpus_doc_embeddings[doc_idx]
        total_sim = ColBERT_similarity(query_embeddings,doc_embeddings)
        if doc_idx == query_index:
            pass
        elif doc_idx in relevance_dict[query_index]:
            true += total_sim
            true_count += 1
        else:
            false += total_sim
            false_count += 1

    false = false / false_count
    if true_count > 0:
        true = true / true_count

    print("False", false)
    print("True", true)



# Full Model

In [144]:
def document_retrieval_model(query,corpus_embedded,corpus_token_embedded,BERT_model):

    ## HYPER PARAMS ##
    k = 100
    max_query_tokens = 250

    top_k_idx = retrieve_top_k(BERT_model,query,corpus_embedded,k)
    top_k_token_embeddings = [corpus_token_embedded[idx] for idx in top_k_idx]
    top_k_re_ranked = re_rank(query,top_k_token_embeddings,top_k_idx,max_query_tokens)
    
    return top_k_re_ranked

## Test Full Model

In [30]:
corpus_embedded = BERT_model.encode(docs)
corpus_token_embedded = get_corpus_token_embeddings(BERT_model,docs)

### Load new dataset

In [112]:
def prepare_data(docs_file,query_file,relevance_file):

    with open(docs_file) as f:
        doc_lines = f.readlines()

    docs = []
    for index, line in enumerate(doc_lines):
        previous_line = index-1
        if doc_lines[previous_line][:2] == ".W":
            sentence = ""
            while not previous_line+1 == len(doc_lines) and not doc_lines[previous_line+1][:1] == ".":
                previous_line += 1
                sentence += " " + doc_lines[previous_line]
            docs.append(sentence.replace("\n",""))
    
    with open(query_file) as f:
        query_lines = f.readlines()
    
    queries = []
    for index, line in enumerate(query_lines):
        previous_line = index-1
        if query_lines[previous_line][:2] == ".W":
            sentence = ""
            while not previous_line+1 == len(query_lines) and not query_lines[previous_line+1][:1] == ".":
                previous_line += 1
                sentence += " " + query_lines[previous_line]
            queries.append(sentence.replace("\n",""))
    
    with open(relevance_file) as f:
        relevance_lines = f.readlines()

    relevance_dict = defaultdict(def_value)
    relevance_dict_rated = {}
    for line in relevance_lines:
        indexes = line.split()
        relevance_dict[int(indexes[0])-1].append(int(indexes[1])-1)
        relevance_dict_rated[str(int(indexes[0])-1) + "-" + str(int(indexes[1])-1)] = int(indexes[2])


    return docs, queries, relevance_dict, relevance_dict_rated

In [113]:
docs2, queries2, relevance_dict2, relevance_dict_rated = prepare_data("cran/cran.all.1400","cran/cran.qry","cran/cranqrel")

In [77]:
corpus_embedded = BERT_model.encode(docs2)
corpus_token_embedded = get_corpus_token_embeddings(BERT_model,docs2)

In [145]:
query_idx = 87
print(len(relevance_dict2[query_idx]))

query = queries2[query_idx]
top_k_ranked = document_retrieval_model(query,corpus_embedded,corpus_token_embedded,BERT_model)
for idx in top_k_ranked:
    if idx in relevance_dict2[query_idx]:
        print(relevance_dict_rated[str(query_idx) + "-" + str(idx)])
    else:
        print("__")

# print("QUESTION:")
# print(queries2[query_idx])
# for i in range(3):
#     print("ANSWER:")
#     print(docs2[top_k_ranked[i]])

7
-1
__
__
2
4
__
3
__
__
__
__
__
3
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
__
