# Compute Bert Embeddings
There will be a lot of BERT embeddings to be computed. It's easier to pre-compute all of them, and them use what we ACTUALLY need.

## Embeddings for Wikipedia pages
Wikipedia pages can be too long. So we need a strategy for pooling such long documents.
we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.

In [2]:
import itertools
import pickle
import urllib.parse
from collections import defaultdict

import numpy as np
import torch
from nltk.tokenize import sent_tokenize, word_tokenize
from scipy.spatial.distance import cdist
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer

In [3]:
# Prepare model. May take a while without a GPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
# base_model = "all-MiniLM-L6-v2"
base_model = "msmarco-MiniLM-L6-cos-v5"
# base_model = "all-mpnet-base-v2"

bert_model = SentenceTransformer(base_model, device=device)

In [5]:
wikipedia_sum = {}
wikipedia_mean = {}
wikipedia_all = {}
wikipedia_trunc = {}

wikipedia_sum_norm = {}
wikipedia_mean_norm = {}
wikipedia_all_norm = {}
wikipedia_trunc_norm = {}

for l in open("../data/wikipedia_texts.tsv"):
    topic_title, text = l.strip().split("\t", maxsplit=1)
    sentences = sent_tokenize(text)
    embeddings = bert_model.encode(sentences, normalize_embeddings=False)
    embeddings_norm = normalize(embeddings)

    wikipedia_trunc[topic_title] = bert_model.encode(text)
    wikipedia_trunc_norm[topic_title] = normalize(wikipedia_trunc[topic_title].reshape(1, -1)).flatten()

    wikipedia_sum[topic_title] = np.sum(embeddings, axis=0)
    wikipedia_mean[topic_title] = np.mean(embeddings, axis=0)

    wikipedia_sum_norm[topic_title] = normalize(np.sum(embeddings_norm, axis=0).reshape(1, -1)).flatten()
    wikipedia_mean_norm[topic_title] = normalize(np.mean(embeddings_norm, axis=0).reshape(1, -1)).flatten()

    wikipedia_all[topic_title] = embeddings
    wikipedia_all_norm[topic_title] = embeddings_norm

In [6]:
# dump to a pickle file
pickle.dump(wikipedia_mean, open(f"../data/bert_embeddings/{base_model}_wikipedia_mean_embeddings_un.pkl", "wb"))
pickle.dump(wikipedia_sum, open(f"../data/bert_embeddings/{base_model}_wikipedia_sum_embeddings_un.pkl", "wb"))
pickle.dump(wikipedia_all, open(f"../data/bert_embeddings/{base_model}_wikipedia_all_embeddings_un.pkl", "wb"))
pickle.dump(wikipedia_trunc, open(f"../data/bert_embeddings/{base_model}_wikipedia_trunc_embeddings_un.pkl", "wb"))

pickle.dump(wikipedia_mean_norm, open(f"../data/bert_embeddings/{base_model}_wikipedia_mean_embeddings.pkl", "wb"))
pickle.dump(wikipedia_sum_norm, open(f"../data/bert_embeddings/{base_model}_wikipedia_sum_embeddings.pkl", "wb"))
pickle.dump(wikipedia_all_norm, open(f"../data/bert_embeddings/{base_model}_wikipedia_all_embeddings.pkl", "wb"))
pickle.dump(wikipedia_trunc_norm, open(f"../data/bert_embeddings/{base_model}_wikipedia_trunc_embeddings.pkl", "wb"))

## Embeddings for clicked documents
Clicked documents can also be too long. Here we use three different approaches. Pooling (with mean/sum) and a BIRCH-like one, where we only keep the sentence with higher similarity to the "golden" ALL of the embeddings, and, when computing similarity, we only keep the one with higher score to the golden  pages can be too long. So we need a strategy for pooling such long documents. We will try the following:

- Mean of embeddings vs mean of wikipedia
- Sum of embeddings vs mean of wikipedia
- Truncate doc at maximum length (384)
- BIRCH MaxP all vs all: Keep Best score (max) for each sentence vs each sentence from wikipedia
- BIRCH MaxP all vs SUM: Keep Best score (max) for each sentence vs SUM of wikipedia
- BIRCH MaxP all vs MEAN: Keep Best score (max) for each sentence vs MEAN of wikipedia

we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.

In [7]:
docs_sum = {}
docs_mean = {}
docs_trunc = {}
# maxp_pairwise = {}
# maxp_sum = {}
# maxp_mean = {}
# maxp_trunc = {}

docs_sum_norm = {}
docs_mean_norm = {}
docs_trunc_norm = {}
# maxp_pairwise_norm = {}
# maxp_sum_norm = {}
# maxp_mean_norm = {}
# maxp_trunc_norm = {}

for line in tqdm(open("../data/clicked_docs_with_topics_more.tsv"), total=1137):
    # for line in tqdm(open("../data/clicked_docs.tsv"), total=947):
    try:
        url, topic, text = line.strip().split("\t", maxsplit=2)
    except ValueError:
        url = line.split("\t")[0]
        docs_sum[url] = np.zeros(bert_model[1].word_embedding_dimension)  # Empty. We don't have this embedding now.
    sentences = sent_tokenize(text)

    embeddings = bert_model.encode(sentences, normalize_embeddings=False)
    embeddings_norm = normalize(embeddings)

    docs_trunc[url] = bert_model.encode(text, normalize_embeddings=False)
    docs_trunc_norm[url] = normalize(docs_trunc[url].reshape(1, -1)).flatten()

    docs_sum[url] = np.sum(embeddings, axis=0)
    docs_mean[url] = np.mean(embeddings, axis=0)

    docs_sum_norm[url] = normalize(np.sum(embeddings, axis=0).reshape(1, -1)).flatten()
    docs_mean_norm[url] = normalize(np.mean(embeddings, axis=0).reshape(1, -1)).flatten()

    # wikipedia_embeddings = wikipedia_all[topic]
    # wikipedia_embeddings_norm = wikipedia_all_norm[topic]

    # dimensions: (document_embeddings, wikipedia_embeddings)
    # pairwise_distances = cdist(embeddings, wikipedia_embeddings, metric="cosine")
    # best_doc, _ = np.unravel_index(pairwise_distances.argmin(), pairwise_distances.shape)
    # maxp_pairwise[url] = embeddings[best_doc]

    # maxp_sum[url] = embeddings[np.argmin(cdist(embeddings, [wikipedia_sum[topic]]))]
    # maxp_mean[url] = embeddings[np.argmin(cdist(embeddings, [wikipedia_mean[topic]]))]
    # maxp_trunc[url] = embeddings[np.argmin(cdist(embeddings, [wikipedia_trunc[topic]]))]

    # normalized version
    # pairwise_distances = cdist(embeddings_norm, wikipedia_embeddings_norm, metric="cosine")
    # best_doc, _ = np.unravel_index(pairwise_distances.argmin(), pairwise_distances.shape)
    # maxp_pairwise_norm[url] = embeddings_norm[best_doc]

    # maxp_sum_norm[url] = embeddings_norm[np.argmin(cdist(embeddings_norm, [wikipedia_sum_norm[topic]]))]
    # maxp_mean_norm[url] = embeddings_norm[np.argmin(cdist(embeddings_norm, [wikipedia_mean_norm[topic]]))]
    # maxp_trunc_norm[url] = embeddings_norm[np.argmin(cdist(embeddings_norm, [wikipedia_trunc_norm[topic]]))]

  0%|          | 0/1137 [00:00<?, ?it/s]

In [10]:
pickle.dump(docs_mean, open(f"../data/bert_embeddings/{base_model}_docs_mean_embeddings_un_more.pkl", "wb"))
pickle.dump(docs_sum, open(f"../data/bert_embeddings/{base_model}_docs_sum_embeddings_un_more.pkl", "wb"))
pickle.dump(docs_trunc, open(f"../data/bert_embeddings/{base_model}_docs_trunc_embeddings_un_more.pkl", "wb"))
# pickle.dump(maxp_pairwise, open(f"../data/bert_embeddings/{base_model}_docs_maxp_pairwise_embeddings_more.pkl", "wb"))
# pickle.dump(maxp_sum, open(f"../data/bert_embeddings/{base_model}_docs_maxp_sum_embeddings_more.pkl", "wb"))
# pickle.dump(maxp_mean, open(f"../data/bert_embeddings/{base_model}_docs_maxp_mean_embeddings_more.pkl", "wb"))
# pickle.dump(maxp_trunc, open(f"../data/bert_embeddings/{base_model}_docs_maxp_trunc_embeddings_more.pkl", "wb"))

pickle.dump(docs_mean_norm, open(f"../data/bert_embeddings/{base_model}_docs_mean_embeddings_more.pkl", "wb"))
pickle.dump(docs_sum_norm, open(f"../data/bert_embeddings/{base_model}_docs_sum_embeddings_more.pkl", "wb"))
pickle.dump(docs_trunc_norm, open(f"../data/bert_embeddings/{base_model}_docs_trunc_embeddings_more.pkl", "wb"))
# pickle.dump(maxp_pairwise_norm, open(f"../data/bert_embeddings/{base_model}_docs_maxp_pairwise_embeddings_un_more.pkl", "wb"))
# pickle.dump(maxp_sum_norm, open(f"../data/bert_embeddings/{base_model}_docs_maxp_sum_embeddings_un_more.pkl", "wb"))
# pickle.dump(maxp_mean_norm, open(f"../data/bert_embeddings/{base_model}_docs_maxp_mean_embeddings_un_more.pkl", "wb"))
# pickle.dump(maxp_trunc_norm, open(f"../data/bert_embeddings/{base_model}_docs_maxp_trunc_embeddings_un_more.pkl", "wb"))

In [96]:
# dump to a pickle file
pickle.dump(docs_mean, open(f"../data/bert_embeddings/{base_model}_docs_mean_embeddings.pkl", "wb"))
pickle.dump(docs_sum, open(f"../data/bert_embeddings/{base_model}_docs_sum_embeddings.pkl", "wb"))
pickle.dump(docs_trunc, open(f"../data/bert_embeddings/{base_model}_docs_trunc_embeddings.pkl", "wb"))
pickle.dump(maxp_pairwise, open(f"../data/bert_embeddings/{base_model}_docs_maxp_pairwise_embeddings.pkl", "wb"))
pickle.dump(maxp_sum, open(f"../data/bert_embeddings/{base_model}_docs_maxp_sum_embeddings.pkl", "wb"))
pickle.dump(maxp_mean, open(f"../data/bert_embeddings/{base_model}_docs_maxp_mean_embeddings.pkl", "wb"))
pickle.dump(maxp_trunc, open(f"../data/bert_embeddings/{base_model}_docs_maxp_trunc_embeddings.pkl", "wb"))

pickle.dump(docs_mean_norm, open(f"../data/bert_embeddings/{base_model}_docs_mean_embeddings_un.pkl", "wb"))
pickle.dump(docs_sum_norm, open(f"../data/bert_embeddings/{base_model}_docs_sum_embeddings_un.pkl", "wb"))
pickle.dump(docs_trunc_norm, open(f"../data/bert_embeddings/{base_model}_docs_trunc_embeddings_un.pkl", "wb"))
pickle.dump(maxp_pairwise_norm, open(f"../data/bert_embeddings/{base_model}_docs_maxp_pairwise_embeddings_un.pkl", "wb"))
pickle.dump(maxp_sum_norm, open(f"../data/bert_embeddings/{base_model}_docs_maxp_sum_embeddings_un.pkl", "wb"))
pickle.dump(maxp_mean_norm, open(f"../data/bert_embeddings/{base_model}_docs_maxp_mean_embeddings_un.pkl", "wb"))
pickle.dump(maxp_trunc_norm, open(f"../data/bert_embeddings/{base_model}_docs_maxp_trunc_embeddings_un.pkl", "wb"))

## Alternative: A cross-encoder model that predicts relevance (0-1)
Final score is either SUM or MEAN of relevances for all clicked docs

In [3]:
wiki_texts = {}
wiki_sentences = {}
for l in open("../data/wikipedia_texts.tsv"):
    topic_title, text = l.strip().split("\t", maxsplit=1)
    wiki_texts[topic_title] = text
    wiki_sentences[topic_title] = sent_tokenize(text)

In [18]:
from sentence_transformers import CrossEncoder

base_model = "cross-encoder/stsb-roberta-base"
# base_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# bert_model = CrossEncoder(base_model)
sim_matrixes = {}  # similarity between a doc and its topic
both_truncated_scores = {}  # truncated similarity
doc_truncated_scores = {}
wiki_truncated_scores = {}

In [None]:
docs_sum = {}
docs_mean = {}
docs_trunc = {}
maxp_pairwise = {}
maxp_sum = {}
maxp_mean = {}
maxp_trunc = {}
batch_size = 128

for line in tqdm(open("../data/clicked_docs_with_topics.tsv"), total=947):
    try:
        url, topic, text = line.strip().split("\t", maxsplit=2)
    except ValueError:
        continue
    sentences = sent_tokenize(text)
    wiki_text = wiki_texts[topic]
    wiki_sentence = wiki_sentences[topic]
    pairs_of_sentences = list(itertools.product(sentences, wiki_sentence))
    wiki_truncated = list(itertools.product(sentences, [wiki_text]))
    doc_truncated = list(itertools.product([text], wiki_sentence))

    scores = bert_model.predict(pairs_of_sentences, show_progress_bar=True, batch_size=batch_size)
    scores = scores.reshape((len(sentences), len(wiki_sentence)))

    scores_trunc_doc = bert_model.predict(doc_truncated, batch_size=batch_size)
    scores_trunc_wiki = bert_model.predict(wiki_truncated, batch_size=batch_size)

    sim_matrixes[url] = scores
    both_truncated_scores[url] = bert_model.predict([text, wiki_text])
    doc_truncated_scores[url] = scores_trunc_doc
    wiki_truncated_scores[url] = scores_trunc_wiki
    continue

    docs_mean[url] = np.mean(scores)

    # MAX-P approaches
    # We can combine each axis of the scores matrix in different ways. Using mean, max or sum. Try all.

    maxp_pairwise[url] = np.max(scores)
    mean_sentence_score = np.mean(scores, axis=1)  # one score for each sentence
    sum_sentence_score = np.sum(scores, axis=1)  # one score for each sentence
    max_sentence_score = np.max(scores, axis=1)

    maxp_mean[url] = np.max(mean_sentence_score)
    maxp_sum[url] = np.max(sum_sentence_score)

    break

In [None]:
pickle.dump(sim_matrixes, open(f"../data/{base_model}_CE_scores.pkl", "wb"))
pickle.dump(both_truncated_scores, open(f"../data/{base_model}_CE_truncated_scores.pkl", "wb"))
pickle.dump(doc_truncated_scores, open(f"../data/{base_model}_CE_doc_truncated_scores.pkl", "wb"))
pickle.dump(wiki_truncated_scores, open(f"../data/{base_model}_CE_wiki_truncated_scores.pkl", "wb"))

In [5]:
base_model = "cross-encoder/stsb-roberta-base"

sim_matrixes = pickle.load(open(f"../data/{base_model}_CE_scores.pkl", "rb"))
both_truncated_scores = pickle.load(open(f"../data/{base_model}_CE_truncated_scores.pkl", "rb"))
doc_truncated_scores = pickle.load(open(f"../data/{base_model}_CE_doc_truncated_scores.pkl", "rb"))
wiki_truncated_scores = pickle.load(open(f"../data/{base_model}_CE_wiki_truncated_scores.pkl", "rb"))

In [None]:
# get best wikipedia article match for this url and use it
# get best PARARAPGRAH and use it

In [69]:
# Select best first (1)
# First, get the best wikipedia article by either averaging (mean) over the score of all sentences
# Or getting the one with the higher value overall (max)
best_wiki_mean_mean = {}  # Best wiki paragraph by mean and average over its scores
best_wiki_mean_max = {}  # Best wiki paragraph by mean and MAX of its scores
best_wiki_max_mean = {}  # Best wiki paragraph by MAX and mean over its scores

# (2)
best_sentence_mean_mean = {}
best_sentence_max_mean = {}
best_sentence_mean_max = {}

# best overall (3)
best_pairwise = {}

# truncated variations
max_wiki_truncated = {}  # MAX over the scores of the doc sentences over the truncated wiki
mean_wiki_truncated = {}  # MEAN over the scores of the doc sentences over the truncated wiki
max_sentence_truncated = {}  # MAX over the scores of the WIKI sentences over the truncated DOC
mean_sentence_truncated = {}  # MEAN over the scores of the WIKI sentences over the truncated DOC
both_truncated = both_truncated_scores


for line in tqdm(open("../data/clicked_docs_with_topics.tsv"), total=947):
    try:
        url, topic, text = line.strip().split("\t", maxsplit=2)
    except ValueError:
        continue
    sim_matrix = sim_matrixes[url]

    # Best first (1)
    best_wiki_sentence = sim_matrix[:, np.argmax(np.sum(sim_matrix, axis=0))]
    best_wiki_mean_mean[url] = np.mean(best_wiki_sentence)
    best_wiki_mean_max[url] = np.max(best_wiki_sentence)

    best_wiki_sentence = sim_matrix[:, np.argmax(np.sum(sim_matrix, axis=0))]
    best_wiki_max_mean[url] = np.mean(best_wiki_sentence)

    # (2)
    best_doc_sentence = sim_matrix[np.argmax(np.sum(sim_matrix, axis=1))]
    best_sentence_mean_mean[url] = np.mean(best_doc_sentence)
    best_sentence_mean_max[url] = np.max(best_doc_sentence)

    best_doc_sentence = sim_matrix[np.argmax(np.max(sim_matrix, axis=1))]
    best_sentence_max_mean[url] = np.mean(best_doc_sentence)

    # (3)
    best_pairwise[url] = np.max(sim_matrix)

    # (4)
    max_wiki_truncated[url] = np.max(wiki_truncated_scores[url])
    mean_wiki_truncated[url] = np.mean(wiki_truncated_scores[url])
    max_sentence_truncated[url] = np.max(doc_truncated_scores[url])
    mean_sentence_truncated[url] = np.mean(doc_truncated_scores[url])

  0%|          | 0/947 [00:00<?, ?it/s]

In [80]:
m_name = base_model.split("/")[1]
pickle.dump(best_wiki_mean_mean, open(f"../data/cross_encoder_scores/{m_name}_best_wiki_mean_mean.pkl", "wb"))
pickle.dump(best_wiki_mean_max, open(f"../data/cross_encoder_scores/{m_name}_best_wiki_mean_max.pkl", "wb"))
pickle.dump(best_wiki_max_mean, open(f"../data/cross_encoder_scores/{m_name}_best_wiki_max_mean.pkl", "wb"))

pickle.dump(best_sentence_mean_mean, open(f"../data/cross_encoder_scores/{m_name}_best_sentence_mean_mean.pkl", "wb"))
pickle.dump(best_sentence_mean_max, open(f"../data/cross_encoder_scores/{m_name}_best_sentence_mean_max.pkl", "wb"))
pickle.dump(best_sentence_max_mean, open(f"../data/cross_encoder_scores/{m_name}_best_sentence_max_mean.pkl", "wb"))

pickle.dump(best_pairwise, open(f"../data/cross_encoder_scores/{m_name}_best_pairwise.pkl", "wb"))

pickle.dump(max_wiki_truncated, open(f"../data/cross_encoder_scores/{m_name}_max_wiki_truncated.pkl", "wb"))
pickle.dump(mean_wiki_truncated, open(f"../data/cross_encoder_scores/{m_name}_mean_wiki_truncated.pkl", "wb"))
pickle.dump(mean_sentence_truncated, open(f"../data/cross_encoder_scores/{m_name}_mean_sentence_truncated.pkl", "wb"))