# Compute Bert Embeddings
There will be a lot of BERT embeddings to be computed. It's easier to pre-compute all of them, and them use what we ACTUALLY need.

## Embeddings for Wikipedia pages
Wikipedia pages can be too long. So we need a strategy for pooling such long documents.
we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.

In [24]:
import pickle
import urllib.parse
from collections import defaultdict

import numpy as np
import torch
from nltk.tokenize import sent_tokenize, word_tokenize
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer


def normalize_vector(v):
    return v / np.linalg.norm(v)

In [25]:
# Prepare model. May take a while without a GPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
bert_model = SentenceTransformer("all-mpnet-base-v2", device=device)

In [203]:
wikipedia_trunc.keys()

dict_keys(['Radiocarbon%2520dating%2520considerations'])

In [199]:
wikipedia_sum.keys()

dict_keys(['Subprime%20mortgage%20crisis', 'Irritable%20bowel%20syndrome', 'Genetically%20modified%20organism', 'Noise-induced%20hearing%20loss', 'Business%20cycle', 'Ethics', 'Radiocarbon%20dating%20considerations'])

In [201]:
wikipedia_trunc["Subprime%20mortgage%20crisis"].shape

KeyError: 'Subprime%20mortgage%20crisis'

In [206]:
wikipedia_sum = {}
wikipedia_mean = {}
wikipedia_all = {}
wikipedia_trunc = {}

for l in open("../data/wikipedia_texts.tsv"):
    topic_title, text = l.strip().split("\t", maxsplit=1)
    sentences = sent_tokenize(text)
    embeddings = bert_model.encode(sentences, normalize_embeddings=True)
    wikipedia_sum[topic_title] = normalize_vector(np.sum(embeddings, axis=0))
    wikipedia_mean[topic_title] = normalize_vector(np.mean(embeddings, axis=0))
    wikipedia_trunc[topic_title] = bert_model.encode(text)

    wikipedia_all[topic_title] = embeddings

# dump to a pickle file
pickle.dump(wikipedia_mean, open("../data/wikipedia_mean_embeddings.pkl", "wb"))
pickle.dump(wikipedia_sum, open("../data/wikipedia_sum_embeddings.pkl", "wb"))
pickle.dump(wikipedia_all, open("../data/wikipedia_all_embeddings.pkl", "wb"))
pickle.dump(wikipedia_trunc, open("../data/wikipedia_trunc_embeddings.pkl", "wb"))

## Embeddings for clicked documents
Clicked documents can also be too long. Here we use three different approaches. Pooling (with mean/sum) and a BIRCH-like one, where we only keep the sentence with higher similarity to the "golden" ALL of the embeddings, and, when computing similarity, we only keep the one with higher score to the golden  pages can be too long. So we need a strategy for pooling such long documents. We will try the following:

- Mean of embeddings vs mean of wikipedia
- Sum of embeddings vs mean of wikipedia
- Truncate doc at maximum length (384)
- BIRCH MaxP all vs all: Keep Best score (max) for each sentence vs each sentence from wikipedia
- BIRCH MaxP all vs SUM: Keep Best score (max) for each sentence vs SUM of wikipedia
- BIRCH MaxP all vs MEAN: Keep Best score (max) for each sentence vs MEAN of wikipedia

we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.

In [209]:
docs_sum = {}
docs_mean = {}
docs_trunc = {}
maxp_pairwise = {}
maxp_sum = {}
maxp_mean = {}
maxp_trunc = {}

for line in tqdm(open("../data/clicked_docs_with_topics.tsv"), total=947):
    try:
        url, topic, text = line.strip().split("\t", maxsplit=2)
    except ValueError:
        url = line.split("\t")[0]
        docs_sum[url] = np.zeros(768)  # Empty. We don't have this embedding now.
    sentences = sent_tokenize(text)
    embeddings = bert_model.encode(sentences, normalize_embeddings=True)
    docs_sum[url] = normalize_vector(np.sum(embeddings, axis=0))
    docs_mean[url] = normalize_vector(np.mean(embeddings, axis=0))
    docs_trunc[url] = bert_model.encode(text, normalize_embeddings=True)

    # Best PAIRWISE similarity
    wikipedia_embeddings = wikipedia_all[topic]
    # dimensions: (document_embeddings, wikipedia_embeddings)
    pairwise_similarities = np.dot(embeddings, wikipedia_embeddings.T)
    best_doc, best_wiki = np.unravel_index(pairwise_similarities.argmax(), pairwise_similarities.shape)

    maxp_pairwise[url] = embeddings[best_doc]
    maxp_sum[url] = embeddings[np.argmax(np.dot(embeddings, wikipedia_sum[topic].T))]
    maxp_mean[url] = embeddings[np.argmax(np.dot(embeddings, wikipedia_mean[topic].T))]
    maxp_trunc[url] = embeddings[np.argmax(np.dot(embeddings, wikipedia_trunc[topic].T))]

# dump to a pickle file
pickle.dump(docs_mean, open("../data/docs_mean_embeddings.pkl", "wb"))
pickle.dump(docs_sum, open("../data/docs_sum_embeddings.pkl", "wb"))
pickle.dump(docs_trunc, open("../data/docs_trunc_embeddings.pkl", "wb"))
pickle.dump(maxp_pairwise, open("../data/docs_maxp_pairwise_embeddings.pkl", "wb"))
pickle.dump(maxp_sum, open("../data/docs_maxp_sum_embeddings.pkl", "wb"))
pickle.dump(maxp_mean, open("../data/docs_maxp_mean_embeddings.pkl", "wb"))
pickle.dump(maxp_trunc, open("../data/docs_maxp_trunc_embeddings.pkl", "wb"))

  0%|          | 0/947 [00:00<?, ?it/s]