### Setup

In [78]:
import json
import random
import nltk
nltk.download('punkt')

from tqdm.auto import tqdm
from typing import List

from nltk.tokenize import sent_tokenize, word_tokenize

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


### Config

In [79]:
class Config:
    val_size = 0.15
    seed = 42
    
    do_split = False
    save_nerel = True

random.seed(Config.seed)

### Read data

In [80]:
train_sentences = []

with open("/kaggle/input/ner-datasets/datasets/nerel/nerel_train_raw.jsonl") as f:
    for line in f:
        train_sentences.append(json.loads(line))
        #train_sentences[-1]['sentences'] = train_sentences[-1]['sentences'].split('\n\n')

dev_sentences = []

with open("/kaggle/input/ner-datasets/datasets/nerel/nerel_dev_raw.jsonl") as f:
    for line in f:
        line = json.loads(line)
        line['sentences'] = line['senences']
        del line['senences']
        dev_sentences.append(line)

### Split train

In [81]:
if Config.do_split:
    random.shuffle(train_sentences)
    split_ix = int(len(train_sentences) * Config.val_size)
    train_train_sentences = train_sentences[split_ix:]
    train_val_sentences = train_sentences[:split_ix]
    
    with open('nerel_train_train_raw.jsonl', 'w', encoding='utf-8') as f:
        for line in train_train_sentences:
            print(json.dumps(line), file=f)
    
    with open('nerel_train_val_raw.jsonl', 'w', encoding='utf-8') as f:
        for line in train_val_sentences:
            print(json.dumps(line), file=f)
else:
    train_train_sentences = []
    with open("/kaggle/input/ner-datasets/datasets/nerel/nerel_train_train_raw.jsonl") as f:
        for line in f:
            train_train_sentences.append(json.loads(line))
            
    train_val_sentences = []
    with open("/kaggle/input/ner-datasets/datasets/nerel/nerel_train_val_raw.jsonl") as f:
        for line in f:
            train_val_sentences.append(json.loads(line))

In [82]:
print(len(train_train_sentences), len(train_val_sentences))

442 77


### Preprocess functions

In [83]:
def tokens_to_indices(text: str, tokens: list[str]):
    token_ix = 0
    token_char_ix = 0
    for i, char in enumerate(text):
        # check if token was found
        if len(tokens[token_ix]) == token_char_ix:
            yield i - token_char_ix, i - 1
            token_char_ix = 0
            token_ix += 1
            
            # check if all tokens returned
            if len(tokens) == token_ix:
                break
        
        # fix quoting bug
        if (tokens[token_ix] == "``" or tokens[token_ix] == "''") and char == '\"':
            yield i, i
            token_char_ix = 0
            token_ix += 1
            
            # check if all tokens returned
            if len(tokens) == token_ix:
                break

        if tokens[token_ix][token_char_ix] == char:
            token_char_ix += 1
        else:
            token_char_ix = 0
            
    # token not found error
    if len(tokens) != token_ix:
        if len(tokens[token_ix]) == token_char_ix:
            yield i + 1 - token_char_ix, i
            token_char_ix = 0
            token_ix += 1
        else:
            print(char)
            raise ValueError(f"Token '{tokens[token_ix]}' by index {token_ix} not found")

def index_to_tokens(token_indices: list[tuple[int]], index: tuple[int]):
    token_overlaps = None
    for i, token_ix in enumerate(token_indices):
        # before overlap
        if token_ix[1] < index[0]:
            continue
        # overlap
        if token_ix[0] <= index[1]:
            if token_overlaps is None:
                token_overlaps = (i, i + 1)
            else:
                token_overlaps = (token_overlaps[0], i + 1)
        # no overlap
        if token_ix[0] > index[1]:
            return token_overlaps
    return token_overlaps

### Transformation functions

In [84]:
def nerel_to_sentinfo(nerel_data: List[dict], train=True) -> List[dict]:
    sentences = []
    for batch_ix, nerel_batch in tqdm(enumerate(nerel_data), total=len(nerel_data), desc="Parsing NEREL data"):
        sentence_tokens = sent_tokenize(nerel_batch['sentences'], language='russian')
        for sent_ix, sent_offset in enumerate(tokens_to_indices(nerel_batch['sentences'], sentence_tokens)):
            sent = sentence_tokens[sent_ix]
            
            # token info
            word_tokens = word_tokenize(sent, language='russian')
            tokens = []
            tokens_offsets = []
            for token, token_offset in zip(word_tokens, tokens_to_indices(sent, word_tokens)):
                tokens.append(token)
                tokens_offsets.append(token_offset)
            
            # ner info
            if train:
                ners = []
                for ner in nerel_batch['ners']:
                    # compute offset relative to sentence
                    ner_offset = ner[0] - sent_offset[0], ner[1] - sent_offset[0]
                    sent_len = sent_offset[1] - sent_offset[0]

                    # skip ners in other sentences
                    if ner_offset[1] < 0 or ner_offset[0] > sent_len:
                        continue

                    ner_indices = index_to_tokens(tokens_offsets, (ner_offset[0], ner_offset[1]))
                    if ner_indices:
                        ners.append({"type": ner[2], "start": ner_indices[0], "end": ner_indices[1]})
                    else:
                        print(tokens_offsets, (ner_offset[0], ner_offset[1]))
                        assert False, 'Token not found'
            
                sentences.append({
                    'sentence': sentence_tokens[sent_ix],
                    'sent_offset': sent_offset,
                    'tokens': tokens,
                    'tokens_offsets': tokens_offsets,
                    'batch_ix': batch_ix,
                    'sent_ix': sent_ix,
                    'ners': ners
                })
            else:
                sentences.append({
                    'sentence': sentence_tokens[sent_ix],
                    'sent_offset': sent_offset,
                    'tokens': tokens,
                    'tokens_offsets': tokens_offsets,
                    'batch_ix': batch_ix,
                    'sent_ix': sent_ix,
                    'id': nerel_batch['id']
                })
    
    return sentences

def sentinfo_to_spert(sentinfo: List[dict]) -> List[dict]:
    return [{
        'tokens': sent['tokens'],
        'entities': sent['ners'],
        'relations': []
    } for sent in sentinfo]

def sentinfo_to_spert_pred(sentinfo: List[dict]) -> List[dict]:
    return [{'tokens': sent['tokens']} for sent in sentinfo]

def pred_to_nerel(batches_count: int, sentinfo: List[dict], pred_data: List[dict]) -> List[dict]:
    ners = [[] for _ in range(batches_count)]
    
    for sent, pred in zip(sentinfo[5:], pred_data[5:]):
        for ner in pred['entities']:
            offsets = (
                sent['sent_offset'][0] + sent['tokens_offsets'][ner['start']][0],
                sent['sent_offset'][0] + sent['tokens_offsets'][ner['end'] - 1][1]
            )
            nerel_format = [offsets[0], offsets[1], ner['type']]
            ners[sent['batch_ix']].append(nerel_format)
    
    return ners

### Convert SPERT predictions to NEREL

In [85]:
# load spert prediction
with open('/kaggle/input/nerel-predictions/predictions/dev/rubert-large-15/rubert-large-15.json', 'r') as f:
    spert_prediction = json.load(f)

In [86]:
# load NEREL input
pred_sentences = []

with open("/kaggle/input/ner-datasets/datasets/nerel/nerel_dev_raw.jsonl") as f:
    for line in f:
        line = json.loads(line)
        line['sentences'] = line['senences']
        del line['senences']
        pred_sentences.append(line)

In [87]:
# convert NEREL input to sentinfo
pred_sentinfo = nerel_to_sentinfo(pred_sentences, train=False)
ners = pred_to_nerel(len(pred_sentences), pred_sentinfo, spert_prediction)

# write to file
with open('test.jsonl', 'w', encoding='utf-8') as f:
    for batch, ner_line in zip(pred_sentences, ners):
        print(json.dumps({"id": batch["id"], "ners": ner_line}), file=f)

# for batch_ix, ner_line in enumerate(ners):
#     print(pred_sentences[batch_ix]['sentences'])
#     for ner in ner_line:
#         print(ner, pred_sentences[batch_ix]['sentences'][ner[0]:ner[1] + 1])
#     break

Parsing NEREL data:   0%|          | 0/65 [00:00<?, ?it/s]

### Write NEREL in spert format

In [17]:
if Config.save_nerel:
    train_spert = sentinfo_to_spert(nerel_to_sentinfo(train_sentences))
    with open("nerel_train.json", "w", encoding="utf-8") as f:
        json.dump(train_spert, f)
        
    train_train_spert = sentinfo_to_spert(nerel_to_sentinfo(train_train_sentences))
    with open("nerel_train-train.json", "w", encoding="utf-8") as f:
        json.dump(train_train_spert, f)
        
    train_val_spert = sentinfo_to_spert(nerel_to_sentinfo(train_val_sentences))
    with open("nerel_train-val.json", "w", encoding="utf-8") as f:
        json.dump(train_val_spert, f)
        
    dev_spert = sentinfo_to_spert_pred(nerel_to_sentinfo(dev_sentences, train=False))
    with open("nerel_dev.json", "w", encoding="utf-8") as f:
        json.dump(dev_spert, f)

Parsing NEREL data:   0%|          | 0/519 [00:00<?, ?it/s]

Parsing NEREL data:   0%|          | 0/442 [00:00<?, ?it/s]

Parsing NEREL data:   0%|          | 0/77 [00:00<?, ?it/s]

Parsing NEREL data:   0%|          | 0/65 [00:00<?, ?it/s]

### Sanity check

In [11]:
# test if batches are tokenized into sentences and detokenized correctly
def test_sent_tokenization(nerel_data: List[List[dict]]):
    print('Sentence tokenization test:')
    for test_ix, test_data in enumerate(nerel_data, start=1):
        for ix in tqdm(range(len(test_data)), desc=f'Test {test_ix}/{len(nerel_data)}'):
            sentence_tokens = sent_tokenize(test_data[ix]['sentences'], language='russian')
            for i, index in enumerate(tokens_to_indices(test_data[ix]['sentences'], sentence_tokens)):
                orig = test_data[ix]['sentences'][index[0]:index[1] + 1]
                tok = sentence_tokens[i]
                if orig != tok:
                    print(orig)
                    print(tok)
                    assert False, f"Sentence tokenization test {test_ix}/{len(nerel_data)} failed."
    
    print('All sentence tokenization tests passed.')
    return True

In [12]:
def test_word_tokenization(nerel_data: List[List[dict]]):
    print('Word tokenization test:')
    for test_ix, test_data in enumerate(nerel_data, start=1):
        sentinfo = nerel_to_sentinfo(test_data)
        for ix in tqdm(range(len(sentinfo)), desc=f'Test {test_ix}/{len(nerel_data)}'):
            sent = sentinfo[ix]['sentence']
            word_tokens = word_tokenize(sent, language='russian')
            for token_ix, index in enumerate(tokens_to_indices(sentinfo[ix]['sentence'], word_tokens)):
                orig = sentinfo[ix]['sentence'][index[0]:index[1] + 1]
                tok = word_tokens[token_ix]
                # fix quoting bug
                if (tok == "``" or tok == "''") and orig == '\"':
                    tok = '\"'
                if orig != tok:
                    print(orig)
                    print(tok)
                    assert False, f"Word tokenization test {test_ix}/{len(nerel_data)} failed."
    
    print('All word tokenization tests passed.')
    return True

In [13]:
tests = [
    test_sent_tokenization([train_sentences, dev_sentences]),
    test_word_tokenization([train_sentences, dev_sentences])
]

if all(tests):
    print('All tests passed.')
else:
    failed_count = len(list(filter(lambda x: not x, tests)))
    print(f'{failed_count}/{len(tests)} tests failed.')

Sentence tokenization test:


Test 1/2:   0%|          | 0/519 [00:00<?, ?it/s]

Test 2/2:   0%|          | 0/65 [00:00<?, ?it/s]

All sentence tokenization tests passed.
Word tokenization test:


Parsing NEREL data:   0%|          | 0/519 [00:00<?, ?it/s]

Test 1/2:   0%|          | 0/6278 [00:00<?, ?it/s]

Parsing NEREL data:   0%|          | 0/65 [00:00<?, ?it/s]

KeyError: 'ners'