In [3]:
!pip install -r requirements.txt

In [1]:
from transformers import BertTokenizer, BertForMaskedLM, pipeline
import torch
import string
from collections import defaultdict
import nltk
from nltk.corpus import stopwords
import numpy as np
import datetime
import time 
import spacy
import pandas as pd
import re
nlp = spacy.load("en_core_web_sm")
stop_words = set(stopwords.words('english')) 


### configure the BERT network 
"""
Two models available, one is uncased, the other one is cased, 
Change according to the importance of CASE in the sentence

"""
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased').eval()

# bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# bert_model = BertForMaskedLM.from_pretrained('bert-base-cased').eval()

"""
Input network using Gelphic
"""
# network = pd.read_csv("MLDA Synonym.csv")
Gelphi_output_file_path = "../source_data/Gelphi_output_1.csv" 


"""
corpus input path to read the original text
"""
corpus_input_csv_path = '../source_data/corpus_input.csv'

"""
output_name file name 
"""
bert_with_network_score_file_name="../output/bert_final_output"

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Processing to fit the BERT MODEL

In [2]:
### Get the predicted words from  BERT model
def decode(tokenizer, pred_idx, top_clean=5):
    ignore_tokens = string.punctuation + '[PAD]'
    tokens = []
    for w in pred_idx:
        token = ''.join(tokenizer.decode(w).split())
        if token not in ignore_tokens:
            tokens.append(token.replace('##', ''))
    return tokens[:top_clean]  ## each line one prediction

### Encode the words by BERT tokenizers
def encode(tokenizer, text_sentence, add_special_tokens=True):
    text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
    if tokenizer.mask_token == text_sentence.split()[-1]:
        text_sentence += ' .'
    input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
    return input_ids, mask_idx

###
def get_predictions(text_sentence,top_clean=5):
    input_ids,mask_idx = encode(bert_tokenizer,text_sentence)
    with torch.no_grad():
        predict= bert_model(input_ids)[0]
    bert = decode(bert_tokenizer,predict[0,mask_idx,:].topk(top_clean).indices.tolist(),top_clean=top_clean)
    return bert

## Data Input cleaning

In [3]:
"""
Text pre-processing functions
To split those cut words with come with punctuation e.g. "like.", ". enable"
"""

def punctuation_corr(input_sent):
    ## correct punctuation position
    input_split = input_sent.split()
    for i in range(len(input_split)):
        if not input_split[i]: ## for \t\n char 
            continue
        ## word starts with a punctuation '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' but not @ => @abc , may be a tweet mention
        if input_split[i][0] in string.punctuation and input_split[i][0]!='@':
            input_split[i]= input_split[i][0] + " " + input_split[i][1:]
        
        if input_split[i][-1] in string.punctuation:
            input_split[i]= input_split[i][:-1] + " " + input_split[i][-1]
            if input_split[i][-3] in string.punctuation: ## for ." 
                orig_punc = input_split[i][-3]
                input_split[i]= input_split[i][:-3] +" "+input_split[i][-3:]
                
        ### account for all CAP words, convert to lower case
        elif input_split[i].upper() == input_split[i] and len(input_split[i])> 1 :  
            input_split[i] = input_split[i].lower()
        
        ## for punct in between the words without space e.g."buy,i"

        try:
            punc_pos = len(re.search('\w+',input_split[i])[0])
            if punc_pos<len(input_split[i])-1:
                if input_split[i][punc_pos] in string.punctuation and input_split[i][punc_pos] !='-' and input_split[i][punc_pos] !="'" :
                    input_split[i] = input_split[i][:punc_pos] + " " +input_split[i][punc_pos]+" "+input_split[i][punc_pos+1:]
        except:
            pass

    input_sent = ' '.join(input_split)
    return input_sent



"""
Based on the cleaned text, mask out interested words for BERT prediction and get the BERT prediction
Two ways to split the sentence, by Spacy or NLTk (Spacy is more advanced and time-consuming)

input_sent (string): the corpus for prediction
top_k (int): number of prediction get from BERT model for each masked word 
useSpacy (boolean): whether the use of Spacy to split words and pos tagging
"""

def find_masked_words(input_sent,top_k=5, useSpacy=True):
    keyword = defaultdict(dict)
    if not useSpacy:
        input_sent = punctuation_corr(input_sent)
        input_split = input_sent.split()
            ## second filter with pos tagging for nltk
        ### POS TAGGGING 
        not_consider_pos = ['PRON','DET','ADP','CONJ','NUM','PRT','.',':','CC','CD','EX','DT','PDT',
                        'IN','LS','MD','NNP','NNPS','PRP','POS','PRP$','TO','UH','WDT','WP$','WP','WRB']
            #  refer to   https://www.learntek.org/blog/categorizing-pos-tagging-nltk-python/ 
        be_do_verb = ['is','are','was','were','did','does','do','not','had','have','has','ever']
        conjunction_words = ['therefore','thus']
        pos_res = nltk.pos_tag(input_sent.split())
        for item_num in range(len(pos_res)):
            if pos_res[item_num][1] not in not_consider_pos:
                ## another level of filtering of words before masking 
                if '@' in pos_res[item_num][0]:
                    continue
                if pos_res[item_num][0].lower() in be_do_verb:
                    continue
                if pos_res[item_num][0].lower() in conjunction_words:
                    continue
                if pos_res[item_num][0][-2:] in ["'s","'t","'r","'d","'l","'v","'m"]:
                    continue
                if 'http' in pos_res[item_num][0]:
                    continue
                if len(pos_res[item_num][0])<3:
                    continue
                if pos_res[item_num][0] in stop_words:
                    continue
                if pos_res[item_num][0][-1] in string.punctuation:
                    pos_res[item_num] = pos_res[item_num][:-1]
                orig = input_split[item_num]
                input_split[item_num]='<mask>'
                input_text_for_pred = ' '.join(input_split)
                input_split[item_num]=orig
                keyword[pos_res[item_num][0]+ "_"+str(item_num)]['prediction']=get_predictions(input_text_for_pred, top_clean=top_k)
    else:
        doc = nlp(input_sent)
        input_split = doc.text.split()
        ### reg_exp to detect punctuation and number in the word splits
        reg_exp= "["+string.punctuation+"0-9]"
        for i in range(len(input_split)):
            
            if len(doc[i].text)<3: ## skip words with length < 3
                continue
            if re.search(reg_exp,doc[i].text): ## skip punctuation and number
                continue
            ### remove words that are definitely not emo-denoting for easier computation
            if not doc[i].is_stop and doc[i].pos_ not in ['SPACE','PUNCT','ADX','CONJ','CCONJ',
                                                        'DET','INTJ','NUM','PRON','PROPN','SCONJ','SYM']:
                orig = input_split[i]
                input_split[i]= "<mask>"
                input_text_for_pred = ' '.join(input_split) ### join the split words together with <mask> for BERT prediction
                input_split[i]= orig
                keyword[doc[i].text+ "_"+str(i)]['prediction']=get_predictions(input_text_for_pred, top_clean=top_k)
            
        
    return keyword

## Link network

In [6]:
network = pd.read_csv(Gelphi_output_file_path)
network.set_index(network['Label'],inplace=True)

In [7]:
"""
Auxilary functions to find out the prediction from BERT model, change top_k_choic

"""

### match score pertaining to the masked words with the network metrics/score
## match_col : "Authority" , "modularity_class","Weighted Degree","betweenesscentrality"
def self_score(match_col="Authority",pred_out_pf=None):
    return pred_out_pf['cleaned_index'].map(network[match_col].to_dict())

### match score pertaining to the masked predictions with the network metrics/score
## match_col : "Authority" , "modularity_class","Weighted Degree","betweenesscentrality"
def pred_score(match_col="Authority",pred_out=None):
    pred_score_output = []
    for item in pred_out['prediction']:
        item = item.lower()
        try:
            pred_score_output.append(network[match_col].to_dict()[item])
        except:
            pred_score_output.append(-1)
    return pred_score_output


### aggregate function
def key_word_predict_with_network_from_sent(input_sent=None,top_k=None, filter_NA_pred=True):
    keyword_pred_from_bert_output = find_masked_words(input_sent, top_k=top_k)
    res_out = pd.DataFrame(keyword_pred_from_bert_output).transpose()
    res_out['cleaned_index']= [ item.split('_')[0].lower() for item in res_out.index]
    res_out['Label'] = self_score(match_col="Label",pred_out_pf=res_out)## check whether in the network
    res_out['self_auth'] = self_score(match_col="Authority",pred_out_pf=res_out)
    res_out['self_class'] = self_score(match_col="modularity_class",pred_out_pf=res_out)
    res_out['self_deg'] = self_score(match_col="Weighted Degree",pred_out_pf=res_out)
    res_out['self_betcent'] = self_score(match_col="betweenesscentrality",pred_out_pf=res_out)
    res_out['pred_betcent'] = res_out.apply(lambda row: pred_score(match_col="betweenesscentrality",pred_out=row),axis=1)
    res_out['pred_auth'] = res_out.apply(lambda row: pred_score(match_col="Authority",pred_out=row),axis=1)
    res_out['pred_deg'] = res_out.apply(lambda row: pred_score(match_col="Weighted Degree",pred_out=row),axis=1)
    res_out['pred_class'] = res_out.apply(lambda row: pred_score(match_col="modularity_class",pred_out=row),axis=1)
    return res_out

In [9]:
"""
Main function to generate final output 

"""
corpus_input_csv_path = '../source_data/corpus_input.csv'
start = datetime.datetime.now()
df = pd.read_csv(corpus_input_csv_path) ### load the corpus text 
df=df.reset_index()
res_new = pd.DataFrame()
# for i in range(0,len(df)):
for i in range(0,10):    
    input_text = df['Text'].loc[i]
    if i%100==0:
        print(f'{i} / {len(df)} done')
    res = key_word_predict_with_network_from_sent(input_text,top_k=5)
    res['from_textid'] = df['index'].loc[i]
    res_new = res_new.append(res)
end =datetime.datetime.now()
print(f'time taken {end-start}')
res_new2=res_new.reset_index()
res_new2.to_csv(bert_with_network_score_file_name+'.csv',index=False)
print(f'{round(27/116,2)}s average time taken to predict each masked word' )
print(f'BERT Final Prediction with score is saved in {bert_with_network_score_file_name}.csv')

0 / 824 done
time taken 0:00:26.766007


In [10]:
res_new2

Unnamed: 0,index,prediction,cleaned_index,Label,self_auth,self_class,self_deg,self_betcent,pred_betcent,pred_auth,pred_deg,pred_class,from_textid
0,good_1,"[the, new, little, old, more]",good,good,0.013005,18.0,398.0,0.000000,"[-1, 2012.54884, 4296.3050299999995, 0.0, 4938...","[-1, 0.001431, 0.001866, 0.000349, 0.000708]","[-1, 59, 162, 66, 61]","[-1, 10, 6, 10, 9]",0
1,Stooge_2,"[comedy, movie, old, story, funny]",stooge,,,,,,"[-1, 3307.571201, 0.0, 499.947545, 42958.099068]","[-1, 1e-05, 0.000349, 0.000131, 0.001755]","[-1, 6, 66, 68, 127]","[-1, 24, 10, 0, 10]",0
2,lovely_7,"[good, evil, mean, young, bad]",lovely,lovely,0.101399,17.0,226.0,0.000000,"[0.0, 776.6235889999999, 97091.279297, -1, 433...","[0.013005, 0.00123, 0.102042, -1, 0.004394]","[398, 130, 423, -1, 309]","[18, 11, 17, -1, 11]",0
3,evil_9,"[happy, funny, young, beautiful, different]",evil,evil,0.001230,11.0,130.0,776.623589,"[64807.099323, 42958.099068, -1, 0.0, 10103.08...","[0.000502, 0.001755, -1, 0.101401, 7.8e-05]","[61, 127, -1, 216, 6]","[18, 10, -1, 17, 5]",0
4,time_13,"[thing, way, person, as, woman]",time,time,0.000318,9.0,10.0,0.000000,"[37456.128994, 0.0, 71005.676086, 0.0, -1]","[0.000816, 0.016628, 0.000334, 0.000121, -1]","[172, 266, 61, 28, -1]","[3, 25, 3, 9, -1]",0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
111,way_21,"[dialogue, story, action, plot, movie]",way,way,0.016628,25.0,266.0,0.000000,"[0.0, 499.947545, 23449.4279, 0.0, 3307.571201]","[8e-06, 0.000131, 0.000116, 6.500000000000001e...","[7, 68, 39, 40, 6]","[0, 0, 6, 3, 24]",9
112,movie_23,"[so, just, even, exactly, being]",movie,movie,0.000010,24.0,6.0,3307.571201,"[2087.67817, 0.0, 0.0, 62.2058, 0.0]","[0.015486000000000003, 0.003158, 0.00074699999...","[222, 168, 106, 83, 125]","[25, 18, 5, 5, 3]",9
113,randomly_27,"[also, just, definitely, even, all]",randomly,,,,,,"[0.0, 0.0, 0.0, 0.0, 0.0]","[0.000391, 0.003158, 0.000491, 0.0007469999999...","[33, 168, 43, 106, 90]","[9, 18, 21, 5, 5]",9
114,silly_28,"[had, sent, carried, left, got]",silly,silly,0.001007,10.0,47.0,0.000000,"[-1, -1, -1, -1, 149.284191]","[-1, -1, -1, -1, 0.001297]","[-1, -1, -1, -1, 20]","[-1, -1, -1, -1, 9]",9
