In [4]:
!python -m pip install datasets



In [5]:
!python -m pip install transformers



In [6]:
!pip install evaluate



In [7]:
!pip install seqeval



In [8]:
import os
import random
from typing import Dict, List, Tuple, Union  # эмуляция статической типизации

In [11]:
import matplotlib.colors as mcolors  # красиво раскрасим наши именованные сущности
import nltk
from nltk.tokenize.treebank import TreebankWordDetokenizer
from nltk.corpus import wordnet
import evaluate
import numpy as np
from datasets import load_dataset, Dataset
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import DataCollatorForTokenClassification, TrainingArguments, Trainer
from transformers import pipeline

In [12]:
nltk_data_path = "/kaggle/working/nltk_data"
if not os.path.exists(nltk_data_path):
    os.makedirs(nltk_data_path)
nltk.data.path.append(nltk_data_path)

nltk.download('wordnet', download_dir=nltk_data_path)

[nltk_data] Downloading package wordnet to
[nltk_data]     /kaggle/working/nltk_data...


True

In [13]:
%cd /kaggle/working/nltk_data/corpora
!unzip wordnet.zip 
%cd /kaggle/working/

/kaggle/working/nltk_data/corpora


  pid, fd = os.forkpty()


Archive:  wordnet.zip
   creating: wordnet/
  inflating: wordnet/lexnames        
  inflating: wordnet/data.verb       
  inflating: wordnet/index.adv       
  inflating: wordnet/adv.exc         
  inflating: wordnet/index.verb      
  inflating: wordnet/cntlist.rev     
  inflating: wordnet/data.adj        
  inflating: wordnet/index.adj       
  inflating: wordnet/LICENSE         
  inflating: wordnet/citation.bib    
  inflating: wordnet/noun.exc        
  inflating: wordnet/verb.exc        
  inflating: wordnet/README          
  inflating: wordnet/index.sense     
  inflating: wordnet/data.noun       
  inflating: wordnet/data.adv        
  inflating: wordnet/index.noun      
  inflating: wordnet/adj.exc         
/kaggle/working


In [14]:
DATASET_NAME = 'adsabs/WIESP2022-NER'

In [15]:
RANDOM_SEED = 42

random.seed(RANDOM_SEED)

torch.manual_seed(RANDOM_SEED)

torch.cuda.manual_seed(RANDOM_SEED)

np.random.seed(RANDOM_SEED)

In [67]:
trainset = load_dataset(DATASET_NAME, split='train')

In [68]:
def get_entity(dataset):
    sequences = []
    
    for example in dataset:
        tokens = example["tokens"]
        tags = example["ner_tags"]
        ids = example["ner_ids"]
        
        i = 0
        while i < len(tags):
            if tags[i].startswith("B-"):
                entity_tokens = [tokens[i]]
                entity_tags = [tags[i]]
                entity_ids = [ids[i]]
                
                j = i + 1
                while j < len(tags) and tags[j].startswith("I-"):
                    entity_tokens.append(tokens[j])
                    entity_tags.append(tags[j])
                    entity_ids.append(ids[j])
                    j += 1
                
                sequences.append((entity_tokens, entity_tags, entity_ids))
                i = j 
            else:
                i += 1
    
    return sequences

def insert_random_entity(example, entity_sequences):
    tokens = example["tokens"]
    ner_tags = example["ner_tags"]
    ner_ids = example["ner_ids"]
    
    random_tokens, random_tags, random_ids = random.choice(entity_sequences)
    
    possible_positions = [i for i, tag in enumerate(ner_tags) if tag == "O"]
    if not possible_positions:
        return example 
    
    insert_position = random.choice(possible_positions) 
    
    tokens = tokens[:insert_position] + random_tokens + tokens[insert_position:]
    ner_tags = ner_tags[:insert_position] + random_tags + ner_tags[insert_position:]
    ner_ids = ner_ids[:insert_position] + random_ids + ner_ids[insert_position:]
    
    example["tokens"] = tokens
    example["ner_tags"] = ner_tags
    example["ner_ids"] = ner_ids
    
    return example

entity_sequences = get_entity(trainset)

augmented_dataset = trainset.map(lambda x: insert_random_entity(x, entity_sequences))

trainset = Dataset.from_dict({
    "bibcode": trainset["bibcode"] + augmented_dataset["bibcode"],
    "label_studio_id": trainset["label_studio_id"] + augmented_dataset["label_studio_id"],
    "ner_ids": trainset["ner_ids"] + augmented_dataset["ner_ids"],
    "ner_tags": trainset["ner_tags"] + augmented_dataset["ner_tags"],
    "section": trainset["section"] + augmented_dataset["section"],
    "tokens": trainset["tokens"] + augmented_dataset["tokens"],
    "unique_id": trainset["unique_id"] + augmented_dataset["unique_id"],
})

In [69]:
def mask_tokens(example):
    tokens = example["tokens"]
    augmented_tokens = []

    for token in tokens:
        if random.random() < 0.25:
            augmented_tokens.append("[MASK]")
        else:
            augmented_tokens.append(token)

    example["augmented_tokens"] = augmented_tokens
    return example


augmented_dataset = trainset.map(mask_tokens)

trainset = Dataset.from_dict({
    "bibcode": trainset["bibcode"] + augmented_dataset["bibcode"],
    "label_studio_id": trainset["label_studio_id"] + augmented_dataset["label_studio_id"],
    "ner_ids": trainset["ner_ids"] + augmented_dataset["ner_ids"],
    "ner_tags": trainset["ner_tags"] + augmented_dataset["ner_tags"],
    "section": trainset["section"] + augmented_dataset["section"],
    "tokens": trainset["tokens"] + augmented_dataset["tokens"],
    "unique_id": trainset["unique_id"] + augmented_dataset["unique_id"],
    "augmented_tokens": [None] * len(trainset) + augmented_dataset["augmented_tokens"]
})

Map:   0%|          | 0/3506 [00:00<?, ? examples/s]

In [15]:
'''
def replace_with_synonyms(example):
    tokens = example["tokens"]
    augmented_tokens = []

    for token in tokens:
        if random.random() < 0.3:
            synonyms = wordnet.synsets(token)
            if synonyms:
                synonym = random.choice(synonyms).lemmas()[0].name()  # Берем случайный синоним
                augmented_tokens.append(synonym)
            else:
                augmented_tokens.append(token)
        else:
            augmented_tokens.append(token)
    
    example["augmented_tokens"] = augmented_tokens
    return example


augmented_dataset_1 = trainset.map(replace_with_synonyms)

trainset = Dataset.from_dict({
    "bibcode": trainset["bibcode"] + augmented_dataset_1["bibcode"],
    "label_studio_id": trainset["label_studio_id"] + augmented_dataset_1["label_studio_id"],
    "ner_ids": trainset["ner_ids"] + augmented_dataset_1["ner_ids"],
    "ner_tags": trainset["ner_tags"] + augmented_dataset_1["ner_tags"],
    "section": trainset["section"] + augmented_dataset_1["section"],
    "tokens": trainset["tokens"] + augmented_dataset_1["tokens"],
    "unique_id": trainset["unique_id"] + augmented_dataset_1["unique_id"],
    "augmented_tokens": [None] * len(trainset) + augmented_dataset_1["augmented_tokens"]
})
'''



Map:   0%|          | 0/1753 [00:00<?, ? examples/s]

In [77]:
trainset.shape

(7012, 8)

In [78]:
label_set = set()

for it in trainset['ner_tags']:
    label_set |= set(it)
    
label_list = ['O'] + sorted(list(label_set - {'O'}))

In [79]:
entity_classes = sorted(list(set(
    map(lambda it2: it2[2:],filter(lambda it1: it1 != 'O',label_list))
)))

In [80]:
entity_colors = [mcolors.rgb2hex((0.5 + random.random() / 2, 0.5 +random.random() / 2, 0.5 +random.random() / 2))
                 for _ in range(len(entity_classes))]

In [81]:
MODEL_NAME = 'FacebookAI/xlm-roberta-base'  # бертоподобный языконезависимый трансформер-энкодер

In [82]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [83]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)
    labels = []
    for i, label in enumerate(examples['ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label_list.index(label[word_idx]))
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [84]:
tokenized_trainset = trainset.map(tokenize_and_align_labels, batched=True)

Map:   0%|          | 0/7012 [00:00<?, ? examples/s]

In [85]:
valset = load_dataset(DATASET_NAME, split='validation')

In [86]:
tokenized_valset = valset.map(tokenize_and_align_labels, batched=True)

In [87]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [88]:
seqeval = evaluate.load('seqeval')

In [89]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]



    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        'precision': results['overall_precision'],
        'recall': results['overall_recall'],
        'f1': results['overall_f1'],
        'accuracy': results['overall_accuracy'],
    }

In [90]:
id2label = dict(enumerate(label_list))

In [91]:
label2id = dict((val, idx) for idx, val in enumerate(label_list))

In [92]:
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at FacebookAI/xlm-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [93]:
MODEL_NAME_ON_DISK = os.path.abspath('astro_ner')

In [94]:
training_args = TrainingArguments(
    output_dir=MODEL_NAME_ON_DISK,
    logging_dir=os.path.join(MODEL_NAME_ON_DISK, 'logs'),
    learning_rate=1e-4,
    warmup_ratio=0.5,  # делаем "прогрев": начинаем с околонулевого lr и до середины обучения (то есть до пятой эпохи, если у нас их 10) линейно увеличиваем до 1e-4
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,  # при тестировании не считаются градиенты, поэтому мини-батч можно и побольше
    num_train_epochs=10,
    weight_decay=0.01,  # для регуляризации обновлений весов
    eval_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,  # чтобы не забивать жёсткий диск, будем сохранять только 2 чекпойнта: лучший и последний
    logging_strategy='epoch',
    report_to='tensorboard',  # хотим красиво нарисовать графики обучения в tensorboard
    metric_for_best_model='f1',
    greater_is_better=True,  # чем больше f1, тем лучше
    load_best_model_at_end=True,
    seed=RANDOM_SEED,
    data_seed=RANDOM_SEED
)

In [95]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_trainset,
    eval_dataset=tokenized_valset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [96]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,1.0161,0.205406,0.671623,0.701005,0.686,0.95071
2,0.1291,0.15426,0.750896,0.80714,0.778003,0.961694
3,0.0658,0.169222,0.7562,0.822043,0.787748,0.96203
4,0.0389,0.17536,0.79295,0.816036,0.804327,0.96586
5,0.0271,0.186534,0.785341,0.814919,0.799856,0.964508
6,0.0191,0.19709,0.799226,0.818924,0.808955,0.965359
7,0.0079,0.22203,0.809301,0.818308,0.813779,0.966752
8,0.0033,0.238912,0.804267,0.827435,0.815687,0.966834
9,0.0014,0.250974,0.80676,0.826395,0.816459,0.966799
10,0.0008,0.258696,0.809337,0.82782,0.818474,0.967009


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=4390, training_loss=0.13096066903416279, metrics={'train_runtime': 4351.3203, 'train_samples_per_second': 16.115, 'train_steps_per_second': 1.009, 'total_flos': 1.833223318622208e+16, 'train_loss': 0.13096066903416279, 'epoch': 10.0})

In [97]:
for it in os.listdir(MODEL_NAME_ON_DISK): print(it)

checkpoint-3951
logs
checkpoint-4390


In [98]:
possible_checkpoints = sorted(
    list(map(
        lambda it2: os.path.join(MODEL_NAME_ON_DISK, it2),
        filter(
            lambda it1: it1.startswith('checkpoint-'),
            os.listdir(MODEL_NAME_ON_DISK)
        )
    )),
    key=lambda it3: -len(os.listdir(it3))
)

In [99]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [100]:
classifier = pipeline('ner', model=possible_checkpoints[0], device=0)

In [101]:
testset = load_dataset(DATASET_NAME, split='test')

In [102]:
def find_token(token_bounds: List[Tuple[int, int]], char_idx: int) -> int:
    res = -1
    for token_idx, (token_start, token_end) in enumerate(token_bounds):
        if (char_idx >= token_start) and (char_idx < token_end):
            res = token_idx
            break
    return res

In [103]:
def predictions_to_bio(text: str, tokens: List[str], predictions: List[Tuple[int, int, str]]) -> List[str]:
    token_bounds = []
    token_labels = []
    start_pos = 0
    for cur_token in tokens:
        found_idx = text[start_pos:].find(cur_token)
        
        if found_idx < 0:
            err_msg = f'The token {cur_token} is not found in the text {text}'
            raise RuntimeError(err_msg)
            
        token_start = found_idx + start_pos
        token_end = token_start + len(cur_token)
        start_pos = token_end
        token_bounds.append((token_start, token_end))
        token_labels.append('O')

    for span_start, span_end, span_label in predictions:
        start_token = find_token(token_bounds, span_start)
        end_token = find_token(token_bounds, span_end - 1)
        if (start_token >= 0) and (end_token >= 0):
            for token_idx in range(start_token, end_token + 1):
                token_labels[token_idx] = span_label

        elif start_token >= 0:
            token_labels[start_token] = span_label

        elif end_token >= 0:
            token_labels[end_token] = span_label

    corrected_token_labels = []
    previous_label = 'O'

    for cur_label in token_labels:
        if cur_label == previous_label:
            corrected_token_labels.append(cur_label)

        else:
            if (cur_label == 'O') or cur_label.startswith('B-'):
                corrected_token_labels.append(cur_label)

            else:
                if previous_label == 'O':
                    corrected_token_labels.append('B-' + cur_label[2:])

                elif previous_label[2:] != cur_label[2:]:
                    corrected_token_labels.append('B-' + cur_label[2:])

                else:
                    corrected_token_labels.append(cur_label)

        previous_label = cur_label

    return corrected_token_labels

In [104]:
from tqdm.notebook import tqdm

In [105]:
y_true = []
y_pred = []

for tokens, reference_tags in tqdm(zip(testset['tokens'], testset['ner_tags']), total=len(testset)):
    y_true.append(reference_tags)
    cur_text = TreebankWordDetokenizer().detokenize(tokens)
    cur_res = classifier(cur_text)
    y_pred.append(
        predictions_to_bio(
            cur_text,
            tokens,
            [(it['start'], it['end'], it['entity']) for it in cur_res]
        )
    )

  0%|          | 0/2505 [00:00<?, ?it/s]

In [106]:
from seqeval.scheme import IOB2
from seqeval.metrics import classification_report

In [65]:
print(classification_report(y_true, y_pred, digits=4))

                         precision    recall  f1-score   support

                Archive     0.7052    0.4930    0.5803       359
        CelestialObject     0.7127    0.5051    0.5912      3609
  CelestialObjectRegion     0.3188    0.0913    0.1419       723
        CelestialRegion     0.2626    0.2488    0.2555       209
               Citation     0.8430    0.5731    0.6824      8621
          Collaboration     0.7066    0.6752    0.6906       428
      ComputingFacility     0.5111    0.4926    0.5017       607
               Database     0.3138    0.2661    0.2880       342
                Dataset     0.4230    0.2500    0.3143       516
 EntityOfFutureInterest     0.3333    0.0069    0.0135       435
                  Event     0.4500    0.4576    0.4538        59
             Fellowship     0.5168    0.5074    0.5121       607
                Formula     0.7135    0.4948    0.5843      3452
                  Grant     0.4157    0.3615    0.3867      5259
             Identifier 