In [91]:
# imports
import sys
sys.path.append("../")

from scripts.load_data import mapping, read_tsv_file, extract_labeled_tokens
from collections import defaultdict
import random

random.seed(42)

## Load original datasets

In [92]:
# path to the data files
path_news_train = "../data/da_news/da_news_train.tsv"
path_news_dev = "../data/da_news/da_news_dev.tsv"
path_news_test = "../data/da_news/da_news_test.tsv"

# create mapping
label2id, id2label = mapping(path_news_train)

# read in the DaN+ data
train_data_news = read_tsv_file(path_news_train, label2id)
dev_data_news = read_tsv_file(path_news_dev, label2id)
test_data_news = read_tsv_file(path_news_test, label2id)

In [93]:
# dataset sizes
print("train size:", len(train_data_news))
print("dev size:", len(dev_data_news))
print("test size:", len(test_data_news))
print("total dataset size:", len(train_data_news) + len(dev_data_news) + len(test_data_news))

train size: 4383
dev size: 564
test size: 565
total dataset size: 5512


In [94]:
# concatenate datasets
total_data = train_data_news + dev_data_news + test_data_news

In [95]:
# extraxt unique entities
total_entities = extract_labeled_tokens(total_data)

## Build mapping from entity to sentence and sentence to entity

In [96]:
# dict with entities as keys and lists of sentence IDs as values
entity_to_sents = defaultdict(set)
sent_to_entities = defaultdict(set) # also creating mapping from sentence ID to entity

for sent_id, sent in enumerate(total_data):

    for tok_id, ent in enumerate(sent["tokens"]):

        if ent in total_entities and sent['ner_tags'][tok_id] != 'O':

            entity_to_sents[ent].add(sent_id)

            sent_to_entities[sent_id].add(ent)

## Group sentences by overlapping entities

In [97]:
# group sentences by shared entities

visited = set()
sentence_groups = []

for sent_id in sent_to_entities:

    if sent_id in visited:
        continue

    group, queue = set(), [sent_id]

    while queue:

        current = queue.pop()

        if current in visited:
            continue

        visited.add(current)
        group.add(current)

        for entity in sent_to_entities[current]:

            queue.extend(entity_to_sents[entity])

    sentence_groups.append(group)

In [98]:
# shuffle and split groups by total sentence count

random.shuffle(sentence_groups)

train_group, dev_group, test_group, count = [], [], [], 0
total = sum(len(g) for g in sentence_groups)
train_cutoff, dev_cutoff = int(total * 0.7), int(total * 0.85)

for group in sentence_groups:

    if count < train_cutoff:
        train_group += group
        
    elif count < dev_cutoff:
        dev_group += group

    else:
        test_group += group

    count += len(group)

## Add sentences with only 'O' tags

In [99]:
# add unused sentences with all 'O' tags
used = set(train_group + dev_group + test_group)
o_tagged = []

for idx, sent in enumerate(total_data):
    if idx not in used and all(tag == "O" for tag in sent["ner_tags"]):
        o_tagged.append(idx)

random.shuffle(o_tagged)

cut1, cut2 = int(len(o_tagged) * 0.7), int(len(o_tagged) * 0.85)

train_group += o_tagged[:cut1]
dev_group += o_tagged[cut1:cut2]
test_group += o_tagged[cut2:]

In [100]:
# final splits
train_data = [total_data[i] for i in sorted(train_group)]
dev_data = [total_data[i] for i in sorted(dev_group)]
test_data = [total_data[i] for i in sorted(test_group)]

## Check sizes and overlap

In [101]:
# sizes of new datasets
print("train size:", len(train_data))
print("dev size:", len(dev_data))
print("test size:", len(test_data))
print("total dataset size:", len(train_data) + len(dev_data) + len(test_data))

train size: 4270
dev size: 489
test size: 753
total dataset size: 5512


In [102]:
# extract tokens with non-"O" labels from each split
train_tokens = extract_labeled_tokens(train_data)
dev_tokens = extract_labeled_tokens(dev_data)
test_tokens = extract_labeled_tokens(test_data)

# overlap between datasets
train_dev_overlap = train_tokens & dev_tokens
dev_test_overlap = dev_tokens & test_tokens
train_test_overlap = train_tokens & test_tokens

In [103]:
# check for overlap between datasets
print('overlap between train and dev:', len(train_dev_overlap))
print('overlap between dev and test:', len(dev_test_overlap))
print('overlap between train and test:', len(train_test_overlap))

overlap between train and dev: 0
overlap between dev and test: 0
overlap between train and test: 0


## Write to tsv files

In [104]:
def write_tsv_file(data, path):
    with open(path, 'w', encoding='utf-8') as f:
        for sentence in data:
            tokens = sentence['tokens']
            ner_tags = sentence['ner_tags']
            for token, tag in zip(tokens, ner_tags):
                f.write(f"{token}\t{tag}\n")
            f.write("\n") 

In [105]:
write_tsv_file(train_data, '../data/no_overlap_da_news/da_news_train.tsv')
write_tsv_file(dev_data, '../data/no_overlap_da_news/da_news_dev.tsv')
write_tsv_file(test_data, '../data/no_overlap_da_news/da_news_test.tsv')