In [13]:
import os, sys
import json
import pickle
import random
import numpy as np
import re
from pprint import pprint

from itertools import combinations
from collections import defaultdict
from utils import *

# Unify dataset format

The dataset has a list of items in json, where each item contains (in order):
```json
{
   "text": "(optional) str"
   "category": "str:(domain,intent,etc.)"
   "tokens":"list[str]",
   "token_spans": "(optional) list[tuple(int,int)]",
   "slot_tags": "list[str]:BIO or BIEO or BIEOS"
}
```

we adopt bert vocab file.

## Slot-Filling

### Snip

Download: https://github.com/snipsco/nlu-benchmark/tree/master/2017-06-custom-intent-engines

In [2]:
ROOT = './original/SNIP/2017-06-custom-intent-engines/'
train_paths = []
valid_paths = []
for path in os.listdir(ROOT):
    path = os.path.join(ROOT, path)
    if not os.path.isdir(path):
        continue
    for sub_path in os.listdir(path):
        if re.match('^train.*full.json$', sub_path):
            sub_path = os.path.join(path, sub_path)
            train_paths.append(sub_path)
        elif re.match('^valid.*json$', sub_path):
            sub_path = os.path.join(path, sub_path)
            valid_paths.append(sub_path)

In [4]:
def build_dataset(paths):
    dataset = []
    for path in paths:
        with open(path, 'r', encoding='utf-8', errors="ignore") as f:
            json_list = list(json.load(f).values())[0]
        for item in json_list:
                
            text = ''
            slots = []
            for sub_item in item['data']:
                _text = strip_accents(sub_item['text']).lower()

                if 'entity' in sub_item:
                    slots.append((sub_item['entity'], (len(text), len(text)+len(_text))))
                    text += _text
                else:
                    text += _text
            tokens, token_spans = tokenize_with_span(text)
            tags = ['O'] * len(tokens)
            
            if_add = True
            for slot_name, slot_span in slots:
                        
                i_slot_token = -1
                j_slot_token = -1
                for i, token_span in enumerate(token_spans):
                    if token_span[0] == slot_span[0]:
                        i_slot_token = i
                    if token_span[1] == slot_span[1] or token_span[1]+1 == slot_span[1]:
                        j_slot_token = i + 1
                if i_slot_token < 0 or j_slot_token < 0:
                    if_add = False
                    print('warning: not found.')
                    print(slot_name, slot_span)
                    print(text)
                    print([(token, token_span) for token, token_span in zip(tokens, token_spans)])
                else:
                    add_tags(tags, attr=slot_name, span=[i_slot_token, j_slot_token])

            if if_add:
                dataset.append({
                   "text": text,
                   "category": 'SNIPS' + '__' + split_path(path)[-2],
                   "tokens": [re.sub('\d', '0', t) for t in tokens],
                   "token_spans": token_spans,
                   "slot_tags": tags,
                })
            
    return dataset

In [5]:
trainset = build_dataset(train_paths)
validset = build_dataset(valid_paths)

random.shuffle(trainset)
ratio = len(validset) / len(trainset)
split_n = int(len(trainset)*ratio)
trainset, testset = trainset[split_n:], trainset[:split_n]

with open('./unified/train.snip.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open('./unified/valid.snip.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open('./unified/test.snip.json', 'w', encoding='utf-8') as f: # 
    json.dump(testset, f, sort_keys=True, indent=4)

artist (11, 23)
live in l.ajoseph meyer please
[('live', (0, 4)), ('in', (5, 7)), ('l', (8, 9)), ('.', (9, 10)), ('ajoseph', (10, 17)), ('meyer', (18, 23)), ('please', (24, 30))]
timeRange (40, 46)
what kind of weather is forecast around one pmnear vatican?
[('what', (0, 4)), ('kind', (5, 9)), ('of', (10, 12)), ('weather', (13, 20)), ('is', (21, 23)), ('forecast', (24, 32)), ('around', (33, 39)), ('one', (40, 43)), ('pmnear', (44, 50)), ('vatican', (51, 58)), ('?', (58, 59))]
spatial_relation (46, 50)
what kind of weather is forecast around one pmnear vatican?
[('what', (0, 4)), ('kind', (5, 9)), ('of', (10, 12)), ('weather', (13, 20)), ('is', (21, 23)), ('forecast', (24, 32)), ('around', (33, 39)), ('one', (40, 43)), ('pmnear', (44, 50)), ('vatican', (51, 58)), ('?', (58, 59))]
movie_name (9, 21)
show the sexy dance 2times at the  closest movie house
[('show', (0, 4)), ('the', (5, 8)), ('sexy', (9, 13)), ('dance', (14, 19)), ('2times', (20, 26)), ('at', (27, 29)), ('the', (30, 33)), (

In [8]:
domains = set()
for item in trainset:
    domains.add(item['category'])

for exclude_domain in domains:
    _trainset = []
    _validset = []
    _testset = []
    for item in trainset:
        if item['category'] != exclude_domain:
            _trainset.append(item)
    for item in validset:
        if item['category'] != exclude_domain:
            _validset.append(item)
    for item in testset:
        if item['category'] != exclude_domain:
            _testset.append(item)
    with open(f'./unified/train.snip.no_{exclude_domain}.json', 'w', encoding='utf-8') as f:
        json.dump(_trainset, f, sort_keys=True, indent=4)
        
    with open(f'./unified/valid.snip.no_{exclude_domain}.json', 'w', encoding='utf-8') as f:
        json.dump(_validset, f, sort_keys=True, indent=4)
    
    with open(f'./unified/test.snip.no_{exclude_domain}.json', 'w', encoding='utf-8') as f:
        json.dump(_testset, f, sort_keys=True, indent=4)
        
    ###
    count_dict = defaultdict(int)
    for item in _trainset:
        for tag in item['slot_tags']:
            if tag[0] == 'B':
                count_dict[tag[2:]] += 1
    print(f"exclude_domain: {exclude_domain}")
    print(len(count_dict))
#     pprint(count_dict)

exclude_domain: SNIPS__SearchScreeningEvent
35
exclude_domain: SNIPS__AddToPlaylist
37
exclude_domain: SNIPS__GetWeather
35
exclude_domain: SNIPS__BookRestaurant
31
exclude_domain: SNIPS__PlayMusic
34
exclude_domain: SNIPS__SearchCreativeWork
39
exclude_domain: SNIPS__RateBook
34


In [10]:
from itertools import combinations, permutations

domains = set()
for item in trainset:
    domains.add(item['category'])

for exclude_domains in combinations(domains, 2):
    _trainset = []
    _validset = []
    _testset = []
    for item in trainset:
        if item['category'] not in exclude_domains:
            _trainset.append(item)
    for item in validset:
        if item['category'] not in exclude_domains:
            _validset.append(item)
    for item in testset:
        if item['category'] not in exclude_domains:
            _testset.append(item)
    with open(f"./unified/train.snip.no_{'_'.join(exclude_domains)}.json", 'w', encoding='utf-8') as f:
        json.dump(_trainset, f, sort_keys=True, indent=4)
        
    with open(f"./unified/valid.snip.no_{'_'.join(exclude_domains)}.json", 'w', encoding='utf-8') as f:
        json.dump(_validset, f, sort_keys=True, indent=4)
    
    with open(f"./unified/test.snip.no_{'_'.join(exclude_domains)}.json", 'w', encoding='utf-8') as f:
        json.dump(_testset, f, sort_keys=True, indent=4)
              
              
    for item in trainset:
        if item['category'] in exclude_domains:
            _trainset.append(item)
    for item in validset:
        if item['category'] in exclude_domains:
            _validset.append(item)
    for item in testset:
        if item['category'] in exclude_domains:
            _testset.append(item)
    with open(f"./unified/train.snip.{'_'.join(exclude_domains)}.json", 'w', encoding='utf-8') as f:
        json.dump(_trainset, f, sort_keys=True, indent=4)
        
    with open(f"./unified/valid.snip.{'_'.join(exclude_domains)}.json", 'w', encoding='utf-8') as f:
        json.dump(_validset, f, sort_keys=True, indent=4)
    
    with open(f"./unified/test.snip.{'_'.join(exclude_domains)}.json", 'w', encoding='utf-8') as f:
        json.dump(_testset, f, sort_keys=True, indent=4)
        
    ###
    count_dict = defaultdict(int)
    for item in _trainset:
        for tag in item['slot_tags']:
            if tag[0] == 'B':
                count_dict[tag[2:]] += 1
    print(f"exclude_domain: {exclude_domains}")
    print(len(count_dict))
#     pprint(count_dict)

exclude_domain: ('SNIPS__SearchScreeningEvent', 'SNIPS__AddToPlaylist')
39
exclude_domain: ('SNIPS__SearchScreeningEvent', 'SNIPS__GetWeather')
39
exclude_domain: ('SNIPS__SearchScreeningEvent', 'SNIPS__BookRestaurant')
39
exclude_domain: ('SNIPS__SearchScreeningEvent', 'SNIPS__PlayMusic')
39
exclude_domain: ('SNIPS__SearchScreeningEvent', 'SNIPS__SearchCreativeWork')
39
exclude_domain: ('SNIPS__SearchScreeningEvent', 'SNIPS__RateBook')
39
exclude_domain: ('SNIPS__AddToPlaylist', 'SNIPS__GetWeather')
39
exclude_domain: ('SNIPS__AddToPlaylist', 'SNIPS__BookRestaurant')
39
exclude_domain: ('SNIPS__AddToPlaylist', 'SNIPS__PlayMusic')
39
exclude_domain: ('SNIPS__AddToPlaylist', 'SNIPS__SearchCreativeWork')
39
exclude_domain: ('SNIPS__AddToPlaylist', 'SNIPS__RateBook')
39
exclude_domain: ('SNIPS__GetWeather', 'SNIPS__BookRestaurant')
39
exclude_domain: ('SNIPS__GetWeather', 'SNIPS__PlayMusic')
39
exclude_domain: ('SNIPS__GetWeather', 'SNIPS__SearchCreativeWork')
39
exclude_domain: ('SNIPS__

### MIT REST

Download: https://groups.csail.mit.edu/sls/downloads/restaurant/

In [19]:
def convert_tokens_tags(tokens, tags, use_tokenizer=True, digit2zero=True):
    '''
    tags in BIO format
    '''
    
    if use_tokenizer:
        _tokenize = tokenize
    else:
        _tokenize = lambda x: [x]
        
    if digit2zero:
        tokens = [re.sub('\d', '0', t) for t in tokens]
    
    target_tokens = []
    target_tags = []
    for token, tag in zip(tokens, tags):
        if tag[0:2] not in ['O', 'B-', 'I-']:
            tag = 'B-' + tag
        for sub_token in _tokenize(token):
            target_tokens.append(sub_token)
            target_tags.append(tag)
            if tag[0] == 'B':
                tag = 'I' + tag[1:]
    return target_tokens, target_tags

def read_dataset(path, category=None, splitter='\t', col_tag=0, col_token=1, use_tokenizer=False):
    dataset = []
    last_is_O = True
    with open(path, 'r', encoding='utf-8') as f:
        tokens, tags = [], []
        for line in f:
            line = line.strip()
            if line == '':
                tokens, tags = convert_tokens_tags(tokens, tags, use_tokenizer)
                item = {
                    "text": None,
                    "category": category,
                    "tokens": tokens,
                    "slot_tags": tags,
                }
                dataset.append(item)
                tokens, tags = [], []
                continue
            tmp = line.split(splitter)
            tag, token = tmp[col_tag], tmp[col_token]
            if last_is_O and (tag != 'O'):
                tag = 'B' + tag[1:]
            last_is_O = (tag == 'O')
            tokens.append(token)
            tags.append(tag)
    return dataset

In [24]:
trainset = read_dataset('./original/MIT-Restaurant/train.txt', category="MIT-Restaurant")
testset = read_dataset('./original/MIT-Restaurant/test.txt', category="MIT-Restaurant")

random.shuffle(trainset)
ratio = len(testset) / len(trainset)
split_n = int(len(trainset)*ratio)
trainset, validset = trainset[split_n:], trainset[:split_n]

with open('./unified/train.rest.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open('./unified/valid.rest.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open('./unified/test.rest.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

### MIT Movie

Download: https://groups.csail.mit.edu/sls/downloads/movie/

In [25]:
flag = 'movie_eng'

trainset = read_dataset('./original/MIT-Movie/engtrain.txt', category="MIT-Movie-Eng")
testset = read_dataset('./original/MIT-Movie/engtest.txt', category="MIT-Movie-Eng")

random.shuffle(trainset)
ratio = len(testset) / len(trainset)
split_n = int(len(trainset)*ratio)
trainset, validset = trainset[split_n:], trainset[:split_n]

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

In [26]:
flag = 'movie_trivia10k13'

trainset = read_dataset('./original/MIT-Movie/trivia10k13train.txt', category="MIT-Movie-Triv")
testset = read_dataset('./original/MIT-Movie/trivia10k13test.txt', category="MIT-Movie-Triv")

random.shuffle(trainset)
ratio = len(testset) / len(trainset)
split_n = int(len(trainset)*ratio)
trainset, validset = trainset[split_n:], trainset[:split_n]

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

## Entity Recognition

### AnEM

https://github.com/juand-r/entity-recognition-datasets/tree/master/data/AnEM/CONLL-format

In [62]:
flag = 'AnEM'

trainset = read_dataset('./original/ER-AnEM/AnEM.train',category='AnEM' , splitter='\t', col_tag=-1, col_token=0)
testset = read_dataset('./original/ER-AnEM/AnEM.test',category='AnEM' , splitter='\t', col_tag=-1, col_token=0)
validset = read_dataset('./original/ER-AnEM/AnEM.test',category='AnEM' , splitter='\t', col_tag=-1, col_token=0)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

### BTC

In [63]:
flag = 'BTC'

trainset = read_dataset('./original/ER-BTC/all.txt', category='BTC' , splitter='\t', col_tag=-1, col_token=0)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

### FIN5

In [64]:
flag = 'FIN5'

trainset = read_dataset('./original/ER-FIN5/train.txt',category=flag , splitter=' ', col_tag=-1, col_token=0)
testset = read_dataset('./original/ER-FIN5/test.txt',category=flag , splitter=' ', col_tag=-1, col_token=0)
validset = read_dataset('./original/ER-FIN5/test.txt',category=flag , splitter=' ', col_tag=-1, col_token=0)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

### GUM

In [65]:
flag = 'GUM'

trainset = read_dataset('./original/ER-GUM/gum-train.conll',category=flag , splitter='\t', col_tag=-1, col_token=0)
testset = read_dataset('./original/ER-GUM/gum-test.conll',category=flag , splitter='\t', col_tag=-1, col_token=0)
validset = read_dataset('./original/ER-GUM/gum-test.conll',category=flag , splitter='\t', col_tag=-1, col_token=0)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

### Ritter

In [66]:
flag = 'Ritter'

trainset = read_dataset('./original/ER-Ritter/train.txt',category=flag , splitter=' ', col_tag=-1, col_token=0)
testset = read_dataset('./original/ER-Ritter/test.txt',category=flag , splitter=' ', col_tag=-1, col_token=0)
validset = read_dataset('./original/ER-Ritter/dev.txt',category=flag , splitter=' ', col_tag=-1, col_token=0)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

### WikiGold

In [67]:
flag = 'WikiGold'

trainset = read_dataset('./original/ER-WikiGold/train.txt', category=flag , splitter=' ', col_tag=-1, col_token=0)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

### CONLL2003

In [35]:
flag = 'CONLL2003-NER'
use_tokenizer = False

trainset = read_dataset('./original/CONLL2003/train.txt',category=flag , splitter=' ', col_tag=-1, col_token=0, use_tokenizer=use_tokenizer)
testset = read_dataset('./original/CONLL2003/test.txt',category=flag , splitter=' ', col_tag=-1, col_token=0, use_tokenizer=use_tokenizer)
validset = read_dataset('./original/CONLL2003/valid.txt',category=flag , splitter=' ', col_tag=-1, col_token=0, use_tokenizer=use_tokenizer)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

    
flag = 'CONLL2003-CHUNK'

trainset = read_dataset('./original/CONLL2003/train.txt',category=flag , splitter=' ', col_tag=-2, col_token=0, use_tokenizer=use_tokenizer)
testset = read_dataset('./original/CONLL2003/test.txt',category=flag , splitter=' ', col_tag=-2, col_token=0, use_tokenizer=use_tokenizer)
validset = read_dataset('./original/CONLL2003/valid.txt',category=flag , splitter=' ', col_tag=-2, col_token=0, use_tokenizer=use_tokenizer)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

    
flag = 'CONLL2003-POS'

trainset = read_dataset('./original/CONLL2003/train.txt',category=flag , splitter=' ', col_tag=-3, col_token=0, use_tokenizer=use_tokenizer)
testset = read_dataset('./original/CONLL2003/test.txt',category=flag , splitter=' ', col_tag=-3, col_token=0, use_tokenizer=use_tokenizer)
validset = read_dataset('./original/CONLL2003/valid.txt',category=flag , splitter=' ', col_tag=-3, col_token=0, use_tokenizer=use_tokenizer)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

### ONTO

In [10]:
flag = 'ONTO'
use_tokenizer = False

trainset = read_dataset('./original/ontonotes/onto.train.ner',category=flag , splitter='\t', col_tag=-1, col_token=0, use_tokenizer=use_tokenizer)
testset = read_dataset('./original/ontonotes/onto.test.ner',category=flag , splitter='\t', col_tag=-1, col_token=0, use_tokenizer=use_tokenizer)
validset = read_dataset('./original/ontonotes/onto.development.ner',category=flag , splitter='\t', col_tag=-1, col_token=0, use_tokenizer=use_tokenizer)

with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(trainset, f, sort_keys=True, indent=4)

with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(validset, f, sort_keys=True, indent=4)
    
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(testset, f, sort_keys=True, indent=4)

In [18]:
print(len(trainset))
print(len(validset))
print(len(testset))

149153
20032
16069


## ACE05

In [4]:
# def read_ace_data(path):
    
#     valid_ett = {
#         'FAC', 'GPE', 'LOC', 'ORG', 'LOC', 'PER', 'VEH', 'WEA'
#     }
    
#     with open(path) as f:
#         dataset = json.load(f)

#     dataset = [
#         {
#             'text': item['sentence'],
#             'tokens': item['words'],
#             'entities': [
#                 {
#                     'entity_type': entity_item['entity-type'].split(':')[0],
#                     'span': (entity_item['start'], entity_item['end'],),
#                 } for entity_item in item['golden-entity-mentions'] if entity_item['entity-type'].split(':')[0] in valid_ett
#             ]
#         } for item in dataset
#     ]
#     return dataset

In [5]:
# test = read_ace_data('./original/ace_2005_td_v7/output/test.json')
# valid = read_ace_data('./original/ace_2005_td_v7/output/dev.json')
# train = read_ace_data('./original/ace_2005_td_v7/output/train.json')

# flag = 'ACE05_OLD'
# with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
#     json.dump(test, f, sort_keys=True, indent=4)
# with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
#     json.dump(valid, f, sort_keys=True, indent=4)
# with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
#     json.dump(train, f, sort_keys=True, indent=4)

In [15]:
def read_nest_data(path, n_line_one_item=3, 
                   tokens_line=0, etts_line=1, pos_line=None,
                   end_plus=0):
    items = []
    with open(path) as f:
        last = {}
        for i, line in enumerate(f):
            if i % n_line_one_item == tokens_line:
                last['tokens'] = line.strip().split(' ')
#                 if len(line.strip().split(' ')) != len(line.strip().split()):
#                     print(line.strip().split(' '))
#                     print(line.strip().split())
            elif pos_line is not None and i % n_line_one_item == pos_line:
                last['pos'] = line.strip().split(' ')
                assert len(last['pos']) == len(last['tokens'])
            elif i % n_line_one_item == etts_line:
                line = line.strip()
                if line:
                    entities = {
                        (
                            txt.split()[1], tuple(int(p) for p in txt.split()[0].split(','))
                        ) for txt in line.split('|')
                    }
                    
                    last['entities'] = [
                        {
                            'entity_type': entity[0], 
                            'span': (entity[1][0], entity[1][1] + int(end_plus)),
                        } for entity in entities
                    ]
                    
                else:
                    last['entities'] = []
            elif i % n_line_one_item == n_line_one_item - 1:
                items.append(last)
                last = {}
    return items

In [16]:
def print_num_of_ett_by_len(dataset):
    count = defaultdict(int)
    for item in dataset:
        entities = item['entities']
        for i, entity in enumerate(entities):
            start, end = entity['span']
            count[end-start] += 1
    n_total = sum(count.values())
    for k in sorted(count.keys()):
        print(f"{k}: {count[k]/n_total*100:.2f}%, {count[k]}")
        
def print_num_of_ett_by_type(dataset):
    count = defaultdict(int)
    for item in dataset:
        entities = item['entities']
        entity_by_span = defaultdict(set)
        for entity in entities:
            start, end, ett_type = *entity['span'], entity['entity_type']
            entity_by_span[(start, end)].add(ett_type)
        for (start, end), ett_types in entity_by_span.items():
            ett_type = '|'.join(sorted(list(ett_types)))
            count[ett_type] += 1
            
    print(len(count))
    n_total = sum(count.values())
    for k, v in sorted(count.items(), key=lambda x: -x[1]):
        print(f"{k}: {count[k]/n_total*100:.2f}%, {count[k]}")
    
    return [count[k]/n_total*100 for k,v in sorted(count.items(), key=lambda x: -x[1])]
        
        
def print_overlapping_num_of_ett_by_len(dataset):
    
    count = defaultdict(int)
    n_total = 0
    for item in dataset:
        entities = item['entities']
        for i, entity in enumerate(entities):
            n_total += 1
            start, end = entity['span']
            if is_overlapping_list(entity['span'], [
                entities[j]['span'] for j in range(len(entities)) \
                    if i!=j #and entities[j]['span'][1] - entities[j]['span'][0] >= end - start
            ]):
                count[end-start] += 1
    #n_total = sum(count.values())
    for k in sorted(count.keys()):
        print(f"{k}: {count[k]/n_total*100:.2f}%, {count[k]}")
        
    return [count[k]/n_total*100 for k,v in sorted(count.items(), key=lambda x: -x[1])]
        
def is_overlapping(span_a, span_b):
    if span_a[0] <= span_b[0] < span_a[1] or span_a[0] < span_b[1] <= span_a[1]:
        return True
    span_a, span_b = span_b, span_a
    if span_a[0] <= span_b[0] < span_a[1] or span_a[0] < span_b[1] <= span_a[1]:
        return True
    return False

def is_nesting(span_a, span_b):
    if span_a[0] <= span_b[0] < span_a[1] and span_a[0] < span_b[1] <= span_a[1]:
        return True
    span_a, span_b = span_b, span_a
    if span_a[0] <= span_b[0] < span_a[1] and span_a[0] < span_b[1] <= span_a[1]:
        return True
    return False

def is_overlapping_list(span_a, span_list):
    for span_b in span_list:
        if is_overlapping(span_a, span_b):
            return True
    return False
        
def has_overlapping_but_not_nested(l):
    for last, curr in combinations(l, 2):
        if (curr[2][0] < last[2][0] < curr[2][1] and last[2][1] > curr[2][1]) or \
            (last[2][0] < curr[2][0] < last[2][1] and curr[2][1] > last[2][1]):
            return True
    return False

def has_overlapping(l):
    for last, curr in combinations(l, 2):
        if is_overlapping(last[2], curr[2]):
            return True
    return False

In [9]:
test = read_nest_data('./original/NEST_ACE2004/ace2004.test')
valid = read_nest_data('./original/NEST_ACE2004/ace2004.dev')
train = read_nest_data('./original/NEST_ACE2004/ace2004.train')

flag = 'ACE04'
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(test, f, sort_keys=True, indent=4)
with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(valid, f, sort_keys=True, indent=4)
with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(train, f, sort_keys=True, indent=4)
    
print(f"# train: {len(train)}")
print(f"# valid: {len(valid)}")
print(f"# test: {len(test)}")

# train: 6198
# valid: 742
# test: 809


In [17]:
test = read_nest_data('./original/NEST_ACE2005/ace2005.test')
valid = read_nest_data('./original/NEST_ACE2005/ace2005.dev')
train = read_nest_data('./original/NEST_ACE2005/ace2005.train')

flag = 'ACE05'
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(test, f, sort_keys=True, indent=4)
with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(valid, f, sort_keys=True, indent=4)
with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(train, f, sort_keys=True, indent=4)
    
print(f"# train: {len(train)}")
print(f"# valid: {len(valid)}")
print(f"# test: {len(test)}")

# train: 7285
# valid: 968
# test: 1058


In [18]:
# # 最长的实体
# for dataset in [ train, valid, test]:
#     mlen = max([max([0, *[e['span'][1]-e['span'][0] for e in item['entities']]]) for item in dataset])
#     print(mlen)
    
# print('===')
# # 数据集统计
# for dataset in [train, valid, test]:
#     n = n_total = 0
#     for item in dataset:
#         n_total += 1
#         if has_overlapping([(e['entity_type'], e['entity_type'], e['span']) for e in item['entities']]):
#             n += 1
#     print(n, n_total, n / n_total)
# print('===')
# for dataset in [train, valid, test]:
#     n = n_total = 0
#     for item in dataset:
#         entities = item['entities']
#         n_total += len(entities)
#         for i, entity in enumerate(entities):
#             if is_overlapping_list(entity['span'], [
#                 entities[j]['span'] for j in range(len(entities)) \
#                     if i!=j #and entities[j]['span'][1] - entities[j]['span'][0] >= end - start
#             ]):
#                 n += 1
#     print(n, n_total, n / n_total)

49
31
27
===
2797 7285 0.383939601921757
352 968 0.36363636363636365
339 1058 0.32041587901701324
===
9946 24700 0.4026720647773279
1191 3218 0.37010565568676196
1179 3029 0.3892373720699901


In [13]:
test = read_nest_data('./original/NEST_GENIA/genia.test')
valid = read_nest_data('./original/NEST_GENIA/genia.dev')
train = read_nest_data('./original/NEST_GENIA/genia.train')

flag = 'GENIA'
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(test, f, sort_keys=True, indent=4)
with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(valid, f, sort_keys=True, indent=4)
with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(train, f, sort_keys=True, indent=4)
    
print(f"# train: {len(train)}")
print(f"# valid: {len(valid)}")
print(f"# test: {len(test)}")

# train: 15022
# valid: 1669
# test: 1855


In [15]:
test = read_nest_data('./original/NNE/test.txt', n_line_one_item=4, tokens_line=0, etts_line=2, end_plus=1)
valid = read_nest_data('./original/NNE/dev.txt', n_line_one_item=4, tokens_line=0, etts_line=2, end_plus=1)
train = read_nest_data('./original/NNE/train.txt', n_line_one_item=4, tokens_line=0, etts_line=2, end_plus=1)

flag = 'NNE'
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(test, f, sort_keys=True, indent=4)
with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(valid, f, sort_keys=True, indent=4)
with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(train, f, sort_keys=True, indent=4)
    
print(f"# train: {len(train)}")
print(f"# valid: {len(valid)}")
print(f"# test: {len(test)}")

# train: 43457
# valid: 1989
# test: 3762


In [56]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.predictors.predictor import Predictor
predictor = Predictor.from_path(
    "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo-constituency-parser-2018.03.14.tar.gz",
    cuda_device=0,
)

predictor._tokenizer = SpacyWordSplitter(split_on_spaces=True)

In [61]:
def get_overlapping(sentences):
    json_list = [{'sentence':s} for s in sentences]
    objs = predictor.predict_batch_json(json_list)
    rets = []
    for obj in objs:
        l = [obj['hierplane_tree']['root']]
        ret = []
        while l:
            _tmp = l.pop()
            if _tmp['nodeType'] == 'NP' and any(c['nodeType']=='SBAR' for c in _tmp['children']):
                ret.append(_tmp)
            if 'children' in _tmp:
                l += _tmp['children']
        rets.append(ret)
    return rets

def find_sub_list(main_list, sub_list):
    for i in range(len(main_list) - len(sub_list)+1):
        if main_list[i] != sub_list[0]:
            continue
        if main_list[i:i+len(sub_list)] == sub_list:
            return i
    return None

In [91]:
from tqdm import tqdm
from copy import deepcopy

dataset = [item for item in train+valid+test if \
           find_sub_list(item['tokens'], [',', 'which']) is not None or \
           find_sub_list(item['tokens'], [',', 'who']) is not None]

rets = []
bs = 32
for i in tqdm(range(0, len(dataset), bs)):
    items = dataset[i:i+bs]
    sentences = [' '.join(item['tokens']) for item in items]
    batch_rets = get_overlapping(sentences)
    for item, _ret in zip(items, batch_rets):
        if len(_ret):
            rets.append((item, _ret))

In [89]:
len(rets)

2599

In [92]:
overlapping_dataset = []
for item, tree_list in rets:
    item = deepcopy(item)
    tokens = item['tokens']
    text = ' '.join(tokens)
    for tree in tree_list:
        root_tokens = tree['word'].split()
        if ',' in root_tokens:
            comma = root_tokens.index(',')
        else:
            continue
        NN_tokens = root_tokens[:comma]
        for _splitter in ['\'s', 'of']:
            if _splitter in NN_tokens:
                clause_tokens = root_tokens[NN_tokens.index(_splitter)+1:]
#                 print(' '.join(NN_tokens))
#                 print(' '.join(clause_tokens))
#                 print('=-=')
                break
        else:
            clause_tokens = root_tokens
#             print(text)
#             print('=-=')
            
        if clause_tokens[-1] in [',', '.']:
            clause_tokens = clause_tokens[:-1]
        
        start = find_sub_list(tokens, NN_tokens)
        end = start + len(NN_tokens)
        item['entities'].append({
            'entity_type': 'NN',
            'span': (start, end),
        })
        
        start = find_sub_list(tokens, clause_tokens)
        end = start + len(clause_tokens)
        item['entities'].append({
            'entity_type': 'CLA',
            'span': (start, end),
        })
            
    overlapping_dataset.append(item)

In [96]:
random.shuffle(overlapping_dataset)
valid, test, train = overlapping_dataset[:400], overlapping_dataset[400:1000], overlapping_dataset[1000:]

flag = 'NNE_OL'
with open(f'./unified/test.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(test, f, sort_keys=True, indent=4)
with open(f'./unified/valid.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(valid, f, sort_keys=True, indent=4)
with open(f'./unified/train.{flag}.json', 'w', encoding='utf-8') as f:
    json.dump(train, f, sort_keys=True, indent=4)
    
print(f"# train: {len(train)}")
print(f"# valid: {len(valid)}")
print(f"# test: {len(test)}")

# train: 1599
# valid: 400
# test: 600
