# AllenNLP SRL BERT

## 1. Environment Set-up

In [1]:
!pip install allennlp==2.1.0 allennlp-models==2.1.0 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting allennlp==2.1.0
  Downloading allennlp-2.1.0-py3-none-any.whl (585 kB)
[K     |████████████████████████████████| 585 kB 4.9 MB/s 
[?25hCollecting allennlp-models==2.1.0
  Downloading allennlp_models-2.1.0-py3-none-any.whl (407 kB)
[K     |████████████████████████████████| 407 kB 54.8 MB/s 
Collecting torch<1.8.0,>=1.6.0
  Downloading torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl (776.8 MB)
[K     |████████████████████████████████| 776.8 MB 16 kB/s 
Collecting filelock<3.1,>=3.0
  Downloading filelock-3.0.12-py3-none-any.whl (7.6 kB)
Collecting torchvision<0.9.0,>=0.8.1
  Downloading torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl (12.8 MB)
[K     |████████████████████████████████| 12.8 MB 48.2 MB/s 
[?25hCollecting jsonnet>=0.10.0
  Downloading jsonnet-0.19.1.tar.gz (593 kB)
[K     |████████████████████████████████| 593 kB 69.6 MB/s 
[?25hCollecting transformers<4.4,>

In [2]:
!pip install spacy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
!python -m spacy download en_core_web_sm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting en-core-web-sm==3.0.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl (13.7 MB)
[K     |████████████████████████████████| 13.7 MB 1.5 MB/s 
Installing collected packages: en-core-web-sm
  Attempting uninstall: en-core-web-sm
    Found existing installation: en-core-web-sm 3.4.1
    Uninstalling en-core-web-sm-3.4.1:
      Successfully uninstalled en-core-web-sm-3.4.1
Successfully installed en-core-web-sm-3.0.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


## 2. Download pretrained predictor model

In [4]:
from allennlp.predictors.predictor import Predictor
import allennlp_models.tagging

# coreference resolution
coref = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2021.03.10.tar.gz")
# srl bert predictor
srlbert = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
downloading: 100%|##########| 1345986155/1345986155 [00:26<00:00, 50983125.37B/s]


Downloading:   0%|          | 0.00/414 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665M [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-large-cased and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
downloading: 100%|##########| 405972254/405972254 [00:09<00:00, 44417618.82B/s]


Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

## 3. Pre-process pipleline

In [5]:
import re
import spacy
nlp = spacy.load('en_core_web_sm')

def check_generic(doc):
  if doc.ents:
      for ent in doc.ents:
          return(ent.text+' - ' +str(ent.start_char) +' - '+ str(ent.end_char) +' - '+ent.label_+ ' - '+str(spacy.explain(ent.label_)))
  else:
      return('No named entities found')

def clean_text(sentence):
  """
  Input sentence: Raw sentence
  Output sentence: cleaned sentence with - 
                  (i) no extra whitespaces, no new lines, no tabs
                  (ii) lemmatized sentence
                  (iii) generic sentences with no NER are returned as empty string

  """

  # removing whitespace, /n, tabs
  sentence = sentence.replace('\\n', ' ').replace('\n', ' ').replace('\t',' ').replace('\\', ' ')
  pattern = re.compile(r'\s+') 
  Without_whitespace = re.sub(pattern, ' ', sentence)
  # There are some instances where there is no space after '?' & ')', 
  # So I am replacing these with one space so that It will not consider two words as one token.
  sentence = Without_whitespace.replace('?', ' ? ').replace(')', ') ')

  doc=nlp(sentence)
  
  # lemmatization
  lemmatized_sentence=""
  for token in doc:
    if token.lemma_ !="-PRON-":
      lemmatized_sentence=lemmatized_sentence+token.lemma_+" "
    else:
      lemmatized_sentence=lemmatized_sentence+token+" "
    
  sentence=lemmatized_sentence[:-1]

  # check for generic sentences
  ner=check_generic(doc)
  if ner=="No named entities found":
    sentence=""

  return sentence

## 4. Post-process pipeline

In [6]:
def post_process(srloutput):
    args = ['B-ARG0','B-ARG1','B-ARG2','B-ARG3','B-ARG4','B-ARG5','B-ARG6']

    def extract_triplets_and_tmp(tags, words):
        counter = 0
        s = ''
        v = ''
        o = ''
        tmp = ''

        for arg in args:
            if arg in tags:
                counter += 1

                # Assigns the first 'ARG' to subject
                if counter == 1:
                    s = ' '.join([words[i] for i,x in enumerate(tags) if arg[2:] in x])
                
                # Assigns the second 'ARG' to object
                elif counter == 2:
                    o = ' '.join([words[i] for i,x in enumerate(tags) if arg[2:] in x])
                    break

        # Extract verb
        search_v = '-V'
        search_neg = 'NEG'
        v = ' '.join([words[i] for i,x in enumerate(tags) if (search_v in x) or (search_neg in x)])

        # Extract temporal argument
        search_tmp = 'TMP'
        tmp = ' '.join([words[i] for i,x in enumerate(tags) if search_tmp in x])
        
        if tmp == '':
            return [(s, v, o)]
        else:
            return [(s, v, o), (o, v, tmp)]

    def filter_output(output):
        
        # Counter the number of ARG's in the tags
        def count_args(tags):
            count = 0
            for arg in args:
                if arg in tags: count += 1
            return count

        # Filter out outputs that don't have at least 2 ARG's
        filtered_verbs = list(filter(lambda x: count_args(x['tags'])>1, output['verbs']))

        # Reconstruct the output with filtered verbs
        return {'verbs': filtered_verbs, 'words': output['words']}

    filtered_output = filter_output(srloutput)

    final_output = []
    for _ in filtered_output['verbs']:
        final_output.extend(extract_triplets_and_tmp(_['tags'], filtered_output['words']))  
         
    return final_output

## 5. Model

On Google Colab, load cleand_json.zip first

In [None]:
!unzip cleaned_json.zip

In [None]:
import json
import os 
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
import gc
 
files = sorted(os.listdir('cleaned_json'))

output = [] # Final output of triplets
tracking = {} # keep tracks of the number of triplets after each step

for file in tqdm(files):
    t = {}

    # open file
    with open(f'cleaned_json/{file}') as f:
        d = json.load(f)
    t['orignal'] = len(sent_tokenize(d['text']))
    
    # pre-process (clean text)
    cleaned_text = clean_text(d['text'])
    t['cleaned'] = len(sent_tokenize(cleaned_text))

    tmp = len(output)
    output.append((d['auhtors'][0],'is author of',d['title']))

    if d['text']:
        # coref resolution
        coref_resolved = coref.coref_resolved(document=cleaned_text)
        del cleaned_text

        # break into sentences for prediction
        sentences = sent_tokenize(coref_resolved)
        del coref_resolved

        # predict
        for s in sentences:
            try:
                output.extend(post_process(srlbert.predict(s)))
            except:
                pass
    t['triplets'] = len(output) - tmp
    
    tracking[file] = t
    gc.collect()

## 6. Heuristic filtering (based on KnowText paper)

In [11]:
import collection
import string
import spacy
nlp = spacy.load("en_core_web_sm")

def filter_triple(triple):
    DAY_OF_THE_WEEK = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']

    subject = triple[0]
    predicate = triple[1]
    object_ = triple[2]

    doc_sub = nlp(subject)
    doc_obj = nlp(object_)
    doc_pred = nlp(predicate)

    subject_pos = [token.pos_ for token in doc_sub] # all parts of speech
    object_pos = [token.pos_ for token in doc_obj] # all parts of speech
    predicate_pos = [token.pos_ for token in doc_pred] # all parts of speech

    all_words_day_week = [True if word.lower() in DAY_OF_THE_WEEK else False for word in ' '.join(triple).split()]
    contains_day_of_week = any(all_words_day_week) # performs OR operation of booleans in list

    # subject should cotain Noun or Pronoun
    if 'NOUN' not in subject_pos and 'PROPN' not in subject_pos:
        return False
    # triplet should not contain days of the week
    elif contains_day_of_week:
        return False
    # verb cannot be in subject or object
    elif ('VERB' in subject_pos) or ('VERB' in object_pos):
        return False
    # pronouns cannot be in subject or object
    elif any(["PRON" in subject_pos,"PRON" in predicate_pos, "PRON" in object_pos]):
        return False
    # subeject and object cannot be equal
    elif subject == object_:
        return False
    return True
  
def max_three(triple):
     # max of three tokens in subject and object
    subject = triple[0]
    predicate = triple[1]
    object_ = triple[2]

    doc_sub = nlp(subject)
    doc_obj = nlp(object_)
    doc_pred = nlp(predicate)

    subject_pos = [token.pos_ for token in doc_sub] #all parts of speech
    object_pos = [token.pos_ for token in doc_obj] #all parts of speech
    predicate_pos = [token.pos_ for token in doc_pred] #all parts of speech

    if len(list(filter(lambda x: x != 'PUNCT', subject_pos))) > 3:
        return False
    elif len(list(filter(lambda x: x != 'PUNCT', object_pos))) > 3:
        return False
    return True
    
def min_char_count(triple):
    subject = triple[0]
    predicate = triple[1]
    object_ = triple[2]
    if (len(subject)) < 2 or (len(predicate) < 2) or (len(object_) < 2):
        return False
    return True

def duplicate(triple):
    # duplicants should not exist in subject and object
    subject = triple[0]
    predicate = triple[1]
    object_ = triple[2]
        
    def count_duplicate(string_input):
        split_list = string_input.split()
        word_counts = collections.Counter(split_list)
        for word, count in word_counts.items():
            if count > 1:
                return True
        return False

    if any([count_duplicate(subject), count_duplicate(predicate), count_duplicate(object_)]):
        return False
    return True

def special_characters(triple):
    subject = triple[0]
    predicate = triple[1]
    object_ = triple[2]
  
    def find_sc(string_input):
        for s in string_input:    
            if s.isalpha():
                pass
            elif s.isdigit():
                pass
            elif s in string.punctuation:
                return True

    if any([find_sc(subject), find_sc(predicate), find_sc(object_)]):
        return False
    return True

In [None]:
triplets = output
filtered = list(filter(lambda x: filter_triple(x), triplets))
print(f'after filter_triple: {len(filtered)}')
filtered_maxthree = list(filter(lambda x: max_three(x), filtered))
print(f'after max_three: {len(filtered_maxthree)}')
filtered_minchar = list(filter(lambda x: min_char_count(x), filtered_maxthree))
print(f'after min_char_count: {len(filtered_minchar)}')
filtered_duplicate = list(filter(lambda x: duplicate(x), filtered_minchar))
print(f'after duplicate: {len(filtered_duplicate)}')
filtered_specchar = list(filter(lambda x: special_characters(x), filtered_duplicate))
print(f'after special_character: {len(filtered_specchar)}')
output = filtered_specchar

## 7. Remove similar triplets

In [None]:
import math
import re
from collections import Counter

WORD = re.compile(r"\w+")

def get_cosine(vec1, vec2):
  intersection = set(vec1.keys()) & set(vec2.keys())
  numerator = sum([vec1[x] * vec2[x] for x in intersection])

  sum1 = sum([vec1[x] ** 2 for x in list(vec1.keys())])
  sum2 = sum([vec2[x] ** 2 for x in list(vec2.keys())])
  denominator = math.sqrt(sum1) * math.sqrt(sum2)

  if not denominator:
      return 0.0
  else:
      return float(numerator) / denominator

def text_to_vector(text):
    words = WORD.findall(text)
    return Counter(words)

def subset_phrase(triples, simScore):
  n = len(triples)
  new_triple = triples[:]
  for i in range(n):
    firstTri = triples[i]
    for j in range(i + 1, n):
      secondTri = triples[j]
      text1 = firstTri[0] + " " + firstTri[1] + " " + firstTri[2]
      text2 = secondTri[0] + " " + secondTri[1] + " " + secondTri[2]
      vector1 = text_to_vector(text1)
      vector2 = text_to_vector(text2)
      # Doing the above eliminates worrying about scenarios of exactly the same subject, object or predicate
      if get_cosine(vector1, vector2) >= simScore:
        # temp = firstTri if len(firstTri[0]) > len(secondTri[0]) else secondTri #can make this based on the subject, not text
        temp = firstTri if len(text1) > len(text2) else secondTri
        if temp == secondTri:
          if firstTri in new_triple:
            new_triple.remove(firstTri)
        elif secondTri in new_triple:
            new_triple.remove(secondTri)
  return new_triple


In [None]:
subset = subset_phrase(output,0.5)
print(f'after subset_phrase: {len(subset)}')
output = subset

## 8. Output results into json file

In [None]:
import json
dr = {"results" : output}
json_object = json.dumps(dr)
with open("allennlp.json", "w") as f:
  f.write(json_object)

In [None]:
dr = tracking
json_object = json.dumps(dr)
with open("tracking.json", "w") as f:
  f.write(json_object)

Please note that the output from this notebook may not match the output files on Github because the order of filering was changed (which resulted in different triplets) when finalizing this notebook. 