# BERT Functions

I define two functions to run Bert models. First part processes a single text document into a format that is recognizable by BERT. The second part uses the tokenized text to generate embedding values using pre-trained BERT models. 

In [1]:
from transformers import BertModel, BertTokenizer, AutoTokenizer
import numpy as np
import streamlit as st
import re
import pandas as pd
from datetime import datetime
import nltk
import torch

In [2]:
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

In [3]:
#input is "light.csv" which does not include stop words. 
df = pd.read_csv('../../../data/processed/paragraph.csv')
# Filter
timestamps = df.year.to_list()
texts = df.text.to_list()
text = texts[1]

In [10]:
df.head(1)

Unnamed: 0.1,Unnamed: 0,ccode_iso,session,year,paragraph_index,text
0,1,AFG,7,1952,1,I consider it a great honour and privilege to ...


In [4]:
print(type(text))

<class 'str'>


# Define functions

In [5]:
def bert_preprocess(text):
    """
    Preprocesses a document into a BERT-recognizable format 
    Input: text in a string format
    output: three objects ready to be used for Bert modeling 
        marked_text (list)
        indexed_tokens(list)
        attention_mask(list)

    """
    # Tokenize the text
    tokenized_text = tokenizer.tokenize(text)
    truncate_length = len(tokenized_text) - 512 + 2  # +2 to account for [CLS] and [SEP]
    
    # Truncate the beginning and end of the text
    truncated_text = tokenized_text[truncate_length//2 : -truncate_length//2]
    
    # Add padding
    
    # Add special tokens [CLS] and [SEP], convert tokens to ids, and create attention mask
    marked_text = ["[CLS] "] + truncated_text + [" [SEP]"]
    indexed_tokens = tokenizer.convert_tokens_to_ids(marked_text)
    attention_mask = [1] * len(indexed_tokens)

    # Pad sequences to max_seq_length
    if len(indexed_tokens) < 512:
        indexed_tokens.append(0)
        attention_mask.append(0)
    
    return marked_text, indexed_tokens, attention_mask

In [6]:
marked_text, indexed_tokens, attention_mask = bert_preprocess(text)

In [43]:
help(get_bert_embeddings)

Help on function get_bert_embeddings in module __main__:

get_bert_embeddings(marked_text, indexed_tokens, attention_mask)
    input: processed text
    output: dataframe of embedding weights for each token 
        ex) dimension of 512*768 where row represents token, column represents bert features



In [7]:
def get_bert_embeddings(marked_text, indexed_tokens, attention_mask):
    """
    Generates embedding values for tokenized text 
    input: processed text, indexed_tokens and attention mask (all in list format)
    output: dataframe of embedding weights for each token 
        ex) dimension of 512*768 where row represents token, column represents bert features
    
    """
    # Convert lists to PyTorch tensors
    tokens_tensors = torch.tensor([indexed_tokens])
    attention_masks = torch.tensor([attention_mask])
    
    with torch.no_grad():
        #Run the embedding
        outputs = model(input_ids=tokens_tensors.view(-1, tokens_tensors.size(-1)), 
                        attention_mask=attention_masks.view(-1, attention_masks.size(-1)))

        # Extract the hidden states 
        hidden_states = outputs[2][0].squeeze().numpy()
        
        # Convert to data frame
        pd_words = pd.Series(marked_text, name='term')
        df_outputs = pd.DataFrame(hidden_states)
        df_outputs['term'] = pd_words
        
        # Move 'term' column to the first position
        df_outputs = df_outputs[['term'] + [col for col in df_outputs.columns if col != 'term']]
        
        # Remove duplicate tokens by averaging them out
        df_outputs_embedding = df_outputs.groupby(['term']).mean()
    return df_outputs_embedding

In [8]:
get_bert_embeddings(marked_text, indexed_tokens, attention_mask)

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
term,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
[SEP],-0.599731,-0.287527,0.995737,-0.0676,-0.116662,-0.319243,-0.035646,-0.550722,-0.154269,0.226906,...,0.433558,0.360281,0.753348,-1.1558,0.198939,0.126193,0.058285,0.035218,-0.225301,-0.376395
##s,-0.082928,0.06461,0.062934,1.201868,0.41649,-0.351008,-0.419693,0.793464,-0.682201,-0.435875,...,-1.640653,-0.082774,1.440754,0.477181,0.555801,0.517778,0.029644,0.16733,-0.804072,1.10095
",",0.238827,-0.49953,-0.229385,-0.420359,0.382101,-0.133325,-0.423249,-0.133761,0.079275,-0.810453,...,-0.476354,0.31068,-0.071447,-0.350534,-0.166876,-0.15276,0.157087,0.18291,-0.305537,0.119838
-,0.211686,-0.337158,-0.282966,-0.379349,0.21355,-0.254544,-0.361127,-0.094978,0.072562,-0.825429,...,-0.457777,0.271165,0.238502,-0.122299,-0.090638,-0.070707,0.017792,-0.04912,-0.269647,0.226714
.,0.117108,-0.388444,-0.088623,0.064858,0.52323,-0.428734,-0.267266,-0.420127,0.190312,-0.766927,...,-0.490892,0.354983,-0.355035,-0.418002,0.253672,0.08662,0.108094,-0.150912,-0.198612,0.161442
In,1.179606,0.055646,0.182922,1.080442,0.191964,-0.443555,-0.347383,0.757875,-0.356111,-1.096234,...,0.154227,0.468894,0.522242,-0.500889,0.947098,1.150616,-0.767531,-0.148597,-0.223548,0.134304
Nations,0.006905,-0.580956,0.575177,-1.220215,0.371888,-0.279317,0.700901,-0.903072,0.631587,-0.73947,...,0.495864,0.132003,0.06781,0.293461,0.424441,0.552663,0.318386,-0.620372,-0.583365,0.04838
United,-0.077556,-0.105724,1.057649,-0.423865,0.314821,-0.185792,-0.714371,-0.607652,-0.070977,0.103325,...,-0.70418,-1.008595,-0.426895,-0.018947,0.567454,0.48305,-0.180623,-0.287286,-0.55601,0.511452
[CLS],-0.118897,-0.518255,0.159338,-0.461482,-0.003488,-0.453042,-0.212884,-0.229699,-0.063944,-0.421272,...,0.242429,0.009379,0.467546,-0.957577,0.114305,-0.36999,0.035248,0.089144,-0.146707,-0.127492
accomplish,-0.085958,0.289747,0.502433,-0.699373,-0.112547,0.115439,1.011047,-0.57065,-0.124616,-0.372738,...,-1.153663,-0.272904,0.580479,0.201596,-0.194586,-0.872371,-0.358961,0.039148,0.10191,-0.16974
