### Convert predicted entities in ground truth format in the ATLOP format

Define the imports

In [1]:
import json
import re

In [2]:
import nltk
print('Nltk version: {}.'.format(nltk.__version__))

from nltk.tokenize import TreebankWordTokenizer as twt
from nltk.tokenize import WordPunctTokenizer as wpt
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
nltk.download('punkt_tab')

Nltk version: 3.9.1.


[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/nlp/ronke21/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

Define path to the prediction file

In [3]:
PATH_NER_PREDICTIONS = "../Predictions/NER/predicted_entities_eval_format.json"

Define output path

In [4]:
PATH_OUTPUT_NER_PREDICTIONS = "../Train/RE/data/predicted_entities_atlop_format.json"

Load the input file into a dictionary variable 

In [5]:
with open(PATH_NER_PREDICTIONS, 'r', encoding='utf-8') as file:
	ner_predictions = json.load(file)

Tokenize the articles

In [6]:
# Define the tokens of length 2 that are not captured by the tokenizer
ILLEGAL_WORDS_2 = [').', '(<', '>)', '),', '.,', '].', '],', '.:', '>.', '>,', '))', '+)', '>-', '</', '[<', '-,', '.)', '™,', ')-', '™)', '+.']
# Define the tokens of length 3 that are not captured by the tokenizer
ILLEGAL_WORDS_3 = ['.),', '.].', '>),', '.).', '>).', ')),', '>)-', '.</']

def tokenize_docs(data: dict, data_name: str):
	print(f"Tokenizing articles in set {data_name}...")

	for pmid, article in data.items():
		title = article['metadata']['title']
		abstract = article['metadata']['abstract']

		title_spans = list(wpt().span_tokenize(title))
		abstract_spans = list(wpt().span_tokenize(abstract))
		
		article['tokenized_title'] = []
		for start, end in title_spans:
			word = title[start:end]
			if word in ILLEGAL_WORDS_2:
				word1 = title[start:end-1]
				word2 = title[end-1:end]
				article['tokenized_title'].append((word1, start, end-1))
				article['tokenized_title'].append((word2, end-1, end))
			elif word in ILLEGAL_WORDS_3:
				word1 = title[start:start+1]
				word2 = title[start+1:end-1]
				word3 = title[end-1:end]
				article['tokenized_title'].append((word1, start, start+1))
				article['tokenized_title'].append((word2, start+1, end-1))
				article['tokenized_title'].append((word3, end-1, end))
			else:	
				article['tokenized_title'].append((word, start, end))
		
		article['tokenized_abstract'] = []
		for start, end in abstract_spans:
			word = abstract[start:end]
			if word in ILLEGAL_WORDS_2:
				word1 = abstract[start:end-1]
				word2 = abstract[end-1:end]
				article['tokenized_abstract'].append((word1, start, end-1))
				article['tokenized_abstract'].append((word2, end-1, end))
			elif word in ILLEGAL_WORDS_3:
				word1 = abstract[start:start+1]
				word2 = abstract[start+1:end-1]
				word3 = abstract[end-1:end]
				article['tokenized_abstract'].append((word1, start, start+1))
				article['tokenized_abstract'].append((word2, start+1, end-1))
				article['tokenized_abstract'].append((word3, end-1, end))
			else:
				article['tokenized_abstract'].append((word, start, end))

In [7]:
tokenize_docs(ner_predictions, "ner_predictions")

Tokenizing articles in set ner_predictions...


Map annotated entities to tokens 

In [8]:
def map_entities_to_tokens(data: dict, data_name: str):
	print(f"Mapping entities to tokens in set {data_name}...")
	
	for pmid, article in data.items():
		for entity in article['entities']:
			location = entity['location']
			start = entity['start_idx']
			end = entity['end_idx']
			start_token = None
			end_token = None
			if location == 'title':
				for idx, token in enumerate(article['tokenized_title']):
					if start == token[1] and start is not None:
						start_token = idx
					if end == token[2]-1 and end is not None:
						end_token = idx
			elif location == 'abstract':
				for idx, token in enumerate(article['tokenized_abstract']):
					if start == token[1] and start is not None:
						start_token = idx
					if end == token[2]-1 and end is not None:
						end_token = idx
			else:
				raise Exception(f'{pmid} - Unrecognized Location: {location}')
			if start_token is not None and end_token is not None:
				entity['start_token'] = start_token
				entity['end_token'] = end_token
			else:
				print (data[pmid]['tokenized_title'])
				print(data[pmid]['tokenized_abstract'])
				raise Exception(f'{pmid} - Not able to assign token(s) to entity: {entity}')

In [9]:
map_entities_to_tokens(ner_predictions, "ner_predictions")

Mapping entities to tokens in set ner_predictions...


Get the start and end indices for articles sentences

In [10]:
# Define abbreviations to not be splitted by the sentence tokenizer
extra_abbrevs = {'etc', 'etc.', 'etc.)', '<i>L', 'sp', 'subsp', '<i>A', '(<i>Hippophae rhamnoides</i> L.)', 'Rupr'}
punkt_param = PunktParameters()
for abbr in extra_abbrevs:
	punkt_param.abbrev_types.add(abbr.lower())
sentence_splitter = PunktSentenceTokenizer(punkt_param)

In [11]:
def get_sentence_spans(data: dict, data_name: str):
    print(f"Getting sentence spans in set {data_name}...")
    
    for pmid, article in data.items():
        title = article['metadata']['title']
        abstract = article['metadata']['abstract']

        # Convert the generator to a list so we can iterate it repeatedly.
        sentences = list(sentence_splitter.span_tokenize(abstract))

        # Prepare a list of booleans that will flag whether the sentence at a given index should be merged with the next one. 
        # A sentence is merged with the following one if an entity spans across them.
        # Initially, no merge is flagged.
        merge_next = [False] * len(sentences)
        
        # Process each entity in the article.
        for entity in article['entities']:
            location = entity['location']
            start = entity['start_idx']
            end = entity['end_idx']

            # For title entities, we do nothing regarding sentence spans.
            if location == 'title':
                if end > len(title):
                    raise Exception(f'{pmid} - Found title entity having illegal end index: {entity}')
                continue

            # Only process abstract entities.
            if location == 'abstract':
                start_sentence = None
                end_sentence = None

                # Iterate over the original sentence spans to determine in which sentences the entity start and end fall.
                for idx, s in enumerate(sentences):
                    # Using >= and <= to include boundaries.
                    if start >= s[0] and start <= s[1] and start_sentence is None:
                        start_sentence = idx
                        #print(f'Start sentence assigned: {idx}')
                    if end >= s[0] and end <= s[1] and end_sentence is None:
                        end_sentence = idx
                        #print(f'End sentence assigned: {idx}')

                if start_sentence is None:
                    raise Exception(f'{pmid} - Start sentence not assigned for entity: {entity}')
                if end_sentence is None:
                    raise Exception(f'{pmid} - End sentence not assigned for entity: {entity}')
                
                # If the entity falls in two different sentences, check if they are consecutive.
                if start_sentence != end_sentence:
                    if end_sentence - start_sentence == 1:
                        # Mark that sentence 'start_sentence' should be merged with its following sentence.
                        merge_next[start_sentence] = True
                        #print(f'{pmid} - Marking merge for sentences {start_sentence} and {end_sentence} due to entity: {entity}')
                    else:
                        raise Exception(f'{pmid} - Entity assigned to two non-consecutive sentences ({start_sentence}, {end_sentence}): {entity}')
        
        # At this point, we have a merge flag for each sentence that should be merged with its next one.
        # Now we build the updated list of sentence spans, merging as flagged.
        new_spans = []
        i = 0
        while i < len(sentences):
            start_val = sentences[i][0]
            end_val = sentences[i][1]
            # While the current sentence is flagged to merge with the next one, update the end_val.
            while i < len(sentences) - 1 and merge_next[i]:
                i += 1
                end_val = sentences[i][1]
            new_spans.append((start_val, end_val))
            i += 1

        # Add the updated list of sentence spans to the article dictionary.
        article['sentences'] = new_spans

In [12]:
get_sentence_spans(ner_predictions, "ner_predictions")

Getting sentence spans in set ner_predictions...


Check if sentence spans have been computed correctly.

In [13]:
def check_sentence_spans(data: dict, data_name: str):
    print(f"Checking sentence spans in set {data_name}...")

    for pmid, article in data.items():
        # Process each entity in the article.
        for entity in article['entities']:
            location = entity['location']
            start = entity['start_idx']
            end = entity['end_idx']
            # For title entities, we do nothing regarding sentence spans.
            if location == 'title':
                continue

            # Only process abstract entities.
            if location == 'abstract':
                start_sentence = None
                end_sentence = None

                # Iterate over the original sentence spans to determine in which sentences the entity start and end fall.
                for idx, s in enumerate(article['sentences']):
                    # Using >= and <= to include boundaries.
                    if start >= s[0] and start <= s[1] and start_sentence is None:
                        start_sentence = idx
                        #print(f'Start sentence assigned: {idx}')
                    if end >= s[0] and end <= s[1] and end_sentence is None:
                        end_sentence = idx
                        #print(f'End sentence assigned: {idx}')

                if start_sentence is None:
                    raise Exception(f'{pmid} - Start sentence not assigned for entity: {entity}')
                if end_sentence is None:
                    raise Exception(f'{pmid} - End sentence not assigned for entity: {entity}')
                
                # If the entity falls in two different sentences, raise Exception.
                if start_sentence != end_sentence:
                      raise Exception(f'{pmid} - Entity assigned to two different sentences ({start_sentence}, {end_sentence}): {entity}')


In [14]:
check_sentence_spans(ner_predictions, "ner_predictions")

Checking sentence spans in set ner_predictions...


Map tokens to the sentence in which they are located

In [15]:
def map_tokens_to_sentences(data: dict, data_name: str):
    """
    For each article, map tokens in the 'tokenized abstract' to the sentence in which they are located.
    Uses the 'sentences' field in the article, which is assumed to be a list of (start, end) tuples.
    
    The mapping is stored as a dictionary where the key is the token index (its position in the tokenized abstract)
    and the value is the sentence index. For example, if the first token belongs to sentence 0 and the third token
    belongs to sentence 1, the mapping will include entries {0: 0, 2: 1}.
    
    Raises an Exception if a token does not fall within any of the sentence spans.
    """
    print(f"Mapping tokens to sentences in set {data_name}...")

    for pmid, article in data.items():
        # Retrieve the tokenized abstract and the sentence spans.
        tokens = article.get('tokenized_abstract')
        sentences = article.get('sentences')
        
        if tokens is None:
            raise Exception(f"Article {pmid} is missing 'tokenized abstract'.")
        if sentences is None:
            raise Exception(f"Article {pmid} is missing 'sentences'. Make sure to run get_sentence_spans first.")
        
        token_to_sentence = {}
        
        # Iterate over each token and determine which sentence it belongs to.
        for token_index, token_entry in enumerate(tokens):
            # Each token_entry is assumed to be a tuple: (token_text, start_offset, end_offset)
            token_text, token_start, token_end = token_entry
            assigned_sentence = None
            
            # Check each sentence span to see if the token falls within it.
            for sentence_index, (sent_start, sent_end) in enumerate(sentences):
                # We assume a token belongs to a sentence if its start is >= sentence start and its end is <= sentence end.
                if token_start >= sent_start and token_end <= sent_end:
                    assigned_sentence = sentence_index
                    break  # Stop once we find the sentence that contains the token.
            
            if assigned_sentence is None:
                raise Exception(
                    f"Token '{token_text}' (index {token_index}, offsets {token_start}-{token_end}) "
                    f"in article {pmid} does not fall within any sentence span: {sentences}"
                )
            
            token_to_sentence[token_index] = assigned_sentence
        
        # Add the mapping to the article dictionary.
        article['tokens_to_sentences_map'] = token_to_sentence

In [16]:
map_tokens_to_sentences(ner_predictions, "ner_predictions")

Mapping tokens to sentences in set ner_predictions...


Map entities to the token positions within each sentence containing them

In [17]:
def map_entities_to_tokens_within_sentences(data: dict, data_name: str) -> dict:
    """
    For each article, this function maps each entity (assumed to be in the abstract)
    to the token positions within the sentence that contains it.
    
    For each entity in article['entities'] (with location 'abstract'), it adds:
      - 'located_in_sentence': the sentence index in which the entity's tokens are located,
      - 'start_token_in_sentence': the position of the entity's start token within that sentence,
      - 'end_token_in_sentence': the position of the entity's end token within that sentence.
    
    This function relies on:
      - article['tokenized abstract']: a list of tokens of the form (token_text, start_offset, end_offset)
      - article['tokens_to_sentences_map']: a mapping { token_index -> sentence_index }
      - article['sentences']: a list of (start, end) sentence spans for the abstract.
    """
    print(f"Mapping entities to tokens within sentences in set {data_name}...")
    
    for pmid, article in data.items():
        # Retrieve required fields.
        tokens = article.get('tokenized_abstract')
        token_to_sentence = article.get('tokens_to_sentences_map')
        sentences = article.get('sentences')
        
        if tokens is None:
            raise Exception(f"Article {pmid} is missing 'tokenized abstract'.")
        if token_to_sentence is None:
            raise Exception(f"Article {pmid} is missing 'tokens_to_sentences_map'. Run map_tokens_to_sentences first.")
        if sentences is None:
            raise Exception(f"Article {pmid} is missing 'sentences'. Run get_sentence_spans first.")
        
        # Build a helper mapping: for each sentence index, list the token indices that fall into that sentence.
        sentence_to_token_indices = {}
        for token_index in range(len(tokens)):
            sent_idx = token_to_sentence.get(token_index)
            if sent_idx is None:
                raise Exception(
                    f"In article {pmid}, token index {token_index} is not mapped to any sentence. Tokens: {tokens[token_index]}"
                )
            sentence_to_token_indices.setdefault(sent_idx, []).append(token_index)
        
        # Now process each entity.
        for entity in article.get('entities', []):
            if entity.get('location') != 'abstract': # Only process entities in the abstract.
                continue
            
            # Retrieve the token indices for this entity.
            entity_start_token = entity.get('start_token')
            entity_end_token = entity.get('end_token')
            
            if entity_start_token is None or entity_end_token is None:
                raise Exception(
                    f"Entity in article {pmid} is missing start_token or end_token: {entity}"
                )
            
            # Determine the sentence in which the entity's tokens are located.
            sentence_for_start = token_to_sentence.get(entity_start_token)
            sentence_for_end = token_to_sentence.get(entity_end_token)
            
            if sentence_for_start is None or sentence_for_end is None:
                raise Exception(
                    f"Entity in article {pmid} has tokens not mapped to any sentence: {entity}"
                )
            
            if sentence_for_start != sentence_for_end:
                raise Exception(
                    f"Entity in article {pmid} spans multiple sentences (start in {sentence_for_start}, end in {sentence_for_end}): {entity}"
                )
            
            located_sentence = sentence_for_start  # or sentence_for_end, both are same.
            
            # Get the list of token indices for the sentence.
            tokens_in_sentence = sentence_to_token_indices.get(located_sentence)
            if tokens_in_sentence is None:
                raise Exception(
                    f"Sentence {located_sentence} not found in helper mapping for article {pmid}."
                )
            
            # Find the position within the sentence for the start token.
            try:
                start_token_in_sentence = tokens_in_sentence.index(entity_start_token)
            except ValueError:
                raise Exception(
                    f"Entity start token {entity_start_token} not found in sentence tokens {tokens_in_sentence} for article {pmid}."
                )
            
            # And the position within the sentence for the end token.
            try:
                end_token_in_sentence = tokens_in_sentence.index(entity_end_token)
            except ValueError:
                raise Exception(
                    f"Entity end token {entity_end_token} not found in sentence tokens {tokens_in_sentence} for article {pmid}."
                )
            
            # Add the new fields to the entity.
            entity['located_in_sentence'] = located_sentence
            entity['start_token_in_sentence'] = start_token_in_sentence
            entity['end_token_in_sentence'] = end_token_in_sentence

In [18]:
map_entities_to_tokens_within_sentences(ner_predictions, "ner_predictions")

Mapping entities to tokens within sentences in set ner_predictions...


Convert processed annotations to the DocRED format used by ATLOP for finetuning

In [19]:
def convert_to_docred_format(data: dict, data_name: str, is_test=False) -> list:
    """
    Converts articles (in our intermediate format) to the DocRED format.
    
    For each article, a new dictionary is produced with the following keys:
      - "vertexSet": a list of entity mentions (each entity becomes a list with one mention).
          Each mention is a dict with:
              "pos": [start_token_in_sentence, end_token_in_sentence],
              "type": entity label,
              "sent_id": sentence id (0 for title; abstract sentences are numbered starting at 1),
              "name": the entity text span.
      - "title": the title string of the article.
      - "sents": a list of lists of tokens. The first entry is the title tokenization and subsequent
                 entries are the tokenizations of the abstract sentences.
    
    For abstract sentence tokenization, we use the sentence spans in article['sentences'] (a list of (start, end) offsets)
    and the tokenized abstract (article['tokenized abstract'], where each token is a tuple (token_text, start, end)).
    
    Returns a list of DocRED-formatted document dictionaries.
    """
    print(f"Converting articles to DocRED format for set {data_name}...")

    docred_docs = []
    for pmid, article in data.items():
        # 1. Build the vertexSet.
        # Each entity becomes a single mention. We assume that the entities in article['entities'] 
        # are ordered with title entities first, then abstract entities.
        vertexSet = []
        for entity in article.get('entities', []):
            # Determine the sentence id according to DocRED.
            # Title entities are assigned to sentence 0.
            if entity.get('location') == 'title':
                sent_id = 0
            else:
                # For abstract entities, we expect a field 'located_in_sentence' computed earlier.
                # In our intermediate format abstract sentences are numbered 0,1,... but in DocRED the title is sentence 0.
                # So add 1.
                sent_id = entity.get('located_in_sentence', 0) + 1

            # The token offsets of the entity within its sentence.
            if entity['location'] == 'title':
                pos = [entity.get('start_token'), entity.get('end_token')+1]
            else:
                pos = [entity.get('start_token_in_sentence'), entity.get('end_token_in_sentence')+1]
            mention = {
                "pos": pos,
                "type": entity.get("label").upper(),
                "sent_id": sent_id,
                "name": entity.get("text_span")
            }
            # Each vertexSet entry is a list of mentions (we have one per entity).
            vertexSet.append([mention])
        
        # 2. Build the sents field.
        # The first sentence is the tokenization of the title.
		
        title_tokens = article.get("tokenized_title", [])
        tokens_in_title = []
        for token in title_tokens:
            tokens_in_title.append(token[0])
        sents = [tokens_in_title]
            
        # For the abstract sentences, we use article['sentences'] (list of (start, end) spans)
        # and article['tokenized abstract'] (list of tokens, each as (token_text, start, end)).
        abstract_tokens = article.get("tokenized_abstract", [])
        abstract_sents = []
        for span in article.get("sentences", []):
            s_start, s_end = span
            tokens_in_sentence = []
            for token, t_start, t_end in abstract_tokens:
                # If a token falls completely within the sentence span, add it.
                if t_start >= s_start and t_end <= s_end:
                    tokens_in_sentence.append(token)
            abstract_sents.append(tokens_in_sentence)
        sents.extend(abstract_sents)
        
        # 3. The title string (a concatenation of pmid + '||' + title).
        doc_title = f'{pmid}||{article["metadata"]["title"]}'
        
        # 4. Build the final document dictionary.
        doc = {
            "vertexSet": vertexSet,
            "title": doc_title,
            "sents": sents
        }
        docred_docs.append(doc)
    
    return docred_docs


In [20]:
docred_ner_predictions = convert_to_docred_format(ner_predictions, "ner_predictions")

Converting articles to DocRED format for set ner_predictions...


Dump the dictionary variable to a json file

In [21]:
def dump_to_json(docred_dict, output_file_path):
	dict_with_double_quotes = json.dumps(docred_dict, ensure_ascii=False)
	with open(output_file_path, 'w', encoding='utf-8') as f:
		f.write(dict_with_double_quotes)

dump_to_json(docred_ner_predictions, PATH_OUTPUT_NER_PREDICTIONS)