## BERT Knowledge Representation
We want to compute, at each step of the user session (i.e. each document clicked), how their internal knowledge representation changes. Therefore, we have a few different methods to do so using Bert embeddings as a starting point.
A few assumptions:
- user starts with an empty knowledge representation
- User READS every document, and that is added to their knowledge

Document embeddings for BERT can come in a few different forms. Check `Compute_BERT_embeddings.ipynb` for how we compute each:
- SUM: sum of the embeddings for each sentence of the document
- MEAN: mean of the embeddings for each sentence of the document
- TRUNC: Truncate the document at the first 384 tokens.
- maxp_pairwise: Considering all sentences from the Wikipedia topic and the document, consider only the sentence with the higher similarity for any Wikipedia sentence
- maxp_sum: Consider only the sentence with higher similarity to the SUM of the wikipedia sentences
- maxp_mean: Consider only the sentence with higher similarity to the MEAN of the wikipedia sentences
- maxp_trunc: Consider only the sentence with higher similarity to the truncated wikipedia document
    
These are the ways we can compute the users' knowledge evolution. Will be compared to the same method of aggregation on the Wikipedia text

- MEAN: Concatenate all of the documents, the MEAN of these is the final knowledge.
- SUM: As the user clicks on documents, SUM the embeddings


In [225]:
import json
import pickle
import urllib.parse
from collections import defaultdict
from pprint import pprint

import numpy as np
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import cosine, euclidean
from tqdm.auto import tqdm

normalized = True
# base_model, hidden_size = "all-MiniLM-L6-v2", 384
base_model, hidden_size = "msmarco-MiniLM-L6-cos-v5", 384
# base_model, hidden_size = "all-mpnet-base-v2", 768


def load_embeddings(name, doc="docs", normalized=normalized):
    if not normalized:
        return pickle.load(open(f"../data/{base_model}_{doc}_{name}_embeddings_un.pkl", "rb"))
    else:
        return pickle.load(open(f"../data/{base_model}_{doc}_{name}_embeddings.pkl", "rb"))


dataset = json.load(open("../data/logs_with_position.json"))

embeddings = {
    "docs_mean": load_embeddings("mean"),
    "docs_sum": load_embeddings("sum"),
    "docs_trunc": load_embeddings("trunc"),
    "maxp_pairwise": load_embeddings("maxp_pairwise"),
    "maxp_sum": load_embeddings("maxp_sum"),
    "maxp_mean": load_embeddings("maxp_mean"),
    "maxp_trunc": load_embeddings("maxp_trunc"),
}

wikipedia_embeddings = {
    "sum": load_embeddings("sum", "wikipedia"),
    "mean": load_embeddings("mean", "wikipedia"),
    "trunc": load_embeddings("trunc", "wikipedia"),
}

In [228]:
methods

['docs_mean',
 'docs_sum',
 'docs_trunc',
 'maxp_pairwise',
 'maxp_sum',
 'maxp_mean',
 'maxp_trunc']

In [226]:
methods = list(embeddings.keys())
users_knowledge_MEAN = []  # add final score for the user
users_knowledge_SUM = []
final_knowledges = defaultdict(OrderedDict)
missing_docs = set()

for u in dataset:
    u_id = u["userID"]
    ALG = u["ALG"]
    RPL = u["RPL"]

    user_knowledge_mean = {k: [] for k in methods}
    user_knowledge_sum = {k: np.zeros(hidden_size) for k in methods}
    topic = urllib.parse.quote(u["topic_title"])
    clicks = 0
    for d in u["clicks"]:
        url = d["url"]
        if url not in embeddings["docs_mean"] or not np.any(embeddings["docs_mean"][url]):
            missing_docs.add(url)
            continue
        clicks += 1
        for method in methods:
            emb = embeddings[method][url]
            user_knowledge_mean[method].append(emb)
            # user_knowledge_mean[method] += (emb - user_knowledge_mean[method]) / (clicks)
            user_knowledge_sum[method] += emb
    # normalize and compute final similarity
    for method in methods:
        knowledge_mean = normalize_vector(np.mean(user_knowledge_mean[method], axis=0))
        knowledge_sum = normalize_vector(user_knowledge_sum[method])
        for emb_type in wikipedia_embeddings.keys():
            wiki_emb = wikipedia_embeddings[emb_type][topic]
            sum_dict = {"RPL": RPL, "ALG": ALG, "final_sim": 1 - euclidean(knowledge_sum, wiki_emb)}
            mean_dict = {"RPL": RPL, "ALG": ALG, "final_sim": 1 - euclidean(knowledge_mean, wiki_emb)}

            final_knowledges[f"SUM_{method}_{emb_type}"][u_id] = sum_dict
            final_knowledges[f"MEAN_{method}_{emb_type}"][u_id] = mean_dict

pickle.dump(dict(final_knowledges), open(f"../data/{base_model}_knowledge_gains.pkl", "wb"))
with open("../data/missing_docs.txt", "w") as outf:
    for u in missing_docs:
        outf.write(f"{u}\n")

## Correlation between learning and embedding methods

In the end, we want to know, for each of the 42 methods tried above, which one has the higher correlation with learning gains.

In [227]:
RPLs = []
ALGs = []
u_ids = [u["userID"] for u in dataset]
ALGs = [u["ALG"] for u in dataset]
RPLs = [u["RPL"] for u in dataset]

# make sure we follow the same order of users
pearsons_ALG = {}
spearman_ALG = {}
corr_ALG = {}

correlations = []
for method, _users in final_knowledges.items():
    results = [_users[x]["final_sim"] for x in u_ids]
    pearsons_ALG[method] = pearsonr(results, ALGs)[0]
    spearman_ALG[method] = spearmanr(results, ALGs)[0]

    correlations = [{"userID": x, "score": _users[x]["final_sim"], "ALG": ALGs[idx], "RPL": RPLs[idx]} for (idx, x) in enumerate(u_ids)]

pprint(sorted(spearman_ALG.items(), key=lambda x: x[1], reverse=True))

[('MEAN_docs_mean_sum', 0.2033924875048431),
 ('MEAN_docs_mean_mean', 0.2033924875048431),
 ('MEAN_docs_trunc_sum', 0.154210292516518),
 ('MEAN_docs_trunc_mean', 0.154210292516518),
 ('MEAN_maxp_trunc_trunc', 0.1422547739095751),
 ('MEAN_maxp_sum_trunc', 0.07315063988820826),
 ('MEAN_maxp_mean_trunc', 0.07315063988820826),
 ('MEAN_docs_trunc_trunc', 0.07177579940816829),
 ('MEAN_maxp_sum_sum', 0.06382409890155592),
 ('MEAN_maxp_sum_mean', 0.06382409890155592),
 ('MEAN_maxp_mean_sum', 0.06382409890155592),
 ('MEAN_maxp_mean_mean', 0.06382409890155592),
 ('MEAN_maxp_pairwise_trunc', 0.06187223757764479),
 ('MEAN_docs_mean_trunc', 0.05695872845304796),
 ('MEAN_maxp_trunc_sum', 0.036524536221875674),
 ('MEAN_maxp_trunc_mean', 0.036524536221875674),
 ('MEAN_maxp_pairwise_sum', 0.02269369987235148),
 ('MEAN_maxp_pairwise_mean', 0.02269369987235148),
 ('SUM_docs_trunc_sum', -0.034287108459926254),
 ('SUM_docs_trunc_mean', -0.034287108459926254),
 ('SUM_docs_trunc_trunc', -0.03449613131663469)

In [74]:
pearsonr(results, ALGs)

(0.1745923355832255, 0.049622036223161144)

In [161]:
import matplotlib.pyplot as plt