In [None]:
import itertools
import pickle
import urllib.parse
from collections import defaultdict

import numpy as np
import torch
from nltk.tokenize import sent_tokenize, word_tokenize
from scipy.spatial.distance import cdist
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer

In [None]:
# Prepare model. May take a while without a GPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "msmarco-MiniLM-L6-cos-v5"

bert_model = SentenceTransformer(base_model, device=device)

In [None]:
golden_sum = {}
golden_mean = {}
golden_all = {}
golden_trunc = {}

for l in open("../data/nirmal_targets.csv"):
    topic_id, text = l.strip().split("\t", maxsplit=1)
    topic_id = int(topic_id)
    sentences = sent_tokenize(text)
    embeddings = bert_model.encode(sentences, normalize_embeddings=True)

    golden_trunc[topic_id] = bert_model.encode(text, normalize_embeddings=True)
    golden_sum[topic_id] = np.sum(embeddings, axis=0)
    golden_mean[topic_id] = np.mean(embeddings, axis=0)
    golden_all[topic_id] = embeddings

pickle.dump(golden_mean, open(f"../data/bert_embeddings/{base_model}_golden_mean_embeddings.pkl", "wb"))
pickle.dump(golden_sum, open(f"../data/bert_embeddings/{base_model}_golden_sum_embeddings.pkl", "wb"))
pickle.dump(golden_all, open(f"../data/bert_embeddings/{base_model}_golden_all_embeddings.pkl", "wb"))
pickle.dump(golden_trunc, open(f"../data/bert_embeddings/{base_model}_golden_trunc_embeddings.pkl", "wb"))


In [None]:
docs_sum = {}
docs_mean = {}
docs_trunc = {}

for line in tqdm(open("../data/clicked_docs_nirmal.tsv"), total=383):
    try:
        url, topic, text = line.strip().split("\t", maxsplit=2)
    except ValueError:
        url = line.split("\t")[0]
        docs_sum[url] = np.zeros(bert_model[1].word_embedding_dimension)  # Empty. We don't have this embedding now.
    sentences = sent_tokenize(text)

    embeddings = bert_model.encode(sentences, normalize_embeddings=True)
    docs_trunc[url] = bert_model.encode(text, normalize_embeddings=True)

    docs_sum[url] = np.sum(embeddings, axis=0)
    docs_mean[url] = np.mean(embeddings, axis=0)

pickle.dump(docs_mean, open(f"../data/bert_embeddings/{base_model}nirmal_docs_mean_embeddings.pkl", "wb"))
pickle.dump(docs_sum, open(f"../data/bert_embeddings/{base_model}nirmal_docs_sum_embeddings.pkl", "wb"))
pickle.dump(docs_trunc, open(f"../data/bert_embeddings/{base_model}nirmal_docs_trunc_embeddings.pkl", "wb"))
