In [1]:
from collections import defaultdict
from enum import Enum
import json
import os
import random

In [2]:
def load_pmids(folder):
    return [f.rstrip('.txt') for f in os.listdir(folder) if f.endswith('.txt')]

In [3]:
def load_annotations(folder, pmid):
    annotations = []
    with open(os.path.join(folder, f'{pmid}.ann'), 'r') as f:
        for line in f:
            elements = line.rstrip().split('\t')
            assert len(elements) == 3, f'Bad format annotation: {folder}, {pmid}, {line}'
            tokens = elements[-1].split()
            label = elements[1][0]
            annotations.append((label, tokens))
    return annotations

In [4]:
def load_sentences(folder, pmid):
    sentences = []
    with open(os.path.join(folder, f'{pmid}.txt'), 'r') as f:
        for line in f:
            tokens = line.rstrip().split()
            sentences.append(tokens)
    return sentences

In [5]:
label_encoding = defaultdict(lambda: 0)
label_encoding['P'] = 4
label_encoding['I'] = 2
label_encoding['O'] = 1

def encode_pico_labels(tokens, annotations):
    labels = [0 for _ in tokens]
    for annotation in annotations:
        label, tokens_in_span = annotation
        span_length = len(tokens_in_span)
        for i in range(len(tokens)):
            if tokens[i: i+span_length] == tokens_in_span:
                encoding = label_encoding[label]
                labels[i: i+span_length] = [
                    l | encoding for l in labels[i: i+span_length]]
    return labels

In [6]:
test_folder = 'data/raw/brat/AD'
test_pmids = load_pmids(test_folder)

In [7]:
'''
Test content: pmid: .../brat/AD/10354658

Text:
Title : Selegiline in the treatment of Alzheimer ' s disease : a long - term randomized placebo - controlled trial . 
Czech and Slovak Senile Dementia of Alzheimer Type Study Group . 
METHODS : Long - term , double - blind , placebo - controlled trial . 
Seven cities ( 1 or 2 nursing homes in each city ) in the Czech and Slovak Republics . 
A total of 173 nursing - home residents fulfilling the DSM - III criteria for mild to moderate Alzheimer ' s disease . 
Selegiline ( 10 mg per day ) or placebo ( both including 50 mg ascorbic acid ) administered for 24 weeks . 
Clinical Global Impressions scale and Nurses Observation Scale for Inpatient Evaluation at baseline and at weeks 6 , 12 and 24 ; Clock Drawing Test at baseline and 24 weeks , results of which were evaluated as normal or pathologic , and quantitatively on a modified 6 - point scale ; Sternberg ' s Memory Scanning test at baseline and at weeks 6 , 12 and 24 ; Mini Mental State Examination , and electroencephalogram at baseline and 24 weeks ; Structured Adverse Effects Rating Scale ; physical , laboratory , hematological 
and electrocardiographic examinations at baseline and weeks 12 and 24 . 

Annotation:
T1	I 8 18	Selegiline
T2	P 343 459	A total of 173 nursing - home residents fulfilling the DSM - III criteria for mild to moderate Alzheimer ' s disease
T3	I 463 473	Selegiline
T4	C 495 502	placebo
T5	O 571 604	Clinical Global Impressions scale
T6	O 609 658	Nurses Observation Scale for Inpatient Evaluation
T7	O 700 718	Clock Drawing Test
T8	O 855 889	Sternberg ' s Memory Scanning test
T9	O 931 960	Mini Mental State Examination
T10	O 967 987	electroencephalogram
T11	O 1015 1054	Structured Adverse Effects Rating Scale
T12	O 1057 1065	physical
T13	O 1068 1078	laboratory
T14	O 1081 1094	hematological
T15	O 1100 1133	electrocardiographic examinations

'''
test_pmid = test_pmids[0]
test_annotations = load_annotations(test_folder, test_pmid)
test_sentences = load_sentences(test_folder, test_pmid)

In [8]:
test_sentences[4]

['A',
 'total',
 'of',
 '173',
 'nursing',
 '-',
 'home',
 'residents',
 'fulfilling',
 'the',
 'DSM',
 '-',
 'III',
 'criteria',
 'for',
 'mild',
 'to',
 'moderate',
 'Alzheimer',
 "'",
 's',
 'disease',
 '.']

In [9]:
encode_pico_labels(test_sentences[4], test_annotations)

[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0]

In [10]:
def format_dataset_in_json(input_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    with open(os.path.join(output_folder, 'test_all.json'), 'w+') as fout:
        pmids = load_pmids(input_folder)
        for pmid in pmids:
            annotations = load_annotations(input_folder, pmid)
            sentences = load_sentences(input_folder, pmid)
            for sentence in sentences:
                data = {}
                data['pmid'] = pmid
                data['tokens'] = sentence
                data['labels'] = encode_pico_labels(sentence, annotations)
                fout.write('{}\n'.format(json.dumps(data)))

In [11]:
format_dataset_in_json('data/raw/brat/AD', 'data/bioc/json/brat/AD')
format_dataset_in_json('data/raw/brat/COVID', 'data/bioc/json/brat/COVID')
format_dataset_in_json('data/raw/brat/EBM-NLP', 'data/bioc/json/brat/EBM-NLP')

In [12]:
def load_train_val_test_pmids(input_folder):
    pmids = load_pmids(input_folder)
    total_pmids = len(pmids)
    pmids_train = pmids[:int(total_pmids*.45)]
    pmids_val = pmids[int(total_pmids*.45): int(total_pmids*.5)]
    pmids_test = pmids[int(total_pmids*.5):]
    return pmids_train, pmids_val, pmids_test

def aggregate_brat_datasets(input_folders, pmid_lists, output_fn):
    with open(output_fn, 'w+') as fout:
        for input_folder, pmids in zip(input_folders, pmid_lists):
            for pmid in pmids:
                annotations = load_annotations(input_folder, pmid)
                sentences = load_sentences(input_folder, pmid)
                for sentence in sentences:
                    data = {}
                    data['pmid'] = pmid
                    data['tokens'] = sentence
                    data['labels'] = encode_pico_labels(sentence, annotations)
                    fout.write('{}\n'.format(json.dumps(data)))

def create_brat_train_and_val_set():
    ad_folder = 'data/raw/brat/AD'
    ad_pmids_train, ad_pmids_val, _ = load_train_val_test_pmids(ad_folder)
    
    covid_folder = 'data/raw/brat/COVID'
    covid_pmids_train, covid_pmids_val, _ = load_train_val_test_pmids(covid_folder)
    
    ebm_nlp_folder = 'data/raw/brat/EBM-NLP'
    ebm_nlp_pmids_train, ebm_nlp_pmids_val, _ = load_train_val_test_pmids(ebm_nlp_folder)
    
    train_file = 'data/bioc/json/brat/train.json'
    val_file = 'data/bioc/json/brat/validation.json'
    
    aggregate_brat_datasets(
        [ad_folder, covid_folder, ebm_nlp_folder],
        [ad_pmids_train, covid_pmids_train, ebm_nlp_pmids_train],
        train_file,
    )
    
    aggregate_brat_datasets(
        [ad_folder, covid_folder, ebm_nlp_folder],
        [ad_pmids_val, covid_pmids_val, ebm_nlp_pmids_val],
        val_file,
    )

In [13]:
def create_brat_test_sets_util(input_folder, pmids, output_folder):
    with open(os.path.join(output_folder, 'test.json'), 'w+') as fout:
        for pmid in pmids:
            annotations = load_annotations(input_folder, pmid)
            sentences = load_sentences(input_folder, pmid)
            for sentence in sentences:
                data = {}
                data['pmid'] = pmid
                data['tokens'] = sentence
                data['labels'] = encode_pico_labels(sentence, annotations)
                fout.write('{}\n'.format(json.dumps(data)))

def create_brat_test_sets():
    ad_folder = 'data/raw/brat/AD'  
    covid_folder = 'data/raw/brat/COVID'    
    ebm_nlp_folder = 'data/raw/brat/EBM-NLP'
    
    output_dirs = [
        'data/bioc/json/brat/AD',
        'data/bioc/json/brat/COVID',
        'data/bioc/json/brat/EBM-NLP',
    ]
    
    for input_dir, output_dir in zip([ad_folder, covid_folder, ebm_nlp_folder], output_dirs):
        _, _, pmids_test = load_train_val_test_pmids(input_dir)
        create_brat_test_sets_util(input_dir, pmids_test, output_dir)

In [14]:
# create_brat_train_and_val_set()

In [15]:
# create_brat_test_sets()

In [16]:
def create_brat_ad_train_and_val_set():
    ad_folder = 'data/raw/brat/AD'
    ad_pmids_train, ad_pmids_val, _ = load_train_val_test_pmids(ad_folder)
    
    train_file = 'data/bioc/json/brat/train_ad.json'
    val_file = 'data/bioc/json/brat/validation_ad.json'
    
    aggregate_brat_datasets(
        [ad_folder],
        [ad_pmids_train],
        train_file,
    )
    
    aggregate_brat_datasets(
        [ad_folder],
        [ad_pmids_val],
        val_file,
    )

In [17]:
create_brat_ad_train_and_val_set()

In [18]:
class PicoType(Enum):
    PARTICIPANTS = 4
    INTERVENTIONS = 2
    OUTCOMES = 1

class Span:
    def __init__(self, start, length):
        self.start = start
        self.length = length
        self.end = start + length

    def __repr__(self):
        return f'Span(start:{self.start}, length:{self.length})'

    def __eq__(self, other):
        return (self.start, int(self.length)) == (other.start, int(other.length))
    
    def __hash__(self):
        return hash((self.start, int(self.length)))

def extract_pico_spans_from_encodings(label_encodings, pico_type):
    if not label_encoding:
        return []
    
    labels = [l & pico_type.value for l in label_encodings]
    span_start = [0 for l in labels]
    span_length = [0.0 for l in labels]
    
    for i, label in enumerate(labels):
        if label > 0:
            if i==0 or labels[i-1] <= 0:
                span_start[i] = 1
                start = i
            span_length[start] += 1
            
    spans = []
    for i in range(len(span_start)):
        if span_start[i]:
            s = Span(start=i, length=span_length[i])
            spans.append(s)
    return spans

In [19]:
def synthesize_span_util(span_a, span_b):
    if span_a.start > span_b.start:
        span_a, span_b = span_b, span_a
    spans = []
    if span_b.start >= span_a.end:
        spans.append(Span(span_a.start, span_b.end - span_a.start))
    else:
        if span_a.start != span_b.start and span_a.start != span_b.end:
            spans.append(Span(span_b.start, span_a.end - span_b.start))
            spans.append(Span(span_a.start, span_b.end - span_a.start))
    return spans
        

def synthesize_spans(span_list_a, span_list_b):
    spans = []
    for a in span_list_a:
        for b in span_list_b:
            spans += synthesize_span_util(a, b)
    return spans
            

In [20]:
def create_brat_span_clf_datasets_util(input_folders, pmid_lists, output_fn):
    output_path = os.path.dirname(output_fn)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    count = defaultdict(lambda: 0)
    with open(output_fn, 'w+') as fout:
        for input_folder, pmids in zip(input_folders, pmid_lists):
            for pmid in pmids:
                annotations = load_annotations(input_folder, pmid)
                sentences = load_sentences(input_folder, pmid)
                for sentence in sentences:
                    label_encoding = encode_pico_labels(sentence, annotations)
                    for pico_type in list(PicoType):
                        label = pico_type.name
                        spans = extract_pico_spans_from_encodings(label_encoding, pico_type)
                        if pico_type == PicoType.PARTICIPANTS:
                            participants_spans = spans
                        if pico_type == PicoType.INTERVENTIONS:
                            interventions_spans = spans
                        if pico_type == PicoType.OUTCOMES:
                            outcomes_spans = spans
                        for span in spans:
                            data = {}
                            data['pmid'] = pmid
                            start, end = int(span.start), int(span.start + span.length)
                            data['tokens'] = sentence[start:end]
                            data['PARTICIPANTS'] = False
                            data['INTERVENTIONS'] = False
                            data['OUTCOMES'] = False
                            data[pico_type.name] = True
                            fout.write('{}\n'.format(json.dumps(data)))
                            count[label] += 1
                    
                    # synthesize spans using boundaries
                    synthesized_spans = synthesize_spans(
                        participants_spans, interventions_spans
                    ) + synthesize_spans(
                        participants_spans, outcomes_spans
                    ) + synthesize_spans(
                        interventions_spans, outcomes_spans
                    )
                    synthesized_spans = list(set(synthesized_spans))
                    random.shuffle(synthesized_spans)
                    sample_limit = 1
                    for span in synthesized_spans[:sample_limit]:
                        if (span not in participants_spans 
                            and span not in interventions_spans 
                            and span not in outcomes_spans):
                            data = {}
                            data['pmid'] = pmid
                            start, end = int(span.start), int(span.start + span.length)
                            data['tokens'] = sentence[start:end]
                            data['PARTICIPANTS'] = False
                            data['INTERVENTIONS'] = False
                            data['OUTCOMES'] = False
                            fout.write('{}\n'.format(json.dumps(data)))
                            count['SYNTHESIZED'] += 1
                            fout.write('{}\n'.format(json.dumps(data)))
    print('{}： {}'.format(output_fn, ', '.join([f'{k}: {count[k]}' for k in count])))


def create_brat_span_clf_datasets():
    ad_folder = 'data/raw/brat/AD'
    ad_pmids_train, ad_pmids_val, _ = load_train_val_test_pmids(ad_folder)
    
    covid_folder = 'data/raw/brat/COVID'
    covid_pmids_train, covid_pmids_val, _ = load_train_val_test_pmids(covid_folder)
    
    ebm_nlp_folder = 'data/raw/brat/EBM-NLP'
    ebm_nlp_pmids_train, ebm_nlp_pmids_val, _ = load_train_val_test_pmids(ebm_nlp_folder)
    
    train_file = 'data/bioc/json/brat/train_span_clf.json'
    val_file = 'data/bioc/json/brat/validation_span_clf.json'
    
    create_brat_span_clf_datasets_util(
        [ad_folder, covid_folder, ebm_nlp_folder],
        [ad_pmids_train, covid_pmids_train, ebm_nlp_pmids_train],
        train_file,
    )
    
    create_brat_span_clf_datasets_util(
        [ad_folder, covid_folder, ebm_nlp_folder],
        [ad_pmids_val, covid_pmids_val, ebm_nlp_pmids_val],
        val_file,
    )

In [21]:
# random.seed(0)
# create_brat_span_clf_datasets()

In [22]:
# def create_brat_ad_span_clf_datasets():
#     ad_folder = 'data/raw/brat/AD'
#     ad_pmids_train, ad_pmids_val, _ = load_train_val_test_pmids(ad_folder)
    
#     train_file = 'data/bioc/json/brat/train_ad_span_clf.json'
#     val_file = 'data/bioc/json/brat/validation_ad_span_clf.json'
    
#     create_brat_span_clf_datasets_util(
#         [ad_folder],
#         [ad_pmids_train],
#         train_file,
#     )
    
#     create_brat_span_clf_datasets_util(
#         [ad_folder],
#         [ad_pmids_val],
#         val_file,
#     )

In [23]:
# create_brat_ad_span_clf_datasets()

In [29]:
def count_entities(input_folders):
    for input_folder in input_folders:
        count = defaultdict(lambda: 0)
        pmid_lists = load_train_val_test_pmids(input_folder)
        for pmids in pmid_lists:
            print(len(pmids))
            for pmid in pmids:
                annotations = load_annotations(input_folder, pmid)
                sentences = load_sentences(input_folder, pmid)
                for sentence in sentences:
                    label_encoding = encode_pico_labels(sentence, annotations)
                    for pico_type in list(PicoType):
                        label = pico_type.name
                        spans = extract_pico_spans_from_encodings(label_encoding, pico_type)
                        if pico_type == PicoType.PARTICIPANTS:
                            participants_spans = spans
                        if pico_type == PicoType.INTERVENTIONS:
                            interventions_spans = spans
                        if pico_type == PicoType.OUTCOMES:
                            outcomes_spans = spans
                        for span in spans:
                            data = {}
                            data['pmid'] = pmid
                            start, end = int(span.start), int(span.start + span.length)
                            data['tokens'] = sentence[start:end]
                            data['PARTICIPANTS'] = False
                            data['INTERVENTIONS'] = False
                            data['OUTCOMES'] = False
                            data[pico_type.name] = True
                            count[label] += 1
        print('{}： {}'.format(input_folder, ', '.join([f'{k}: {count[k]}' for k in count])))

In [30]:
ad_folder = 'data/raw/brat/AD'  
covid_folder = 'data/raw/brat/COVID'    
ebm_nlp_folder = 'data/raw/brat/EBM-NLP'

In [31]:
count_entities([ad_folder])

67
8
75
data/raw/brat/AD： INTERVENTIONS: 490, PARTICIPANTS: 218, OUTCOMES: 656


In [32]:
count_entities([covid_folder])

67
8
75
data/raw/brat/COVID： PARTICIPANTS: 263, INTERVENTIONS: 652, OUTCOMES: 619
