In [78]:
import os
from psycho_embeddings.fresh_embedder import ContextualizedEmbedder
import numpy as np
import pandas as pd
from glob import glob
import pickle
from tqdm import tqdm

In [None]:
model = ContextualizedEmbedder("bert-base-cased", 300, device="cpu")

In [79]:
def find_index_for_word(word, data):
    """
    Given a word and a dataframe, finds the idxs of that word in the dataframe
    """
    return data[data["words"] == word].index.tolist()

def get_average_word_embeddings(word, data, embeds):
    """
    Given a word, a data, and the embeddings, it averages the embeddings of that word
    """
    idxs = find_index_for_word(word, data)
    if len(idxs) > 1:
        return np.average(itemgetter(*idxs)(embeds), axis=0)
    else:
        return np.array(embeds[idxs[0]]) # idxs is a list of lists so we access the first element

In [80]:
data = pd.DataFrame({"words" : ["cat", "dog", "cat"], "target_text" : ["the cat is on the table", "the dog is on the table", "the cat is on the table"]})

In [81]:
SIZE_CHUNKS = 2  #chunk row size
FOLDER_NAME = "bert_embeddings"
GPU = 0

In [8]:
layers_of_interest = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
list_df = [data[i:i+SIZE_CHUNKS] for i in range(0,data.shape[0],SIZE_CHUNKS)]

# Create Embeddings for The Enitre Dataset

In [15]:
pbar = tqdm(total=len(list_df), position=0)

for index, sub_portion_od_data in enumerate(list_df):

    #############################
    # DUMPING EMBEDDING ON DISK #
    #############################

    df_slice_embedded = embeddings = model.embed(sub_portion_od_data["target_text"].tolist(), sub_portion_od_data["words"].tolist(), layers_of_interest, 8, averaging=True)

    for layer in layers_of_interest:
        os.makedirs(f"{FOLDER_NAME}/{layer}/temp/", exist_ok=True)

        with open(f"{FOLDER_NAME}/{layer}/temp/bert_embeddings_{index}", "wb") as filino:
            pickle.dump((df_slice_embedded[layer]), filino)

    if index%10==0:
        gc.collect()
    pbar.update(1)

100%|██████████| 2/2 [00:26<00:00, 13.16s/it]


  0%|          | 0/2 [00:00<?, ?ex/s]

100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
 50%|█████     | 1/2 [00:01<00:01,  1.20s/it]

  0%|          | 0/1 [00:00<?, ?ex/s]

100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
100%|██████████| 2/2 [00:01<00:00,  1.04it/s]

# Reconstruct and Save Contextualzied Embeddings

In [23]:
for LAYER in tqdm(range(13), desc="Layer"):
    # We load all the embeddings from disk, in order and reconstruct the actual embedding for a specific layer for the entire dataframe.
    
    emb_files = sorted(glob(f"{FOLDER_NAME}/{LAYER}/temp/*"), key=lambda x: int(os.path.basename(x).split("_")[-1]))
    assert len(emb_files) == len(list_df) # sanity check
    
    all_the_embeddings = []
    pbar = tqdm(total=len(list_df), position=0)
    
    for ff in emb_files:
        with open(ff, "rb") as filino:
            ldata = pickle.load(filino)
            pbar.update(1)
            for value in ldata:
                if len(value) == 1:
                    all_the_embeddings.append(np.array(value[0]))
                else:
                    all_the_embeddings.append(np.array(value))
    pbar.close()
    
    all_the_embeddings = np.array(all_the_embeddings)
    

    with open(f'{FOLDER_NAME}/contextualized_embeddings_bert_{LAYER}_layer.npy', 'wb') as f:
        np.save(f, all_the_embeddings)

    del all_the_embeddings

    embeds = np.load(f"{FOLDER_NAME}/contextualized_embeddings_bert_{LAYER}_layer.npy")

    ##################
    # MAP 2 Sentence #
    ##################
        
    # NOTE:    
    # NOTE: This is probably dataset specific? but also, do we really need this? seems to be only the index dumped on disk?
    # NOTE:
    
    map_sentrepl2emb = {
        (row["words"], row["target_text"]): idx for idx, row in data.iterrows()
    }

    with open(f"{FOLDER_NAME}/map_sentrepl2embbert_{LAYER}.pkl", "wb") as file_to_save:
        pickle.dump(map_sentrepl2emb, file_to_save)

    

100%|██████████| 2/2 [00:00<00:00, 2658.83it/s]
100%|██████████| 2/2 [00:00<00:00, 3622.02it/s]
100%|██████████| 2/2 [00:00<00:00, 4851.71it/s]
100%|██████████| 2/2 [00:00<00:00, 5200.62it/s]
100%|██████████| 2/2 [00:00<00:00, 5429.52it/s]
100%|██████████| 2/2 [00:00<00:00, 10908.46it/s]
100%|██████████| 2/2 [00:00<00:00, 1390.68it/s]
100%|██████████| 2/2 [00:00<00:00, 2286.97it/s]
100%|██████████| 2/2 [00:00<00:00, 2004.93it/s]
100%|██████████| 2/2 [00:00<00:00, 3960.63it/s]
100%|██████████| 2/2 [00:00<00:00, 3309.12it/s]
100%|██████████| 2/2 [00:00<00:00, 4422.04it/s]
100%|██████████| 2/2 [00:00<00:00, 4044.65it/s]
Layer: 100%|██████████| 13/13 [00:00<00:00, 209.49it/s]


In [60]:
for LAYER in tqdm(range(13), desc="Layer"):
    emb_files = sorted(glob(f"{FOLDER_NAME}/{LAYER}/temp/*"), key=lambda x: int(os.path.basename(x).split("_")[-1]))
    assert len(emb_files) == len(list_df) # sanity check

    ##############################
    # Build Prototype Embeddings #
    ##############################
    
    embeds = np.load(f"{FOLDER_NAME}/contextualized_embeddings_bert_{LAYER}_layer.npy")

    mega_embeddings = {}
    pbar = tqdm(total=len(data["words"].unique()), position=0)
    for word in data["words"].unique():
        emb = get_average_word_embeddings(word, data, embeds)
        mega_embeddings[word] = emb 
        pbar.update(1)
    pbar.close()

    with open(f"{FOLDER_NAME}/prototype_embeddings_bert_{LAYER}.pkl", "wb") as filino:
        pickle.dump(mega_embeddings, filino)


100%|██████████| 2/2 [00:00<00:00, 1273.51it/s]
100%|██████████| 2/2 [00:00<00:00, 1597.22it/s]
100%|██████████| 2/2 [00:00<00:00, 2016.01it/s]
100%|██████████| 2/2 [00:00<00:00, 1544.58it/s]
100%|██████████| 2/2 [00:00<00:00, 2060.07it/s]
100%|██████████| 2/2 [00:00<00:00, 665.82it/s]
100%|██████████| 2/2 [00:00<00:00, 986.66it/s]
100%|██████████| 2/2 [00:00<00:00, 1280.51it/s]
100%|██████████| 2/2 [00:00<00:00, 1053.05it/s]
100%|██████████| 2/2 [00:00<00:00, 816.33it/s]
100%|██████████| 2/2 [00:00<00:00, 566.64it/s]
100%|██████████| 2/2 [00:00<00:00, 925.79it/s]
100%|██████████| 2/2 [00:00<00:00, 972.48it/s]
Layer: 100%|██████████| 13/13 [00:00<00:00, 135.84it/s]
