### Raw Data Processing

In [1]:
# import libraries

import os
import io
import re
import csv
from string import punctuation
from time import time
import numpy as np
from nltk.tokenize import sent_tokenize, word_tokenize

In [2]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /Users/lea/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

#### Helper Functions

In [3]:
def process_ann(ann_file, test=False):
    """Helper function that reads a .ann file,
       strips out newline characters, splits the tab-delimited entries,
       and extracts information for identifying entities and relations 
       in corresponding .txt file.
       Also adds negative relations within the span of +/- 5 entities 
       from each trigger word.
       If generating test data, only includes positive relations within
       the span of +/- 5 entities from the trigger word.
       
       Input:
       ann_file = tab-delimited brat annotation file with the following format
                  NER: [entity_ID]\t[label start_offset end_offset]\t[entity]
                  RE:  [relation_ID]\t[relation_type argument1 argument2]
       span = number of entities from the trigger word to look for pairs
       test = whether or not test data is being generated
       
       Outputs:
       cleaned_offsets = list of tuples for labeling corresponding .txt file
                         format: (offset, label, entity ID)
       relations = list of tuples for extracting relations from corresponding .txt file
                   format: (relation ID, relation_type, entity ID #1, entity ID #2)
       positives_missed = count of positive relations missed"""
    
    # dictionary to convert NER labels to NER-specific markers 
    # for BERT input coded with seldom used tokens
    ner_markers = {'STARTING_MATERIAL': 'Α', 'REAGENT_CATALYST': 'Β', 'REACTION_PRODUCT': 'Π', 
                   'SOLVENT': 'Σ', 'OTHER_COMPOUND': 'Ο', 'EXAMPLE_LABEL': 'Χ', 
                   'TIME': 'Τ', 'TEMPERATURE': 'Θ', 'YIELD_PERCENT': 'Ψ', 'YIELD_OTHER': 'Υ',
                   'WORKUP': 'Λ', 'REACTION_STEP': 'Δ'}
    
    with io.open(ann_file, 'r', encoding='utf-8', errors='ignore') as f:
        text = [x.strip().split('\t') for x in f.readlines()]
        
    ann = [x for x in text if x[0][0] == 'T']
    rel = [x for x in text if x[0][0] == 'R']
    
    # extract information for identifying entities
    offsets = []
    
    for x in ann:
        entity_id = x[0]
        start = int(x[1].split()[1])
        end = int(x[1].split()[2])
        label = x[1].split()[0]
        
        offsets.append((start, 'S', ner_markers[label], entity_id))
        offsets.append((end, 'E', ner_markers[label], entity_id))
    
    # sort offsets and clean overlapping entries
    sorted_offsets = sorted(offsets, key=lambda x:x[0])
    
    cleaned_offsets = []
    corrections = {}
    
    hold = None
    indicator = None
    
    for tup in sorted_offsets:
        
        if indicator == 'S':
            if tup[1] == 'E':
                cleaned_offsets.append(hold)
                hold = (tup[0], 'O', 'X')
                indicator = tup[1]
            elif tup[1] == 'S':
                if tup[2] != hold[1]:
                    corrections.update({tup[3]:'X'})
                else:
                    corrections.update({tup[3]:hold[2]})
                indicator = '*'
        
        elif indicator == 'E':
            cleaned_offsets.append(hold)
            hold = (tup[0], tup[2], tup[3])
            indicator = tup[1]
        
        elif indicator == '*':
            indicator = 'S'

        else:
            hold = (tup[0], tup[2], tup[3])
            indicator = tup[1]
            
    cleaned_offsets.append(hold)
    
    # extract information for identifying relations
    relations = []
    positives = {}
    included = []
    
    # add positive relations
    for r in rel:
        relation_id = r[0]
        relation_type = r[1].split()[0]
        entity1 = r[1].split()[1][5:]
        entity2 = r[1].split()[2][5:]
        
        if entity1 in corrections.keys():
            if corrections[entity1] == 'X':
                continue
            else:
                entity1 = corrections[entity1]
        if entity2 in corrections.keys():
            if corrections[entity2] == 'X':
                continue
            else:
                entity2 = corrections[entity2]
        
        positives.update({(entity1, entity2):(relation_id, relation_type)})
        
        if not test:
            relations.append((relation_id, relation_type, entity1, entity2))
            included.append((entity1, entity2))
    
    # negative relations
    negatives = []
    triggers = {x[2] for x in cleaned_offsets 
                if (x[1] == ner_markers['WORKUP'] or x[1] == ner_markers['REACTION_STEP'])}
    entity_order = [x[2] for x in cleaned_offsets if (x[2] != 'X' and x[1] != ner_markers['EXAMPLE_LABEL'])]
    trigger_indices = {i for i in range(len(entity_order)) if entity_order[i] in triggers}
    
    # find negative relations
    for index in trigger_indices:
        
        # find indices in span of +/- 5 from trigger
        find_span = [index - (i+1) for i in range(5)] + [index + (i+1) for i in range(5)]
        real_span = [i for i in find_span if (i >= 0 and i < len(entity_order))]
        final_span = [i for i in real_span if i not in trigger_indices]
        
        # make tuples of trigger words and entities from span indices
        potential_pairs = [(entity_order[index], entity_order[i]) for i in final_span]
        
        # check if tuple is in positives
        # only add to negatives if not in positives
        for pair in potential_pairs:
            if pair in positives.keys():
                if not test:
                    continue
                elif test:
                    if pair not in included:
                        relations.append((positives[pair][0], positives[pair][1], pair[0], pair[1]))
                        included.append(pair)
            else:
                negatives.append(pair)
    
    # make a list of negative relations 
    # in same format as positives to add to relations list
    i = 0
    for pair in negatives:
        negative_id = f'N{i}'
        relations.append((negative_id, 'NONE', pair[0], pair[1]))
        i += 1
        
    # count how many positive relations were missed
    positives_missed = len(positives.keys()) - len(included) 
    
    return cleaned_offsets, relations, positives_missed

In [4]:
def ann_chunker(txt_file, offsets):
    """Helper function that reads in a .txt file as one string,
       divides it based on the cleaned offsets from its .ann file
       and labels chunks with NER tags
       
       Inputs:
       txt_file = file that contains all the patent text
                  considered as one sentence in this task
       offsets = list of tuples for labeling corresponding .txt file
                 format: (offset, label, entity ID)
       
       Output:
       ann_chunks = list of annotated chunks based on .ann file offsets
                    format: (chunk, label, entity ID)"""
    
    with io.open(txt_file, 'r', encoding='utf-8', errors='ignore') as text:
        full_text = text.read()
    
    start = 0
    end = offsets[0][0]
    label = 'O'
    entity_id = 'X'
    
    ann_chunks = [(full_text[:end], label, entity_id)]
    
    for i in range(len(offsets)):
        start = offsets[i][0]
        label = offsets[i][1]
        entity_id = offsets[i][2]
        
        if i < len(offsets) - 1:
            end = offsets[i+1][0]
            term = [(full_text[start:end], label, entity_id)]
            if term[0]:
                ann_chunks.extend(term)
        
        else:
            term = [(full_text[start:], label, entity_id)]  
            ann_chunks.extend(term)
    
    return ann_chunks

In [5]:
def relation_input(snippet_id, rel_tup, chunks):
    """Helper function that creates one input snippet for BERT SRE:
       Inserts entity markers and truncates snippet to include only
       sentences containing the entities
    
       Inputs:
       snippet_id = filename of snippet
       rel_tup = tuple from relations list generated by process_ann()
             format: (relation ID, relation_type, entity ID #1, entity ID #2)
       chunks = list of annotated chunks from ann_chunker()
    
       Output:
       rel_input = input snippet ready for BERT SRE
                   format: [snippet_id+relation_id]/t[relation_type]/t[cleaned snippet with ner markers]"""
    
    # unpack relation_tup
    relation_id = rel_tup[0]
    relation_type = rel_tup[1]
    entity_list = [rel_tup[2], rel_tup[3]]
    
    new_id = snippet_id + '-' + relation_id
    
    # build cleaned snippet with ner markers
    snippets = []
    i = 1
    
    for tup in chunks:
        chunk, label, entity = tup
        
        # clean chunk: remove punctuation,
        # word tokenize if not an entity
        # split by whitespace if entity
        processed_chunk = []
            
        if label == 'O':
            nopunct = re.sub(r'[,/()":\-\[\]\']', '', chunk.strip())
            sentences = sent_tokenize(nopunct)
            if sentences:
                for s in sentences:
                    for x in word_tokenize(s):
                        processed_chunk.append(x)
                
        else:
            nopunct = re.sub(r'[,/()":\-\[\]\']', '', chunk)
            tokens = [x for x in nopunct.split(' ') if x]
            for t in tokens:
                processed_chunk.append(t)
        
        # add ner markers before and after entities in relation
        if entity in entity_list:
            snippets.append(label)
            if i == 1:
                e1 = label
            snippets.extend(processed_chunk)
            snippets.append(f'[/E{i}]')
            i += 1
        
        else:
            snippets.extend(processed_chunk)

    # keep only sentences containing the entities
    # sentences are marked by periods
    e1_index = snippets.index(e1)
    e2_index = snippets.index('[/E2]')
    
    periods = [i for i in range(len(snippets)) if snippets[i] == '.']
    periods_before = [i for i in periods if i < e1_index]
    periods_after = [i for i in periods if i > e2_index]
    
    if periods_before:
        start = max(periods_before) + 1
    else:
        start = 0
    
    if periods_after:
        end = min(periods_after) + 1
    else:
        end = None
    
    truncated_snippet = snippets[start:end]
    
    # join snippet chunks to one clean snippet
    cleaned_snippet = ' '.join(truncated_snippet)
    
    return [new_id, relation_type, cleaned_snippet]

In [6]:
def generate_re_files(filepaths, output_path, test=False):
    """Helper function that reads .txt and corresponding .ann files from a path
       and generates csv file with snippets ready for BERT SRE (one snippet per line)
       
       Inputs:
       filepaths = filepaths (folder + filename, but no extension) for .txt and .ann files
       output_path = filepath (folder + filename, but no extension) for output file
       missed_count = count of total positive relations missed"""
    
    start = time()
    
    snippets = []
    missed_positives = []
    
    for file in filepaths:
        
        snippet_id = file[-4:]
        
        cleaned_offsets, relations, missed = process_ann(f'{file}.ann', test=test)
        missed_positives.append(missed)
        chunks = ann_chunker(f'{file}.txt', cleaned_offsets)
        
        for tup in relations:
            line = relation_input(snippet_id, tup, chunks)
            snippets.append(line)
    
    missed_count = sum(missed_positives)
    print(f'Number of positive relations missed: {missed_count}')
    
    with open(f'{output_path}.csv', 'w') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerows(snippets)
    
    end = time() - start
    print(f'Finished in {end:.3f} seconds')

In [7]:
def generate_re_files_subsampled(filepaths, output_path, fraction=0.5):
    """Helper function that reads .txt and corresponding .ann files from a path
       and generates csv file with snippets ready for BERT SRE (one snippet per line).
       Subsampled 25% of negative relations randomly.
       
       Inputs:
       filepaths = filepaths (folder + filename, but no extension) for .txt and .ann files
       output_path = filepath (folder + filename, but no extension) for output file"""
    
    start = time()
    
    snippets = []
    missed_positives = []
    
    for file in filepaths:
        
        snippet_id = file[-4:]
        
        cleaned_offsets, relations, missed = process_ann(f'{file}.ann')
        missed_positives.append(missed)
        chunks = ann_chunker(f'{file}.txt', cleaned_offsets)
        
        # add all positive relations
        positive_relations = [tup for tup in relations if tup[0][0] == 'R']
        
        for tup in positive_relations:
            line = relation_input(snippet_id, tup, chunks)
            snippets.append(line)
        
        # keep only 25% of negative relations
        negative_relations = [tup for tup in relations if tup[0][0] == 'N']
        neg_count = len(negative_relations)
        np.random.seed(424)
        negative_keep = np.random.binomial(1, fraction, neg_count)
        
        for i in range(neg_count):
            if negative_keep[i] == 1:
                line = relation_input(snippet_id, negative_relations[i], chunks)
                snippets.append(line)
                
    missed_count = sum(missed_positives)
    print(f'Number of positive relations missed: {missed_count}')
    
    with open(f'{output_path}.csv', 'w') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerows(snippets)
    
    end = time() - start
    print(f'Finished in {end:.3f} seconds')

#### Test the Helper Functions

In [8]:
test_path = '../raw_data/EE/ee_train/1069'

In [9]:
with io.open(f'{test_path}.txt', 'r', encoding='utf-8', errors='ignore') as text:
    full_text = text.read()
full_text

"Intermediate 2: 2,3,4,5,6,7-hexahydropyrrolo[3',4':3,4]pyrazolo[1 ,5-b][1,2]thiazine-1,1-dioxide p-toluenesulfonate\nStep 1: tert-butyl 3-iodo-4,6-dihydropyrrolo[3,4-c]pyrazole-5(2H)-carboxylate\nTo a solution of tert-butyl 4,6-dihydropyrrolo[3,4-c]pyrazole-5(2H)-carboxylate (9.41 g, 45 mmol) in 200 mL 1,2-dichloroethane was added N-iodosuccinimide (13.16 g, 58.5 mmol), and refluxed overnight. The solvent was removed by evaporation, and the residue was purified by silica gel column chromatography to give tert-butyl 3-iodo-4,6-dihydropyrrolo[3,4-c]pyrazole-5(2H)-carboxylate (4.66 g). Yield: 31%."

In [10]:
with io.open(f'{test_path}.ann', 'r', encoding='utf-8', errors='ignore') as text:
    ann = [x.strip().split('\t') for x in text.readlines()] #if x.strip().split('\t')[0][0] == 'R']
ann[:3]

[['T12', 'WORKUP 455 463', 'purified'],
 ['T0', 'EXAMPLE_LABEL 121 122', '1'],
 ['T13', 'REACTION_STEP 375 383', 'refluxed']]

In [11]:
test_offsets, test_relations, test_missed = process_ann(f'{test_path}.ann', test=False)

In [12]:
test_missed

0

In [13]:
test_offsets[:3]

[(13, 'Χ', 'T6'), (14, 'O', 'X'), (16, 'Ο', 'T11')]

In [14]:
test_relations

[('R0', 'ARG1', 'T12', 'T7'),
 ('R1', 'ARG1', 'T14', 'T9'),
 ('R2', 'ARG1', 'T14', 'T5'),
 ('R3', 'ARG1', 'T14', 'T8'),
 ('R5', 'ARG1', 'T15', 'T3'),
 ('R6', 'ARGM', 'T15', 'T10'),
 ('R7', 'ARGM', 'T15', 'T2'),
 ('N0', 'NONE', 'T15', 'T7'),
 ('N1', 'NONE', 'T15', 'T5'),
 ('N2', 'NONE', 'T14', 'T4'),
 ('N3', 'NONE', 'T14', 'T11'),
 ('N4', 'NONE', 'T14', 'T7'),
 ('N5', 'NONE', 'T13', 'T5'),
 ('N6', 'NONE', 'T13', 'T9'),
 ('N7', 'NONE', 'T13', 'T8'),
 ('N8', 'NONE', 'T13', 'T4'),
 ('N9', 'NONE', 'T13', 'T7'),
 ('N10', 'NONE', 'T13', 'T3'),
 ('N11', 'NONE', 'T13', 'T10'),
 ('N12', 'NONE', 'T12', 'T5'),
 ('N13', 'NONE', 'T12', 'T9'),
 ('N14', 'NONE', 'T12', 'T8'),
 ('N15', 'NONE', 'T12', 'T3'),
 ('N16', 'NONE', 'T12', 'T10'),
 ('N17', 'NONE', 'T12', 'T2')]

In [15]:
trial_sentence = ann_chunker(f'{test_path}.txt', test_offsets)

In [16]:
trial_sentence[:3]

[('Intermediate ', 'O', 'X'), ('2', 'Χ', 'T6'), (': ', 'O', 'X')]

In [17]:
trial_snippet = relation_input('0000', test_relations[0], trial_sentence)
trial_snippet

['0000-R0',
 'ARG1',
 'The solvent was removed by evaporation and the residue was Λ purified [/E1] by Ο silica gel [/E2] column chromatography to give tertbutyl 3iodo46dihydropyrrolo34cpyrazole52Hcarboxylate 4.66 g .']

In [18]:
generate_re_files([test_path], '../raw_data/test', test=True)
with io.open('../raw_data/test.csv', 'r', encoding='utf-8', errors='ignore') as sample:
    output = sample.readlines()
len(output)

Number of positive relations missed: 0
Finished in 0.110 seconds


25

In [19]:
generate_re_files_subsampled([test_path], '../raw_data/test2')
with io.open('../raw_data/test2.csv', 'r', encoding='utf-8', errors='ignore') as sample:
    output2 = sample.readlines()
len(output2)

Number of positive relations missed: 0
Finished in 0.053 seconds


16

#### Process the Raw Data

In [20]:
# generate sample set
path_sample = '../raw_data/sample_ee'
filenames_sample = list({x[:4] for x in os.listdir(path_sample) if x[0] != '.'})
filepath_sample = [f'{path_sample}/{x}' for x in filenames_sample]

output_sample = '../data/sre_ner/sre_ner_sample'
generate_re_files(filepath_sample, output_sample)

with io.open(f'{output_sample}.csv', 'r', encoding='utf-8', errors='ignore') as sample:
    output = sample.readlines()
check = [1 for x in output]
print(f'Number of sample snippets: {len(check)}')

Number of positive relations missed: 0
Finished in 5.229 seconds
Number of sample snippets: 1977


In [21]:
# generate filename list for train, dev, and test sets
path_train = '../raw_data/EE/ee_train'
filenames_train = list({x[:4] for x in os.listdir(path_train) if x[0] != '.'})
print(f'Number of train files: {len(filenames_train)}')

path_dev = '../raw_data/EE/ee_dev'
filenames_dev = list({x[:4] for x in os.listdir(path_dev) if x[0] != '.'})
print(f'Number of dev files: {len(filenames_dev)}')

path_test = '../raw_data/EE/ee_test'
filenames_test = list({x[:4] for x in os.listdir(path_test) if x[0] != '.'})
print(f'Number of test files: {len(filenames_test)}')

path_test_ann = '../raw_data/EE/ee_test_ann'
filenames_test_ann = list({x[:4] for x in os.listdir(path_test_ann) if x[0] != '.'})
print(f'Number of test .ann files: {len(filenames_test_ann)}')

Number of train files: 900
Number of dev files: 225
Number of test files: 9999
Number of test .ann files: 375


In [22]:
# check how many test .txt files match the .ann files
intersect = list(set(filenames_test) & set(filenames_test_ann))
len(intersect)

375

In [23]:
# generate train set
filepath_train = [f'{path_train}/{x}' for x in filenames_train]

output_train = '../data/sre_ner/sre_ner_train'
generate_re_files(filepath_train, output_train)

with io.open(f'{output_train}.csv', 'r', encoding='utf-8', errors='ignore') as sample:
    output = sample.readlines()
check = [1 for x in output]
print(f'Number of train snippets: {len(check)}')

Number of positive relations missed: -2
Finished in 186.369 seconds
Number of train snippets: 45805


In [24]:
# generate dev set
filepath_dev = [f'{path_dev}/{x}' for x in filenames_dev]

output_dev = '../data/sre_ner/sre_ner_dev'
generate_re_files(filepath_dev, output_dev)

with io.open(f'{output_dev}.csv', 'r', encoding='utf-8', errors='ignore') as sample:
    output = sample.readlines()
check = [1 for x in output]
print(f'Number of dev sentences: {len(check)}')

Number of positive relations missed: 0
Finished in 35.089 seconds
Number of dev sentences: 10673


In [25]:
# generate test set
filepath_test = [f'{path_test}/{x}' for x in intersect]

output_test = '../data/sre_ner/sre_ner_test'
generate_re_files(filepath_test, output_test, test=True)

with io.open(f'{output_test}.csv', 'r', encoding='utf-8', errors='ignore') as sample:
    output = sample.readlines()
check = [1 for x in output if x[:8]]
print(f'Number of test sentences: {len(check)}')

Number of positive relations missed: 25
Finished in 71.302 seconds
Number of test sentences: 18488


#### Subsampled data

In [26]:
# generate train set
filepath_train = [f'{path_train}/{x}' for x in filenames_train]

output_train = '../data/sre_ner/sre_ner_train_subsampled'
generate_re_files_subsampled(filepath_train, output_train)

with io.open(f'{output_train}.csv', 'r', encoding='utf-8', errors='ignore') as sample:
    output = sample.readlines()
check = [1 for x in output]
print(f'Number of train snippets: {len(check)}')

Number of positive relations missed: -2
Finished in 117.221 seconds
Number of train snippets: 27771


In [27]:
# generate dev set
filepath_dev = [f'{path_dev}/{x}' for x in filenames_dev]

output_dev = '../data/sre_ner/sre_ner_dev_subsampled'
generate_re_files_subsampled(filepath_dev, output_dev)

with io.open(f'{output_dev}.csv', 'r', encoding='utf-8', errors='ignore') as sample:
    output = sample.readlines()
check = [1 for x in output]
print(f'Number of dev sentences: {len(check)}')

Number of positive relations missed: 0
Finished in 22.501 seconds
Number of dev sentences: 6447
