In [80]:
!python -m pip install spacy -qq # в данном примере нужен только displacy для красивых визуализаций

In [81]:
!python -m pip install datasets -qq  # чтобы загружать датасеты с HF

In [82]:
!python -m pip install transformers -qq  # чтобы работать с трансформерными нейросетями и загружать их с HF

In [83]:
!python -m pip install evaluate -qq  # библиотека из экосистемы HF для интеграции метрик качества в Trainer

In [84]:
!python -m pip install seqeval -qq  # нужно для расчёта F1 по сущностям, а не по токенам

In [85]:
!python -m pip install tensorboard -qq # чтобы выводить графики обучения

In [86]:
import os
import random
from typing import Dict, List, Tuple, Union  # эмуляция статической типизации

In [87]:
import matplotlib.colors as mcolors  # раскрасим наши именованные сущности
from nltk.tokenize.treebank import TreebankWordDetokenizer
from spacy import displacy
import evaluate
import numpy as np
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import DataCollatorForTokenClassification, TrainingArguments, Trainer
from transformers import pipeline


In [88]:
# Фиксируем все рандом сиды, которые только можем
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Загрузим данные
Хороший датасет для NER по текстам на тему астрофизики.
Этой задаче в 2022 году была посвящена статья https://aclanthology.org/2022.wiesp-1.1.pdf и даже целый воркшоп https://ui.adsabs.harvard.edu/WIESP/ .   
Ссылка на данные https://huggingface.co/datasets/adsabs/WIESP2022-NER

In [89]:
DATASET_NAME = 'adsabs/WIESP2022-NER'

In [90]:
trainset = load_dataset(DATASET_NAME, split='train')

## Посмотрим на все теги сущностей

In [91]:
label_set = set()
for it in trainset['ner_tags']:
    label_set |= set(it)
label_list = ['O'] + sorted(list(label_set - {'O'}))

In [92]:
for it in label_list: print(it)

O
B-Archive
B-CelestialObject
B-CelestialObjectRegion
B-CelestialRegion
B-Citation
B-Collaboration
B-ComputingFacility
B-Database
B-Dataset
B-EntityOfFutureInterest
B-Event
B-Fellowship
B-Formula
B-Grant
B-Identifier
B-Instrument
B-Location
B-Mission
B-Model
B-ObservationalTechniques
B-Observatory
B-Organization
B-Person
B-Proposal
B-Software
B-Survey
B-Tag
B-Telescope
B-TextGarbage
B-URL
B-Wavelength
I-Archive
I-CelestialObject
I-CelestialObjectRegion
I-CelestialRegion
I-Citation
I-Collaboration
I-ComputingFacility
I-Database
I-Dataset
I-EntityOfFutureInterest
I-Event
I-Fellowship
I-Formula
I-Grant
I-Identifier
I-Instrument
I-Location
I-Mission
I-Model
I-ObservationalTechniques
I-Observatory
I-Organization
I-Person
I-Proposal
I-Software
I-Survey
I-Tag
I-Telescope
I-TextGarbage
I-URL
I-Wavelength


In [93]:
print(f'Size of the entity tag dictionary is {len(label_list)}.')

Size of the entity tag dictionary is 63.


## Посмотрим на все классы

In [94]:
entity_classes = sorted(list(set(
    map(
        lambda it2: it2[2:],
        filter(
            lambda it1: it1 != 'O',
            label_list
        )
    )
)))

In [95]:
for it in entity_classes: print(it)

Archive
CelestialObject
CelestialObjectRegion
CelestialRegion
Citation
Collaboration
ComputingFacility
Database
Dataset
EntityOfFutureInterest
Event
Fellowship
Formula
Grant
Identifier
Instrument
Location
Mission
Model
ObservationalTechniques
Observatory
Organization
Person
Proposal
Software
Survey
Tag
Telescope
TextGarbage
URL
Wavelength


In [96]:
print(f'Number of entity classes is {len(entity_classes)}.')

Number of entity classes is 31.


# Визуализируем текст с разметкой

In [97]:
entity_colors = [mcolors.rgb2hex((0.5 + random.random() / 2, 0.5 +random.random() / 2, 0.5 +random.random() / 2))
                 for _ in range(len(entity_classes))]

In [98]:
assert len(entity_classes) == len(entity_colors)

In [None]:
def bio_to_spans(bio: List[str]) -> List[Dict[str, Union[int, str]]]:
    ne_tag = ''
    start_pos = -1
    bounds = []
    for idx, val in enumerate(bio):
        if val.upper() == 'O':
            if start_pos >= 0:
                bounds.append({
                    'start_token': start_pos,
                    'end_token': idx,
                    'label': ne_tag
                })
            start_pos = -1
            ne_tag = ''
        elif val.upper().startswith('B-'):
            if start_pos >= 0:
                bounds.append({
                    'start_token': start_pos,
                    'end_token': idx,
                    'label': ne_tag
                })
            start_pos = idx
            ne_tag = val[2:]
    if start_pos >= 0:
        bounds.append({
            'start_token': start_pos,
            'end_token': len(bio),
            'label': ne_tag
        })
    return bounds

In [100]:
sample_for_rendering = {
    'text': TreebankWordDetokenizer().detokenize(trainset[0]['tokens']),
    'spans': bio_to_spans(trainset[0]['ner_tags']),
    'tokens': trainset[0]['tokens'],
}

In [101]:
rendered = displacy.render(
    sample_for_rendering, style='span',
    options={'ents': entity_classes, 'colors': dict(zip(entity_classes, entity_colors))},
    manual=True, jupyter=True
)

In [None]:
def bio_to_ent(tokens: List[str], bio: List[str]) -> Tuple[str, List[Dict[str, Union[int, str]]]]:
    if len(tokens) != len(bio):
        err_msg = f'Tokens do not correspond to their labels: {len(tokens)} != {len(bio)}!'
        raise RuntimeError(err_msg)
    full_text = TreebankWordDetokenizer().detokenize(tokens)
    token_bounds = []
    previous_pos = 0
    for cur in tokens:
        found_idx = full_text[previous_pos:].find(cur)
        if found_idx < 0:
            err_msg = f'The token {cur} is not found in the text "{full_text}".'
            raise RuntimeError(err_msg)
        token_start = found_idx + previous_pos
        token_end = token_start + len(cur)
        token_bounds.append((token_start, token_end))
        previous_pos = token_end
    entity_spans = bio_to_spans(bio)
    entity_bounds = []
    for cur in entity_spans:
        entity_class = cur['label']
        entity_start = token_bounds[cur['start_token']][0]
        entity_end = token_bounds[cur['end_token'] - 1][1]
        entity_bounds.append({
            'start': entity_start,
            'end': entity_end,
            'label': entity_class
        })
    del token_bounds, entity_spans
    return full_text, entity_bounds

In [103]:
sample_for_rendering_2 = dict(zip(
    ('text', 'ents'),
    bio_to_ent(trainset[0]['tokens'], trainset[0]['ner_tags'])
))

In [104]:
rendered_2 = displacy.render(
    sample_for_rendering_2, style='ent',
    options={'ents': entity_classes, 'colors': dict(zip(entity_classes, entity_colors))},
    manual=True, jupyter=True
)

# Загрузим модель и токенайзер

In [105]:
MODEL_NAME = 'adsabs/astroBERT'

In [106]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,add_prefix_space=True)

In [107]:
example = trainset[0]
tokenized_input = tokenizer(example['tokens'], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input['input_ids'])
print(tokens)

['[CLS]', 'Whilst', 'a', 'reasonable', 'harmonic', 'fit', 'to', 'the', 'ESP', '##a', '##DO', '##nS', 'data', 'can', 'be', 'achieved', 'using', 'this', 'period', ',', 'it', 'does', 'not', 'produce', 'an', 'acceptable', 'phasing', 'of', 'all', 'available', '〈', 'B', 'z', '〉', 'measurements', '.', 'Figure', '1', '.', 'Photometric', '(', 'top', ')', 'and', 'magnetic', '〈', 'B', 'z', '〉', '(', 'bottom', ')', 'measurements', ',', 'phased', 'with', 'periods', 'determined', 'from', '(', 'left', 'to', 'right', ')', 'K2', 'photometry', ',', 'all', '〈', 'B', 'z', '〉', 'measurements', ',', 'and', 'all', 'photometric', 'measurements', '.', '〈', 'B', 'z', '〉', 'measurements', 'were', 'obtained', 'from', 'ESP', '##a', '##DO', '##nS', 'by', 'Shu', '##l', '##tz', 'et', 'al', '.', '(', '2018', ')', 'and', 'photop', '##olar', '##imetric', 'data', 'by', 'Bor', '##ra', 'et', 'al', '.', '(', '1983', ',', 'BL', '##T', '##83', ')', 'and', 'Boh', '##len', '##der', 'et', 'al', '.', '(', '1993', ',', 'BL', '##T'

# Токенизируем текст

In [108]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples['tokens'],
        truncation=True,
        max_length=512,
        is_split_into_words=True,
    )
    labels = []
    for i, label in enumerate(examples['ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label_list.index(label[word_idx]))
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [109]:
tokenized_trainset = trainset.map(tokenize_and_align_labels, batched=True)

In [110]:
for k in sorted(tokenized_trainset[0].keys()):
    print(f'{k}\t{tokenized_trainset[0][k]}')

attention_mask	[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [111]:
valset = load_dataset(DATASET_NAME, split='validation')

In [112]:
tokenized_valset = valset.map(tokenize_and_align_labels, batched=True)

In [113]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

# Загрузим метрику

In [114]:
seqeval = evaluate.load('seqeval')

In [115]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels,zero_division=0)
    return {
        'precision': results['overall_precision'],
        'recall': results['overall_recall'],
        'f1': results['overall_f1'],
        'accuracy': results['overall_accuracy'],
    }

In [116]:
id2label = dict(enumerate(label_list))

In [117]:
print(id2label)

{0: 'O', 1: 'B-Archive', 2: 'B-CelestialObject', 3: 'B-CelestialObjectRegion', 4: 'B-CelestialRegion', 5: 'B-Citation', 6: 'B-Collaboration', 7: 'B-ComputingFacility', 8: 'B-Database', 9: 'B-Dataset', 10: 'B-EntityOfFutureInterest', 11: 'B-Event', 12: 'B-Fellowship', 13: 'B-Formula', 14: 'B-Grant', 15: 'B-Identifier', 16: 'B-Instrument', 17: 'B-Location', 18: 'B-Mission', 19: 'B-Model', 20: 'B-ObservationalTechniques', 21: 'B-Observatory', 22: 'B-Organization', 23: 'B-Person', 24: 'B-Proposal', 25: 'B-Software', 26: 'B-Survey', 27: 'B-Tag', 28: 'B-Telescope', 29: 'B-TextGarbage', 30: 'B-URL', 31: 'B-Wavelength', 32: 'I-Archive', 33: 'I-CelestialObject', 34: 'I-CelestialObjectRegion', 35: 'I-CelestialRegion', 36: 'I-Citation', 37: 'I-Collaboration', 38: 'I-ComputingFacility', 39: 'I-Database', 40: 'I-Dataset', 41: 'I-EntityOfFutureInterest', 42: 'I-Event', 43: 'I-Fellowship', 44: 'I-Formula', 45: 'I-Grant', 46: 'I-Identifier', 47: 'I-Instrument', 48: 'I-Location', 49: 'I-Mission', 50: '

In [118]:
label2id = dict((val, idx) for idx, val in enumerate(label_list))

In [119]:
print(label2id)

{'O': 0, 'B-Archive': 1, 'B-CelestialObject': 2, 'B-CelestialObjectRegion': 3, 'B-CelestialRegion': 4, 'B-Citation': 5, 'B-Collaboration': 6, 'B-ComputingFacility': 7, 'B-Database': 8, 'B-Dataset': 9, 'B-EntityOfFutureInterest': 10, 'B-Event': 11, 'B-Fellowship': 12, 'B-Formula': 13, 'B-Grant': 14, 'B-Identifier': 15, 'B-Instrument': 16, 'B-Location': 17, 'B-Mission': 18, 'B-Model': 19, 'B-ObservationalTechniques': 20, 'B-Observatory': 21, 'B-Organization': 22, 'B-Person': 23, 'B-Proposal': 24, 'B-Software': 25, 'B-Survey': 26, 'B-Tag': 27, 'B-Telescope': 28, 'B-TextGarbage': 29, 'B-URL': 30, 'B-Wavelength': 31, 'I-Archive': 32, 'I-CelestialObject': 33, 'I-CelestialObjectRegion': 34, 'I-CelestialRegion': 35, 'I-Citation': 36, 'I-Collaboration': 37, 'I-ComputingFacility': 38, 'I-Database': 39, 'I-Dataset': 40, 'I-EntityOfFutureInterest': 41, 'I-Event': 42, 'I-Fellowship': 43, 'I-Formula': 44, 'I-Grant': 45, 'I-Identifier': 46, 'I-Instrument': 47, 'I-Location': 48, 'I-Mission': 49, 'I-Mo

# Загрузим модель

In [120]:
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at adsabs/astroBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [121]:
MODEL_NAME_ON_DISK = os.path.abspath('astro_ner')

In [122]:
training_args = TrainingArguments(
    output_dir=MODEL_NAME_ON_DISK,

    learning_rate=2e-5,

    warmup_ratio=0.1,


    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,

    num_train_epochs=10,
    weight_decay=0.01,

    eval_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=1,
    logging_strategy='steps',
    logging_steps=50,

    metric_for_best_model='f1',
    greater_is_better=True,
    load_best_model_at_end=True,

    fp16=True,

    seed=RANDOM_SEED,
    data_seed=RANDOM_SEED
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [123]:
from transformers import EarlyStoppingCallback

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_trainset,
    eval_dataset=tokenized_valset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

  trainer = Trainer(


In [124]:
os.environ["WANDB_DISABLED"] = "true"

In [125]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 16340}.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.4804,0.261936,0.552626,0.645955,0.595657,0.939226
2,0.1629,0.142708,0.699314,0.776891,0.736064,0.962193
3,0.107,0.120204,0.763942,0.81569,0.788968,0.967871
4,0.0832,0.117034,0.764311,0.832429,0.796917,0.968756
5,0.0621,0.120132,0.780919,0.831684,0.805503,0.969255
6,0.0525,0.120458,0.793844,0.835089,0.813944,0.970429
7,0.0478,0.122476,0.789856,0.837784,0.813114,0.97027
8,0.0397,0.126405,0.789772,0.836295,0.812368,0.970026
9,0.0339,0.129342,0.797353,0.837394,0.816883,0.970445
10,0.0332,0.131696,0.798802,0.837217,0.817558,0.970656


TrainOutput(global_step=1100, training_loss=0.19386046160351147, metrics={'train_runtime': 775.7535, 'train_samples_per_second': 22.597, 'train_steps_per_second': 1.418, 'total_flos': 4583058296555520.0, 'train_loss': 0.19386046160351147, 'epoch': 10.0})

In [126]:
possible_checkpoints = sorted(
    list(map(
        lambda it2: os.path.join(MODEL_NAME_ON_DISK, it2),
        filter(
            lambda it1: it1.startswith('checkpoint-'),
            os.listdir(MODEL_NAME_ON_DISK)
        )
    )),
    key=lambda it3: -len(os.listdir(it3))
)

In [127]:
for it in possible_checkpoints: print(it)

/content/astro_ner/checkpoint-1100


In [128]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Получаем готовый пайлайн

In [129]:
classifier = pipeline('ner', model=possible_checkpoints[0], device=0)

Device set to use cuda:0


In [130]:
original_text = 'The authors would like to thank Adam Burgasser, Brendan Bowler, Kelle Cruz, Mike Cushing, Michael Liu, and Emily Rice for useful discussions on benchmark systems, data treatment, and various data-model comparison approaches. The authors thank Richard Freedman and Roxana Lupu for providing gas opacities and Caroline Morley for radiative transfer code comparisons and helpful discussions. We thank Jacob Lustig-Yeager and Kyle Luther for rewriting portions of the code in python and C for significant speed improvements and also Dan Foreman-Mackey for making EMCEE available to the community. Finally, we thank the anonymous referee and statistics consultant for useful and insightful comments.'

In [131]:
original_res = classifier(original_text)
for it in original_res: print(it)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{'entity': 'B-Person', 'score': np.float32(0.99849176), 'index': 7, 'word': 'Adam', 'start': 32, 'end': 36}
{'entity': 'I-Person', 'score': np.float32(0.99904865), 'index': 8, 'word': 'Burgasser', 'start': 37, 'end': 46}
{'entity': 'B-Person', 'score': np.float32(0.9986143), 'index': 10, 'word': 'Bren', 'start': 48, 'end': 52}
{'entity': 'I-Person', 'score': np.float32(0.8725239), 'index': 11, 'word': '##da', 'start': 52, 'end': 54}
{'entity': 'I-Person', 'score': np.float32(0.980095), 'index': 12, 'word': '##n', 'start': 54, 'end': 55}
{'entity': 'I-Person', 'score': np.float32(0.99912935), 'index': 13, 'word': 'Bow', 'start': 56, 'end': 59}
{'entity': 'I-Person', 'score': np.float32(0.99885774), 'index': 14, 'word': '##ler', 'start': 59, 'end': 62}
{'entity': 'B-Person', 'score': np.float32(0.9985781), 'index': 16, 'word': 'Kell', 'start': 64, 'end': 68}
{'entity': 'I-Person', 'score': np.float32(0.6958668), 'index': 17, 'word': '##e', 'start': 68, 'end': 69}
{'entity': 'I-Person', '

In [132]:
original_res_2 = classifier(original_text, aggregation_strategy='first')
for it in original_res_2: print(it)

{'entity_group': 'Person', 'score': np.float32(0.99877024), 'word': 'Adam Burgasser', 'start': 32, 'end': 46}
{'entity_group': 'Person', 'score': np.float32(0.9988718), 'word': 'Brendan Bowler', 'start': 48, 'end': 62}
{'entity_group': 'Person', 'score': np.float32(0.9988537), 'word': 'Kelle Cruz', 'start': 64, 'end': 74}
{'entity_group': 'Person', 'score': np.float32(0.9988067), 'word': 'Mike Cushing', 'start': 76, 'end': 88}
{'entity_group': 'Person', 'score': np.float32(0.9987556), 'word': 'Michael Liu', 'start': 90, 'end': 101}
{'entity_group': 'Person', 'score': np.float32(0.99866116), 'word': 'Emily Rice', 'start': 107, 'end': 117}
{'entity_group': 'Person', 'score': np.float32(0.9987589), 'word': 'Richard Freedman', 'start': 243, 'end': 259}
{'entity_group': 'Person', 'score': np.float32(0.9986496), 'word': 'Roxana Lupu', 'start': 264, 'end': 275}
{'entity_group': 'Person', 'score': np.float32(0.9985831), 'word': 'Caroline Morley', 'start': 308, 'end': 323}
{'entity_group': 'Per

In [133]:
sample_for_rendering_3 = {
    'text': original_text,
    'ents': [{'start': it['start'], 'end': it['end'], 'label': it['entity_group']}
             for it in original_res_2]
}

In [134]:
rendered_3 = displacy.render(
    sample_for_rendering_3, style='ent',
    options={'ents': entity_classes, 'colors': dict(zip(entity_classes, entity_colors))},
    manual=True, jupyter=True
)

In [135]:
ru_text = 'Авторы хотели бы поблагодарить Адама Бургассера, Брендана Боулера, Куз, Майка Кушинга, Майкла Лю и Эмили Райс за панныхвторы благодарят Ричарда Фридмана и Роксану Лупу за предоставление непрозрачности газа и Кэролайн Морли за сравнения кодов переноса излучения и полезные обсуждения. Мы благодарим Джейкоба Люстига-Йегера и Кайла Лютера за переписывание частей кода на Python и C для значительного улучшения скорости, а также Дэна Формана-Макки за предоставление EMCEE сообществу. Наконец, мы благодарим анонимного рецензента и консультанта по статистике за полезные и проницательные комментарии.'

In [136]:
ru_res = classifier(
    ru_text,
    aggregation_strategy='first',
)
for it in ru_res: print(it)

In [137]:
sample_for_rendering_4 = {
    'text': ru_text,
    'ents': [{'start': it['start'], 'end': it['end'], 'label': it['entity_group']}
             for it in ru_res]
}

In [138]:
rendered_4 = displacy.render(
    sample_for_rendering_4, style='ent',
    options={'ents': entity_classes, 'colors': dict(zip(entity_classes, entity_colors))},
    manual=True, jupyter=True
)

In [139]:
testset = load_dataset(DATASET_NAME, split='test')

In [140]:
def find_token(token_bounds: List[Tuple[int, int]], char_idx: int) -> int:
    res = -1
    for token_idx, (token_start, token_end) in enumerate(token_bounds):
        if (char_idx >= token_start) and (char_idx < token_end):
            res = token_idx
            break
    return res

In [141]:
def predictions_to_bio(text: str, tokens: List[str], predictions: List[Tuple[int, int, str]]) -> List[str]:
    token_bounds = []
    token_labels = []
    start_pos = 0
    for cur_token in tokens:
        found_idx = text[start_pos:].find(cur_token)
        if found_idx < 0:
            err_msg = f'The token {cur_token} is not found in the text {text}'
            raise RuntimeError(err_msg)
        token_start = found_idx + start_pos
        token_end = token_start + len(cur_token)
        start_pos = token_end
        token_bounds.append((token_start, token_end))
        token_labels.append('O')
    for span_start, span_end, span_label in predictions:
        start_token = find_token(token_bounds, span_start)
        end_token = find_token(token_bounds, span_end - 1)
        if (start_token >= 0) and (end_token >= 0):
            for token_idx in range(start_token, end_token + 1):
                token_labels[token_idx] = span_label
        elif start_token >= 0:
            token_labels[start_token] = span_label
        elif end_token >= 0:
            token_labels[end_token] = span_label
    corrected_token_labels = []
    previous_label = 'O'
    for cur_label in token_labels:
        if cur_label == previous_label:
            corrected_token_labels.append(cur_label)
        else:
            if (cur_label == 'O') or cur_label.startswith('B-'):
                corrected_token_labels.append(cur_label)
            else:
                if previous_label == 'O':
                    corrected_token_labels.append('B-' + cur_label[2:])
                elif previous_label[2:] != cur_label[2:]:
                    corrected_token_labels.append('B-' + cur_label[2:])
                else:
                    corrected_token_labels.append(cur_label)
        previous_label = cur_label
    return corrected_token_labels

In [142]:
from tqdm.notebook import tqdm

In [None]:
# y_true = []
# y_pred = []
# for tokens, reference_tags in tqdm(zip(testset['tokens'], testset['ner_tags']), total=len(testset)):
#     y_true.append(reference_tags)
#     cur_text = TreebankWordDetokenizer().detokenize(tokens)
#     cur_res = classifier(cur_text)
#     y_pred.append(
#         predictions_to_bio(
#             cur_text,
#             tokens,
#             [(it['start'], it['end'], it['entity']) for it in cur_res]
#         )
#     )

from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, Trainer, TrainingArguments
import numpy as np

tokenizer = AutoTokenizer.from_pretrained(possible_checkpoints[0])
model = AutoModelForTokenClassification.from_pretrained(possible_checkpoints[0])
model.eval()
if torch.cuda.is_available():
    model.to('cuda')

# Функция токенизации
def tokenize_and_align_labels_for_inference(examples):
    tokenized_inputs = tokenizer(
        examples['tokens'],
        truncation=True,
        is_split_into_words=True,
        max_length=512
    )
    labels = []
    for i, label in enumerate(examples['ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label_list.index(label[word_idx]) if label[word_idx] in label_list else -100)
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs['labels'] = labels
    return tokenized_inputs

# Токенизация тестсета 
tokenized_testset = testset.map(tokenize_and_align_labels_for_inference, batched=True)

# Инференс через Trainer 
training_args = TrainingArguments(
    output_dir='./temp_infer',
    per_device_eval_batch_size=32,
    report_to="none",
    use_cpu=not torch.cuda.is_available(),
    dataloader_pin_memory=False 
)

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
)

predictions, labels, _ = trainer.predict(tokenized_testset)
predictions = np.argmax(predictions, axis=2)

# Преобразование в BIO-теги 
id2label = model.config.id2label

true_predictions = [
    [id2label[p] for (p, l) in zip(pred, label) if l != -100]
    for pred, label in zip(predictions, labels)
]

true_labels = [
    [id2label[l] for (p, l) in zip(pred, label) if l != -100]
    for pred, label in zip(predictions, labels)
]

# Присваиваем напрямую
y_true = true_labels
y_pred = true_predictions

print(f"Пример длины: исходный тестсет = {len(testset)}, y_true = {len(y_true)}, y_pred = {len(y_pred)}")
print(f"Длина первого примера: токены = {len(testset[0]['tokens'])}, теги = {len(y_true[0])}")

Map:   0%|          | 0/2505 [00:00<?, ? examples/s]

  trainer = Trainer(


Пример длины: исходный тестсет = 2505, y_true = 2505, y_pred = 2505
Длина первого примера: токены = 153, теги = 153


In [159]:
for true_label, predicted_label in zip(y_true[0], y_pred[0]):
    print('{0:>25}   {1:>25}'.format(true_label, predicted_label))

                        O                           O
                        O                           O
                        O                           O
                        O                           O
                        O                           O
                        O                           O
                 B-Person                    B-Person
                 I-Person                    I-Person
                 B-Person                    B-Person
                 I-Person                    I-Person
                 B-Person                    B-Person
                 I-Person                    I-Person
                 B-Person                    B-Person
                 I-Person                    I-Person
                 B-Person                    B-Person
                 I-Person                    I-Person
                        O                           O
                 B-Person                    B-Person
                 I-Person   

In [160]:
from seqeval.scheme import IOB2
from seqeval.metrics import classification_report

In [161]:
print(classification_report(y_true, y_pred, digits=4))

  _warn_prf(average, modifier, msg_start, len(result))


                         precision    recall  f1-score   support

                Archive     0.8516    0.8723    0.8619       329
        CelestialObject     0.8212    0.8804    0.8497      2733
  CelestialObjectRegion     0.4064    0.2091    0.2761       550
        CelestialRegion     0.4848    0.5096    0.4969       157
               Citation     0.9496    0.9758    0.9625      6778
          Collaboration     0.6554    0.7411    0.6957       367
      ComputingFacility     0.6382    0.6621    0.6499       586
               Database     0.5255    0.5410    0.5331       305
                Dataset     0.4427    0.4825    0.4617       400
 EntityOfFutureInterest     0.0000    0.0000    0.0000       338
                  Event     0.3448    0.3636    0.3540        55
             Fellowship     0.5885    0.6886    0.6346       594
                Formula     0.7766    0.7929    0.7846      2665
                  Grant     0.4911    0.5224    0.5063      5069
             Identifier 

In [162]:
print(classification_report(y_true, y_pred, digits=4, mode='strict', scheme=IOB2))

                         precision    recall  f1-score   support

                Archive     0.8750    0.8723    0.8737       329
        CelestialObject     0.8299    0.8785    0.8535      2733
  CelestialObjectRegion     0.4730    0.2073    0.2882       550
        CelestialRegion     0.5678    0.4268    0.4873       157
               Citation     0.9602    0.9749    0.9675      6778
          Collaboration     0.7452    0.7411    0.7432       367
      ComputingFacility     0.6696    0.6570    0.6632       586
               Database     0.5449    0.5377    0.5413       305
                Dataset     0.5040    0.4775    0.4904       400
 EntityOfFutureInterest     0.0000    0.0000    0.0000       338
                  Event     0.5000    0.3636    0.4211        55
             Fellowship     0.6303    0.6717    0.6504       594
                Formula     0.8583    0.7910    0.8233      2665
                  Grant     0.5155    0.5145    0.5150      5069
             Identifier 