# Compute Bert Embeddings
There will be a lot of BERT embeddings to be computed. It's easier to pre-compute all of them, and them use what we ACTUALLY need.

## Embeddings for Wikipedia pages
Wikipedia pages can be too long. So we need a strategy for pooling such long documents.
we are using the [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it.

In [33]:
import itertools
import pickle
import urllib.parse

from collections import defaultdict

import numpy as np
import torch

from nltk.tokenize import sent_tokenize, word_tokenize
from scipy.spatial.distance import cdist
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer
from utils import rawcount

In [29]:
# Prepare model. May take a while without a GPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "msmarco-MiniLM-L6-cos-v5"

# These  don't work as well
# base_model = "all-MiniLM-L6-v2"
# base_model = "all-mpnet-base-v2"

bert_model = SentenceTransformer(base_model, device=device)

In [3]:
wikipedia_sum = {}
wikipedia_mean = {}
wikipedia_all = {}
wikipedia_trunc = {}

wikipedia_sum_norm = {}
wikipedia_mean_norm = {}
wikipedia_all_norm = {}
wikipedia_trunc_norm = {}

for l in open("../data/wikipedia_texts.tsv"):
    topic_title, text = l.strip().split("\t", maxsplit=1)
    sentences = sent_tokenize(text)
    embeddings = bert_model.encode(sentences, normalize_embeddings=False)
    embeddings_norm = normalize(embeddings)

    wikipedia_trunc[topic_title] = bert_model.encode(text)
    wikipedia_trunc_norm[topic_title] = normalize(wikipedia_trunc[topic_title].reshape(1, -1)).flatten()

    wikipedia_sum[topic_title] = np.sum(embeddings, axis=0)
    wikipedia_mean[topic_title] = np.mean(embeddings, axis=0)

    wikipedia_sum_norm[topic_title] = normalize(np.sum(embeddings_norm, axis=0).reshape(1, -1)).flatten()
    wikipedia_mean_norm[topic_title] = normalize(np.mean(embeddings_norm, axis=0).reshape(1, -1)).flatten()

    wikipedia_all[topic_title] = embeddings
    wikipedia_all_norm[topic_title] = embeddings_norm

In [16]:
# dump to a pickle file
pickle.dump(wikipedia_mean, open(f"../data/bert_embeddings/{base_model}_wikipedia_mean_embeddings.pkl", "wb"))
pickle.dump(wikipedia_sum, open(f"../data/bert_embeddings/{base_model}_wikipedia_sum_embeddings.pkl", "wb"))
pickle.dump(wikipedia_all, open(f"../data/bert_embeddings/{base_model}_wikipedia_all_embeddings.pkl", "wb"))
pickle.dump(wikipedia_trunc, open(f"../data/bert_embeddings/{base_model}_wikipedia_trunc_embeddings.pkl", "wb"))

pickle.dump(wikipedia_mean_norm, open(f"../data/bert_embeddings/{base_model}_wikipedia_mean_embeddings_norm.pkl", "wb"))
pickle.dump(wikipedia_sum_norm, open(f"../data/bert_embeddings/{base_model}_wikipedia_sum_embeddings_norm.pkl", "wb"))
pickle.dump(wikipedia_all_norm, open(f"../data/bert_embeddings/{base_model}_wikipedia_all_embeddings_norm.pkl", "wb"))
pickle.dump(wikipedia_trunc_norm, open(f"../data/bert_embeddings/{base_model}_wikipedia_trunc_embeddings_norm.pkl", "wb"))

In [46]:
# Just load the embeddings....

wikipedia_mean = pickle.load(open(f"../data/bert_embeddings/{base_model}_wikipedia_mean_embeddings.pkl", "rb"))
wikipedia_sum = pickle.load(open(f"../data/bert_embeddings/{base_model}_wikipedia_sum_embeddings.pkl", "rb"))
wikipedia_all = pickle.load(open(f"../data/bert_embeddings/{base_model}_wikipedia_all_embeddings.pkl", "rb"))
wikipedia_trunc = pickle.load(open(f"../data/bert_embeddings/{base_model}_wikipedia_trunc_embeddings.pkl", "rb"))


wikipedia_mean_norm = pickle.load(
    open(f"../data/bert_embeddings/{base_model}_wikipedia_mean_embeddings_norm.pkl", "rb")
)
wikipedia_sum_norm = pickle.load(open(f"../data/bert_embeddings/{base_model}_wikipedia_sum_embeddings_norm.pkl", "rb"))
wikipedia_trunc_norm = pickle.load(
    open(f"../data/bert_embeddings/{base_model}_wikipedia_trunc_embeddings_norm.pkl", "rb")
)
wikipedia_trunc_norm = pickle.load(
    open(f"../data/bert_embeddings/{base_model}_wikipedia_trunc_embeddings_norm.pkl", "rb")
)

## Embeddings for clicked documents
Clicked documents can also be too long. Here we use three different approaches. Pooling (with mean/sum) and a BIRCH-like one, where we only keep the sentence with higher similarity to the "golden" ALL of the embeddings, and, when computing similarity, we only keep the one with higher score to the golden  pages can be too long. So we need a strategy for pooling such long documents. We will try the following:

- Mean of embeddings vs mean of wikipedia
- Sum of embeddings vs mean of wikipedia
- Truncate doc at maximum length (384)
- BIRCH MaxP all vs all: Keep Best score (max) for each sentence vs each sentence from wikipedia
- BIRCH MaxP all vs SUM: Keep Best score (max) for each sentence vs SUM of wikipedia
- BIRCH MaxP all vs MEAN: Keep Best score (max) for each sentence vs MEAN of wikipedia

we are using the [`msmarco-MiniLM-L6-cos-v5`](https://huggingface.co/sentence-transformers/msmarco-MiniLM-L6-cos-v5) model as base, using a mean pooling strategy from `sentence-transformers` to compute embeddings.

Each document is split into sentences, using NLTK. If a sentence is longer than the limit of the model (384 tokens in this case), we truncate it. (does it actually happen?)

In [52]:
bert_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [47]:
docs_sum = {}
docs_mean = {}
docs_trunc = {}

docs_sum_norm = {}
docs_mean_norm = {}
docs_trunc_norm = {}

docs_file = "../data/clicked_docs.tsv"

for line in tqdm(open(docs_file), total=rawcount(docs_file)):
    url, text = line.strip().split("\t", maxsplit=1)
    sentences = sent_tokenize(text)

    embeddings = bert_model.encode(sentences, normalize_embeddings=False)
    embeddings_norm = normalize(embeddings)

    # docs_trunc[url] = bert_model.encode(text, normalize_embeddings=False)
    # docs_trunc_norm[url] = normalize(docs_trunc[url].reshape(1, -1)).flatten()

    docs_sum[url] = np.sum(embeddings, axis=0)
    docs_mean[url] = np.mean(embeddings, axis=0)

    docs_sum_norm[url] = normalize(np.sum(embeddings, axis=0).reshape(1, -1)).flatten()
    docs_mean_norm[url] = normalize(np.mean(embeddings, axis=0).reshape(1, -1)).flatten()

  0%|          | 0/1074 [00:00<?, ?it/s]

In [48]:
pickle.dump(docs_mean, open(f"../data/bert_embeddings/{base_model}_docs_mean_embeddings.pkl", "wb"))
pickle.dump(docs_sum, open(f"../data/bert_embeddings/{base_model}_docs_sum_embeddings.pkl", "wb"))
pickle.dump(docs_trunc, open(f"../data/bert_embeddings/{base_model}_docs_trunc_embeddings.pkl", "wb"))

pickle.dump(docs_mean_norm, open(f"../data/bert_embeddings/{base_model}_docs_mean_embeddings_norm.pkl", "wb"))
pickle.dump(docs_sum_norm, open(f"../data/bert_embeddings/{base_model}_docs_sum_embeddings_norm.pkl", "wb"))
pickle.dump(docs_trunc_norm, open(f"../data/bert_embeddings/{base_model}_docs_trunc_embeddings_norm.pkl", "wb"))