# Compute Bert Embeddings
There will be a lot of BERT embeddings to be computed. It's easier to pre-compute all of them, and them use what we ACTUALLY need.

## Embeddings for Wikipedia pages
Wikipedia pages can be too long. So we need a strategy for pooling such long documents.
we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.

In [14]:
from transformers import AutoTokenizer, AutoModel
from nltk.tokenize import sent_tokenize, word_tokenize
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm

import numpy as np
import torch
import pickle


def normalize_vector(v):
    return v / np.linalg.norm(v)

In [4]:
# Prepare model. May take a while without a GPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
bert_model = SentenceTransformer("all-mpnet-base-v2", device=device)

In [8]:
wikipedia_sum = {}
wikipedia_mean = {}
wikipedia_all = {}

for l in open("../data/wikipedia_texts.tsv"):
    topic_title, text = l.strip().split("\t", maxsplit=1)
    sentences = sent_tokenize(text)
    embeddings = bert_model.encode(sentences, normalize_embeddings=True)
    wikipedia_sum[topic_title] = normalize_vector(np.sum(embeddings, axis=0))
    wikipedia_mean[topic_title] = normalize_vector(np.mean(embeddings, axis=0))
    wikipedia_all[topic_title] = embeddings

# dump to a pickle file
pickle.dump(wikipedia_mean, open("wikipedia_mean_embeddings.pkl", "wb"))
pickle.dump(wikipedia_sum, open("wikipedia_sum_embeddings.pkl", "wb"))
pickle.dump(wikipedia_all, open("wikipedia_all_embeddings.pkl", "wb"))

## Embeddings for clicked documents
Clicked documents can also be too long. Here we use three different approaches. Pooling (with mean/sum) and a BIRCH-like one, where we only keep the sentence with higher similarity to the "golden" ALL of the embeddings, and, when computing similarity, we only keep the one with higher score to the golden  pages can be too long. So we need a strategy for pooling such long documents. We will try the following:

- Mean of embeddings vs mean of wikipedia
- Sum of embeddings vs mean of wikipedia
- BIRCH MaxP all vs all: Keep Best score (max) for each sentence vs each sentence from wikipedia
- BIRCH MaxP all vs SUM: Keep Best score (max) for each sentence vs SUM of wikipedia
- BIRCH MaxP all vs MEAN: Keep Best score (max) for each sentence vs MEAN of wikipedia

we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.

In [12]:
!wc -l ../data/clicked_docs_with_topics.tsv

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
947 ../data/clicked_docs_with_topics.tsv


In [18]:
docs_sum = {}
docs_mean = {}
maxp_pairwise = {}
maxp_sum = {}
maxp_mean = {}

for line in tqdm(open("../data/clicked_docs_with_topics.tsv"), total=947):
    url, topic, text = line.split("\t", maxsplit=2)
    topic = 
    sentences = sent_tokenize(text)
    embeddings = bert_model.encode(sentences, normalize_embeddings=True)
    docs_sum[url] = normalize_vector(np.sum(embeddings, axis=0))
    docs_mean[url] = normalize_vector(np.mean(embeddings, axis=0))

    # GET PAIRWISE data
    break

  0%|          | 0/947 [00:00<?, ?it/s]

In [20]:
wikipedia_all.keys()

dict_keys(['Subprime%20mortgage%20crisis', 'Irritable%20bowel%20syndrome', 'Genetically%20modified%20organism', 'Noise-induced%20hearing%20loss', 'Business%20cycle', 'Ethics', 'Radiocarbon%20dating%20considerations'])