# BERT Functions

I define two functions to run Bert models. First part processes a single text document into a format that is recognizable by BERT. The second part uses the tokenized text to generate embedding values using pre-trained BERT models. 

In [1]:
from transformers import BertModel, BertTokenizer, AutoTokenizer
import numpy as np
import streamlit as st
import re
import pandas as pd
from datetime import datetime
import nltk
import torch

In [2]:
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

In [3]:
#input is "light.csv" which does not include stop words. 
df = pd.read_csv('../../../data/processed/light.csv')
# Filter
timestamps = df.year.to_list()
texts = df.text.to_list()
text = texts[1]

In [6]:
print(type(text))

<class 'str'>


# Define functions

In [34]:
def bert_preprocess(text):
    """
    Preprocesses a document into a BERT-recognizable format 
    Input: text in a string format
    output: three objects ready to be used for Bert modeling 
        marked_text (list)
        indexed_tokens(list)
        attention_mask(list)
    
    """
    # Tokenize the text
    tokenized_text = tokenizer.tokenize(text)
    truncate_length = len(tokenized_text) - 512 + 2  # +2 to account for [CLS] and [SEP]
    
    # Truncate the beginning and end of the text
    truncated_text = tokenized_text[truncate_length//2 : -truncate_length//2]
    
    # Add padding
    
    # Add special tokens [CLS] and [SEP], convert tokens to ids, and create attention mask
    marked_text = ["[CLS] "] + truncated_text + [" [SEP]"]
    indexed_tokens = tokenizer.convert_tokens_to_ids(marked_text)
    attention_mask = [1] * len(indexed_tokens)

    # Pad sequences to max_seq_length
    if len(indexed_tokens) < 512:
        indexed_tokens.append(0)
        attention_mask.append(0)
    
    return marked_text, indexed_tokens, attention_mask

In [35]:
marked_text, indexed_tokens, attention_mask = bert_preprocess(text)

In [43]:
help(get_bert_embeddings)

Help on function get_bert_embeddings in module __main__:

get_bert_embeddings(marked_text, indexed_tokens, attention_mask)
    input: processed text
    output: dataframe of embedding weights for each token 
        ex) dimension of 512*768 where row represents token, column represents bert features



In [38]:
def get_bert_embeddings(marked_text, indexed_tokens, attention_mask):
    """
    Generates embedding values for tokenized text 
    input: processed text, indexed_tokens and attention mask (all in list format)
    output: dataframe of embedding weights for each token 
        ex) dimension of 512*768 where row represents token, column represents bert features
    
    """
    # Convert lists to PyTorch tensors
    tokens_tensors = torch.tensor([indexed_tokens])
    attention_masks = torch.tensor([attention_mask])
    
    with torch.no_grad():
        #Run the embedding
        outputs = model(input_ids=tokens_tensors.view(-1, tokens_tensors.size(-1)), 
                        attention_mask=attention_masks.view(-1, attention_masks.size(-1)))

        # Extract the hidden states 
        hidden_states = outputs[2][0].squeeze().numpy()
        
        # Convert to data frame
        pd_words = pd.Series(marked_text, name='term')
        df_outputs = pd.DataFrame(hidden_states)
        df_outputs['term'] = pd_words
        
        # Move 'term' column to the first position
        df_outputs = df_outputs[['term'] + [col for col in df_outputs.columns if col != 'term']]
        
        # Remove duplicate tokens by averaging them out
        df_outputs_embedding = df_outputs.groupby(['term']).mean()
    return df_outputs_embedding

In [39]:
get_bert_embeddings(marked_text, indexed_tokens, attention_mask)

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
term,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
[SEP],-0.409629,-0.746030,0.954548,-0.989581,-0.666579,-0.456214,-0.482902,-0.474517,-0.324448,-0.554452,...,0.937598,0.148158,0.841253,-0.650847,-0.348678,-0.864865,-0.147155,0.445907,0.226717,-1.860195
##al,-0.524131,0.479899,-0.476182,-0.390430,0.571847,0.344540,-0.890469,-0.778739,-0.369768,-0.369769,...,-0.796855,-0.407300,0.098990,0.707968,-0.098804,-1.176002,0.376283,0.099251,-0.782967,-0.062336
##atic,-1.258919,-1.749650,0.189675,-1.009726,1.250488,-0.127454,-0.174475,0.132727,0.210920,0.159041,...,0.256487,0.588991,-1.045582,-0.885318,-0.134264,-0.281723,0.578229,-0.318711,-0.002958,0.664755
##ation,-0.321686,-0.599758,0.114046,-0.144932,0.953009,-0.936072,0.038542,0.732640,0.667231,-0.173093,...,-0.485267,0.513310,0.546876,0.632246,0.262685,-0.174507,0.236400,-0.632015,-0.797872,-1.417162
##ci,-0.209563,0.008983,0.226166,-0.037539,-0.431536,-0.854313,-0.930581,-0.289056,0.014480,0.836836,...,0.035582,-0.291024,0.075818,0.033696,0.396918,-0.968211,-0.381200,0.396398,0.112123,-0.490416
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
words,0.397322,-0.091103,0.350278,0.339480,0.188769,-0.122880,0.119570,-0.366882,0.528809,0.175150,...,-0.813381,-0.366494,0.197109,-0.413089,0.482895,-0.190066,-0.568232,0.127709,-0.095941,0.127358
world,0.991384,-0.232762,0.670588,0.508774,0.526652,0.077158,-0.581935,0.377466,0.073109,0.119594,...,-0.505561,0.136024,0.009123,0.495411,0.342043,-1.043479,-0.131786,-0.333381,-0.390346,-0.141330
worth,-0.107531,0.469970,-0.674297,0.848895,0.115055,0.818750,-0.049695,-0.560952,-0.509267,-0.288416,...,0.032292,-0.336956,0.691349,0.267413,-0.045436,0.660112,1.184719,-0.789721,0.254513,-0.662614
years,0.478471,-0.254025,-0.260787,0.681818,-0.097176,0.219434,-0.531882,-0.517204,0.729554,0.117572,...,0.398866,0.446946,-0.115180,-0.216491,-0.296925,0.555254,-0.139911,-0.880707,-0.042803,0.468518
