Install dependencies

In [None]:
pip install -q -U tensorflow-text

In [None]:
pip install -q tf-models-official==2.4.0

In [None]:
pip install transformers

In [4]:
import os
import io
import re
import sys

import numpy as np
import pandas as pd
from time import time
import matplotlib.pyplot as plt

import pickle
from csv import reader

import tensorflow as tf

from transformers import BertTokenizer, TFBertModel

In [5]:
# path = 'drive/MyDrive/MIDS/chemical_patent_cer_ee'

## Generate BERT inputs

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [6]:
# full_path = f'{path}/data/sre_em/sre_em_sample.csv'
full_path = '../data/sre_em/sre_em_sample.csv'

In [62]:
max_length = 97

In [8]:
with io.open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
        text = f.readlines()

In [9]:
# lists for BERT input
bertTokenIDs = []
bertMasks = []
bertSeqIDs = []

# lists for entity masks
entity1StartMasks = []
entity2StartMasks = []
entity1PooledMasks = []
entity2PooledMasks = []

# list for labels
origLabels = []
codedLabels = []

In [71]:
# list for designating markers
em_markers = ['[E1]', '[/E1]', '[E2]', '[/E2]']
em_start = ['[E1]', '[E2]']
em_end = ['[/E1]', '[/E2]']


# dictionary for converting labels to code
code = {'ARG1': 0, 'ARGM': 1}

# lists for processing
snippetLengthList = []
discardedEntries = []


for line in text:
    
    parsed_line = line.strip().split('\t')

    snippet_id = parsed_line[0]
    label = parsed_line[1]
    snippet = parsed_line[2].split()

    # generate inputs for BERT
    # convert snippets to tokenIDs, cap snippet length using max_length
    # snippets which are shorter than max_length are padded
    # snippets which are longer are truncated
    # truncated snippets with only one entity are discarded
    # all snippets end with a [SEP] token, padded or not
    
    # tokenize snippet, except for entity markers
    snippetTokens = ['[CLS]']
    
    for word in snippet:
        if word not in em_markers:
            tokens = tokenizer.tokenize(word)
            snippetTokens.extend(tokens)
        else:
            snippetTokens.append(word)
    
    # check that both entities will make it within max_length
    # by finding the index for [/E2] and comparing it to (max_length - 1)
    check = snippetTokens.index('[/E2]')
    
    # discard if only one entity will make it
    if check >= (max_length - 1):
        discardedEntries.append(snippet_id)
        continue
        
    # figure out sentence length for padding or truncating
    snippetLengthList.append(snippetLength)
    
    # create space for at least a final [SEP] token
    if snippetLength >= max_length - 1:
        snippetTokens = snippetTokens[:(max_length - 2)]
        
    # add [SEP] token and padding
    snippetTokens += ['[SEP]'] + ['[PAD]'] * ((max_length -1) - len(snippetTokens))
    
    # generate BERT input lists
    bertTokenIDs.append(tokenizer.convert_tokens_to_ids(snippetTokens))
    bertMasks.append([1] * (sentenceLength + 1) + [0] * (max_length -1 - sentenceLength ))
    bertSeqIDs.append([0] * (max_length))
    
    # generate label lists
    origLabels.append(label)
    codedLabels.append(code[label])

In [90]:
def generate_entity_start_mask(snippetTokens, max_length):
    """Helper function that generates a mask 
    that picks out the start marker for each entity 
    given a list of snippet tokens"""
    
    e1_mask = np.zeros(shape=(max_length,), dtype=bool)
    e1_mask[np.argwhere(np.array(snippetTokens) == '[E1]')] = True

    e2_mask = np.zeros(shape=(max_length,), dtype=bool)
    e2_mask[np.argwhere(np.array(snippetTokens) == '[E2]')] = True

    return e1_mask, e2_mask

In [86]:
def generate_entity_mention_mask(snippetTokens, max_length):
    """Helper function that generates a mask
    that picks out the tokens for each entity
    between (but not including) the entity markers"""
    
    em_markers = ['[E1]', '[/E1]', '[E2]', '[/E2]']
    
    e1_mask = np.zeros(shape=(max_length,), dtype=bool)
    e2_mask = np.zeros(shape=(max_length,), dtype=bool)
    in_e1 = False
    in_e2 = False
    
    for (i, t) in enumerate(snippetTokens):
        if t in em_markers:
            if t in ["[E1]", "[/E1]"]:
                in_e1 = not in_e1
            elif t in ["[E2]", "[/E2]"]:
                in_e2 = not in_e2
        else:
            if in_e1 is True:
                e1_mask[i] = True
            elif in_e2 is True:
                e2_mask[i] = True
                
    return e1_mask, e2_mask

In [None]:
elif marker_type == 'ner':
        e1_mask = np.isin(np.array(snippetTokens), ner_start)
        e2_mask = np.isin(np.array(snippetTokens), ner_start)


ner_start = ['Α', 'Β', 'Π', 'Σ', 'Ο', 'Τ', 'Θ', 'Ψ', 'Υ', 'Χ', 'Λ', 'Δ']
ner_markers = ['[/E]', 'Α', 'Β', 'Π', 'Σ', 'Ο', 'Τ', 'Θ', 'Ψ', 'Υ', 'Χ', 'Λ', 'Δ']