In [283]:
import pickle
from transformers import BertTokenizer, BertModel
import re
import pandas as pd
from string import punctuation
from math import log
import torch

In [284]:
tokenizer_name = 'bert-base-uncased'
embedding_model_name = 'bert-base-uncased'

# Set embedding model
embedding_model = BertModel.from_pretrained(embedding_model_name,
                                  output_hidden_states = True, # Whether the model returns all hidden-states.
                                  )
embedding_model.eval()

# Tokenize as self.self.self.self.sentence
tokenizer = BertTokenizer.from_pretrained(tokenizer_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [285]:
# Read in raw data
with open('data/docs_demo_processed_10.pickle', 'rb') as f:
    docs = pickle.load(f)

In [286]:
def merge_punctuation(sentence):
    sentence = sentence.replace('!', '.')
    sentence = sentence.replace(':', ',')
    sentence = sentence.replace('--', ',')
    
    reg = "(?<=[a-zA-Z])-(?=[a-zA-Z]{2,})"
    r = re.compile(reg, re.DOTALL)
    sentence = r.sub(' ', sentence)
    
    sentence = re.sub(r'\s-\s', ' , ', sentence)
    
    sentence = sentence.replace('-', ',')
    # sentence = sentence.replace('[\[\]]', ' ')
    sentence = sentence.replace(';', '.')
    sentence = sentence.replace(' ,', ',')
    sentence = sentence.replace('♫', '')
    sentence = sentence.replace('...', '')
    sentence = sentence.replace('.\"', ',')
    sentence = sentence.replace('"', ',')

    sentence = re.sub(r'--\s?--', '', sentence)
    sentence = re.sub(r'\s+', ' ', sentence)
    
    sentence = re.sub(r',\s?,', ',', sentence)
    sentence = re.sub(r',\s?\.', '.', sentence)
    sentence = re.sub(r'\?\s?\.', '?', sentence)
    sentence = re.sub(r'\s+', ' ', sentence)
    
    sentence = re.sub(r'\s+\?', '?', sentence)
    sentence = re.sub(r'\s+,', ',', sentence)
    sentence = re.sub(r'\.[\s+\.]+', '. ', sentence)
    sentence = re.sub(r'\s+\.', '.', sentence)
    
    return sentence.strip().lower()

In [287]:
# Process as sentence

def preprocess_text(sentence):
    # sentence = sentence.lower()  # Lowercase text
    sentence = re.sub(f"[{re.escape(punctuation)}]", "", sentence)  # Remove punctuation
    sentence = " ".join(sentence.split())  # Remove extra spaces, tabs, and new lines
    return(sentence)

In [288]:
def tokenize_text(sentence):
    sentence = tokenizer.tokenize(sentence)
    return(sentence)

In [289]:
# sentence = ['the', 'truth', 'of', 'the', 'matter', 'is', 'that', 'the', 'titanic', 'even', 'though', 'its', 'breaking', 'all', 'sorts', 'of', 'box', 'office', 'records', 'its', 'not', 'the', 'most', 'exciting', 'story', 'from', 'the', 'sea']

# Add sentence to vocab (with frequency)

def add_sentence_to_vocab(sentence, vocab):
    for word in sentence:
        if word in vocab.keys():
            vocab[word] = vocab[word] + 1
        else:
            vocab[word] = 1

def add_doc_vocab_to_corpus(doc_vocab, corpus_vocab):
    for token in doc_vocab.keys():
        if token in corpus_vocab.keys():
            corpus_vocab[token] = corpus_vocab[token] + doc_vocab[token]
        else:
            corpus_vocab[token] = doc_vocab[token]

In [290]:
def add_token_embeddings_to_doc(tokens, token_embeddings, doc_token_embeddings):
    for idx, token in enumerate(tokens):
        if token in doc_token_embeddings.keys():
            doc_token_embeddings[token] = torch.cat((doc_token_embeddings[token], token_embeddings[idx].unsqueeze(1)), dim=1)

        else:
            doc_token_embeddings[token] = token_embeddings[idx].unsqueeze(1) # Get 2-dim vector (768 x 1) prepare for potential concat

In [291]:
def add_token_embeddings_to_corpus(doc_token_embeddings, corpus_token_embeddings):
    for token in doc_token_embeddings.keys():
        if token in corpus_token_embeddings.keys():
            corpus_token_embeddings[token] = torch.cat((doc_token_embeddings[token].unsqueeze(1), corpus_token_embeddings[token]), dim=1)
        else:
            corpus_token_embeddings[token] = doc_token_embeddings[token].unsqueeze(1) # Get 2-dim vector (768 x 1) prepare for potential concat

In [292]:
corpus_token_embeddings = {}

In [293]:
def get_mean_token_embeddings(token_embeddings):
    for token in token_embeddings.keys():
        token_embeddings[token] = token_embeddings[token].mean(dim=1)
    return token_embeddings

def create_doc_token_embeddings(doc):

    doc_token_embeddings = {}

    for sentence in doc:
        tokens = tokenize_text(sentence)
        token_embeddings = get_token_embeddings(tokens)
        # print(token_embeddings)
        add_token_embeddings_to_doc(tokens, token_embeddings, doc_token_embeddings)
    
    return get_mean_token_embeddings(doc_token_embeddings)

In [294]:
# List of tokenized words for one document

# Loop through sentences, do preprocessing & tokenization then add tokens to doc list and doc vocab
def preprocess_doc(doc):
    tokens_from_doc = list()
    doc_vocab = {}

    for sentence in doc:
        # sentence = preprocess_text(sentence)
        sentence = merge_punctuation(sentence)
        sentence = tokenize_text(sentence)
        add_sentence_to_vocab(sentence, doc_vocab)
        [tokens_from_doc.append(token) for token in sentence]
    
    return tokens_from_doc, doc_vocab

In [295]:
def add_new_tokens_to_corpus_adj_list(doc_vocab, corpus_adj_list, printout = False):
    new_tokens = set(doc_vocab.keys()).difference(set(corpus_adj_list.keys()))

    print('Doc has %d new tokens' % len(new_tokens))
    for token in new_tokens:
        corpus_adj_list[token] = {}

In [296]:
def get_doc_windows(tokens_from_doc, window_size, padding):
    doc_windows = []
    for i in range(0, len(tokens_from_doc), window_size - padding):
        doc_windows.append(tokens_from_doc[i:i + window_size])
    return(doc_windows)


In [297]:
def token_count_in_windows(token, windows):
    return sum([token in window for window in windows])

def token_pair_count_in_windows(token_1, token_2, windows):
    return sum([(token_1 in window) & (token_2 in window) for window in windows])

In [298]:
def probability_of_token_in_windows(token, windows):
    return round(sum([token in window for window in windows])/len(windows), 4)

def probability_of_token_pair_in_windows(token_1, token_2, windows):
    return round(sum([(token_1 in window) & (token_2 in window) for window in windows])/len(windows), 4)

In [299]:
def mpni(token_1, token_2, windows, printout = False):
    prob_both_tokens = probability_of_token_pair_in_windows(token_1, token_2, windows)

    if prob_both_tokens == 0:
        npmi = -1
        return(npmi)

    prob_token_1 = probability_of_token_in_windows(token_1, windows)
    prob_token_2 = probability_of_token_in_windows(token_2, windows)

    npmi = (-1/log(prob_both_tokens, 10))*log((prob_both_tokens/(prob_token_1*prob_token_2)), 10)

    if printout == True:
        print('#i = %d' % token_count_in_windows(token_1, windows))
        print('#j = %d' % token_count_in_windows(token_2, windows))
        print('#i,j = %d' % token_pair_count_in_windows(token_1, token_2, windows))
        print('p(i,j) = %.4f' % prob_both_tokens)
        print('log(p(i,j) = %.4f' % log(prob_both_tokens, 10))
        print('p(i) = %.4f' % prob_token_1)
        print('p(j) = %.4f' % prob_token_2)

    return(npmi)

In [300]:
def get_token_embeddings(tokens):
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokens)
    segments_ids = [1] * len(tokens)
    tokens_tensor = torch.tensor([indexed_tokens]) # Get Token IDs from BERT vocab
    segments_tensors = torch.tensor([segments_ids]) # Sentence ID, separation, e.g. 000/111

    # Get raw output, for all BERT layers
    ''' Object hidden_states has 4 dim in following order:
    The layer number (13 layers), 1st is input embeddings
    The batch number (1 sentence)
    The word / token number (22 tokens in our sentence)
    The hidden unit / feature number (768 features)
    '''
    with torch.no_grad():
        outputs = embedding_model(tokens_tensor, segments_tensors)
        hidden_states = outputs[2]
    
    # Reshape hidden_states
    token_embeddings = torch.stack(hidden_states, dim=0) # Convert to 4-dim tensor from tuple
    token_embeddings = torch.squeeze(token_embeddings, dim = 1) # Remove batch layer (only 1 sentence)
    token_embeddings = token_embeddings.permute(1, 0, 2) # Swap token & layer number
    token_embeddings = token_embeddings[:,11,:] # Get second-to-last layer as final embedding
    # according to Han Xiao https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/

    return(token_embeddings)

In [301]:
test_1 = get_token_embeddings(tokenize_text('i love you very much'))
test_2 = get_token_embeddings(tokenize_text('i love you'))

In [302]:
from scipy.spatial.distance import cosine

cosine(test_1[1], test_2[1])

0.17923825979232788

In [308]:
def get_token_adjancency_from_window(adj, window):

    for token in window:
        if token not in adj.keys(): adj[token] = {}
        window_temp = filter(lambda x: x != token, window)
        for token_temp in window_temp:
            if token_temp not in adj[token].keys():
                adj[token][token_temp] =1
            else: adj[token][token_temp] +=1

## Execution

In [305]:
corpus_vocab = {}
corpus_token_embeddings = {}
corpus_adj_list = {}
corpus_windows_count = 0
corpus_windows = []
corpus_token_appeared_in_window = {}
corpus_token_pair_appeared_in_window = {}

window_size = 10
padding = 2

- Preprocess docs
- Get all corpus_windows (for calculating MNPI)
- Get ordered corpus_vocab

In [306]:
for doc in docs:

    # Create doc vocab and add to corpus
    tokens_from_doc, doc_vocab = preprocess_doc(doc)
    print(doc_vocab)
    add_doc_vocab_to_corpus(doc_vocab, corpus_vocab=corpus_vocab)

    # Create doc token embeddings and add to corpus
    doc_token_embeddings = create_doc_token_embeddings(doc)
    add_token_embeddings_to_corpus(doc_token_embeddings, corpus_token_embeddings)

    # Create adjacency matrix
    add_new_tokens_to_corpus_adj_list(doc_vocab, corpus_adj_list=corpus_adj_list, printout=True)

    doc_windows = get_doc_windows(tokens_from_doc, window_size=20, padding=5)

    for window in doc_windows:
        token_adj_from_window(window)
    
    for window in doc_windows: corpus_windows.append(window) 
    
    # print('Corpus windows count = %d' % len(corpus_windows))
    # print('Doc windows count = %d' % len(doc_windows))

corpus_vocab = dict(sorted(corpus_vocab.items()))
corpus_token_embeddings = get_mean_token_embeddings(corpus_token_embeddings)


{'david': 1, 'gallo': 2, ',': 190, 'this': 54, 'is': 59, 'bill': 5, 'lange': 3, '.': 203, 'i': 10, "'": 131, 'm': 1, 'dave': 1, 'and': 89, 'we': 64, 're': 34, 'going': 13, 'to': 52, 'tell': 1, 'you': 37, 'some': 6, 'stories': 4, 'from': 21, 'the': 171, 'sea': 16, 'here': 35, 'in': 45, 'video': 3, 've': 9, 'got': 9, 'of': 94, 'most': 8, 'incredible': 4, 'titanic': 3, 'that': 77, 's': 71, 'ever': 3, 'been': 5, 'seen': 1, 'not': 9, 'show': 5, 'any': 4, 'it': 63, 'truth': 1, 'matter': 1, 'even': 4, 'though': 1, 'breaking': 1, 'all': 23, 'sorts': 2, 'box': 1, 'office': 1, 'records': 1, 'exciting': 1, 'story': 4, 'problem': 2, 'think': 7, 'take': 4, 'ocean': 10, 'for': 12, 'granted': 1, 'when': 3, 'about': 14, 'oceans': 5, 'are': 26, '75': 1, 'percent': 1, 'planet': 7, 'water': 11, 'average': 3, 'depth': 5, 'two': 8, 'miles': 7, 'part': 1, 'stand': 1, 'at': 27, 'beach': 2, 'or': 7, 'see': 22, 'images': 1, 'like': 19, 'look': 3, 'out': 25, 'great': 1, 'big': 3, 'blue': 1, 'expanse': 1, 'shimm

In [362]:
corpus_token_embeddings[','].size()

torch.Size([768])

In [314]:
corpus_adj_list = {}

for window in corpus_windows:
    get_token_adjancency_from_window(corpus_adj_list, window)

In [328]:
len(corpus_adj_list.keys())

2984

In [327]:
len(corpus_vocab.keys())

2984

In [145]:
corpus_vocab_index = {token: idx for idx, token in enumerate(corpus_vocab.keys())}

In [329]:
def token2idx(token):
    return(corpus_vocab_index[token])
    
def idx2token(idx):
    return([key for key in corpus_vocab_index.keys() if corpus_vocab_index[key] == idx].pop())

In [331]:
corpus_token_embeddings[',']

tensor([[ 0.2291,  0.2721, -0.0126,  ...,  0.0952,  0.2752,  0.2363],
        [-0.0341, -0.0138,  0.0637,  ..., -0.0755,  0.0757,  0.0567],
        [ 0.2248,  0.2510,  0.5798,  ...,  0.3062,  0.2787,  0.2863],
        ...,
        [-0.5499, -0.6126, -0.6591,  ..., -0.4822, -0.7565, -0.5348],
        [ 0.3075,  0.3771,  0.2886,  ...,  0.3298,  0.1368,  0.2476],
        [ 0.3932,  0.5512,  0.6295,  ...,  0.5005,  0.5362,  0.5525]])

In [332]:
from torch_geometric.data import Data

In [353]:
data.num_node_features

3

'.'