In [1]:
from sentence_transformers import SentenceTransformer
import os
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModel, TrainingArguments, Trainer, DataCollatorForLanguageModeling, AutoTokenizer, AdamW, get_scheduler
from tqdm.auto import tqdm
import numpy as np
import torch
import torch.nn as nn
from datasets import Dataset, DatasetDict
import torch.nn.functional as F
import gc
import h5py

In [2]:
num_paras = None
# filepath = str(os.path.dirname(os.path.realpath(__file__)))

# if "hert5583" in filepath:
#     data_dir = "/data/inet-demtech/hert5583/embed_resonance_data/data"
#     finetuned_model_dir = "/data/inet-demtech/hert5583/embed_resonance_models/who_leads_model_final/checkpoint-500000"
# else:
data_dir = "/Users/HannahBailey/Documents/GitHub/embedding_resonance/data"
finetuned_model_dir = "/Users/HannahBailey/Documents/GitHub/embedding_resonance/model/who_leads_model_final"

In [3]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def compute_embeddings(pickle_path, col_name, model_directory, save_name, num_paras=None, batch_size=1000):
    print('Started')

    df = pd.read_pickle(pickle_path)
    # convert all columns to string
    # df = df.astype(str)
    print('Loaded dataframe')

    if num_paras is not None:
        df = df.head(num_paras)

    print("Computing embeddings for", len(df), "paragraphs")
    
    paragraphs = df[col_name].tolist()
    paragraphs = [' ' if para is None else para for para in paragraphs]

    # load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_directory)
    model = AutoModel.from_pretrained(model_directory)

    sentence_embedding_list = []
    last_index = 0
    batch_number = 0
    for i in range(0, len(paragraphs), batch_size):
        end_idx = min(i+batch_size, len(paragraphs))
        print(f"Processing batch {i//batch_size + 1}/{(len(paragraphs)-1)//batch_size + 1}")
        batch = paragraphs[i:end_idx]
        encoded_input = tokenizer(batch, padding=True, truncation=True, return_tensors='pt')

        with torch.no_grad():
            model_output = model(**encoded_input)

        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        sentence_embedding_list.append(sentence_embeddings)

        # Write embeddings to HDF5
        if i == 0:
            print("Creating dataset...")
            sentence_embeddings = torch.cat(sentence_embedding_list, dim=0)
            # Create dataset for embeddings]
            # Define the HDF5 file
            with h5py.File(os.path.join(model_directory, f'{save_name}_embeddings.h5'), 'w') as h5f:
                h5f.create_dataset('embeddings', data=sentence_embeddings.cpu().numpy(),
                                maxshape=(len(df), sentence_embeddings.shape[1]), chunks=True)
                for col in df:
                    print(col)
                    print(type(df[col][0]))
                    h5f.create_dataset(col, data=df[col].iloc[last_index:end_idx].str.encode('utf-8').values.astype('S'),
                        maxshape=(len(df),), chunks=True)
            sentence_embedding_list = []
            last_index = end_idx

        elif batch_number % 50 == 0 or end_idx == len(paragraphs):
            sentence_embeddings = torch.cat(sentence_embedding_list, dim=0)
            print("Saving dataset...")
            # Append to dataset for embeddings
            with h5py.File(os.path.join(model_directory, f'{save_name}_embeddings.h5'), 'a') as h5f:
                h5f['embeddings'].resize((h5f['embeddings'].shape[0] + sentence_embeddings.shape[0]), axis=0)
                h5f['embeddings'][-sentence_embeddings.shape[0]:] = sentence_embeddings.cpu().numpy()

                for col in df:
                    h5f[col].resize((h5f[col].shape[0] + sentence_embeddings.shape[0]), axis=0)
                    h5f[col][-sentence_embeddings.shape[0]:] = df[col].iloc[last_index:end_idx].str.encode('utf-8').values.astype('S')
            sentence_embedding_list = []
            last_index = end_idx

        batch_number += 1

        print(f"Completed step for batch {i//batch_size + 1}")

        # Clear memory
        del encoded_input, model_output, sentence_embeddings
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print("All embeddings computed and saved to HDF5")

    # Save the final DataFrame
    final_save_path = os.path.join(model_directory, f'{save_name}_tuned_embeddings_final.pkl')
    df.to_pickle(final_save_path)
    print(df.head())
    print("All embeddings computed and saved to", final_save_path)


In [4]:
who_leads_folder = os.path.join(data_dir, 'who_leads_who_follows')  
pickle_path = os.path.join(who_leads_folder, 'cleaned_who_leads_df.pkl')
compute_embeddings(pickle_path=pickle_path, col_name='post_text', model_directory=finetuned_model_dir, save_name='who_leads_model', num_paras=num_paras, batch_size=10)

Started
Loaded dataframe
Computing embeddings for 35727478 paragraphs


Some weights of the model checkpoint at /Users/HannahBailey/Documents/GitHub/embedding_resonance/model/who_leads_model_final were not used when initializing XLMRobertaModel: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaModel were not initialized from the model checkpoint at /Users/HannahBailey/Documents/GitHub/embedding_resonance/model/who_leads_model_final and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.de

Processing batch 1/3572748
Creating dataset...
user_type
<class 'str'>
date
<class 'str'>
post_text
<class 'str'>
Completed step for batch 1
Processing batch 2/3572748
Completed step for batch 2
Processing batch 3/3572748
Completed step for batch 3
Processing batch 4/3572748
Completed step for batch 4
Processing batch 5/3572748
Completed step for batch 5
Processing batch 6/3572748
Saving dataset...
Completed step for batch 6
Processing batch 7/3572748
Completed step for batch 7
Processing batch 8/3572748
Completed step for batch 8
Processing batch 9/3572748
Completed step for batch 9
Processing batch 10/3572748
Completed step for batch 10
Processing batch 11/3572748
Saving dataset...
Completed step for batch 11
Processing batch 12/3572748
Completed step for batch 12
Processing batch 13/3572748
Completed step for batch 13
Processing batch 14/3572748
Completed step for batch 14
Processing batch 15/3572748
Completed step for batch 15
Processing batch 16/3572748
Saving dataset...
Completed

KeyboardInterrupt: 

In [10]:
# optional - check that the embeddings were saved correctly
# Path to your HDF5 file
hdf5_file_path = '/Users/HannahBailey/Documents/GitHub/embedding_resonance/model/who_leads_model_final/who_leads_model_embeddings.h5'

# Open the HDF5 file in read mode
with h5py.File(hdf5_file_path, 'r') as file:
    # List all groups
    print("Keys: %s" % file.keys())
    # Get the embeddings dataset
    embeddings = file['embeddings']
    print("Shape of embeddings:", embeddings.shape)
    
    # Optionally, read a subset of embeddings into memory
    # Here, we read the first 5 embeddings
    first_five_embeddings = embeddings[:5]
    print("First five embeddings:\n", first_five_embeddings)

Keys: <KeysViewHDF5 ['date', 'embeddings', 'post_text', 'user_type']>
Shape of embeddings: (660, 768)
First five embeddings:
 [[-0.0156679   0.00471961 -0.00389284 ...  0.05953515 -0.02004566
  -0.03649166]
 [-0.00964241 -0.02049181 -0.00454568 ...  0.03874613 -0.06312865
  -0.02873504]
 [-0.02146283  0.03928605 -0.00450213 ...  0.028684   -0.02098445
  -0.05096287]
 [-0.03896829  0.00304933 -0.00493355 ...  0.02121274 -0.03294142
  -0.02406158]
 [-0.04928146  0.07072747 -0.00353959 ...  0.04813822 -0.00219928
  -0.02595499]]
