### Convert annotations to the format used by ATLOP for finetuning.

Define the imports

In [1]:
import json
import re

In [2]:
import nltk
print('Nltk version: {}.'.format(nltk.__version__))

from nltk.tokenize import TreebankWordTokenizer as twt
from nltk.tokenize import WordPunctTokenizer as wpt
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
nltk.download('punkt_tab')

Nltk version: 3.9.1.


[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\samue\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

Define paths to the annotation files

In [3]:
PATH_PLATINUM_TRAIN = r"C:\Users\samue\OneDrive\Desktop\ThesisPiron\train_platinum.json"
#PATH_GOLD_TRAIN = "../Annotations/Train/gold_quality/json_format/train_gold.json"
#PATH_SILVER_TRAIN = "../Annotations/Train/silver_quality/json_format/train_silver.json"
#PATH_BRONZE_TRAIN = "../Annotations/Train/bronze_quality/json_format/train_bronze.json"
#PATH_DEV = "../Annotations/Dev/json_format/dev.json"

Define output paths

In [4]:
PATH_OUTPUT_PLATINUM_TRAIN = r"C:\Users\samue\OneDrive\Desktop\ThesisPiron\train_platinumv2.json"
#PATH_OUTPUT_GOLD_TRAIN = "../Train/RE/data/train_gold.json"
#PATH_OUTPUT_SILVER_TRAIN = "../Train/RE/data/train_silver.json"
#PATH_OUTPUT_BRONZE_TRAIN = "../Train/RE/data/train_bronze.json"
#PATH_OUTPUT_DEV = "../Train/RE/data/dev.json"

Load the input files into dictionary variables

In [5]:
with open(PATH_PLATINUM_TRAIN, 'r', encoding='utf-8') as file:
	train_platinum = json.load(file)


Remove articles having less that 5 entities annotated

In [6]:
def get_articles_with_less_than_5_entities(data: dict):
    pmid_list = []
    for pmid, article in data.items():
        entities = article['entities']
        if len(entities) < 5:
            pmid_list.append(pmid)
    return pmid_list

def remove_articles_with_less_than_5_entities(data: dict, data_name: str):
    pmid_list = get_articles_with_less_than_5_entities(data)
    for pmid in pmid_list:
        del data[pmid]
    print(f'{data_name} - {len(pmid_list)} articles removed.')

In [7]:
remove_articles_with_less_than_5_entities(train_platinum, "train_platinum")

train_platinum - 0 articles removed.


Remove articles without relations annotated

In [8]:
def remove_articles_without_relations(data: dict, data_name: str):
    pmid_list = []
    for pmid, article in data.items():
        if len(article['relations']) == 0:
            pmid_list.append(pmid)

    for pmid in pmid_list:
        del data[pmid]
    print(f'{data_name} - {len(pmid_list)} articles removed.')

In [9]:
#remove_articles_without_relations(train_platinum, "train_platinum")

Tokenize the articles

In [32]:
# Define the tokens of length 2 that are not captured by the tokenizer
ILLEGAL_WORDS_2 = [').', '(<', '>)', '),', '.,', '].', '],', '.:', '>.', '>,', '))', '+)', '>-', '</', '[<', '-,', '.)', 'â„¢,']
# Define the tokens of length 3 that are not captured by the tokenizer
ILLEGAL_WORDS_3 = ['.),', '.].', '>),', '.).', '>).']

def tokenize_docs(data: dict, data_name: str):
	print(f"Tokenizing articles in set {data_name}...")

	for pmid, article in data.items():
		title = article['metadata']['title']
		abstract = article['metadata']['abstract']

		title_spans = list(wpt().span_tokenize(title))
		abstract_spans = list(wpt().span_tokenize(abstract))
		
		article['tokenized_title'] = []
		for start, end in title_spans:
			word = title[start:end]
			if word in ILLEGAL_WORDS_2:
				word1 = title[start:end-1]
				word2 = title[end-1:end]
				article['tokenized_title'].append((word1, start, end-1))
				article['tokenized_title'].append((word2, end-1, end))
			elif word in ILLEGAL_WORDS_3:
				word1 = title[start:start+1]
				word2 = title[start+1:end-1]
				word3 = title[end-1:end]
				article['tokenized_title'].append((word1, start, start+1))
				article['tokenized_title'].append((word2, start+1, end-1))
				article['tokenized_title'].append((word3, end-1, end))
			else:	
				article['tokenized_title'].append((word, start, end))
		
		article['tokenized_abstract'] = []
		for start, end in abstract_spans:
			word = abstract[start:end]
			if word in ILLEGAL_WORDS_2:
				word1 = abstract[start:end-1]
				word2 = abstract[end-1:end]
				article['tokenized_abstract'].append((word1, start, end-1))
				article['tokenized_abstract'].append((word2, end-1, end))
			elif word in ILLEGAL_WORDS_3:
				word1 = abstract[start:start+1]
				word2 = abstract[start+1:end-1]
				word3 = abstract[end-1:end]
				article['tokenized_abstract'].append((word1, start, start+1))
				article['tokenized_abstract'].append((word2, start+1, end-1))
				article['tokenized_abstract'].append((word3, end-1, end))
			else:
				article['tokenized_abstract'].append((word, start, end))

In [33]:
tokenize_docs(train_platinum, "train_platinum")

Tokenizing articles in set train_platinum...


Adjust wrong annotations (i.e., annotations including partially annotated words) from the silver collection 

In [12]:
# Each PMIDs maps to a dict of {annotation_with_wrong_text_span: correct_text_span}
PARTIAL_WORDS = {
    '35275534': {
        (1395, 1412, 'abstract', 'Ruminococcusgnavus'): 'Ruminococcusgnavusgroup'  
    },
    '38963982': {
        (92, 94, 'title', 'TAM'): 'TAMs',
        (435, 451, 'abstract', 'Intestinal tissue'): 'Intestinal tissues'
    },
    '38959280': {
        (74, 76, 'abstract', 'SGM'): 'SGMs',
        (695, 697, 'abstract', 'SGM'): 'SGMs',
        (266, 287, 'abstract', 'cisgender heterosexual'): 'cisgender heterosexuals',
        (713, 734, 'abstract', 'cisgender-heterosexual'): 'cisgender-heterosexuals' 
    },
    '38968876': {
        (764, 770, 'abstract', 'patient'): 'patients'
    },
    '38892525': {
        (1397, 1407, 'abstract', 'IBS symptom'): 'IBS symptoms'
    }
}

def fix_wrong_annotations(data: dict, data_name: str):
    print(f'Fixing annotations in set {data_name}...')
    
    for pmid in list(data.keys()):
        pmid_str = str(pmid)
        if pmid_str in PARTIAL_WORDS:
            replacements = PARTIAL_WORDS[pmid_str]
            
            # Fix entities:
            for entity in data[pmid]['entities']:
                wrong_entry = (entity['start_idx'], entity['end_idx'], entity['location'], entity['text_span'])
                if wrong_entry in replacements:
                    correct_text = replacements[wrong_entry]
                    entity['text_span'] = correct_text
                    # Update end index based on the new text length.
                    entity['end_idx'] = entity['start_idx'] + len(correct_text) - 1
                    print(f"Fixed entity in pmid {pmid}: '{wrong_entry[3]}' -> '{correct_text}'")
            
            # Fix relations:
            for relation in data[pmid]['relations']:
                # Check and fix subject if needed.
                wrong_entry = (relation['subject_start_idx'], relation['subject_end_idx'], relation['subject_location'], relation['subject_text_span'])
                if wrong_entry in replacements:
                    correct_text = replacements[wrong_entry]
                    relation['subject_text_span'] = correct_text
                    relation['subject_end_idx'] = relation['subject_start_idx'] + len(correct_text) - 1
                    print(f"Fixed subject in pmid {pmid}: '{wrong_entry[3]}' -> '{correct_text}'")
                    
                # Check and fix object if needed.
                wrong_entry = (relation['object_start_idx'], relation['object_end_idx'], relation['object_location'], relation['object_text_span'])
                if wrong_entry in replacements:
                    correct_text = replacements[wrong_entry]
                    relation['object_text_span'] = correct_text
                    relation['object_end_idx'] = relation['object_start_idx'] + len(correct_text) - 1
                    print(f"Fixed object in pmid {pmid}: '{wrong_entry[3]}' -> '{correct_text}'")

In [13]:
#fix_wrong_annotations(train_silver, 'train_silver')

Map annotated entities to tokens 

In [34]:
def map_entities_to_tokens(data: dict, data_name: str):
	print(f"Mapping entities to tokens in set {data_name}...")
	
	for pmid, article in data.items():
		for entity in article['entities']:
			location = entity['location']
			start = entity['start_idx']
			end = entity['end_idx']
			start_token = None
			end_token = None
			if location == 'title':
				for idx, token in enumerate(article['tokenized_title']):
					if start == token[1] and start is not None:
						start_token = idx
					if end == token[2]-1 and end is not None:
						end_token = idx
			elif location == 'abstract':
				for idx, token in enumerate(article['tokenized_abstract']):
					if start == token[1] and start is not None:
						start_token = idx
					if end == token[2]-1 and end is not None:
						end_token = idx
			else:
				raise Exception(f'{pmid} - Unrecognized Location: {location}')
			if start_token is not None and end_token is not None:
				entity['start_token'] = start_token
				entity['end_token'] = end_token
			else:
				print (data[pmid]['tokenized_title'])
				print(data[pmid]['tokenized_abstract'])
				raise Exception(f'{pmid} - Not able to assign token(s) to entity: {entity}')

In [35]:
map_entities_to_tokens(train_platinum, "train_platinum")

Mapping entities to tokens in set train_platinum...


Get the start and end indices for articles sentences

In [36]:
# Define abbreviations to not be splitted by the sentence tokenizer
extra_abbrevs = {'etc', 'etc.', 'etc.)', '<i>L', 'sp', 'subsp', '<i>A', '(<i>Hippophae rhamnoides</i> L.)'}
punkt_param = PunktParameters()
for abbr in extra_abbrevs:
	punkt_param.abbrev_types.add(abbr)
sentence_splitter = PunktSentenceTokenizer(punkt_param)

def get_sentence_spans(data: dict, data_name: str):
    print(f"Getting sentence spans in set {data_name}...")
    
    for pmid, article in data.items():
        title = article['metadata']['title']
        abstract = article['metadata']['abstract']

        # Convert the generator to a list so we can iterate it repeatedly.
        sentences = list(sentence_splitter.span_tokenize(abstract))

        # Prepare a list of booleans that will flag whether the sentence at a given index should be merged with the next one. 
        # A sentence is merged with the following one if an entity spans across them.
        # Initially, no merge is flagged.
        merge_next = [False] * len(sentences)
        
        # Process each entity in the article.
        for entity in article['entities']:
            location = entity['location']
            start = entity['start_idx']
            end = entity['end_idx']

            # For title entities, we do nothing regarding sentence spans.
            if location == 'title':
                if end > len(title):
                    raise Exception(f'{pmid} - Found title entity having illegal end index: {entity}')
                continue

            # Only process abstract entities.
            if location == 'abstract':
                start_sentence = None
                end_sentence = None

                # Iterate over the original sentence spans to determine in which sentences the entity start and end fall.
                for idx, s in enumerate(sentences):
                    # Using >= and <= to include boundaries.
                    if start >= s[0] and start <= s[1] and start_sentence is None:
                        start_sentence = idx
                        #print(f'Start sentence assigned: {idx}')
                    if end >= s[0] and end <= s[1] and end_sentence is None:
                        end_sentence = idx
                        #print(f'End sentence assigned: {idx}')

                if start_sentence is None:
                    raise Exception(f'{pmid} - Start sentence not assigned for entity: {entity}')
                if end_sentence is None:
                    raise Exception(f'{pmid} - End sentence not assigned for entity: {entity}')
                
                # If the entity falls in two different sentences, check if they are consecutive.
                if start_sentence != end_sentence:
                    if end_sentence - start_sentence == 1:
                        # Mark that sentence 'start_sentence' should be merged with its following sentence.
                        merge_next[start_sentence] = True
                        #print(f'{pmid} - Marking merge for sentences {start_sentence} and {end_sentence} due to entity: {entity}')
                    else:
                        raise Exception(f'{pmid} - Entity assigned to two non-consecutive sentences ({start_sentence}, {end_sentence}): {entity}')
        
        # At this point, we have a merge flag for each sentence that should be merged with its next one.
        # Now we build the updated list of sentence spans, merging as flagged.
        new_spans = []
        i = 0
        while i < len(sentences):
            start_val = sentences[i][0]
            end_val = sentences[i][1]
            # While the current sentence is flagged to merge with the next one, update the end_val.
            while i < len(sentences) - 1 and merge_next[i]:
                i += 1
                end_val = sentences[i][1]
            new_spans.append((start_val, end_val))
            i += 1

        # Add the updated list of sentence spans to the article dictionary.
        article['sentences'] = new_spans

In [37]:
get_sentence_spans(train_platinum, "train_platinum")

Getting sentence spans in set train_platinum...


Check if sentence spans have been computed correctly.

In [38]:
def check_sentence_spans(data: dict, data_name: str):
    print(f"Checking sentence spans in set {data_name}...")

    for pmid, article in data.items():
        # Process each entity in the article.
        for entity in article['entities']:
            location = entity['location']
            start = entity['start_idx']
            end = entity['end_idx']
            # For title entities, we do nothing regarding sentence spans.
            if location == 'title':
                continue

            # Only process abstract entities.
            if location == 'abstract':
                start_sentence = None
                end_sentence = None

                # Iterate over the original sentence spans to determine in which sentences the entity start and end fall.
                for idx, s in enumerate(article['sentences']):
                    # Using >= and <= to include boundaries.
                    if start >= s[0] and start <= s[1] and start_sentence is None:
                        start_sentence = idx
                        #print(f'Start sentence assigned: {idx}')
                    if end >= s[0] and end <= s[1] and end_sentence is None:
                        end_sentence = idx
                        #print(f'End sentence assigned: {idx}')

                if start_sentence is None:
                    raise Exception(f'{pmid} - Start sentence not assigned for entity: {entity}')
                if end_sentence is None:
                    raise Exception(f'{pmid} - End sentence not assigned for entity: {entity}')
                
                # If the entity falls in two different sentences, raise Exception.
                if start_sentence != end_sentence:
                      raise Exception(f'{pmid} - Entity assigned to two different sentences ({start_sentence}, {end_sentence}): {entity}')


In [39]:
check_sentence_spans(train_platinum, "train_platinum")

Checking sentence spans in set train_platinum...


Map tokens to the sentence in which they are located

In [40]:
def map_tokens_to_sentences(data: dict, data_name: str):
    """
    For each article, map tokens in the 'tokenized abstract' to the sentence in which they are located.
    Uses the 'sentences' field in the article, which is assumed to be a list of (start, end) tuples.
    
    The mapping is stored as a dictionary where the key is the token index (its position in the tokenized abstract)
    and the value is the sentence index. For example, if the first token belongs to sentence 0 and the third token
    belongs to sentence 1, the mapping will include entries {0: 0, 2: 1}.
    
    Raises an Exception if a token does not fall within any of the sentence spans.
    """
    print(f"Mapping tokens to sentences in set {data_name}...")

    for pmid, article in data.items():
        # Retrieve the tokenized abstract and the sentence spans.
        tokens = article.get('tokenized_abstract')
        sentences = article.get('sentences')
        
        if tokens is None:
            raise Exception(f"Article {pmid} is missing 'tokenized abstract'.")
        if sentences is None:
            raise Exception(f"Article {pmid} is missing 'sentences'. Make sure to run get_sentence_spans first.")
        
        token_to_sentence = {}
        
        # Iterate over each token and determine which sentence it belongs to.
        for token_index, token_entry in enumerate(tokens):
            # Each token_entry is assumed to be a tuple: (token_text, start_offset, end_offset)
            token_text, token_start, token_end = token_entry
            assigned_sentence = None
            
            # Check each sentence span to see if the token falls within it.
            for sentence_index, (sent_start, sent_end) in enumerate(sentences):
                # We assume a token belongs to a sentence if its start is >= sentence start and its end is <= sentence end.
                if token_start >= sent_start and token_end <= sent_end:
                    assigned_sentence = sentence_index
                    break  # Stop once we find the sentence that contains the token.
            
            if assigned_sentence is None:
                raise Exception(
                    f"Token '{token_text}' (index {token_index}, offsets {token_start}-{token_end}) "
                    f"in article {pmid} does not fall within any sentence span: {sentences}"
                )
            
            token_to_sentence[token_index] = assigned_sentence
        
        # Add the mapping to the article dictionary.
        article['tokens_to_sentences_map'] = token_to_sentence

In [41]:
map_tokens_to_sentences(train_platinum, "train_platinum")

Mapping tokens to sentences in set train_platinum...


Map entities to the token positions within each sentence containing them

In [42]:
def map_entities_to_tokens_within_sentences(data: dict, data_name: str) -> dict:
    """
    For each article, this function maps each entity (assumed to be in the abstract)
    to the token positions within the sentence that contains it.
    
    For each entity in article['entities'] (with location 'abstract'), it adds:
      - 'located_in_sentence': the sentence index in which the entity's tokens are located,
      - 'start_token_in_sentence': the position of the entity's start token within that sentence,
      - 'end_token_in_sentence': the position of the entity's end token within that sentence.
    
    This function relies on:
      - article['tokenized abstract']: a list of tokens of the form (token_text, start_offset, end_offset)
      - article['tokens_to_sentences_map']: a mapping { token_index -> sentence_index }
      - article['sentences']: a list of (start, end) sentence spans for the abstract.
    """
    print(f"Mapping entities to tokens within sentences in set {data_name}...")
    
    for pmid, article in data.items():
        # Retrieve required fields.
        tokens = article.get('tokenized_abstract')
        token_to_sentence = article.get('tokens_to_sentences_map')
        sentences = article.get('sentences')
        
        if tokens is None:
            raise Exception(f"Article {pmid} is missing 'tokenized abstract'.")
        if token_to_sentence is None:
            raise Exception(f"Article {pmid} is missing 'tokens_to_sentences_map'. Run map_tokens_to_sentences first.")
        if sentences is None:
            raise Exception(f"Article {pmid} is missing 'sentences'. Run get_sentence_spans first.")
        
        # Build a helper mapping: for each sentence index, list the token indices that fall into that sentence.
        sentence_to_token_indices = {}
        for token_index in range(len(tokens)):
            sent_idx = token_to_sentence.get(token_index)
            if sent_idx is None:
                raise Exception(
                    f"In article {pmid}, token index {token_index} is not mapped to any sentence. Tokens: {tokens[token_index]}"
                )
            sentence_to_token_indices.setdefault(sent_idx, []).append(token_index)
        
        # Now process each entity.
        for entity in article.get('entities', []):
            if entity.get('location') != 'abstract': # Only process entities in the abstract.
                continue
            
            # Retrieve the token indices for this entity.
            entity_start_token = entity.get('start_token')
            entity_end_token = entity.get('end_token')
            
            if entity_start_token is None or entity_end_token is None:
                raise Exception(
                    f"Entity in article {pmid} is missing start_token or end_token: {entity}"
                )
            
            # Determine the sentence in which the entity's tokens are located.
            sentence_for_start = token_to_sentence.get(entity_start_token)
            sentence_for_end = token_to_sentence.get(entity_end_token)
            
            if sentence_for_start is None or sentence_for_end is None:
                raise Exception(
                    f"Entity in article {pmid} has tokens not mapped to any sentence: {entity}"
                )
            
            if sentence_for_start != sentence_for_end:
                raise Exception(
                    f"Entity in article {pmid} spans multiple sentences (start in {sentence_for_start}, end in {sentence_for_end}): {entity}"
                )
            
            located_sentence = sentence_for_start  # or sentence_for_end, both are same.
            
            # Get the list of token indices for the sentence.
            tokens_in_sentence = sentence_to_token_indices.get(located_sentence)
            if tokens_in_sentence is None:
                raise Exception(
                    f"Sentence {located_sentence} not found in helper mapping for article {pmid}."
                )
            
            # Find the position within the sentence for the start token.
            try:
                start_token_in_sentence = tokens_in_sentence.index(entity_start_token)
            except ValueError:
                raise Exception(
                    f"Entity start token {entity_start_token} not found in sentence tokens {tokens_in_sentence} for article {pmid}."
                )
            
            # And the position within the sentence for the end token.
            try:
                end_token_in_sentence = tokens_in_sentence.index(entity_end_token)
            except ValueError:
                raise Exception(
                    f"Entity end token {entity_end_token} not found in sentence tokens {tokens_in_sentence} for article {pmid}."
                )
            
            # Add the new fields to the entity.
            entity['located_in_sentence'] = located_sentence
            entity['start_token_in_sentence'] = start_token_in_sentence
            entity['end_token_in_sentence'] = end_token_in_sentence

In [43]:
map_entities_to_tokens_within_sentences(train_platinum, "train_platinum")

Mapping entities to tokens within sentences in set train_platinum...


Map each relation's subject and object to an index corresponding to their position in the article's 'entities' list.

In [44]:
def map_relations_to_entities(data: dict, data_name: str) -> dict:
    """
    For each article, this function maps each relation's subject and object to an index 
    corresponding to their position in the article's 'entities' list. It adds two new fields 
    to each relation:
      - 'subject_entity_idx': the index of the subject entity in the article's 'entities' list.
      - 'object_entity_idx': the index of the object entity in the article's 'entities' list.
      
    The matching is performed based on:
      - For the subject:
            relation['subject_location'] == entity['location']
            relation['subject_start_idx'] == entity['start_idx']
            relation['subject_end_idx'] == entity['end_idx']
      - For the object:
            relation['object_location'] == entity['location']
            relation['object_start_idx'] == entity['start_idx']
            relation['object_end_idx'] == entity['end_idx']
            
    It is assumed that the article['entities'] list is ordered such that title entities come first (starting at index 0)
    and abstract entities follow (with indices continuing from the last title entity index).
    """
    print(f"Mapping relations to entities in set {data_name}...")
    
    for pmid, article in data.items():
        entities = article.get('entities')
        relations = article.get('relations')
        
        if entities is None:
            raise Exception(f"Article {pmid} is missing the 'entities' field.")
        if relations is None:
            # If there are no relations, there's nothing to map.
            continue
        
        # Process each relation.
        for relation in relations:
            # Map the subject.
            subject_idx_found = None
            for idx, entity in enumerate(entities):
                if (entity.get('location') == relation.get('subject_location') and
                    entity.get('start_idx') == relation.get('subject_start_idx') and
                    entity.get('end_idx') == relation.get('subject_end_idx')):
                    subject_idx_found = idx
                    break
                    
            if subject_idx_found is None:
                raise Exception(
                    f"Subject entity not found for relation in article {pmid}: {relation}"
                )
            relation['subject_entity_idx'] = subject_idx_found
            
            # Map the object.
            object_idx_found = None
            for idx, entity in enumerate(entities):
                if (entity.get('location') == relation.get('object_location') and
                    entity.get('start_idx') == relation.get('object_start_idx') and
                    entity.get('end_idx') == relation.get('object_end_idx')):
                    object_idx_found = idx
                    break
                    
            if object_idx_found is None:
                raise Exception(
                    f"Object entity not found for relation in article {pmid}: {relation}"
                )
            relation['object_entity_idx'] = object_idx_found

In [45]:
map_relations_to_entities(train_platinum, "train_platinum")

Mapping relations to entities in set train_platinum...


Convert processed annotations to the DocRED format used by ATLOP for finetuning

In [58]:
def convert_to_docred_format(data: dict, data_name: str, is_test=False) -> list:
    """
    Converts articles (in our intermediate format) to the DocRED format.
    
    For each article, a new dictionary is produced with the following keys:
      - "vertexSet": a list of entity mentions (each entity becomes a list with one mention).
          Each mention is a dict with:
              "pos": [start_token_in_sentence, end_token_in_sentence],
              "type": entity label,
              "sent_id": sentence id (0 for title; abstract sentences are numbered starting at 1),
              "name": the entity text span.
      - "labels": a list of relations, each a dict with:
              "r": relation predicate/label,
              "h": subject_entity_idx,
              "t": object_entity_idx,
              "evidence": a list of two sentence ids: [subject sentence id, object sentence id] 
                          (or one element if both are the same).
      - "paper_id": the id of the paper.
      - "title": the title string of the article.
      - "sents": a list of lists of tokens. The first entry is the tokenization of the title and subsequent
                 entries are the tokenizations of the abstract sentences.
    
    For abstract sentence tokenization, we use the sentence spans in article['sentences'] (a list of (start, end) offsets)
    and the tokenized abstract (article['tokenized abstract'], where each token is a tuple (token_text, start, end)).
    
    Returns a list of DocRED-formatted document dictionaries.
    """
    print(f"Converting articles to DocRED format for set {data_name}...")

    docred_docs = []
    for pmid, article in data.items():
        # 1. Build the vertexSet.
        vertexSet = []
        for entity in article.get('entities', []):
            # Determine sentence id: title entities are sentence 0,
            # and abstract entities use a pre-computed 'located_in_sentence' (plus 1).
            if entity.get('location') == 'title':
                sent_id = 0
            else:
                sent_id = entity.get('located_in_sentence', 0) + 1

            # Determine token offsets within the sentence.
            if entity.get('location') == 'title':
                pos = [entity.get('start_token'), entity.get('end_token') + 1]
            else:
                pos = [entity.get('start_token_in_sentence'), entity.get('end_token_in_sentence') + 1]
                
            mention = {
                "pos": pos,
                "type": entity.get("label").upper(),
                "sent_id": sent_id,
                "name": entity.get("text_span")
            }
            # Each vertexSet entry is a list of one mention.
            vertexSet.append([mention])
        
        # 2. Build the labels (relations).
        labels = []
        for relation in article.get("relations", []):
            subj_idx = relation.get("subject_entity_idx")
            obj_idx = relation.get("object_entity_idx")
            if subj_idx is None or obj_idx is None:
                raise Exception(f"Missing entity mapping in relation: {relation} in article {pmid}")
            
            subj_entity = article["entities"][subj_idx]
            obj_entity = article["entities"][obj_idx]
            
            if subj_entity.get("location") == "title":
                subj_sent = 0
            else:
                subj_sent = subj_entity.get("located_in_sentence", 0) + 1
                
            if obj_entity.get("location") == "title":
                obj_sent = 0
            else:
                obj_sent = obj_entity.get("located_in_sentence", 0) + 1
            
            relation_dict = {
                "r": relation.get("predicate").upper(),
                "h": subj_idx,
                "t": obj_idx,
                "evidence": [subj_sent, obj_sent] if subj_sent != obj_sent else [subj_sent]
            }
            labels.append(relation_dict)
        
        # 3. Build the sents field.
        # The first sentence is the tokenized title.
        title_tokens = article.get("tokenized_title", [])
        tokens_in_title = [token[0] for token in title_tokens]
        sents = [tokens_in_title]
            
        # For abstract sentences, use 'sentences' (spans) and 'tokenized_abstract'.
        abstract_tokens = article.get("tokenized_abstract", [])
        abstract_sents = []
        for span in article.get("sentences", []):
            s_start, s_end = span
            tokens_in_sentence = []
            for token, t_start, t_end in abstract_tokens:
                if t_start >= s_start and t_end <= s_end:
                    tokens_in_sentence.append(token)
            abstract_sents.append(tokens_in_sentence)
        sents.extend(abstract_sents)
        
        # 4. Get the title string.
        doc_title = article["metadata"]["title"]
        print(pmid)
        
        # 5. Build the final document dictionary.
        doc = {
            "vertexSet": vertexSet,
            "labels": labels,
            "paper_id": pmid,
            "title": doc_title,
            "sents": sents
        }
        docred_docs.append(doc)
    
    return docred_docs


In [59]:
docred_train_platinum = convert_to_docred_format(train_platinum, "train_platinum")

Converting articles to DocRED format for set train_platinum...
38068763
35965349
34870091
28158162
34172092
37845499
37371676
37574818
37571393
37841274
37485660
31955786
34098340
38350463
29352709
33511258
33422110
34985325
36550591
30459574
38026003
33194817
29022384
29857583
34758889
37881577
36984505
32979562
34961418
25034760
33067915
33271210
36794003
38132705
36900437
34603341
34422393
35914559
38422755
37228957
30717162
31248424
37469436
31179435
37995075
35326429
31083360
38010793
31685046
34444820
34092293
37927130
35432226
36757367
36493975
37213508
33046051
38204948
31952911
29023380
28572752
36346385
32459708
33177907
38089822
31646148
23981537
37657622
36760344
33722869
34776854
28976454
31053995
38576868
37511699
37464164
37368331
37396336
36517709
37978477
33713734
34830610
29843470
32469834
37960197
30707176
38139446
37209162
36738999
37627638
30394313
34393849
37606895
38508549
36564391
37095530
35083314
37377497
37207228
38707924
30584306
34422879
26927355
38409102
3

Dump the dictionary variable to a json file

In [60]:
def dump_to_json(docred_dict, output_file_path):
	dict_with_double_quotes = json.dumps(docred_dict, ensure_ascii=False)
	with open(output_file_path, 'w', encoding='utf-8') as f:
		f.write(dict_with_double_quotes)

dump_to_json(docred_train_platinum, PATH_OUTPUT_PLATINUM_TRAIN)