# SET MODE:

In [None]:
SCORING = False                     #Choose True or False
REPRESENTATION = "tf_idf"           #Choose from ["tf_idf", "non_cont_word_emb", "bart_tokenized"]
BART_EMB_TYPE = None                #Select only if REPRESENTATION = "bart_tokenized"! Choose "word" for average word embedding or "doc" for <EOS> Embedding

## Imports

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle5 as pickle
from collections import defaultdict
from torch.optim import AdamW
import numpy as np
from tqdm import tqdm
from transformers import BartModel, BartTokenizerFast
from scipy import sparse

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## BART Helper

In [7]:
def get_bart_embeddings(batch, embedding_type, bart):
    with torch.no_grad():
        input_ids = batch[:,0].to(device)
        attention_mask = batch[:,1].to(device)
        
        doc_embeds = bart(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        if embedding_type == "doc":
            eos_positions = torch.sum(attention_mask, dim=-1).unsqueeze(-1)
            eos_positions = eos_positions - 1
            dummy = eos_positions.unsqueeze(2).expand(eos_positions.size(0), eos_positions.size(1), doc_embeds.size(2))
            eos_embeds= torch.gather(doc_embeds, 1, dummy)
            return eos_embeds.squeeze(1)
        
        elif embedding_type == "word":
            return torch.mean(doc_embeds, dim=1)

# Score Helpers

In [8]:
from scipy.sparse import csr_matrix
def pack_csrs(unpacked_data):
    data = []
    for dp in unpacked_data:
        data.append([
            dp[0], 
            dp[1], 
            csr_matrix(*dp[2]),
            csr_matrix(*dp[3]),
            dp[-1]])
    return data

def unpack_csrs(data):
    unpacked_data = []
    for dp in data:
        unpacked_data.append([
            dp[0], 
            dp[1], 
            ((dp[2].data, dp[2].indices, dp[2].indptr), dp[2].shape),
            ((dp[3].data, dp[3].indices, dp[3].indptr), dp[3].shape),
            dp[-1]])
    return unpacked_data

In [9]:
from operator import pos
def get_avg_doc_len(bows):
    bows = sparse.vstack([bow[3] for bow in bows])
    counts = bows.sum(axis=-1)
    avg = counts.mean()
    return avg

def get_bim_weights(bows):
    N = len(set([bow[1] for bow in bows]))
    pos_docs = sparse.vstack([bow[3] for bow in bows if bool(bow[-1]) is True])
    pos_counts = pos_docs.sum(axis=0).getA().squeeze()
    bim_weights = np.log(((N - pos_counts + 0.5)/(pos_counts + 0.5)) + 1)
    return bim_weights

In [10]:
#Load Bag of Word representations and calculate weights for BM25
bows = pack_csrs(pickle.load(open("data/train_count_vector_unpacked.pickle", "rb")))
AVG_DOC_LEN = get_avg_doc_len(bows)
BIM_WEIGHTS = get_bim_weights(bows)
del bows

#Load pretrained Logistic Regression model
if REPRESENTATION == "tf_idf":
    LOGREG = pickle.load(open("./models/LR_tfidf_fit.pickle", "rb" ))
else:
    LOGREG = pickle.load(open("./models/LR_emb_fit.pickle", "rb" ))

In [10]:
def compute_bm25(query_vec, doc_vec, k=1.5, b=0.75):
    non_zero = [ np.nonzero(t)[0] if np.nonzero(t).size()[0] > 0 else -1 for t in query_vec ] #indices of words occuring in the query
    relevances = []
    for i, q in enumerate(non_zero):
        if q == -1:
            relevances.append(0)
            continue
        doc = doc_vec[i]
        counts = doc[q]
        weights = BIM_WEIGHTS[q]
        doc_len = doc.sum()
        frac = (counts * (k+1))/(counts + k*(doc_len/AVG_DOC_LEN)*b + k*(1-b))
        relevances.append(torch.sum(frac*weights, -1))
    return torch.tensor(np.vstack(relevances))
    
def get_batch_LR_proba(query_vecs, doc_vecs, logreg):
    '''Input : 
    query_vecs, doc_vecs : tfidf vectors of query and doc (2D array)
    logreg : fitted logistic regression

    Output : array of probabilites returned by LR'''
    
    if query_vecs.size() != doc_vecs.size():
        raise ValueError('Arrays are not of the same size')
    X = torch.concat((query_vecs, doc_vecs), dim=1)
    y_scores = logreg.predict_proba(X)
    LR_results = y_scores[:,0]
    return torch.tensor(LR_results)

def compute_cosine_similarity(query_vec, doc_vec):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    return cos(query_vec, doc_vec)


def compute_jaccard_similarity(query_vec, doc_vec):
    query_vec = torch.where(query_vec > 0, 1.0, 0.0)
    doc_vec = torch.where(doc_vec > 0, 1.0, 0.0)

    intersect = query_vec * doc_vec
    union = torch.clamp(query_vec + doc_vec, 0.0, 1.0)    
    result = torch.sum(intersect, dim=1) / torch.sum(union, dim=1)
    return torch.nan_to_num(result, nan=0.0) #Fix Nulldivision

## NN Model

In [6]:
class Net(nn.Module):

    def __init__(self, input_size, scoring=False):
        super(Net, self).__init__()
        if scoring:
            self.fc1 = nn.Linear(input_size, 16) 
            self.fc2 = nn.Linear(16, 8)
            self.fc3 = nn.Linear(8, 8)
            self.fc4 = nn.Linear(8, 1)
        else:
            self.fc1 = nn.Linear(input_size, 256) 
            self.fc2 = nn.Linear(256, 64)
            self.fc3 = nn.Linear(64, 32)
            self.fc4 = nn.Linear(32, 1)

    def forward(self, x, y):
        """Gets query-doc vector concatenation of document x and of document y"""
        
        #Send both query-doc concatenations through same NN
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        y = F.relu(self.fc1(y))
        y = F.relu(self.fc2(y))
        y = F.relu(self.fc3(y))

        #Substract the two output vectors
        z = x - y

        #Send through final layer and through sigmoid to scale between 0 and 1
        z = self.fc4(z)
        z = torch.sigmoid(z)

        return z

In [None]:
def visualize_net():
    """get pretty picture of gradient flow"""
    from torchviz import make_dot
    x=torch.ones(10, requires_grad=True)
    net = Net(10)
    pred = net(x)
    make_dot(pred, params=dict(list(net.named_parameters()))).render("nn", format="png")

## Load the data: Representations

In [13]:
class PairwiseDataset(Dataset):
    def __init__(self, filename, bows_filename=None):
        raw_data = pickle.load(open(filename, "rb" ))
        raw_data = [list(elem) for elem in raw_data]

        raw_bows = None
        if bows_filename:
            raw_bows = pack_csrs(pickle.load(open(bows_filename, "rb")))

        if bows_filename: assert len(raw_data) == len(raw_bows)
        
        #Create dictionary with {qid: [[docid, query_vector, doc_vector, label], [docid, query_vector, doc_vector, label], ...]}
        #For Training and Dev set, it will be always two entries per qid, because we always have only one positive and one negative sample
        sorted_data = defaultdict(list)
        sorted_bow_data = defaultdict(list)
        for i, item in enumerate(raw_data):
            sorted_data[item[0]].append(item[1:])
            if bows_filename:
                sorted_bow_data[item[0]].append(raw_bows[i][1:])
                assert item[0] == raw_bows[i][0]
        del raw_data
        del raw_bows


        self.data = []
        self.bows = []      
        for key, value in sorted_data.items():
            assert len(value) == 2 #Go sure we have really just two docs            
            assert np.array_equal(value[0][1], value[1][1]) # Go both documents really belong to the same query (same query_vector). Just to double check...
            assert value[0][-1] != value[1][-1] #Go sure they have different label
            if bows_filename:
                assert len(sorted_bow_data[key]) == 2
                assert value[0][0] == sorted_bow_data[key][0][0] #Go sure BoW belongs to same datapoint as Representation Vector
            
            #Append: [qid, doc1_id, doc2_id, query_vector, doc1_vector, doc2_vector, label] where label is 0 when doc1 is the positive one and 1 otherwise
            self.data.append([key, value[0][0], value[1][0], value[0][1], value[0][2], value[1][2], not bool(value[0][-1])])                    
            if bows_filename:
                self.bows.append([sorted_bow_data[key][0][1], sorted_bow_data[key][0][2], sorted_bow_data[key][1][2]])

        #Fix empty entries resulting form empty queries/docs        
        for i in range(len(self.data)):
            if type(self.data[i][3]) is float:               
                self.data[i][3] = np.zeros(self.data[0][3].size)                
            if type(self.data[i][4]) is float:               
                self.data[i][4] = np.zeros(self.data[0][4].size)
            if type(self.data[i][5]) is float:               
                self.data[i][5] = np.zeros(self.data[0][5].size)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        ids = self.data[idx][0:3]
        representations = self.data[idx][3:-1]
        representations = [np.array(item) for item in representations]
        label = self.data[idx][-1]
        if self.bows:
            representations.extend([item.toarray().squeeze() for item in self.bows[idx]])
        return ids, representations, label

# Training

In [14]:
def train(representation, scoring, epochs=2, batch_size=16, bart_emb_type=None):
    """Gets as input the representation name (e.g. "tf_idf")"""

    if scoring:
        train_dataset = PairwiseDataset(filename = f"preprocessed_data/train_{representation}.pickle", bows_filename=f"preprocessed_data/train_count_vector_unpacked.pickle")
        #dev_dataset = PairwiseDataset(f"preprocessed_data/dev_{representation}.pickle", bows_filename=f"preprocessed_data/dev_count_vector_unpacked.pickle")
        if bart_emb_type:
            vector_size = 3 #Number of scoring functions is the size of input vector. Times 2, bc always of two documents 
        else:
            vector_size = 4

    else:
        train_dataset = PairwiseDataset(filename = f"preprocessed_data/train_{representation}.pickle")
        #dev_dataset = PairwiseDataset(f"preprocessed_data/dev_{representation}.pickle")
        if bart_emb_type:
            vector_size = 768 * 2 #Chosen Bart Embedding size
        else:
            vector_size = train_dataset.data[0][3].shape[-1] * 2 #Get size of representations#
    
    if bart_emb_type:
        bart = BartModel.from_pretrained("facebook/bart-base").to(device)
        bart.eval()

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    #dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size)

    net = Net(vector_size, scoring).to(device)

    criterion = nn.BCELoss() #Binary Cross Entropy Loss
    optimizer = AdamW(net.parameters())
    for epoch in range(epochs):  # loop over the dataset multiple times
        net.train()
        running_loss = 0.0
        for i, data in enumerate(tqdm(train_dataloader)):
            ids, reps, labels = data

            if bart_emb_type:
                reps[0] = get_bart_embeddings(reps[0], bart_emb_type, bart)
                reps[1] = get_bart_embeddings(reps[1], bart_emb_type, bart)
                reps[2] = get_bart_embeddings(reps[2], bart_emb_type, bart)

            if scoring:
                cosine1 = compute_cosine_similarity(reps[0], reps[1]).unsqueeze(-1).to(device)
                cosine2 = compute_cosine_similarity(reps[0], reps[2]).unsqueeze(-1).to(device)
                jacc1 = compute_jaccard_similarity(reps[3], reps[4]).unsqueeze(-1).to(device)
                jacc2 = compute_jaccard_similarity(reps[3], reps[5]).unsqueeze(-1).to(device)
                bm25_1 = compute_bm25(reps[3], reps[4]).to(device)
                bm25_2 = compute_bm25(reps[3], reps[5]).to(device)
               
                if not bart_emb_type:
                    log_prob1 = get_batch_LR_proba(reps[0], reps[1], LOGREG).unsqueeze(-1).to(device)
                    log_prob2 = get_batch_LR_proba(reps[0], reps[2], LOGREG).unsqueeze(-1).to(device)
                    inputs1 = torch.concat((cosine1, jacc1, bm25_1, log_prob1), dim=1).to(device)
                    inputs2 = torch.concat((cosine2, jacc2, bm25_2, log_prob2), dim=1).to(device)
                else: 
                    inputs1 = torch.concat((cosine1, jacc1, bm25_1), dim=1).to(device)
                    inputs2 = torch.concat((cosine2, jacc2, bm25_2), dim=1).to(device)
            
            else:
                query_vec = torch.tensor(reps[0]).to(device)
                doc1_vec = torch.tensor(reps[1]).to(device)
                doc2_vec = torch.tensor(reps[2]).to(device)

                #concatenate query and doc representation
                inputs1 = torch.concat((query_vec, doc1_vec), dim=-1)
                inputs2 = torch.concat((query_vec, doc2_vec), dim=-1)

            # zero the parameter gradients
            optimizer.zero_grad()
            outputs = net(inputs1.float(), inputs2.float())
            loss = criterion(outputs.squeeze(), labels.float().to(device))            
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
                running_loss = 0.0
            
            if scoring:
                torch.save(net.state_dict(), f=f'./models/checkpoints/{representation}_pairwise_scoring_{epoch}.model')
            else:
                torch.save(net.state_dict(), f=f'./models/checkpoints/{representation}_pairwise_{epoch}.model')
        
        if scoring:
            torch.save(net.state_dict(), f=f'./models/{representation}_pairwise_scoring.model')
        else:
            torch.save(net.state_dict(), f=f'./models/{representation}_pairwise.model')


In [None]:
train(REPRESENTATION, SCORING, epochs=2, batch_size=120, bart_emb_type=BART_EMB_TYPE)

## Test

In [14]:
class TestingDataset(Dataset):
    def __init__(self, filename, bows_filename=None, bart=None):
            raw_data = pickle.load(open(filename, "rb" ))
            raw_data = [list(elem) for elem in raw_data] #Cast tuple to list so it can be edited in place

            #Fix empty entries resulting form empty queries/docs
            for i in range(len(raw_data)):
                if type(raw_data[i][2]) is float:
                    raw_data[i][2] = np.zeros(raw_data[1][2].size)                
                if type(raw_data[i][3]) is float:               
                    raw_data[i][3] = np.zeros(raw_data[1][3].size)
                
            if bows_filename:
                raw_bows = pack_csrs(pickle.load(open(bows_filename, "rb")))

            if bows_filename: assert len(raw_data) == len(raw_bows)


            sorted_data = defaultdict(list)
            sorted_bow_data = defaultdict(list)
            for i, item in enumerate(raw_data):
                sorted_data[item[0]].append(item[1:])
                if bows_filename:
                    sorted_bow_data[item[0]].append(raw_bows[i][1:])
                    if not item[0] == raw_bows[i][0]:
                        print(item[0], raw_bows[i][0])
                        assert item[0] == raw_bows[i][0]
            del raw_data
            del raw_bows
            
            if bart is not None:
                tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")
                empty_seq = tokenizer("", truncation=True, padding=True, max_length=512)
                empty_seq = (empty_seq["input_ids"], empty_seq["attention_mask"])



            self.data = []
            self.bows = []
            for key, value in sorted_data.items():
            
                if bart is not None:
                    reference_doc_vec = empty_seq
                else:
                    reference_doc_vec = np.mean(np.vstack([doc[2] for doc in value]), axis=0)
                    #reference_doc_vec = np.zeros(value[0][2].size)
                    assert reference_doc_vec.size == value[0][2].size

                if bows_filename:
                    reference_bow_vec = sparse.vstack([doc[2] for doc in sorted_bow_data[key]]).mean(axis=0).A

                #Some much better average reference
                reference_doc = ["-", value[0][1], reference_doc_vec]

                #Compare rest with reference
                for i, doc in enumerate(value):
                    #Append: [qid, doc1_id, doc2_id, query_vector, doc1_vector, doc2_vector]
                    self.data.append([key, reference_doc[0], doc[0], reference_doc[1], reference_doc[2], doc[2]])
                    if bows_filename:                                    
                        assert doc[0] == sorted_bow_data[key][i][0]
                        self.bows.append([sorted_bow_data[key][i][1].toarray().squeeze(), reference_bow_vec.squeeze(), sorted_bow_data[key][i][2].toarray().squeeze()])
            


    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        ids = self.data[idx][0:3]
        representations = self.data[idx][3:]
        representations = [np.array(item) for item in representations]
        if self.bows:
            representations.extend([item for item in self.bows[idx]])
        return ids, representations

In [15]:
def test(representation, scoring, epochs=2, batch_size=16, bart_emb_type=None):
    """Gets as input the representation name (e.g. "tf_idf")"""
    if scoring:
        test_dataset = TestingDataset(filename = f"preprocessed_data/test_{representation}.pickle",  bows_filename=f"preprocessed_data/test_count_vector_unpacked.pickle", bart=bart_emb_type)
        if bart_emb_type:
            vector_size = 3 #Number of scoring functions is the size of input vector.
        else:
            vector_size = 4        
        model_file=f'./models/{representation}_word_pairwise_scoring.model'

    else:
        test_dataset = TestingDataset(filename = f"preprocessed_data/test_{representation}.pickle",  bows_filename=f"preprocessed_data/test_count_vector_unpacked.pickle", bart=bart_emb_type)        
        if bart_emb_type:
            vector_size = 768 * 2 #Bart Embedding size
        else:
            vector_size = test_dataset.data[0][3].shape[-1] * 2 #Get size of representations. Multiplied by two, bc document and query vector are always concatenated
        model_file=f'./models/{representation}_word_pairwise.model'
    
    if bart_emb_type:
        bart = BartModel.from_pretrained("facebook/bart-base").to(device)
        bart.eval()

    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    net = Net(vector_size, scoring).to(device)
    net.load_state_dict(torch.load(model_file))

    #Save here the (queryID, docID), labels and model_score for evaluation
    query_ids = []
    doc_ids = []
    scores = []

    with torch.no_grad():
        net.eval()
        for i, data in enumerate(tqdm(test_dataloader)):
            ids, reps = data

            if bart_emb_type:
                reps[0] = get_bart_embeddings(reps[0], bart_emb_type, bart)
                reps[1] = get_bart_embeddings(reps[1], bart_emb_type, bart)
                reps[2] = get_bart_embeddings(reps[2], bart_emb_type, bart)

            if scoring:
                cosine1 = compute_cosine_similarity(reps[0], reps[1]).unsqueeze(-1).to(device)
                cosine2 = compute_cosine_similarity(reps[0], reps[2]).unsqueeze(-1).to(device)
                jacc1 = compute_jaccard_similarity(reps[3], reps[4]).unsqueeze(-1).to(device)
                jacc2 = compute_jaccard_similarity(reps[3], reps[5]).unsqueeze(-1).to(device)
                bm25_1 = compute_bm25(reps[3], reps[4]).to(device)
                bm25_2 = compute_bm25(reps[3], reps[5]).to(device)
               
                if not bart_emb_type:
                    log_prob1 = get_batch_LR_proba(reps[0], reps[1], LOGREG).unsqueeze(-1).to(device)
                    log_prob2 = get_batch_LR_proba(reps[0], reps[2], LOGREG).unsqueeze(-1).to(device)
                    inputs1 = torch.concat((cosine1, jacc1, bm25_1, log_prob1), dim=1).to(device)
                    inputs2 = torch.concat((cosine2, jacc2, bm25_2, log_prob2), dim=1).to(device)
                else: 
                    inputs1 = torch.concat((cosine1, jacc1, bm25_1), dim=1).to(device)
                    inputs2 = torch.concat((cosine2, jacc2, bm25_2), dim=1).to(device)
            
            else:
                query_vec = torch.tensor(reps[0]).to(device)
                doc1_vec = torch.tensor(reps[1]).to(device)
                doc2_vec = torch.tensor(reps[2]).to(device)

                #concatenate query and doc representation
                inputs1 = torch.concat((query_vec, doc1_vec), dim=-1)
                inputs2 = torch.concat((query_vec, doc2_vec), dim=-1)


            outputs = net(inputs1.float(), inputs2.float())
            outputs = outputs.cpu().numpy()
            
            scores.extend(outputs.squeeze().tolist()) #Add score of doc2 being more relevant than doc
            query_ids.extend(list(ids[:][0]))
            doc_ids.extend(list(ids[:][2])) #Add doc2 ID

    test_outputs = defaultdict(list)
    for i in range(len(query_ids)):
        test_outputs[query_ids[i]].append((doc_ids[i], scores[i]))

    if scoring:
        filename = f'model_predictions/{representation}_pairwise_scoring_preds.pickle'
    else:
        filename = f'model_predictions/{representation}_pairwise_preds.pickle'
    
    with open(filename, 'wb') as handle:
        pickle.dump(test_outputs, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
test(REPRESENTATION, SCORING, epochs=2, batch_size=120, bart_emb_type=BART_EMB_TYPE)

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/532M [00:00<?, ?B/s]

  3%|▎         | 8/271 [00:37<20:29,  4.68s/it]