In [None]:
# imports
import sys
sys.path.append("../")

from scripts.load_data import label_mapping, extract_labeled_tokens, read_tsv_file, write_tsv_file, write_iob2_file
from collections import defaultdict
from collections import Counter
import random

random.seed(20) # set seed

## Load original datasets

In [None]:
# path to the data files
path_news_train = "../data/da_news/da_news_train.tsv"
path_news_dev = "../data/da_news/da_news_dev.tsv"
path_news_test = "../data/da_news/da_news_test.tsv"

# create mapping
label2id, id2label = label_mapping(path_news_train)

# read in the DaN+ data
train_data_news = read_tsv_file(path_news_train, label2id)
dev_data_news = read_tsv_file(path_news_dev, label2id)
test_data_news = read_tsv_file(path_news_test, label2id)

In [501]:
# dataset sizes
print("train size:", len(train_data_news))
print("dev size:", len(dev_data_news))
print("test size:", len(test_data_news))
print("total dataset size:", len(train_data_news) + len(dev_data_news) + len(test_data_news))

train size: 4383
dev size: 564
test size: 565
total dataset size: 5512


In [502]:
# concatenate datasets
total_data = train_data_news + dev_data_news + test_data_news

In [503]:
# extraxt unique entities
total_entities = extract_labeled_tokens(total_data)

## Build mapping from entity to sentence and sentence to entity

In [504]:
# dict with entities as keys and lists of sentence IDs as values
entity_to_sents = defaultdict(set)
sent_to_entities = defaultdict(set) # also creating mapping from sentence ID to entity

for sent_id, sent in enumerate(total_data):

    for tok_id, ent in enumerate(sent["tokens"]):

        if ent in total_entities and sent['ner_tags'][tok_id] != 'O':

            entity_to_sents[ent].add(sent_id)

            sent_to_entities[sent_id].add(ent)

## Group sentences by overlapping entities

In [505]:
# group sentences by shared entities

visited = set()
sentence_groups = []

for sent_id in sent_to_entities:

    if sent_id in visited:
        continue

    group, queue = set(), [sent_id]

    while queue:

        current = queue.pop()

        if current in visited:
            continue

        visited.add(current)
        group.add(current)

        for entity in sent_to_entities[current]:

            queue.extend(entity_to_sents[entity])

    sentence_groups.append(group)

In [506]:
# shuffle and split groups by total sentence count

random.shuffle(sentence_groups)

train_group, dev_group, test_group, count = [], [], [], 0
total = sum(len(g) for g in sentence_groups)
train_cutoff, dev_cutoff = int(total * 0.8), int(total * 0.9)

for group in sentence_groups:

    if count < train_cutoff:
        train_group += group
        
    elif count < dev_cutoff:
        dev_group += group

    else:
        test_group += group

    count += len(group)

## Add sentences with only 'O' tags

In [507]:
# add unused sentences with all 'O' tags
used = set(train_group + dev_group + test_group)
o_tagged = []

for idx, sent in enumerate(total_data):
    if idx not in used and all(tag == "O" for tag in sent["ner_tags"]):
        o_tagged.append(idx)

random.shuffle(o_tagged)

cut1, cut2 = int(len(o_tagged) * 0.8), int(len(o_tagged) * 0.9)

train_group += o_tagged[:cut1]
dev_group += o_tagged[cut1:cut2]
test_group += o_tagged[cut2:]

In [508]:
# final splits
train_data = [total_data[i] for i in sorted(train_group)]
dev_data = [total_data[i] for i in sorted(dev_group)]
test_data = [total_data[i] for i in sorted(test_group)]

## Check sizes and overlap

In [509]:
# sizes of new datasets
print("train size:", len(train_data))
print("dev size:", len(dev_data))
print("test size:", len(test_data))
print("total dataset size:", len(train_data) + len(dev_data) + len(test_data))

train size: 4411
dev size: 549
test size: 552
total dataset size: 5512


In [510]:
# extract tokens with non-"O" labels from each split
train_tokens = extract_labeled_tokens(train_data)
dev_tokens = extract_labeled_tokens(dev_data)
test_tokens = extract_labeled_tokens(test_data)

# overlap between datasets
train_dev_overlap = train_tokens & dev_tokens
dev_test_overlap = dev_tokens & test_tokens
train_test_overlap = train_tokens & test_tokens

In [511]:
# check for overlap between datasets
print('overlap between train and dev:', len(train_dev_overlap))
print('overlap between dev and test:', len(dev_test_overlap))
print('overlap between train and test:', len(train_test_overlap))

overlap between train and dev: 0
overlap between dev and test: 0
overlap between train and test: 0


## Look at distribution of tokens

In [512]:
train_tokens = extract_labeled_tokens(train_data, include_label_pair=True)
dev_tokens = extract_labeled_tokens(dev_data, include_label_pair=True)
test_tokens = extract_labeled_tokens(test_data, include_label_pair=True)

train_distr = Counter(tag for _, tag in train_tokens)
test_distr = Counter(tag for _, tag in test_tokens)
dev_distr = Counter(tag for _, tag in dev_tokens)

print(train_distr)
print(dev_distr)
print(test_distr)

Counter({'I-PER': 598, 'B-PER': 521, 'B-ORG': 481, 'B-LOC': 398, 'I-ORG': 329, 'I-MISC': 219, 'B-MISC': 195, 'I-LOC': 64})
Counter({'B-PER': 73, 'B-ORG': 73, 'B-LOC': 55, 'B-MISC': 48, 'I-PER': 27, 'I-ORG': 20, 'I-MISC': 16, 'I-LOC': 7})
Counter({'B-PER': 90, 'B-ORG': 74, 'B-LOC': 47, 'I-PER': 44, 'B-MISC': 33, 'I-ORG': 28, 'I-MISC': 12, 'I-LOC': 5})


In [513]:
def get_percentage_distribution(counter):
    total = sum(counter.values())
    return {tag: round((count / total) * 100, 2) for tag, count in counter.items()}

# calculate percentage distributions
train_percent = get_percentage_distribution(train_distr)
dev_percent = get_percentage_distribution(dev_distr)
test_percent = get_percentage_distribution(test_distr)

# print results
print("Train Percentage Distribution:")
print(train_percent)
print("\nDev Percentage Distribution:")
print(dev_percent)
print("\nTest Percentage Distribution:")
print(test_percent)

Train Percentage Distribution:
{'I-ORG': 11.73, 'I-PER': 21.32, 'I-MISC': 7.81, 'B-MISC': 6.95, 'I-LOC': 2.28, 'B-PER': 18.57, 'B-ORG': 17.15, 'B-LOC': 14.19}

Dev Percentage Distribution:
{'I-ORG': 6.27, 'I-PER': 8.46, 'B-MISC': 15.05, 'B-PER': 22.88, 'B-ORG': 22.88, 'B-LOC': 17.24, 'I-MISC': 5.02, 'I-LOC': 2.19}

Test Percentage Distribution:
{'B-PER': 27.03, 'B-MISC': 9.91, 'B-ORG': 22.22, 'I-ORG': 8.41, 'I-PER': 13.21, 'B-LOC': 14.11, 'I-MISC': 3.6, 'I-LOC': 1.5}


## Check overlap and label distribution in ME data

In [514]:
## checking for ME data ##

# path to the data files
path_me_dev = "../data/me_data/middle_eastern_dev.tsv"
path_me_test = "../data/me_data/middle_eastern_test.tsv"

# read in the data
me_dev_data = read_tsv_file(path_me_dev, label2id)
me_test_data = read_tsv_file(path_me_test, label2id)

# extract labels
me_dev_tokens = extract_labeled_tokens(me_dev_data)
me_test_tokens = extract_labeled_tokens(me_test_data)

# overlap between datasets
me_train_dev_overlap = train_tokens & me_dev_tokens
me_train_test_overlap = train_tokens & me_test_tokens
me_dev_test_overlap = me_dev_tokens & me_test_tokens

print('overlap between train and ME_dev:', len(me_train_dev_overlap))
print('overlap between train and ME_test:', len(me_train_test_overlap))
print('overlap between ME_dev and ME_test:', len(me_dev_test_overlap))

overlap between train and ME_dev: 0
overlap between train and ME_test: 0
overlap between ME_dev and ME_test: 41


In [515]:
me_dev_tokens = extract_labeled_tokens(me_dev_data, include_label_pair=True)
me_test_tokens = extract_labeled_tokens(me_test_data, include_label_pair=True)

me_test_distr = Counter(tag for _, tag in me_test_tokens)
me_dev_distr = Counter(tag for _, tag in me_dev_tokens)

print(me_dev_distr)
print(me_test_distr)

Counter({'B-PER': 94, 'B-ORG': 71, 'B-LOC': 66, 'B-MISC': 48, 'I-PER': 30, 'I-ORG': 18, 'I-MISC': 16, 'I-LOC': 8})
Counter({'B-PER': 117, 'B-ORG': 66, 'B-LOC': 54, 'I-PER': 53, 'B-MISC': 33, 'I-ORG': 26, 'I-MISC': 12, 'I-LOC': 5})


In [516]:
test_tokens = extract_labeled_tokens(test_data, include_label_pair=True)
me_test_tokens = extract_labeled_tokens(me_test_data, include_label_pair=True)

# Get tokens where the label is not 'MISC'
non_misc_overlap = {(token, label) for token, label in (test_tokens & me_test_tokens) if label not in ['B-MISC', 'I-MISC']}

print(non_misc_overlap)

{('Bagdad', 'B-LOC')}


## Write to tsv files

In [517]:
write_tsv_file(train_data, '../data/no_overlap_da_news/da_news_train.tsv')
write_tsv_file(dev_data, '../data/no_overlap_da_news/da_news_dev.tsv')
write_tsv_file(test_data, '../data/no_overlap_da_news/da_news_test.tsv')

In [518]:
write_iob2_file(test_data, path="../data/no_overlap_da_news/da_news_test.iob2", gold=True)