In [1]:
from bioc import biocxml
from collections import defaultdict
from enum import Enum
from typing import Any, List
import json
import glob
import random
import os
import warnings

In [2]:
SAMPLE_INPUT_FILE = 'data/bioc/input/ebm_nlp_2_00_sample_ssplit.xml'
SAMPLE_OUTPUT_PATH = 'data/bioc/json/sample'
SAMPLE_SPAN_OUTPUT_PATH = 'data/bioc/span/sample'

INPUT_FILE = 'data/bioc/input/ebm_nlp_2_00_ssplit.xml'
# OUTPUT_PATH = 'data/bioc/json/no_overlap_training'
# OUTPUT_PATH = 'data/bioc/json/sample'
OUTPUT_PATH = 'data/bioc/json'

class DatasetSplit(Enum):
    train = 0
    validation = 1
    test = 2

class PicoType(Enum):
    PARTICIPANTS = 4
    INTERVENTIONS = 2
    OUTCOMES = 1

### Format EBM-NLP dataset for Sequence Labeling.

See https://github.com/gzhang64/EBM-NLP for detailed description of the dataset. To conveniently fine-tune transformer models for sequence labeling, we convert each sentence in an annotated PubMed abstract to a JSON object with the following format:

Column Name | Data Type | Description
:--- | :--- | :---
pmid | String | PubMed ID of the annotated abstract.
tokens | List(String) | The list of tokens in a sentence from an annotated abstract.
interventions | List(String) | Intervention spanning tags, e.g. "B-interventions", "I-interventions", "O".
outcomes | List(String) | Outcome tags corresponding to each token.
participants | List(String) | Participant tags corresponding to each token.

In [3]:
def load_pmid(folder):
    return set([f.rstrip('.AGGREGATED.ann') for f in os.listdir(folder) if f.endswith('.AGGREGATED.ann')])

In [4]:
participants_train = load_pmid('data/raw/ebm_nlp_2_00/annotations/aggregated/starting_spans/participants/train')
interventions_train = load_pmid('data/raw/ebm_nlp_2_00/annotations/aggregated/starting_spans/interventions/train')
outcomes_train = load_pmid('data/raw/ebm_nlp_2_00/annotations/aggregated/starting_spans/outcomes/train')

participants_test = load_pmid('data/raw/ebm_nlp_2_00/annotations/aggregated/starting_spans/participants/test/gold')
interventions_test = load_pmid('data/raw/ebm_nlp_2_00/annotations/aggregated/starting_spans/interventions/test/gold')
outcomes_test = load_pmid('data/raw/ebm_nlp_2_00/annotations/aggregated/starting_spans/outcomes/test/gold')

print(f'PARTICIPANTS: train {len(participants_train)}, test {len(participants_test)}')
print(f'INTERVENTIONS: train {len(interventions_train)}, test {len(interventions_test)}')
print(f'OUTCOMES: train {len(outcomes_train)}, test {len(outcomes_test)}')

PARTICIPANTS: train 4792, test 189
INTERVENTIONS: train 4782, test 188
OUTCOMES: train 4670, test 190


In [5]:
# Load EBM-NLP dataset as a BioC collection.
with open(INPUT_FILE, 'r', encoding='UTF-8') as fp:
    collection = biocxml.load(fp)

In [6]:
all_pmid = set([document.id for document in collection.documents])
test_pmid = participants_test.intersection(interventions_test, outcomes_test)
train_pmid_list = list(all_pmid - test_pmid)
train_validation_boundary = int(0.95 * len(train_pmid_list))
train_pmid = set(train_pmid_list[:train_validation_boundary])
validation_pmid = set(train_pmid_list[train_validation_boundary:])

print(f'Total size: {len(collection.documents)}')
print(f'Train size: {len(collection.documents) - len(validation_pmid)- len(test_pmid)}')
print(f'Validation size: {len(validation_pmid)}')
print(f'Test size: {len(test_pmid)}')

Total size: 4993
Train size: 4566
Validation size: 241
Test size: 186


In [9]:
'''
Extracts PIO labels from a BioC Sentence object.

Parameters:
    sentence (BioC sentence): a BioC sentence object that contains tokens and labels.
    pico_type (enum PicoType): one of PicoType.{INTERVENTIONS, OUTCOMES, PARTICIPANTS}

Returns:
    A list of PICO tags corresponding to each token or None if the labels are not present.
'''
def extract_pio_labels(sentence: Any, pico_type: Any) -> List[Any]:
    label_name = 'starting_spans/{}'.format(pico_type.name.lower())
    if not sentence.annotations:
        return None
    if label_name not in sentence.annotations[0].infons:
        return [None for _ in sentence.annotations]
    raw_labels = [a.infons[label_name] for a in sentence.annotations]
    b_tag = 'B-{}'.format(pico_type.name)
    i_tag = 'I-{}'.format(pico_type.name)
    o_tag = 'O'
    labels = []
    for i, raw_label in enumerate(raw_labels):
        if raw_label == '0':
            labels.append(o_tag)
        elif i == 0:
            labels.append(b_tag)
        else:
            labels.append(i_tag if raw_labels[i - 1] == '1' else b_tag)
    return labels

'''Extract tokens from a BioC Sentence object.'''
def extract_tokens(sentence: Any) -> List[str]:
    return [a.text for a in sentence.annotations]

'''Checks if PIO labels have overlapping tokens.'''
def check_overlapping(sentence: Any) -> bool:
    def extract_raw_span_labels(sentence, pico_type):
        label_name = 'starting_spans/{}'.format(pico_type.name.lower())
        if not sentence.annotations or label_name not in sentence.annotations[0].infons:
            return [0 for _ in sentence.annotations]
        return [int(a.infons[label_name]) for a in sentence.annotations]
    
    p_labels = extract_raw_span_labels(sentence, PicoType.PARTICIPANTS)
    i_labels = extract_raw_span_labels(sentence, PicoType.INTERVENTIONS)
    o_labels = extract_raw_span_labels(sentence, PicoType.OUTCOMES)
    
    all_labels = [p + i + o for p, i, o in zip(p_labels, i_labels, o_labels)]
    return max(all_labels) > 1

print('Example of extracted tokens and labels (Interventions):')
example_sentence = collection.documents[0].passages[0].sentences[0]
example_tokens = extract_tokens(example_sentence)
example_participant_labels = extract_pio_labels(example_sentence,PicoType.PARTICIPANTS)
example_intervention_labels = extract_pio_labels(example_sentence,PicoType.INTERVENTIONS)
example_outcome_labels = extract_pio_labels(example_sentence,PicoType.OUTCOMES)

assert len(example_tokens) == len(example_participant_labels)
assert len(example_tokens) == len(example_intervention_labels)
assert len(example_tokens) == len(example_outcome_labels)

for token, p, i, o in zip(
    example_tokens,
    example_participant_labels,
    example_intervention_labels,
    example_outcome_labels):
    print('{}\t{}\t{}\t{}'.format(token, p, i, o))
    
print('Has overlapping span: ', check_overlapping(example_sentence))

Example of extracted tokens and labels (Interventions):
[	O	O	O
Triple	O	O	O
therapy	O	O	O
regimens	O	O	O
involving	O	O	O
H2	O	B-INTERVENTIONS	O
blockaders	O	I-INTERVENTIONS	O
for	O	O	O
therapy	O	O	O
of	O	O	O
Helicobacter	O	O	B-OUTCOMES
pylori	O	O	I-OUTCOMES
infections	O	O	I-OUTCOMES
]	O	O	I-OUTCOMES
.	O	O	O
Has overlapping span:  False


In [11]:
'''
Creates a training/validation/test dataset.

Creates a dataset and saves it on disk. Each line of the output file is a JSON dump.
The output file itself is not a valid JSON object.

Parameters:
    collection (BioC Collection): a BioC collection objects containing raw data.
    dataset_split (enum DatasetSplit): one of DatasetSplit.{train, validation, test}.
    output_path (string): directory where json files are stored.
    include_overlapping_spans (bool): whether to include overlapping spans, default True
    only_overlapping_spans (bool): whether to include only overlapping spans, default False
'''
def create_dataset(
    collection: Any,
    dataset_split: Any,
    output_path: str,
    include_overlapping_spans: bool = True,
    only_overlapping_spans: bool = False,
):
    if dataset_split == DatasetSplit.train:
        target_pmid_set = train_pmid
    elif dataset_split == DatasetSplit.validation:
        target_pmid_set = validation_pmid
    else:
        target_pmid_set = test_pmid
        
    output_file = os.path.join(output_path, '{}.json'.format(dataset_split.name))
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        
    overlap_sentence_count = 0
    total_sentence_count = 0
    
    with open(output_file, 'w+') as fout:
        for document in collection.documents:
            pmid = document.id
            if pmid not in target_pmid_set:
                continue
            for passage in document.passages:
                for sentence in passage.sentences:
                    data = {}
                    data['pmid'] = pmid
                    data['tokens'] = extract_tokens(sentence)
                    participants = extract_pio_labels(sentence, PicoType.PARTICIPANTS)
                    interventions = extract_pio_labels(sentence, PicoType.INTERVENTIONS)
                    outcomes = extract_pio_labels(sentence, PicoType.OUTCOMES)
                    
                    if not participants or not interventions or not outcomes:
                        warnings.warn('Empty annotations in abstract {}.'.format(pmid))
                        continue
                    
                    total_sentence_count += 1
                    if check_overlapping(sentence):
                        overlap_sentence_count += 1
                        if not include_overlapping_spans:
                            continue
                    elif only_overlapping_spans:
                        continue
                    
                    
                    labels = []
                    for p, i, o in zip(participants, interventions, outcomes):
                        label = 0
                        if p not in ['O', None]:
                            label = label | PicoType.PARTICIPANTS.value
                        if i not in ['O', None]:
                            label = label | PicoType.INTERVENTIONS.value
                        if o not in ['O', None]:
                            label = label | PicoType.OUTCOMES.value
                        labels.append(label)
                    data['labels'] = labels
                    fout.write('{}\n'.format(json.dumps(data)))
    print('{}: {}/{} sentences have overlapping spans.'.format(
        dataset_split.name, overlap_sentence_count, total_sentence_count))

In [12]:
create_dataset(collection, DatasetSplit.train, OUTPUT_PATH)
create_dataset(collection, DatasetSplit.validation, OUTPUT_PATH)
create_dataset(collection, DatasetSplit.test, OUTPUT_PATH)

train: 4490/48960 sentences have overlapping spans.
validation: 231/2542 sentences have overlapping spans.
test: 115/2042 sentences have overlapping spans.


In [13]:
# data with no overlapping spans
create_dataset(collection, DatasetSplit.train, os.path.join(OUTPUT_PATH, 'no_overlap'), include_overlapping_spans=False)
create_dataset(collection, DatasetSplit.test, os.path.join(OUTPUT_PATH, 'no_overlap'), include_overlapping_spans=False)

train: 4490/48960 sentences have overlapping spans.
test: 115/2042 sentences have overlapping spans.


In [14]:
# data with overlapping spans only
create_dataset(collection, DatasetSplit.test, os.path.join(OUTPUT_PATH, 'overlap_only'), only_overlapping_spans=True)

test: 115/2042 sentences have overlapping spans.


In [15]:
class Span:
    def __init__(self, start, length):
        self.start = start
        self.length = length
        self.end = start + length

    def __repr__(self):
        return f'Span(start:{self.start}, length:{self.length})'

    def __eq__(self, other):
        return (self.start, int(self.length)) == (other.start, int(other.length))
    
    def __hash__(self):
        return hash((self.start, int(self.length)))

def extract_pico_spans(sentence, pico_type):
    label_name = 'starting_spans/{}'.format(pico_type.name.lower())
    if not sentence.annotations:
        return []
    if label_name not in sentence.annotations[0].infons:
        return []
    
    labels = [int(a.infons[label_name]) for a in sentence.annotations]
    span_start = [0 for l in labels]
    span_length = [0.0 for l in labels]
    
    for i, label in enumerate(labels):
        if label > 0:
            if i==0 or labels[i-1] <= 0:
                span_start[i] = 1
                start = i
            span_length[start] += 1
            
    spans = []
    for i in range(len(span_start)):
        if span_start[i]:
            s = Span(start=i, length=span_length[i])
            spans.append(s)
    return spans

In [16]:
def synthesize_span_util(span_a, span_b):
    if span_a.start > span_b.start:
        span_a, span_b = span_b, span_a
    spans = []
    if span_b.start >= span_a.end:
        spans.append(Span(span_a.start, span_b.end - span_a.start))
    else:
        if span_a.start != span_b.start and span_a.start != span_b.end:
            spans.append(Span(span_b.start, span_a.end - span_b.start))
            spans.append(Span(span_a.start, span_b.end - span_a.start))
    return spans
        

def synthesize_spans(span_list_a, span_list_b):
    spans = []
    for a in span_list_a:
        for b in span_list_b:
            spans += synthesize_span_util(a, b)
    return spans
            

In [17]:
'''
Creates a training/validation/test dataset for span classification.

Parameters:
    collection (BioC Collection): a BioC collection objects containing raw data.
    dataset_split (enum DatasetSplit): one of DatasetSplit.{train, validation, test}.
    output_path (string): directory where json files are stored.
'''
def create_span_clf_dataset(collection: Any, dataset_split: Any, output_path: str):
    if dataset_split == DatasetSplit.train:
        target_pmid_set = train_pmid
    elif dataset_split == DatasetSplit.validation:
        target_pmid_set = validation_pmid
    else:
        target_pmid_set = test_pmid
        
    output_file = os.path.join(output_path, '{}_span_clf.json'.format(dataset_split.name))
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    count = defaultdict(lambda: 0)
    with open(output_file, 'w+') as fout:
        for document in collection.documents:
            pmid = document.id
            if pmid not in target_pmid_set:
                continue
            for passage in document.passages:
                for sentence in passage.sentences:
                    tokens = extract_tokens(sentence)
                    # PICO spans
                    for pico_type in list(PicoType):
                        label = pico_type.name
                        spans = extract_pico_spans(sentence, pico_type)
                        if pico_type == PicoType.PARTICIPANTS:
                            participants_spans = spans
                        if pico_type == PicoType.INTERVENTIONS:
                            interventions_spans = spans
                        if pico_type == PicoType.OUTCOMES:
                            outcomes_spans = spans
                        for span in spans:
                            data = {}
                            data['pmid'] = pmid
                            start, end = int(span.start), int(span.start + span.length)
                            data['tokens'] = tokens[start:end]
                            data['PARTICIPANTS'] = False
                            data['INTERVENTIONS'] = False
                            data['OUTCOMES'] = False
                            data[pico_type.name] = True
                            fout.write('{}\n'.format(json.dumps(data)))
                            count[label] += 1
                    
                    # synthesize spans using boundaries
                    synthesized_spans = synthesize_spans(
                        participants_spans, interventions_spans
                    ) + synthesize_spans(
                        participants_spans, outcomes_spans
                    ) + synthesize_spans(
                        interventions_spans, outcomes_spans
                    )
                    synthesized_spans = list(set(synthesized_spans))
                    random.shuffle(synthesized_spans)
                    sample_limit = 1
                    for span in synthesized_spans[:sample_limit]:
                        if (span not in participants_spans 
                            and span not in interventions_spans 
                            and span not in outcomes_spans):
                            data = {}
                            data['pmid'] = pmid
                            start, end = int(span.start), int(span.start + span.length)
                            data['tokens'] = tokens[start:end]
                            data['PARTICIPANTS'] = False
                            data['INTERVENTIONS'] = False
                            data['OUTCOMES'] = False
                            fout.write('{}\n'.format(json.dumps(data)))
                            count['SYNTHESIZED'] += 1
                    
                    # randomly sampled spans
#                     if not synthesized_spans:
#                         boundary = [
#                             random.randint(0, len(tokens)-1),
#                             random.randint(0, len(tokens)-1),
#                         ]
#                         data = {}
#                         data['pmid'] = pmid
#                         data['tokens'] = tokens[min(boundary) : max(boundary)]
#                         data['PARTICIPANTS'] = False
#                         data['INTERVENTIONS'] = False
#                         data['OUTCOMES'] = False
#                         fout.write('{}\n'.format(json.dumps(data)))
#                         count['RANDOM'] += 1
                    
    print('{}： {}'.format(dataset_split.name, ', '.join([f'{k}: {count[k]}' for k in count])))
                    

In [18]:
random.seed(0)
create_span_clf_dataset(collection, DatasetSplit.train, OUTPUT_PATH)
create_span_clf_dataset(collection, DatasetSplit.validation, OUTPUT_PATH)
create_span_clf_dataset(collection, DatasetSplit.test, OUTPUT_PATH)

train： INTERVENTIONS: 31177, OUTCOMES: 31844, PARTICIPANTS: 17011
validation： PARTICIPANTS: 941, INTERVENTIONS: 1682, OUTCOMES: 1710
test： PARTICIPANTS: 643, INTERVENTIONS: 1726, OUTCOMES: 1833
