# Colab specific

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers
!pip install word2vec

# Main script

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

In [None]:
# Jupyter is expected to be run from git repository root directory.
# By default CategoryLearning is the directory name. Modify path in case of an error.
# It's is recommended to use an absolute path in lib_path
import os
from pathlib import Path
import re
# DEFINE PATH BELOW
# lib_path = Path('/content/drive/My Drive/Colab/CategoryLearning')
lib_path = Path.resolve(Path('../'))
print(f'Current absolute path to Git root directory: {lib_path}')
assert 'catlearn' in os.listdir(lib_path), 'It seems notebook is run not from the root directory. Modify lib_path.'

In [None]:
# Modify notebook environment path in case it was updated in the cell above
os.chdir(lib_path)

In [None]:
import torch
from transformers import BertTokenizer, BertModel
import logging
import matplotlib.pyplot as plt
from scipy.spatial.distance import cosine, euclidean
from sklearn.decomposition import PCA
from data.config import (raw_dataset_pat, preproc_pat)
from data.utils import read_file, write_file
from typing import Dict, List, Set, Union, Any
from tqdm import tqdm_notebook as tqdm
import pickle as pkl

In [None]:
device = torch.device('cuda')
torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
model = BertModel.from_pretrained('bert-base-uncased', 
                                 output_hidden_states = True,
                                 )
# model.eval()

In [None]:
!nvidia-smi

In [None]:
def word2vec_bert(sentence: str):
    """
    Return outputs of all 12 layers of BERT plus
    output of the final FC layer.
    All in the format [n_token, n_layer, emb_size],
    where n_token is equal to n-2 tokens with added
    zero position [CLS] and last position [SEP],
    where n_layers is 13 (12 + 1),
    where emb_size is 768.
    Use token_embeddings[12] to get final output
    """
    marked_sentence = '[CLS] ' + sentence + ' [SEP]'

    tokenized = tokenizer.tokenize(marked_sentence)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized)
    # for tup in zip(tokenized, indexed_tokens):
        # print('{:<12} {:>6,}'.format(tup[0], tup[1]))
    indexed_tokens_tensor = torch.tensor([indexed_tokens])
    sentence_ids = [1] * len(tokenized)
    sentence_ids_tensors = torch.tensor([sentence_ids])

    with torch.no_grad():
        outputs = model.forward(input_ids=indexed_tokens_tensor, attention_mask=sentence_ids_tensors)
        # same as: outputs = model(input_ids=tokens_tensor, attention_mask=segments_tensors)
        hidden_states = outputs[2]

    token_embeddings = torch.stack(hidden_states, dim=0)
    token_embeddings = torch.squeeze(token_embeddings, dim=1)
    token_embeddings = token_embeddings.permute(1,0,2)
    return token_embeddings

Note: In average, words without separators [CLS], [SEP] are quite different in meaning from analogoues with separators
Separators are needed for consistency of the results.

## Read dataset and create dictionary

In [None]:
path_dir_wordset = './Datasets/wn18rr/text'

In [None]:
def get_embeddings(path_dir_wordset: str):
    """
    Generate embedding wectors for each word.
    path_dir_wordset: path to the file with a set of words
    output word_embeddings: dictionary with a word as a key
    and 768 dim vector for the mebedding
    Note: for composend words with several embeddings per word,
    will average all embeding vectors.
    return: word_embeddings: dict {'str': Ndarray}
    lengthes: dict {'str': int} -- number of embedding wectors per word/phrase
    """
    word_set: List[str] = read_file(os.path.join(path_dir_wordset, preproc_pat['wordset']))
    
    word_embeddings: Dict = {}
    lengthes = {}
    for word in tqdm(word_set):
        # split compound words on members
        sentence = word.replace('_', ' ')
        embedding_vector = word2vec_bert(sentence)
        # remove start/end markers and take final embedding
        embedding_vector = embedding_vector[1:-1, 12]
        lengthes[word] = embedding_vector.shape[0]
        # average elemntwise. if one embedding vecotr, does not affect the output,
        # but if >= 2 embeding vectors (e.g. because of several words in the entity)
        # will average elementwise
        word_embeddings[word] = torch.mean(embedding_vector, dim=0).cpu().numpy()
    return word_embeddings, lengthes

In [None]:
# Generating embedding takes ~30 min on CPU or ~5-10 min on GPU
# Check first if embedding file is available in pkl or .txt
# You can jub directly to 'Load file directly instead if generating the new one (if available)'
# Get word embeddings with BERT defaul dimmension 768
# %time word_set_embeddings, lengthes = get_embeddings(path_dir_wordset)

In [None]:
# Save to pickle
# with open('./bert_embed_np.pkl', 'wb') as f:
    # pkl.dump(word_set_embeddings, f)

In [None]:
# Save to .txt
# write_file('./bert_embed.txt', word_set_embeddings)

## Load file directly instead if generating the new one (if available)

In [None]:
# If saved file is available read it directly
with open('./bert_embed_np.pkl', 'rb') as f:
    word_set_embeddings = pkl.load(f)

In [None]:
# Plot histogram of embedding vector numbers per node
# Note: all embeddings were averaged to get a single embedding per node
# Data below for analysis pourpose only
plt.hist(lengthes.values())
plt.show()

In [None]:
# Define below new embedding dimension
pca = PCA(n_components = 128)

In [None]:
word_set_embeddings_only = [v for k, v in word_set_embeddings.items()]

In [None]:
pca.fit(word_set_embeddings_only)

In [None]:
word_set_embeddings_only_transformed = pca.transform(word_set_embeddings_only)

In [None]:
# Create final dictionary with word: embedding pairs
word_set_embeddings_transformed = {k: v for k, v in zip(word_set_embeddings.keys(), word_set_embeddings_only_transformed)}

In [None]:
# Save to pickle
with open('./bert_embed_np_128.pkl', 'wb') as f:
    pkl.dump(word_set_embeddings, f)