# Compute Bert Embeddings
There will be a lot of BERT embeddings to be computed. It's easier to pre-compute all of them, and them use what we ACTUALLY need.

## Embeddings for Wikipedia pages
Wikipedia pages can be too long. So we need a strategy for pooling such long documents.
we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.

In [1]:
from transformers import AutoTokenizer, AutoModel
from nltk.tokenize import sent_tokenize, word_tokenize
from collections import defaultdict
from sentence_transformers import SentenceTransformer
import torch
import pickle


def normalize_vector(v):
    return v / np.linalg.norm(v)

In [31]:
# Prepare model. May take a while without a GPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
bert_model = SentenceTransformer("all-mpnet-base-v2", device=device)

In [128]:
wikipedia_sum = {}
wikipedia_mean = {}

for l in open("../data/wikipedia_texts.tsv"):
    topic_title, text = l.strip().split("\t", maxsplit=1)
    sentences = sent_tokenize(text)
    embeddings = bert_model.encode(sentences, normalize_embeddings=True)
    wikipedia_sum[topic_title] = normalize_vector(np.sum(embeddings, axis=0))
    wikipedia_mean[topic_title] = normalize_vector(np.mean(embeddings, axis=0))

# dump to a pickle file
pickle.dump(wikipedia_mean, open("wikipedia_mean_embeddings.pkl", "wb"))
pickle.dump(wikipedia_sum, open("wikipedia_sum_embeddings.pkl", "wb"))

## Embeddings for clicked documents
Clicked documents can also be too long. Here we use three different approaches. Pooling (with mean/sum) and a BIRCH-like one, where we only keep the sentence with higher similarity to the "golden" ALL of the embeddings, and, when computing similarity, we only keep the one with higher score to the golden  pages can be too long. So we need a strategy for pooling such long documents.
we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.