## Imports

In [3]:
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from torch.utils.data import DataLoader
from collections import defaultdict
from scipy import spatial
import random
def def_value():
    return []

2022-11-14 23:32:39.413273: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Initialize model

In [4]:
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## SimANS Implementation

In [106]:
def get_batches(model,docs,relevance_dict,batch_size):
    embeddings = model.encode(docs)
    input_batches = []
    for id in relevance_dict.keys():
        if id == 5:
            break
        if not relevance_dict[id] == []:
            cosine_sims = [1 - spatial.distance.cosine(embedding, embeddings[id]) for embedding in embeddings]
            positive = random.choice(relevance_dict[id])
            negatives = []
            sim_ranked_indexes = [index for _, index in sorted(zip(cosine_sims, range(len(cosine_sims))))]
            positive_index = sim_ranked_indexes.index(positive)
            index = 1
            while len(negatives) < batch_size-1:
                if positive_index-index > 0 and positive_index-index not in relevance_dict[id]:
                    if len(negatives) < batch_size-1:
                        negatives.append(positive_index-index)
                if positive_index+index < len(docs) and positive_index+index not in relevance_dict[id]:
                    if len(negatives) < batch_size-1:
                        negatives.append(positive_index+index)
                
                index += 1
            
            batch = [sim_ranked_indexes[neg] for neg in negatives]
            labels = [0.0 for _ in range(batch_size-1)]
            batch.append(positive)
            labels.append(1.0)

            zipped = list(zip(batch, labels))
            random.shuffle(zipped)
            batch, labels = zip(*zipped)

            input_batch = create_input_batch(batch,labels,docs,id)
            input_batches.append(input_batch)

    return [pair for batch in input_batches for pair in batch]
        

In [96]:
def create_input_batch(batches,batch_labels,docs,id):
    input_batch = []
    for i,doc in enumerate(batches):
        input_batch.append(InputExample(texts=[docs[id],docs[doc]],label=batch_labels[i]))
    return input_batch
        

## Data Pre-Processing

In [6]:

def prepare_data(filename):

    with open(filename) as f:
        lines = f.readlines()

    docs = []
    relevances_list = []
    for index, line in enumerate(lines):
        previous_line = index-1
        if lines[previous_line][:2] == ".W":
            sentence = ""
            while not previous_line+1 == len(lines) and not lines[previous_line+1][:1] == ".":
                previous_line += 1
                sentence += " " + lines[previous_line]
            docs.append(sentence)
        elif lines[previous_line][:2] == ".X":
            while not previous_line+1 == len(lines) and not lines[previous_line+1][0] == ".":
                previous_line += 1
                relevances_list.append(lines[previous_line])

    relevance_dict = defaultdict(def_value)
    for relevance in relevances_list:
        metadata = relevance.replace("\n","").split("\t")
        if not metadata[0] == metadata[2]:
            relevance_dict[int(metadata[2])-1].append(int(metadata[0])-1)

    return docs, relevance_dict

### Prepare Data

In [7]:
docs, relevance_dict = prepare_data("CISI.ALL")

## Train Model

In [104]:
def fine_tune_BERT(batch_size,epochs,calc_negatives_per_epoch):
    train_loss = losses.CosineSimilarityLoss(model)
    for epoch in range(epochs):
        print("epoch", epoch+1)
        if epoch % calc_negatives_per_epoch == 0:
            print("calculating Batches")
            input_batches = get_batches(model,docs,relevance_dict,batch_size)
        print("Training Model")
        train_dataloader = DataLoader(input_batches, shuffle=False, batch_size=batch_size)
        model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100)


In [None]:
fine_tune_BERT(batch_size=20,epochs=8,calc_negatives_per_epoch=2)

## Retriever

In [10]:
def retriever(query,corpus_embedded,k):
    query_embedding = model.encode(query)
    cosine_sims = [1 - spatial.distance.cosine(embedding, query_embedding) for embedding in corpus_embedded]
    sim_ranked_indexes = [index for _, index in sorted(zip(cosine_sims, range(len(cosine_sims))))]
    print(sim_ranked_indexes)
    return sim_ranked_indexes[:k]

### Example

In [None]:
corpus_embedded = model.encode(docs)

In [11]:
top_k = retriever(docs[0],corpus_embedded,20)
print(top_k)

1
[1295, 1287, 1118, 1288, 54, 1289, 1085, 442, 1165, 1300, 1294, 1319, 1278, 1286, 1281, 429, 1099, 1305, 680, 1301]
