In [16]:
import sys
sys.path.append("../")

from scripts.load_data import mapping, read_tsv_file, extract_labeled_tokens
from collections import defaultdict
import random

### Prep data

In [17]:
# path to the data files
path_news_train = "../data/da_news/da_news_train.tsv"
path_news_dev = "../data/da_news/da_news_dev.tsv"
path_news_test = "../data/da_news/da_news_test.tsv"

# create mapping
label2id, id2label = mapping(path_news_train)

# read in the DaN+ data
train_data_news = read_tsv_file(path_news_train, label2id)
dev_data_news = read_tsv_file(path_news_dev, label2id)
test_data_news = read_tsv_file(path_news_test, label2id)

In [18]:
print("train size:", len(train_data_news))
print("dev size:", len(dev_data_news))
print("test size:", len(test_data_news))
print("total dataset size:", len(train_data_news) + len(dev_data_news) + len(test_data_news))

train size: 4383
dev size: 564
test size: 565
total dataset size: 5512


In [19]:
total_data = train_data_news + dev_data_news + test_data_news

### Extract labeled tokens

In [20]:
total_entities = extract_labeled_tokens(total_data)

### Mapping from entities to sentences 

In [21]:
entity_to_sents = dict()

for sent_idx, sentence in enumerate(total_data):

    for tok_idx, ent in enumerate(sentence["tokens"]):

        if ent in total_entities and sentence['ner_tags'][tok_idx] != 'O':
            
            if ent not in entity_to_sents:
                entity_to_sents[ent] = [sent_idx]
            else:
                entity_to_sents[ent].append(sent_idx)

### Make groups that ensure no overlap in entities  

In [None]:
# list of tuples (entity, set of sentence IDs)
entity_sets = [(entity, set(sents)) for entity, sents in entity_to_sents.items()]

# merge overlapping sets of sentence IDs
groups = []

for entity, sents in entity_sets:
    merged = False
    for group in groups:
        # if there is overlap in sentence IDs between sents and group['sents']
        if sents & group['sents']:
            group['entities'].add(entity) # add entity to group
            group['sents'].update(sents) # add sentence IDs to group
            merged = True
            break
    if not merged:
        groups.append({'entities': set([entity]), 'sents': set(sents)})

# Repeat merging until stable
# ensures that there is no transitive overlap between groups
changed = True
while changed:
    changed = False
    new_groups = []
    while groups:
        g1 = groups.pop()
        merged = False
        for i, g2 in enumerate(groups):
            if g1['sents'] & g2['sents']:
                g2['entities'].update(g1['entities'])
                g2['sents'].update(g1['sents'])
                merged = True
                changed = True
                break
        if not merged:
            new_groups.append(g1)
    groups = new_groups

# Extract just the entity sets
result = [sorted(group['entities']) for group in groups]

In [23]:
entity_sets

[('SID', {0, 1936}),
 ('Kjeld', {0, 2646, 3876, 4379}),
 ('Christensen',
  {0,
   521,
   805,
   1473,
   1549,
   2646,
   2682,
   3224,
   3275,
   3876,
   4347,
   4379,
   4481,
   5017}),
 ('Elise', {3, 2785, 4505}),
 ('Gug', {3, 4505}),
 ('Mogens', {4, 585, 1702, 2278, 3067, 4480, 4967, 5168}),
 ('Lykketoft', {4, 4480, 5168}),
 ('Werwolf', {5}),
 ('Storbritannien', {9, 667, 4733}),
 ('Steen',
  {10,
   77,
   240,
   1157,
   1239,
   1490,
   1615,
   1791,
   1876,
   2220,
   2281,
   4029,
   4248,
   5084,
   5307}),
 ('Gade', {10, 240, 1157, 1239, 4248}),
 ('SF',
  {10, 306, 891, 1303, 1905, 1936, 2701, 3215, 3327, 4066, 4787, 4803, 5446}),
 ('Peter',
  {12,
   218,
   323,
   341,
   558,
   986,
   1014,
   1620,
   1659,
   1771,
   2315,
   2513,
   2522,
   2961,
   3301,
   3716,
   3732,
   3837,
   3873,
   4368,
   4488,
   4865,
   4992,
   5064,
   5220,
   5228,
   5271,
   5381,
   5425}),
 ('Elmegaard', {12, 1014, 4368}),
 ('Robinson', {13}),
 ('Crusoe', {1

In [24]:
random.seed(20)

# Step 1: sentence → set of entities
sent_to_entities = defaultdict(set)
for entity, sents in entity_to_sents.items():
    for sid in sents:
        sent_to_entities[sid].add(entity)

# Step 2: build groups of connected sentence IDs
sentence_groups = []

# Track which sentences we’ve already grouped
visited = set()

def collect_connected_sents(start_sid, sent_to_entities, entity_to_sents):
    # BFS-style expansion
    queue = [start_sid]
    group = set()
    seen_sents = set()
    while queue:
        sid = queue.pop()
        if sid in seen_sents:
            continue
        seen_sents.add(sid)
        group.add(sid)
        for entity in sent_to_entities[sid]:
            for connected_sid in entity_to_sents[entity]:
                if connected_sid not in seen_sents:
                    queue.append(connected_sid)
    return group

# Build all connected components
for sid in sent_to_entities:
    if sid not in visited:
        group = collect_connected_sents(sid, sent_to_entities, entity_to_sents)
        sentence_groups.append(group)
        visited.update(group)

# Step 3: Shuffle and split groups into 75/15/15 by sentence count
random.shuffle(sentence_groups)

# Flatten sentence groups and split by sentence count (not group count)
all_sents = [sid for group in sentence_groups for sid in group]
total = len(all_sents)

i1 = int(total * 0.7)
i2 = i1 + int(total * 0.15)

# Accumulate groups into splits without breaking them
group_75, group_15a, group_15b = [], [], []
count = 0

for group in sentence_groups:
    if count < i1:
        group_75.extend(group)
    elif count < i2:
        group_15a.extend(group)
    else:
        group_15b.extend(group)
    count += len(group)

# Optional: sort for readability
print("75% group:", sorted(group_75))
print("15% group A:", sorted(group_15a))
print("15% group B:", sorted(group_15b))

75% group: [0, 3, 4, 9, 10, 12, 29, 30, 36, 38, 40, 42, 43, 45, 57, 61, 66, 68, 74, 77, 81, 92, 99, 104, 120, 122, 125, 126, 127, 137, 139, 143, 145, 152, 154, 157, 158, 159, 164, 166, 167, 170, 178, 184, 186, 192, 202, 211, 214, 218, 219, 220, 224, 226, 229, 235, 237, 239, 240, 245, 254, 259, 261, 267, 271, 281, 283, 285, 288, 290, 295, 300, 306, 315, 320, 323, 326, 341, 349, 353, 356, 357, 358, 366, 368, 369, 374, 375, 376, 378, 388, 391, 396, 397, 399, 401, 402, 403, 410, 411, 413, 420, 423, 428, 444, 448, 462, 473, 486, 488, 494, 496, 500, 505, 506, 515, 519, 521, 523, 526, 527, 528, 531, 533, 540, 543, 546, 555, 557, 558, 559, 565, 573, 576, 579, 581, 584, 585, 590, 599, 603, 604, 605, 609, 611, 613, 615, 618, 622, 625, 626, 628, 630, 632, 637, 645, 649, 653, 657, 658, 666, 667, 671, 673, 681, 684, 685, 689, 691, 692, 699, 706, 708, 709, 710, 718, 719, 736, 740, 742, 744, 745, 747, 756, 757, 762, 776, 778, 784, 787, 789, 791, 797, 804, 805, 812, 813, 817, 819, 820, 824, 827, 832, 

In [25]:
# Step 1: Create a mapping from sentence ID to its data
sid_to_data = {i: s for i, s in enumerate(total_data)}

# Step 2: Identify O-only sentence IDs not already in groups
used_sids = set(group_75 + group_15a + group_15b)

o_only_sents = []
for sid, data in sid_to_data.items():
    if sid not in used_sids and all(tag == "O" for tag in data["ner_tags"]):
        o_only_sents.append(sid)

# Step 3: Distribute O-only sentences proportionally
import random
random.shuffle(o_only_sents)

n = len(o_only_sents)
n75 = int(n * 0.70)
n15a = int(n * 0.15)

group_75 += o_only_sents[:n75]
group_15a += o_only_sents[n75:n75 + n15a]
group_15b += o_only_sents[n75 + n15a:]

# Optional: sort
group_75 = sorted(group_75)
group_15a = sorted(group_15a)
group_15b = sorted(group_15b)

# Now you can retrieve the full sentence data if needed:
data_75 = [sid_to_data[sid] for sid in group_75]
data_15a = [sid_to_data[sid] for sid in group_15a]
data_15b = [sid_to_data[sid] for sid in group_15b]

In [26]:
train_data = [total_data[i] for i in group_75]
dev_data = [total_data[i] for i in group_15a]
test_data = [total_data[i] for i in group_15b]

In [27]:
print("train size:", len(train_data))
print("dev size:", len(dev_data))
print("test size:", len(test_data))
print("total dataset size:", len(train_data) + len(dev_data) + len(test_data))

train size: 3896
dev size: 790
test size: 826
total dataset size: 5512


In [28]:
# extract tokens with non-"O" labels from each split
train_tokens = extract_labeled_tokens(train_data)
dev_tokens = extract_labeled_tokens(dev_data)
test_tokens = extract_labeled_tokens(test_data)

# overlap between datasets
train_dev_overlap = train_tokens & dev_tokens
dev_test_overlap = dev_tokens & test_tokens
train_test_overlap = train_tokens & test_tokens

In [29]:
print('overlap between train and dev:', len(train_dev_overlap))
print('overlap between dev and test:', len(dev_test_overlap))
print('overlap between train and test:', len(train_test_overlap))

overlap between train and dev: 0
overlap between dev and test: 0
overlap between train and test: 0


In [30]:
# union of all overlapping tokens
all_tokens_overlap = train_dev_overlap | dev_test_overlap | train_test_overlap

print('Number of unique overlapping tokens:', len(all_tokens_overlap))

Number of unique overlapping tokens: 0


In [31]:
def write_tsv_file(data, path):
    '''
    Writes a list of sentence dictionaries (with 'tokens' and 'ner_tags') to a TSV file.
    Each token-label pair is written on its own line, separated by a tab.
    Sentences are separated by empty lines.

    Parameters:
        data (List[dict]): List of sentence dictionaries.
        path (str): Path to write the TSV file to.
    '''
    with open(path, 'w', encoding='utf-8') as f:
        for sentence in data:
            tokens = sentence['tokens']
            ner_tags = sentence['ner_tags']
            for token, tag in zip(tokens, ner_tags):
                f.write(f"{token}\t{tag}\n")
            f.write("\n")  # sentence separator

In [32]:
write_tsv_file(train_data, '../data/da_news_new/new_da_news_train.tsv')
write_tsv_file(dev_data, '../data/da_news_new/new_da_news_dev.tsv')
write_tsv_file(test_data, '../data/da_news_new/new_da_news_test.tsv')