# GENIA data

Source: http://www.geniaproject.org/genia-corpus/term-corpus

For GENIA, the authors use `GENIAcorpus3.02p3`, and follow the train/dev/test split of previous works (Finkel and Manning, 2009; Lu and Roth, 2015) i.e.:
1. Split first $81\%$, subsequent $9\%$, and last $10\%$ as train, dev and test set, respectively.
2. Collapse all DNA, RNA, and protein subtypes into DNA, RNA, and protein, keeping cell line and cell type.
3. Removing other entity types, resulting in 5 entity types.

For this reason, we are going to generate three different files: `train.genia.json`, `valid.genia.json` and `test.genia.json`.

In [1]:
from transformers import BertTokenizer
from tokenizers import BertWordPieceTokenizer

import xml.etree.ElementTree as ET
import numpy as np
import re
import os
import json
import copy

np.random.seed(42)

## General functions

Once we have preprocessed the data, the file should be a list of items where each item looks as follows:

```
{
  "tokens": ["token0", "token1", "token2"],
  "entities": [
    {
      "entity_type": "PER", 
      "span": [0, 1],
    },
    {
      "entity_type": "ORG", 
      "span": [2, 3],
    },
  ]
}
```

Since we will use GloVe embeddings later, we are going to use the Spacy tokenizer. With this tokenizer, we obtain the same tokens as GloVe (see https://spacy.io/models/en). <u>Note that BERT will use a different tokenizer!</u>

You may need to install the English dictionary before using Spacy:
`python -m spacy download en_core_web_lg`

In [2]:
def get_tokenizer(artifacts_path='artifacts/', bert_name='dmis-lab/biobert-v1.1'):
    slow_tokenizer = BertTokenizer.from_pretrained(bert_name)

    save_path = '%s%s/' % (artifacts_path, bert_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
        slow_tokenizer.save_pretrained(save_path)

    # We can already use the Slow Tokenizer, but its implementation in Rust is much faster.
    tokenizer = BertWordPieceTokenizer('%svocab.txt' % save_path, lowercase=True)
    
    return tokenizer

In [3]:
def tokenize_text(tokenizer, text, lower=False):
    if lower:
        text = text.lower()
    
    encoded = tokenizer.encode(text)
    
    tokens = encoded.tokens[1:-1]
    spans = encoded.offsets[1:-1]
    
    spans = [[x[0], x[1]-1] for x in spans]
    
    i = len(tokens)
    while i >= 0:
        i -= 1
        if re.search(r"^##.+", tokens[i]):
            token = tokens[i][2:]
            tokens[i-1] += token
            spans[i-1][1] += len(token)
            del tokens[i]
            del spans[i]

    return tokens, spans

In [4]:
tokenizer = get_tokenizer()

In [5]:
# Debug
text = 'In primary T lymphocytes we show that CD28 ligation leads to the rapid intracellular formation of reactive oxygen intermediates (ROIs) which are required for CD28-mediated activation of the NF-kappa B/CD28-responsive complex and IL-2 expression.'

i = 1
for token, span in zip(*tokenize_text(tokenizer, text, True)):
    print(token, span)
    
    if i % 10 == 0:
        print()
    i += 1

in [0, 1]
primary [3, 9]
t [11, 11]
lymphocytes [13, 23]
we [25, 26]
show [28, 31]
that [33, 36]
cd28 [38, 41]
ligation [43, 50]
leads [52, 56]

to [58, 59]
the [61, 63]
rapid [65, 69]
intracellular [71, 83]
formation [85, 93]
of [95, 96]
reactive [98, 105]
oxygen [107, 112]
intermediates [114, 126]
( [128, 128]

rois [129, 132]
) [133, 133]
which [135, 139]
are [141, 143]
required [145, 152]
for [154, 156]
cd28 [158, 161]
- [162, 162]
mediated [163, 170]
activation [172, 181]

of [183, 184]
the [186, 188]
nf [190, 191]
- [192, 192]
kappa [193, 197]
b [199, 199]
/ [200, 200]
cd28 [201, 204]
- [205, 205]
responsive [206, 215]

complex [217, 223]
and [225, 227]
il [229, 230]
- [231, 231]
2 [232, 232]
expression [234, 243]
. [244, 244]


## Preprocess

In [6]:
def get_inner_xml(element):
    return (element.text or '') + ''.join(ET.tostring(e, 'unicode') for e in element)

def parse_cons_attributes(match):
    attributes = {}
    for item in re.findall(r'[a-z]+="[^"]+"\s*', match):
        m = re.search(r'^([^=]+)="([^"]+)"', item.strip())
        name = m[1]
        value = m[2]
        attributes[name] = value
    return attributes

def reverse_enumerate(L):
    i = len(L)
    while i > 0:
        i -= 1
        yield i, L[i]

def get_raw_entities(xml_sentences):
    data = []
    regex_open = r'<cons\s*((?:[a-z]+="[^"]+"\s*)*)>'
    regex_close = r'</cons>'
    
    for sentence in xml_sentences:
        entities = []
        inner_xml = get_inner_xml(sentence)

        m_open = re.search(regex_open, inner_xml)
        m_close = re.search(regex_close, inner_xml)
        while m_open is not None or m_close is not None:
            # Check which regex matched first
            if m_open or m_close:
                """
                print(inner_xml)
                if m_open:
                    print('Open:', m_open.span())
                if m_close:
                    print('Close:', m_close.span())
                print('-'*10)
                """
            
            if m_close is None or m_open is not None and m_open.span()[0] < m_close.span()[0]:
                inner_xml = re.sub(regex_open, '', inner_xml, 1)
                cons_attributes = parse_cons_attributes(m_open.group(1))
                entities.append({'span': [m_open.span()[0], -1], 'attributes': cons_attributes})
            else:
                inner_xml = re.sub(regex_close, '', inner_xml, 1)
                for _, e in reverse_enumerate(entities):
                    # Add close_span to the latest non-closed entity
                    if e['span'][1] == -1:
                        span_start = e['span'][0]
                        e['span'] = [span_start, m_close.span()[0]]
                        break

            m_open = re.search(regex_open, inner_xml)
            m_close = re.search(regex_close, inner_xml)

        data.append({'tokens': inner_xml, 'entities': entities})
    
    return data

def split_sem_lex_info(lexsem_types):
    lexsem_types = re.sub(r'^\((AND|OR|AND/OR|AS\_WELL\_AS|BUT\_NOT) ', '', lexsem_types)
    lexsem_types = re.sub(r'\)$', '', lexsem_types)
    lexsem_types = lexsem_types.strip().split()
    return lexsem_types

def expand_nested_attributes(entities):
    """
    Expands 'sem' attribute from the parent entity to the children.
    """
    for e_parent in entities:
        if 'sem' not in e_parent['attributes'] or not re.search(r'^\(AND', e_parent['attributes']['sem']):
            continue
        
        lex_types = split_sem_lex_info(e_parent['attributes']['lex'])
        sem_types = split_sem_lex_info(e_parent['attributes']['sem'])
        
        for i, lex_type in enumerate(lex_types):
            for _ in re.findall(r',', lex_type):
                sem_types = sem_types[:i] + [sem_types[i]] + sem_types[i:]
        
        i_sem = 0
        for e_child in entities:
            if 'sem' in e_child['attributes']:
                continue
            elif e_parent['span'][0] > e_child['span'][0] or e_parent['span'][1] < e_child['span'][1]:
                continue
            elif re.search(r'^\*', e_child['attributes']['lex']):
                continue
            
            if len(sem_types) > i_sem:
                e_child['attributes']['sem'] = sem_types[i_sem]
            else:
                e_child['attributes']['sem'] = 'G#other_name'

            i_sem += 1
        
    return entities

def get_entity_type(sem):
    if re.search(r'^G\#other', sem):
        return None
    elif re.search(r'^G\#RNA', sem):
        return 'RNA'
    elif re.search(r'^G\#DNA', sem):
        return 'DNA'
    elif re.search(r'^G\#protein', sem):
        return 'protein'
    elif re.search(r'^G\#cell\_line', sem):
        return 'cell_line'
    elif re.search(r'^G\#cell\_type', sem):
        return 'cell_type'
    else:
        return None
        #raise Exception('Unknown sem: %s' % sem)

    return sem

def get_norm_entities(entities):
    """
    Apply these two preprocess steps:
    - Collapses all DNA, RNA, and protein subtypes into DNA, RNA, and protein, keeping cell line and cell type.
    - Removes other entity types, resulting in 5 entity types.
    """
    
    i = len(entities)
    while i > 0:
        i -= 1
        e = entities[i]

        if 'sem' not in e['attributes']:
            del entities[i]
            continue
        
        if re.search(r'^\((AND|OR|AND/OR|AS\_WELL\_AS|BUT\_NOT) ', e['attributes']['sem']):
            # Parent entity
            sem_types = e['attributes']['sem']
            sem_types = re.sub(r'^\([^\s]+', '', sem_types)
            sem_types = re.sub(r'\)$', '', sem_types)
            sem = sem_types.strip().split()[0]
        else:
            sem = e['attributes']['sem']
        
        entity_type = get_entity_type(sem)
        if entity_type is None:
            # Remove this entity
            del entities[i]
        else:
            del e['attributes']
            e['entity_type'] = entity_type

    return entities

def transform_text_spans(text, entity_list, tokenizer, debug=False):
    tokens, spans = tokenize_text(tokenizer, text)
    for i, entity in enumerate(entity_list):
        if debug:
            print('-' * 10)
            print('-' * 10)
            print(entity)
        
        span_start, span_end = entity['span']
        new_span_start = -1
        new_span_end = -1
        
        j = 0
        while j < len(spans):
            token_span = spans[j]
            if debug:
                print('-' * 10)
                print('Entity:', span_start, span_end)
                print('Token:', token_span)
            
            if span_start > token_span[0] and span_start < token_span[1]:
                # Split token
                tokens = tokens[:j] + [tokens[j][:span_start], tokens[j][span_start:]] + tokens[(j+1):]
                spans = spans[:j] + [[token_span[0], span_start-1], [span_start, token_span[1]]] + spans[(j+1):]
                token_span = spans[j]
            
            if span_end > token_span[0] and span_end < token_span[1]:
                # Split token
                tokens = tokens[:j] + [tokens[j][:span_end], tokens[j][span_end:]] + tokens[(j+1):]
                spans = spans[:j] + [[token_span[0], span_end-1], [span_end, token_span[1]]] + spans[(j+1):]
                token_span = spans[j]

            if span_start == token_span[0]:
                new_span_start = j
            elif span_start == token_span[1]:
                new_span_start = j + 1
            elif span_start > token_span[1] and len(spans) > j+1 and span_start < spans[j+1][0]:
                new_span_start = j + (span_start - token_span[1])

            if span_end < token_span[1] and span_end > token_span[0]:
                new_span_end = j - (token_span[1] - span_end)
                break
            elif span_end > token_span[1] and len(spans) > j+1 and span_end < spans[j+1][0]:
                new_span_end = j + (span_end - token_span[1])
                break
            elif span_end == token_span[0]:
                new_span_end = j - 1
                break
            elif span_end == token_span[1]:
                new_span_end = j
                break
            elif span_end > token_span[1] and len(spans) == j+1:
                new_span_end = j
                break
            
            j += 1
        
        if debug:
            print('-' * 10)
            print(new_span_start, new_span_end)

        if new_span_start != -1 and new_span_end != -1:
            entity['span'] = [new_span_start, new_span_end]

    return tokens, entity_list

In [7]:
def parse_genia(filepath, tokenizer, total_layers, no_entity='O'):
    """
    Main function to parse the GENIA corpus.
    
    Arguments:
    - filepath (str): path to the GENIAcorpus3.02.xml file.
    
    Returns:
    - 
    """
    genia_data = []
    root = ET.parse(filepath).getroot()

    #xml = '<set><article><articleinfo><bibliomisc>MEDLINE:98208270</bibliomisc></articleinfo><title></title><abstract><sentence>In <cons lex="primary_T_lymphocyte" sem="G#cell_type">primary T lymphocytes</cons> we show that <cons lex="CD28" sem="G#protein_molecule">CD28</cons> ligation leads to the rapid intracellular formation of <cons lex="reactive_oxygen_intermediate" sem="G#inorganic">reactive oxygen intermediates</cons> (<cons lex="ROI" sem="G#inorganic">ROIs</cons>) which are required for <cons lex="CD28-mediated_activation" sem="G#other_name"><cons lex="CD28" sem="G#protein_molecule">CD28</cons>-mediated activation</cons> of the <cons lex="NF-kappa_B" sem="G#protein_molecule">NF-kappa B</cons>/<cons lex="CD28-responsive_complex" sem="G#protein_complex"><cons lex="CD28" sem="G#protein_molecule">CD28</cons>-responsive complex</cons> and <cons lex="IL-2_expression" sem="G#other_name"><cons lex="IL-2" sem="G#protein_molecule">IL-2</cons> expression</cons>.</sentence></abstract></article></set>'
    #root = ET.fromstring(xml)
    
    for child in root:
        if child.tag != 'article':
            continue

        child_data = get_raw_entities(child.find('title').findall('sentence'))
        child_data += get_raw_entities(child.find('abstract').findall('sentence'))

        for c in child_data:
            c['entities'] = expand_nested_attributes(c['entities'])
            c['entities'] = get_norm_entities(c['entities'])
            
            c['text'] = c['tokens']
            c['tokens'], c['entities'] = transform_text_spans(c['tokens'], c['entities'], tokenizer)

        genia_data += child_data
    
    # Obtain dictionary of entity types
    genia_et_freq = {}
    for item in genia_data:
        for e in item['entities']:
            if e['entity_type'] not in genia_et_freq:
                genia_et_freq[e['entity_type']] = 1
            else:
                genia_et_freq[e['entity_type']] += 1

    genia_et = [no_entity] + list(genia_et_freq.keys())
    genia_et_idx = {e:i for i, e in enumerate(genia_et)}

    return genia_data, genia_et, genia_et_idx

In [8]:
GENIA_CORPUS = './data/GENIA_term_3.02/GENIAcorpus3.02.xml'
NO_ENTITY = 'O' # Described in the IOB format

genia_data, genia_et, genia_et_idx = parse_genia(GENIA_CORPUS, tokenizer, 16, NO_ENTITY)

# Display entity types
genia_et

['O', 'DNA', 'protein', 'cell_type', 'cell_line', 'RNA']

## Sanity check

In [9]:
def sanity_check(dataset, debug=False):
    n_bad = 0
    for i, item in enumerate(dataset):
        length = len(item['tokens'])
        for entity in item['entities']:
            if entity['span'][1] <= length:
                pass
            else:
                n_bad += 1
                if debug:
                    print('Wrong entity span for item %d:' % (i), item['text'])
                    print()
                break
    return n_bad

In [10]:
n_bad = sanity_check(genia_data, debug=True)

In [11]:
print('Bad items: %d / %d' % (n_bad, len(genia_data)))

Bad items: 0 / 18546


## Split and store

In [12]:
# Split first 81%, subsequent 9%, and last 10% as train, dev and test set, respectively
TRAIN_RATIO = 0.81
DEV_RATIO = 0.09

train_idx = [0, int(TRAIN_RATIO * len(genia_data))]
dev_idx = [train_idx[1], train_idx[1] + int(DEV_RATIO * len(genia_data))]
test_idx = [dev_idx[1], len(genia_data)]

train_dataset = genia_data[train_idx[0]:train_idx[1]]
dev_dataset = genia_data[dev_idx[0]:dev_idx[1]]
test_dataset = genia_data[test_idx[0]:test_idx[1]]

In [13]:
train_file = './data/train.genia.json'
valid_file = './data/valid.genia.json'
test_file = './data/test.genia.json'

with open(train_file, 'w') as fp:
    json.dump(train_dataset, fp)
with open(valid_file, 'w') as fp:
    json.dump(dev_dataset, fp)
with open(test_file, 'w') as fp:
    json.dump(test_dataset, fp)