# SET MODE:

In [None]:
SCORING = False                     #Choose True or False
REPRESENTATION = "tf_idf"           #Choose from ["tf_idf", "non_cont_word_emb", "bart_tokenized"]
BART_EMB_TYPE = None                #Select only if REPRESENTATION = "bart_tokenized"! Choose "word" for average word embedding or "doc" for <EOS> Embedding

# Imports

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle5 as pickle
from collections import defaultdict
from torch.optim import AdamW
import itertools
import numpy as np
from tqdm import tqdm
from transformers import BartModel
import random
from scipy import sparse

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## BART Helper

In [6]:
def get_bart_embeddings(batch, embedding_type, bart):
    with torch.no_grad():
        input_ids = batch[:,0].to(device)
        attention_mask = batch[:,1].to(device)
        
        doc_embeds = bart(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        if embedding_type == "doc":
            eos_positions = torch.sum(attention_mask, dim=-1).unsqueeze(-1)
            eos_positions = eos_positions - 1
            dummy = eos_positions.unsqueeze(2).expand(eos_positions.size(0), eos_positions.size(1), doc_embeds.size(2))
            eos_embeds= torch.gather(doc_embeds, 1, dummy)
            return eos_embeds.squeeze(1)
        
        elif embedding_type == "word":
            return torch.mean(doc_embeds, dim=1)

## Score Helpers

In [7]:
from scipy.sparse import csr_matrix
def pack_csrs(unpacked_data):
    data = []
    for dp in unpacked_data:
        data.append([
            dp[0], 
            dp[1], 
            csr_matrix(*dp[2]),
            csr_matrix(*dp[3]),
            dp[-1]])
    return data

def unpack_csrs(data):
    unpacked_data = []
    for dp in data:
        unpacked_data.append([
            dp[0], 
            dp[1], 
            ((dp[2].data, dp[2].indices, dp[2].indptr), dp[2].shape),
            ((dp[3].data, dp[3].indices, dp[3].indptr), dp[3].shape),
            dp[-1]])
    return unpacked_data

In [8]:
from operator import pos
def get_avg_doc_len(bows):
    bows = sparse.vstack([bow[3] for bow in bows])
    counts = bows.sum(axis=-1)
    avg = counts.mean()
    return avg

def get_bim_weights(bows):
    N = len(set([bow[1] for bow in bows]))
    pos_docs = sparse.vstack([bow[3] for bow in bows if bool(bow[-1]) is True])
    pos_counts = pos_docs.sum(axis=0).getA().squeeze()
    bim_weights = np.log(((N - pos_counts + 0.5)/(pos_counts + 0.5)) + 1)
    return bim_weights


In [9]:
#Load Bag of Word representations and calculate weights for BM25
bows = pack_csrs(pickle.load(open("data/train_count_vector_unpacked.pickle", "rb")))
AVG_DOC_LEN = get_avg_doc_len(bows)
BIM_WEIGHTS = get_bim_weights(bows)
del bows

#Load pretrained Logistic Regression model
if REPRESENTATION == "tf_idf":
    LOGREG = pickle.load(open("./models/LR_tfidf_fit.pickle", "rb" ))
else:
    LOGREG = pickle.load(open("./models/LR_emb_fit.pickle", "rb" ))


## Scoring Functions

In [11]:
def compute_bm25(query_vec, doc_vec, k=1.5, b=0.75):
    non_zero = [ np.nonzero(t)[0] if np.nonzero(t).size()[0] > 0 else -1 for t in query_vec ] #indices of words occuring in the query
    relevances = []
    for i, q in enumerate(non_zero):
        if q == -1:
            relevances.append(0)
            continue
        doc = doc_vec[i]
        counts = doc[q]
        weights = BIM_WEIGHTS[q]
        doc_len = doc.sum()
        frac = (counts * (k+1))/(counts + k*(doc_len/AVG_DOC_LEN)*b + k*(1-b))
        relevances.append(torch.sum(frac*weights, -1))
    return torch.tensor(np.vstack(relevances))
    
def get_batch_LR_proba(query_vecs, doc_vecs, logreg):
    '''Input : 
    query_vecs, doc_vecs : tfidf vectors of query and doc (2D array)
    logreg : fitted logistic regression

    Output : array of probabilites returned by LR'''
    
    if query_vecs.size() != doc_vecs.size():
        raise ValueError('Arrays are not of the same size')
    X = torch.concat((query_vecs, doc_vecs), dim=1)
    y_scores = logreg.predict_proba(X)
    LR_results = y_scores[:,0]
    return torch.tensor(LR_results)

def compute_cosine_similarity(query_vec, doc_vec):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    return cos(query_vec, doc_vec)


def compute_jaccard_similarity(query_vec, doc_vec):
    query_vec = torch.where(query_vec > 0, 1.0, 0.0)
    doc_vec = torch.where(doc_vec > 0, 1.0, 0.0)

    intersect = query_vec * doc_vec
    union = torch.clamp(query_vec + doc_vec, 0.0, 1.0)    
    result = torch.sum(intersect, dim=1) / torch.sum(union, dim=1)
    return torch.nan_to_num(result, nan=0.0) #Fix Nulldivision

## NN Model

In [13]:
class Net(nn.Module):
    def __init__(self, input_size, scoring):
        super(Net, self).__init__()
        if scoring:
            self.fc1 = nn.Linear(input_size, 16) 
            self.fc2 = nn.Linear(16, 8)
            self.fc3 = nn.Linear(8, 8)
            self.fc4 = nn.Linear(8, 1)
        else:
            self.fc1 = nn.Linear(input_size, 256) 
            self.fc2 = nn.Linear(256, 64)
            self.fc3 = nn.Linear(64, 32)
            self.fc4 = nn.Linear(32, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        
        return x

In [18]:
def visualize_net():
    """get pretty picture of gradient flow"""
    from torchviz import make_dot
    x=torch.ones(10, requires_grad=True)
    net = Net(10)
    pred = net(x)
    make_dot(pred, params=dict(list(net.named_parameters()))).render("nn", format="png")

## Load the data: Representations

In [15]:
class PointwiseDataset(Dataset):
    #for the clear representation based 
    def __init__(self, filename, bows_filename=None):
        self.data = pickle.load(open(filename, "rb" ))
        self.data = [list(elem) for elem in self.data]

        #Fix empty entries resulting form empty queries/docs
        for i in range(len(self.data)):
            if type(self.data[i][2]) is float:               
                self.data[i][2] = np.zeros(self.data[0][2].size)                
            if type(self.data[i][3]) is float:               
                self.data[i][3] = np.zeros(self.data[0][3].size)

        #Try with less data for bart:
        #random.shuffle(self.data)
        #self.data = self.data[:int(0.65 * len(self.data))]

        self.bows = None
        if bows_filename:
            self.bows = pack_csrs(pickle.load(open(bows_filename, "rb" )))



    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        ids = self.data[idx][0:2]
        label = self.data[idx][-1]
        representations = self.data[idx][2:4]
        representations = [np.array(item) for item in representations]

        if self.bows:            
            #Add BOW for Jaccard and BM25:
            representations.extend([self.bows[idx][2].toarray().squeeze(), self.bows[idx][3].toarray().squeeze()])

        return ids, representations, label

## Training

In [16]:
def train(representation, scoring, epochs=2, batch_size=16, bart_emb_type=None):
    """Gets as input the representation name (e.g. "tf_idf")"""
    if scoring:
        train_dataset = PointwiseDataset(filename = f"data/train_{representation}.pickle", bows_filename=f"data/train_count_vector_unpacked.pickle")
        dev_dataset = PointwiseDataset(f"data/dev_{representation}.pickle", bows_filename=f"data/dev_count_vector_unpacked.pickle")
        if bart_emb_type:
            vector_size = 3 
        else:
            vector_size = 4

    else:
        train_dataset = PointwiseDataset(filename = f"data/train_{representation}.pickle")
        dev_dataset = PointwiseDataset(f"data/dev_{representation}.pickle")
        if bart_emb_type:
            vector_size = 768 * 2 #Bart Embedding size (times two because of concatenation)
        else:
            vector_size = train_dataset.data[0][3].shape[-1] * 2 #Get size of representations#
        

    if bart_emb_type:
        bart = BartModel.from_pretrained("facebook/bart-base").to(device)
        bart.eval()

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size)
    
    net = Net(vector_size, scoring).to(device)

    criterion = nn.BCEWithLogitsLoss() #Binary Cross Entropy Loss
    optimizer = AdamW(net.parameters())

    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        for i, data in enumerate(tqdm(train_dataloader)):
            ids, reps, labels = data

            if bart_emb_type:
                reps[0] = get_bart_embeddings(reps[0], bart_emb_type, bart)
                reps[1] = get_bart_embeddings(reps[1], bart_emb_type, bart)

            if scoring:
                cosine = compute_cosine_similarity(reps[0], reps[1]).unsqueeze(-1).to(device)  
                jacc = compute_jaccard_similarity(reps[2], reps[3]).unsqueeze(-1).to(device)  
                bm25 = compute_bm25(reps[2], reps[3]).to(device)
                if not bart_emb_type:
                    log_prob = get_batch_LR_proba(reps[0], reps[1], LOGREG).unsqueeze(-1)  
                    inputs = torch.concat((cosine, jacc, bm25, log_prob), dim=1).to(device)
                else:
                    inputs = torch.concat((cosine, jacc, bm25), dim=1).to(device)

            else:
                if bart_emb_type:
                    repr1 = reps[0] #already tensor
                    repr2 = reps[1]
                else:
                    repr1 = torch.tensor(reps[0]).to(device)
                    repr2 = torch.tensor(reps[1]).to(device)
                inputs = torch.concat((repr1, repr2), dim=1)

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = net(inputs.float())
            loss = criterion(outputs, labels.float().to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 1 and i > 100:
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0
        
        if scoring:
            torch.save(net.state_dict(), f=f'./models/checkpoints/{representation}_word_pointwise_scoring_{epoch}.model')
        else:
            torch.save(net.state_dict(), f=f'./models/checkpoints/{representation}_pointwise_{epoch}.model')
        
        #Evaluate  
        with torch.no_grad():
            net.eval()
            eval_loss = 0.0
            for i, data in enumerate(dev_dataloader):
                ids, reps, labels = data
                if scoring:
                    cosine = compute_cosine_similarity(reps[0], reps[1]).unsqueeze(-1)  
                    jacc = compute_jaccard_similarity(reps[2], reps[3]).unsqueeze(-1)  
                    bm25 = compute_bm25(reps[2], reps[3])
                    if not bart_emb_type:
                        log_prob = get_batch_LR_proba(reps[0], reps[1], LOGREG).unsqueeze(-1)  
                        inputs = torch.concat((cosine, jacc, bm25, log_prob), dim=1).to(device)
                    else:
                        inputs = torch.concat((cosine, jacc, bm25), dim=1).to(device)
                else:
                    repr1 = torch.tensor(reps[0]).to(device)
                    repr2 = torch.tensor(reps[1]).to(device)
                    inputs = torch.concat((repr1, repr2), dim=1)              

                outputs = net(inputs.float())
                loss = criterion(outputs, labels.float().to(device))
                eval_loss += loss.item()
            
            print(f'Epoch {epoch + 1} dev loss: {eval_loss / len(dev_dataloader)}')            
        
    if scoring:
        torch.save(net.state_dict(), f=f'./models/{representation}_pointwise_scoring.model')
    else:
        torch.save(net.state_dict(), f=f'./models/{representation}_pointwise.model')
   


In [None]:
train(REPRESENTATION, SCORING, epochs=2, batch_size=120, bart_emb_type=BART_EMB_TYPE)

## Test

In [None]:
def test(representation, scoring, batch_size, bart_emb_type=None):
    """Gets as input the representation name (e.g. "tf_idf")"""
    
    
    if scoring:
        test_dataset = PointwiseDataset(filename = f"./data/test_{representation}.pickle", bows_filename=f"./data/test_count_vector_unpacked.pickle")
        if bart_emb_type:
            vector_size = 3 
        else:
            vector_size = 4
        model_file=f'./models/{representation}_pointwise_scoring.model'

    else:
        test_dataset = PointwiseDataset(filename = f"./data/test_{representation}.pickle")
        if bart_emb_type:
            vector_size = 768 * 2 #Chosen Bart Embedding size        
        else:
            vector_size = test_dataset.data[0][3].shape[-1]  * 2  #Get size of representations
        model_file=f'./models/checkpoints/{representation}_pointwise.model'


    if bart_emb_type:
        bart = BartModel.from_pretrained("facebook/bart-base").to(device)

    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    net = Net(vector_size, scoring).to(device)
    net.load_state_dict(torch.load(model_file))

    #Save here the (queryID, docID), labels and model_score for evaluation
    query_ids = []
    doc_ids = []
    scores = []

    with torch.no_grad():
        net.eval()
        eval_loss = 0.0
        for i, data in enumerate(tqdm(test_dataloader)):
            ids, reps, labels = data

            if bart_emb_type:
                reps[0] = get_bart_embeddings(reps[0], bart_emb_type, bart)
                reps[1] = get_bart_embeddings(reps[1], bart_emb_type, bart)

            if scoring:
                cosine = compute_cosine_similarity(reps[0], reps[1]).unsqueeze(-1).to(device)  
                jacc = compute_jaccard_similarity(reps[2], reps[3]).unsqueeze(-1).to(device)  
                bm25 = compute_bm25(reps[2], reps[3]).to(device)
                if not bart_emb_type:
                    log_prob = get_batch_LR_proba(reps[0], reps[1], LOGREG).unsqueeze(-1).to(device)  
                    inputs = torch.concat((cosine, jacc, bm25, log_prob), dim=1).to(device)
                else:
                    inputs = torch.concat((cosine, jacc, bm25), dim=1).to(device)
            else:
                if bart_emb_type:
                    repr1 = reps[0] #already tensor
                    repr2 = reps[1]
                else:
                    repr1 = torch.tensor(reps[0]).to(device)
                    repr2 = torch.tensor(reps[1]).to(device)
                inputs = torch.concat((repr1, repr2), dim=1)            

            outputs = net(inputs.float())
            outputs = outputs.cpu().numpy()
            
            scores.extend(outputs.squeeze().tolist())
            query_ids.extend(list(ids[:][0]))
            doc_ids.extend(list(ids[:][1]))

    test_outputs = defaultdict(list)
    for i in range(len(query_ids)):
        test_outputs[query_ids[i]].append((doc_ids[i], scores[i]))

    if scoring:
        filename = f'model_predictions/{representation}_pointwise_scoring_preds.pickle'
    else:
        filename = f'model_predictions/{representation}_pointwise_preds.pickle'
    with open(filename, 'wb') as handle:
        pickle.dump(test_outputs, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
test(REPRESENTATION, SCORING, epochs=2, batch_size=120, bart_emb_type=BART_EMB_TYPE)