In [1]:
!pip install sentence_transformers



In [2]:
from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, losses, util, models, evaluation, LoggingHandler
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
import re
from sklearn.model_selection import train_test_split
import logging
import numpy as np

In [3]:
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

In [4]:
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor(0.8943)


<h2>Step 1</h2>  
Use SentenceBERT fine tuned on Quora Duplicate Qusetions. Report average P@1 and average MRR.

In [5]:
import csv
from post_parser_record import PostParserRecord

def read_tsv_test_data(file_path):
  # Takes in the file path for test file and generate a dictionary
  # of question id as the key and the list of question ids similar to it
  # as value. It also returns the list of all question ids that have
  # at least one similar question
    dic_similar_questions = {}
    lst_all_test = []
    with open(file_path) as fd:
        rd = csv.reader(fd, delimiter="\t", quotechar='"')
        for row in rd:
            question_id = int(row[0])
            lst_similar = list(map(int, row[1:]))
            dic_similar_questions[question_id] = lst_similar
            lst_all_test.append(question_id)
            lst_all_test.extend(lst_similar)
    return dic_similar_questions, lst_all_test

dic_similar_questions, lst_all_test = read_tsv_test_data("duplicate_questions.tsv")
post_reader = PostParserRecord("Posts_law.xml")

In [7]:
from sentence_transformers import SentenceTransformer, util
import torch

# in question one, we are using the pre-trained model on quora with no further fine-tuning
model_name = 'distilbert-base-nli-stsb-quora-ranking'
model = SentenceTransformer(model_name)

# list of text to be indexed (encoded)
corpus = []
# this dictionary is used as key: corpus index [0, 1, 2, ...] and value: corresponding question id
index_to_question_id = {}
idx = 0

# indexing all the questions in the law stack exchange -- only using the question titles
for question_id in post_reader.map_questions:
    question = post_reader.map_questions[question_id]
    text = question.title
    q_id = question.post_id
    corpus.append(text)
    index_to_question_id[idx] = question_id
    idx += 1

# Indexing (embedding) each question text
corpus_embeddings = model.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

2023-04-22 14:58:58 - Load pretrained SentenceTransformer: distilbert-base-nli-stsb-quora-ranking
2023-04-22 14:59:00 - Use pytorch device: cpu


Batches:   0%|          | 0/756 [00:00<?, ?it/s]

In [9]:
lst_test_question_ids = list(dic_similar_questions.keys())
top_k = 100

correct_count = 0
total = sum([len(pair_list) for pair_list in dic_similar_questions.values()])
rr_list = []

# for each question in train
for question_id in lst_test_question_ids:
    
    question_rr = 0
    #get embedding
    query_text = post_reader.map_questions[question_id].title
    query_embedding = model.encode(query_text, convert_to_tensor=True)

    # We use cosine-similarity and torch.topk to find the highest 100 scores
    # get similarity scores
    cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0] 
    
    # descending order: tensor of scores and tensor of corresponding indices
    top_results = torch.topk(cos_scores, k=top_k)
 
    # check results and calculate P@1
    # if top results match the ground truth update count
    similars_list = dic_similar_questions.get(question_id)

    # check for similar predictions
    for i in range(0,len(similars_list)):
        
        # accumulate correct preds
        # check is top index in list of similars
        q_index = index_to_question_id.get(top_results[1][i])
        if q_index in similars_list:
            correct_count += 1            
            
        # similars_list[i] is true index
        # we need to convert it to the model index
        similar_index = list(index_to_question_id.keys())[list(index_to_question_id.values()).index(similars_list[i])]        
        
        # collect mean reciprocal ranks 
        if similar_index in top_results[1] and question_rr <= 0:
            
            index = (top_results[1] == similar_index).nonzero(as_tuple=False)[0][0].item()
            question_rr = index/100

    rr_list.append(question_rr)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

<h3>Average P@1</h3>

In [10]:
ave_p1 = correct_count / total
print("Average P@1: ", ave_p1)

Average P@1:  0.0


<h3>Average MRR</h3>

In [11]:
sum(rr_list) / len(rr_list)

0.09829787234042558

<h2>Step 2</h2>  
Using a similar split as the previous assignment (10% for testing and the remaining
90% for fine-tuning), fine-tune a Sentence-BERT model with training data on Law Stack
Exchange. Report average P@1 and average MRR.

<h3>Preprocessing</h3>  
To prepare data for training, first we will store text for every question id into a helper dictionary. Then we generate dataframe of corresponding positive and negative pairs of sentences. 

In [114]:
dict_text = {}

# get test for each question
for question_id in post_reader.map_questions:
    
    question = post_reader.map_questions[question_id]
    title = question.title
    body = question.body
    sentence = title + body
    sentence = re.sub(r"(?s)<.*?>", "", sentence)
    dict_text[question_id] = sentence

In [115]:
pairs = []
negative_pairs = []
i = 0

# get positive pairs
for key in dic_similar_questions:
    
    # positive
    for value in dic_similar_questions[key]:
        pairs.append((key,value))
    
    # negative
    # if test id is not in similars for this key
    if lst_all_test[i] not in dic_similar_questions[key] and lst_all_test[i] != key:
        # create pair
        negative_pairs.append((lst_all_test[i],key))
        if i == (len(lst_all_test)-1):
            i = 0
        else:
            i+=1
            
    

In [116]:
print(len(negative_pairs))
print(len(pairs))

280
289


In [117]:
# function to load the dataframe
#
# embedding_dict - a dict to be loaded
# df - target dataframe
# pairs - list of tupples pairs of embeddings
# label - y
def load_df(text_dict, df, pairs, label):
    for index, (a,b) in enumerate(pairs):
    
        pairs = [a,b]
        first = text_dict[a]
        second = text_dict[b]
        df.loc[-1] = [ pairs, first, second, label]  # adding a row
        df.index = df.index + 1 

In [118]:
# construct pandas dataframe with pairs of positive and negative samples
df_positive = pd.DataFrame(columns = ['index pairs', 'a', 'b', 'y'])
df_negative = pd.DataFrame(columns = ['index pairs', 'a', 'b', 'y'])

load_df(dict_text, df_positive, pairs, 1)
load_df(dict_text, df_negative, negative_pairs, 0)
print(df_positive.shape)
print(df_negative.shape)

(289, 4)
(280, 4)


In [119]:
# Append dataframes
df = pd.concat([df_positive, df_negative],                     
                      ignore_index = True,
                      sort = False)
# Shuffle rows
df = df.sample(frac=1).reset_index(drop=True)

df.head()

Unnamed: 0,index pairs,a,b,y
0,"[17959, 9412]",Does the digitalisation (pdf) of a work alread...,Is OCRed text automatically copyright?If someo...,1
1,"[79197, 26950]",Can I sue Dominos for putting meat on my veg p...,Can I sue a restaurant for serving me meat in ...,1
2,"[42607, 5081]",EULA acceptanceIf I have a EULA radio button o...,"How to Prove, Legally, That a User Actually Cl...",1
3,"[25838, 15055]","Package recieved after full refund, illegal to...","After purchasing an item, then being issued a ...",1
4,"[23717, 74218]","In a marriage, can one spouse just run away an...",Can I display university logos on my website u...,0


In [120]:
# Split data
# 90 - 10 
def train_test_split_2(df, frac=0.1):
    
    # get random sample 
    test = df.sample(frac=frac, axis=0)

    # get everything but the test sample
    train = df.drop(index=test.index)

    return train, test

In [121]:
train, test = train_test_split_2(df)
train, dev = train_test_split_2(train)

In [122]:
print(train.shape)
print(dev.shape)
print(test.shape)

(461, 4)
(51, 4)
(57, 4)


In [123]:
from sentence_transformers.readers import InputExample
from sentence_transformers.cross_encoder import CrossEncoder
from torch.utils.data import DataLoader
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
import math
from datetime import datetime

In [124]:
train_samples = []
dev_samples = []

for index, row in train.iterrows():
    # append sentences and label
    train_samples.append(InputExample(texts=[row['a'], row['b']], label=int(row['y'])))
    
for index, row in dev.iterrows():
    # append sentences and label
    dev_samples.append(InputExample(texts=[row['a'], row['b']], label=int(row['y'])))

In [125]:
#Configuration
train_batch_size = 16
num_epochs = 4
model_save_path = 'output-'+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [126]:
# We use distilroberta-base with a single label, i.e., it
# will output a value between 0 and 1 indicating the similarity of the two questions

model = CrossEncoder('distilroberta-base', num_labels=1)

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.out_proj.weig

2023-04-23 01:13:27 - Use pytorch device: cpu


In [127]:
# We wrap train_samples (which is a List[InputExample]) into a pytorch DataLoader
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

In [128]:
# We add an evaluator, which evaluates the performance during training
evaluator = CEBinaryClassificationEvaluator.from_input_examples(dev_samples, name='Law-dev')

In [129]:
# Configure the training
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #dev data for warm-up
logger.info("Warmup-steps: {}".format(warmup_steps))

2023-04-23 01:13:27 - Warmup-steps: 12


In [130]:
# Train the model
model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=5000,
          warmup_steps=warmup_steps,
          output_path=model_save_path)

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Iteration:   0%|          | 0/29 [00:00<?, ?it/s]

2023-04-23 01:24:49 - CEBinaryClassificationEvaluator: Evaluating the model on Law-dev dataset after epoch 0:
2023-04-23 01:25:19 - Accuracy:           60.78	(Threshold: 0.4908)
2023-04-23 01:25:19 - F1:                 71.43	(Threshold: 0.4859)
2023-04-23 01:25:19 - Precision:          58.14
2023-04-23 01:25:19 - Recall:             92.59
2023-04-23 01:25:19 - Average Precision:  62.06

2023-04-23 01:25:19 - Save model to output-2023-04-23_01-13-24


Iteration:   0%|          | 0/29 [00:00<?, ?it/s]

2023-04-23 01:36:01 - CEBinaryClassificationEvaluator: Evaluating the model on Law-dev dataset after epoch 1:
2023-04-23 01:36:31 - Accuracy:           60.78	(Threshold: 0.4754)
2023-04-23 01:36:31 - F1:                 72.97	(Threshold: 0.4754)
2023-04-23 01:36:31 - Precision:          57.45
2023-04-23 01:36:31 - Recall:             100.00
2023-04-23 01:36:31 - Average Precision:  55.13



Iteration:   0%|          | 0/29 [00:00<?, ?it/s]

2023-04-23 01:47:22 - CEBinaryClassificationEvaluator: Evaluating the model on Law-dev dataset after epoch 2:
2023-04-23 01:47:52 - Accuracy:           64.71	(Threshold: 0.4221)
2023-04-23 01:47:52 - F1:                 74.29	(Threshold: 0.4221)
2023-04-23 01:47:52 - Precision:          60.47
2023-04-23 01:47:52 - Recall:             96.30
2023-04-23 01:47:52 - Average Precision:  68.30

2023-04-23 01:47:52 - Save model to output-2023-04-23_01-13-24


Iteration:   0%|          | 0/29 [00:00<?, ?it/s]

2023-04-23 01:58:40 - CEBinaryClassificationEvaluator: Evaluating the model on Law-dev dataset after epoch 3:
2023-04-23 01:59:11 - Accuracy:           66.67	(Threshold: 0.3912)
2023-04-23 01:59:11 - F1:                 73.97	(Threshold: 0.2761)
2023-04-23 01:59:11 - Precision:          58.70
2023-04-23 01:59:11 - Recall:             100.00
2023-04-23 01:59:11 - Average Precision:  72.06

2023-04-23 01:59:11 - Save model to output-2023-04-23_01-13-24


In [132]:
# for prediction model takes the list of sentence pairs.
test_sentences = []
test_pair_indices = {}

i = 0
for index, row in test.iterrows():
    # append sentence pair
    test_sentences.append([row['a'], row['b']])
    test_pair_indices[i] = row['index pairs']
    i += 1

<h3>Predictions and Evaluation</h3>

In [133]:
preds = model.predict(sentences = test_sentences,
              show_progress_bar = True,
              convert_to_numpy = True
              )

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

<h3>Average MRR And P@1</h3>

In [134]:
# dict of dicts
# {question_id : {similar_id : score, similar_id : score} }
dic_scores_for_qs = {}
# get set of indices that have similars
# we will evaluate on them
for i in test_pair_indices:
    
    #get pair indices and their cosine similarity
    
    test_pair = test_pair_indices.get(i)
    similarity_score = preds[i] 
    
    
    # for eachpair member store scores in dict
    if test_pair[0] in dic_similar_questions:
        if test_pair[0] not in dic_scores_for_qs:        
            dic_scores_for_qs[test_pair[0]] = {test_pair[1]:similarity_score}
        else:
            dic_scores_for_qs[test_pair[0]][test_pair[1]] = similarity_score

    if test_pair[1] in dic_similar_questions:
        if test_pair[1] not in dic_scores_for_qs:        
            dic_scores_for_qs[test_pair[1]] = {test_pair[0]:similarity_score}
        else:
            dic_scores_for_qs[test_pair[1]][test_pair[0]] = similarity_score

In [135]:
# sort scores for similar questions in descending order
for question_id in dic_scores_for_qs:
    dic_scores_for_qs[question_id] = dict(sorted(dic_scores_for_qs[question_id].items(), key=lambda item: item[1], reverse = True))
 
rrs = []
correct_guess = 0
total = 0
# get ranks for each question
for question_id in dic_scores_for_qs:
    question_rr = 0
    
    dict_similars = dic_scores_for_qs[question_id]
    list_true_similars = dic_similar_questions[question_id]
    total += len(dict_similars)
    
    # check if any similar questions are in result
    for similar in list_true_similars:
        
        i = 0
        for similar_test in dict_similars:
            
            if similar_test == similar:
                correct_guess += 1
                if question_rr <= 0:
            
                    index = i
                    question_rr = index/100
                    i += 1
            
    
    

In [136]:
print('Average P@1')
correct_guess/total

Average P@1


0.4383561643835616

In [137]:
print('Average MRR')
sum(rr_list) / len(rr_list)

Average MRR


0.02726950354609928

<h2>Step 3</h2>  
Applying the first step on 10% dataset from step 2 and determining statistical difference between results.

<h4>Data Preprocessing</h4>

In [138]:
# we only need 10% for training
# get sample 
df2 = df.sample(frac=0.1, axis=0)
df2.shape

(57, 4)

In [139]:
train_ids = []
for index, row in df2.iterrows():
    
    train_ids.append(row['index pairs'][0])
    train_ids.append(row['index pairs'][1])
    
# eliminate duplicate ids for encoding 10% data
train_ids = set(train_ids)

<h3>Train Model</h3>

In [140]:
# in question one, we are using the pre-trained model on quora with no further fine-tuning
model_name = 'distilbert-base-nli-stsb-quora-ranking'
model = SentenceTransformer(model_name)

# list of text to be indexed (encoded)
corpus = []
# # this dictionary is used as key: corpus index [0, 1, 2, ...] and value: corresponding question id
index_to_question_id = {}
idx = 0

# indexing all the questions in the law stack exchange -- only using the question titles
for question_id in train_ids:
    question = post_reader.map_questions[question_id]
    text = question.title
    q_id = question.post_id
    corpus.append(text)
    index_to_question_id[idx] = question_id
    idx += 1

# Indexing (embedding) each question text
corpus_embeddings = model.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

2023-04-23 02:21:22 - Load pretrained SentenceTransformer: distilbert-base-nli-stsb-quora-ranking
2023-04-23 02:21:22 - Use pytorch device: cpu


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

In [141]:
top_k_ids = index_to_question_id.values()

In [142]:
lst_test_question_ids = list(dic_similar_questions.keys())
# we get only as many similars as there are in the training corpus
top_k = len(top_k_ids)

correct_count = 0
total = 0
rr_list = []

# for each question in test
for question_id in lst_test_question_ids:
    
    question_rr = 0
    #get embedding
    query_text = post_reader.map_questions[question_id].title
    query_embedding = model.encode(query_text, convert_to_tensor=True)

    # We use cosine-similarity and torch.topk to find the highest k scores
    # get similarity scores
    cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0] 
    
    # descending order: tensor of scores and tensor of corresponding indices
    top_results = torch.topk(cos_scores, k=top_k)
 
    # check result
    # if top results match the ground truth update count
    similars_list = dic_similar_questions.get(question_id)

    # check for similar predictions
    for i in range(0,len(similars_list)):
        
        # accumulate correct preds
        # check is top index in list of similars
        q_index = index_to_question_id.get(top_results[1][i])
        if q_index in similars_list:
            correct_count += 1            
            
        # similars_list[i] is true index
        # we need to convert it to the model index
        # but also in this case we must keep in mind that model does not 
        # provide k for indices that aren't in 10% of training corpus
        if similars_list[i] in top_k_ids:
            similar_index = list(index_to_question_id.keys())[list(index_to_question_id.values()).index(similars_list[i])]        
        
            # collect mean reciprocal ranks 
            if similar_index in top_results[1] and question_rr <= 0:
            
                index = (top_results[1] == similar_index).nonzero(as_tuple=False)[0][0].item()
                question_rr = index/100
                
            total+=1
    rr_list.append(question_rr)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

<h3>Average P@1</h3>

In [143]:
ave_p1 = correct_count / total
print("Average P@1: ", ave_p1)

Average P@1:  0.0


<h3>Average MRR</h3>

In [144]:
sum(rr_list) / len(rr_list)

0.020177304964538997

<h1>Extra Credit Attempt</h1>

In [61]:
# list of text to be indexed (encoded)
corpus = []
# this dictionary is used as key: corpus index [0, 1, 2, ...] and value: corresponding question id
index_to_question_id = {}
idx = 0

# indexing all the questions in the law stack exchange -- only using the question titles
for question_id in post_reader.map_questions:
    question = post_reader.map_questions[question_id]
    text = question.title
    q_id = question.post_id
    corpus.append(text)
    index_to_question_id[idx] = question_id
    idx += 1

In [55]:
#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")

#Tokenize sentences
encoded_input = tokenizer(corpus, padding=True, truncation=True, max_length=128, return_tensors='pt')

#Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)

Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [83]:
# Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
#cos = torch.nn.CosineSimilarity(dim=0)

In [80]:
top_k = 100
correct_count = 0
total = sum([len(pair_list) for pair_list in dic_similar_questions.values()])
rr_list = []


# for each question in the test set
for question_id in dic_similar_questions.keys():
    
    #convert question_id to model index
    q_id = list(index_to_question_id.keys())[list(index_to_question_id.values()).index(question_id)] 
    
    # get similarity scores
    # We use cosine-similarity and torch.topk to find the highest 100 scores
    # get similarity scores
    cos_scores = util.cos_sim(sentence_embeddings[q_id], sentence_embeddings)[0] 
    
    # descending order: tensor of scores and tensor of corresponding indices
    top_results = torch.topk(cos_scores, k=top_k)
    
    
    # check results and calculate P@1
    # if top results match the ground truth update count
    similars_list = dic_similar_questions.get(question_id)
    
    # check for similar predictions
    for i in range(0,len(similars_list)):
        
        # accumulate correct preds
        # check is top index in list of similars
        q_index = index_to_question_id.get(top_results[1][i])
        if q_index in similars_list:
            correct_count += 1            
            
        # similars_list[i] is true index
        # we need to convert it to the model index
        similar_index = list(index_to_question_id.keys())[list(index_to_question_id.values()).index(similars_list[i])]        
        
        # collect mean reciprocal ranks 
        if similar_index in top_results[1] and question_rr <= 0:
            
            index = (top_results[1] == similar_index).nonzero(as_tuple=False)[0][0].item()
            question_rr = index/100
    
    rr_list.append(question_rr)

<h3>Average P@1</h3>

In [81]:
ave_p1 = correct_count / total
print("Average P@1: ", ave_p1)

Average P@1:  0.0


<h3>Average MRR</h3>

In [82]:
sum(rr_list) / len(rr_list)

0.00989361702127654