In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import os
from psycho_embeddings.fresh_embedder import ContextualizedEmbedder
import numpy as np
import pandas as pd
from glob import glob
import pickle
import gc
from tqdm import tqdm
from operator import itemgetter
from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, fromstring
from gensim import utils
import gensim

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
model = ContextualizedEmbedder("bert-base-cased", max_length=300, device="cpu")

loading configuration file config.json from cache at /home/vinid/.cache/huggingface/hub/models--bert-base-cased/snapshots/5532cc56f74641d4bb33641f5c76a55d11f846e0/config.json
Model config BertConfig {
  "_name_or_path": "bert-base-cased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.24.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}

loading weights file pytorch_model.bin from cache at /home/vinid/.cache/huggingface/hub/models--bert-base-cased/snapshots/5532cc56f746

In [4]:
def find_index_for_word(word, data):
    """
    Given a word and a dataframe, finds the idxs of that word in the dataframe
    """
    return data[data["words"] == word].index.tolist()

def get_average_word_embeddings(word, data, embeds):
    """
    Given a word, a data, and the embeddings, it averages the embeddings of that word
    """
    idxs = find_index_for_word(word, data)
    if len(idxs) > 1:
        return np.average(itemgetter(*idxs)(embeds), axis=0)
    else:
        return np.array(embeds[idxs[0]]) # idxs is a list of lists so we access the first element

In [5]:
data = pd.DataFrame({"words" : ["cat", "dog", "cat"], "target_text" : ["the cat is on the table", "the dog is on the table", "the cat is on the table"]})

In [6]:
SIZE_CHUNKS = 2  #chunk row size
FOLDER_NAME = "bert_embeddings"

In [7]:
layers_of_interest = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
list_df = [data[i:i+SIZE_CHUNKS] for i in range(0,data.shape[0],SIZE_CHUNKS)]

# Create Embeddings for The Enitre Dataset

In [8]:
pbar = tqdm(total=len(list_df), position=0)

for index, sub_portion_od_data in enumerate(list_df):

    #############################
    # DUMPING EMBEDDING ON DISK #
    #############################

    df_slice_embedded = embeddings = model.embed(
        words=sub_portion_od_data["words"].tolist(),
        layers_id=layers_of_interest,
        target_texts=sub_portion_od_data["target_text"].tolist(),
        batch_size=8,
        averaging=True,
        return_static=True,
        show_progress=True
    )

    for layer in [-1] + layers_of_interest:
        os.makedirs(f"{FOLDER_NAME}/{layer}/temp/", exist_ok=True)

        with open(f"{FOLDER_NAME}/{layer}/temp/bert_embeddings_{index}", "wb") as filino:
            pickle.dump((df_slice_embedded[layer]), filino)

    if index%10==0:
        gc.collect()
    pbar.update(1)

  0%|                                                                                                                                              | 0/2 [00:00<?, ?it/s]

Text tokenization:   0%|          | 0/2 [00:00<?, ?ex/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.16it/s]
 50%|███████████████████████████████████████████████████████████████████                                                                   | 1/2 [00:01<00:01,  1.56s/it]

Text tokenization:   0%|          | 0/1 [00:00<?, ?ex/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.84it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.17s/it]

# Reconstruct and Save Contextualzied Embeddings

In [10]:
for LAYER in tqdm(range(-1, 13), desc="Layer"):
    # We load all the embeddings from disk, in order and reconstruct the actual embedding for a specific layer for the entire dataframe.

    emb_files = sorted(glob(f"{FOLDER_NAME}/{LAYER}/temp/*"), key=lambda x: int(os.path.basename(x).split("_")[-1]))
    assert len(emb_files) == len(list_df) # sanity check

    all_the_embeddings = []
    pbar = tqdm(total=len(list_df), position=0)

    for ff in emb_files:
        with open(ff, "rb") as filino:
            ldata = pickle.load(filino)
            pbar.update(1)
            for value in ldata:
                if len(value) == 1:
                    all_the_embeddings.append(np.array(value[0]))
                else:
                    all_the_embeddings.append(np.array(value))
    pbar.close()

    all_the_embeddings = np.array(all_the_embeddings)


    with open(f'{FOLDER_NAME}/contextualized_embeddings_bert_{LAYER}_layer.npy', 'wb') as f:
        np.save(f, all_the_embeddings)

    del all_the_embeddings

    ##################
    # MAP 2 Sentence #
    ##################

    # NOTE:
    # NOTE: This is probably dataset specific? but also, do we really need this? seems to be only the index dumped on disk?
    # NOTE:

    map_sentrepl2emb = {
        (row["words"], row["words"], row["target_text"], idx): idx for idx, row in data.iterrows()
    }

    with open(f"{FOLDER_NAME}/map_sentrepl2embbert_{LAYER}.pkl", "wb") as file_to_save:
        pickle.dump(map_sentrepl2emb, file_to_save)



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 3405.85it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 8665.92it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 10837.99it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 13025.79it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5111.89it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<

### Prototype Embeddings

In [11]:
for LAYER in tqdm(range(-1, 13), desc="Layer"):
    #emb_files = sorted(glob(f"{FOLDER_NAME}/{LAYER}/temp/*"), key=lambda x: int(os.path.basename(x).split("_")[-1]))
    #assert len(emb_files) == len(list_df) # sanity check

    ##############################
    # Build Prototype Embeddings #
    ##############################

    embeds = np.load(f"{FOLDER_NAME}/contextualized_embeddings_bert_{LAYER}_layer.npy")

    mega_embeddings = {}
    pbar = tqdm(total=len(data["words"].unique()), position=0)
    for word in data["words"].unique():
        emb = get_average_word_embeddings(word, data, embeds)
        mega_embeddings[word] = emb
        pbar.update(1)
    pbar.close()
    
    
    m = gensim.models.keyedvectors.Word2VecKeyedVectors(vector_size=768)
    m.add_vectors(list(mega_embeddings.keys()), list(mega_embeddings.values()))
    m.save_word2vec_format(f"{FOLDER_NAME}/gensim_prototype_embeddings_bert_{LAYER}.bin")

    with open(f"{FOLDER_NAME}/prototype_embeddings_bert_{LAYER}.pkl", "wb") as filino:
        pickle.dump(mega_embeddings, filino)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 611.50it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2410.52it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2524.41it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1297.94it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1740.01it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<

In [5]:
import pickle
with open(f"bert_embeddings/prototype_embeddings_bert_12.pkl", "rb") as file_to_save:
    dici = pickle.load(file_to_save)

In [2]:
import sys
!{sys.executable} -m pip install gensim

Collecting gensim
  Downloading gensim-4.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (24.0 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.0/24.0 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m0:01[0m:01[0m
[?25hCollecting smart-open>=1.8.1
  Using cached smart_open-6.2.0-py3-none-any.whl (58 kB)
Collecting scipy>=0.18.1
  Using cached scipy-1.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (33.7 MB)
Installing collected packages: smart-open, scipy, gensim
Successfully installed gensim-4.2.0 scipy-1.9.3 smart-open-6.2.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3[0m[39;49m -> [0m[32;49m22.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [12]:
mega_embeddings

{'cat': array([ 4.70884860e-01,  5.09123877e-02,  1.45669132e-01,  4.24012989e-02,
        -8.90813768e-02,  2.72655785e-02, -2.06779540e-02,  3.81456316e-02,
        -1.74375713e-01, -1.18198097e-01, -2.79189646e-01,  4.78609443e-01,
        -5.07324934e-01,  4.16824147e-02,  3.07660758e-01, -8.68651196e-02,
        -4.68256883e-02,  5.76282069e-02, -7.61688948e-02, -2.09019288e-01,
        -3.38658243e-01,  5.90826850e-03, -2.11599201e-01, -2.96639800e-01,
         2.16006190e-01,  3.40278000e-02,  4.62816298e-01,  4.36603665e-01,
        -1.30769059e-01,  4.73028898e-01, -2.02750832e-01, -8.73628482e-02,
         4.30546016e-01,  3.52426171e-02,  1.65458694e-01,  3.81627887e-01,
         1.50419086e-01, -1.49662524e-01,  1.45483226e-01,  2.35010058e-01,
         5.28512478e-01, -3.86184990e-01, -5.65899163e-02,  2.64162511e-01,
        -3.00934911e-01, -2.52833843e-01,  4.08722013e-01, -2.25866679e-02,
        -3.37151796e-01, -3.99515070e-02, -1.83522195e-01, -3.41712952e-01,
     