In [1]:
from corus import load_wikiner
from transformers import AutoTokenizer, DataCollatorForTokenClassification, Trainer, TrainingArguments
from transformers import AutoModelForTokenClassification
import evaluate
import numpy as np
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
datagen = load_wikiner('./aij-wikiner-ru-wp3.bz2')

In [3]:
dict = []
possible_ner_tags = set()
for item in datagen:
    d = {'sent': [], 'tags': []}
    for x in item.tokens:
        d['sent'].append(x.text)
        d['tags'].append(x.tag)
        ## Кажется, можно было бы достать возможные теги лучше...
        if x.tag not in possible_ner_tags:
            possible_ner_tags.add(x.tag)
    dict.append(d)
ner_list = list(possible_ner_tags)
ner_list.remove("O")

In [4]:
ner_tag_to_idx = {}
ner_tag_to_idx['O'] = 0
for idx,tag in enumerate(ner_list):
    ner_tag_to_idx[tag] = idx+1
#ner_tag_to_idx = {tag: idx+1 for idx, tag in enumerate(ner_list)}
print(ner_tag_to_idx)

ner_idx_to_tag = {v: k for k,v in ner_tag_to_idx.items()}

{'O': 0, 'I-ORG': 1, 'I-LOC': 2, 'I-MISC': 3, 'B-MISC': 4, 'B-LOC': 5, 'I-PER': 6, 'B-ORG': 7, 'B-PER': 8}


In [5]:
for d in dict:
    d['ner_tags'] = [ner_tag_to_idx[tag] for tag in d['tags']]

In [6]:
tokenizer = AutoTokenizer.from_pretrained('DeepPavlov/rubert-base-cased')

In [7]:
toks = tokenizer(dict[0]['sent'],truncation=True, is_split_into_words=True)
print(toks)
tokenizer.convert_ids_to_tokens(toks['input_ids'])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{'input_ids': [101, 2988, 14576, 24340, 869, 105058, 128, 1469, 12266, 130, 130, 869, 69981, 128, 1469, 9059, 130, 13124, 130, 130, 236, 49322, 851, 37210, 33424, 3590, 132, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


['[CLS]',
 'На',
 'севере',
 'граничит',
 'с',
 'Латвией',
 ',',
 'на',
 'востоке',
 '-',
 '-',
 'с',
 'Белоруссией',
 ',',
 'на',
 'юго',
 '-',
 'западе',
 '-',
 '-',
 'c',
 'Польшей',
 'и',
 'Калининградской',
 'областью',
 'России',
 '.',
 '[SEP]']

In [8]:
toks

{'input_ids': [101, 2988, 14576, 24340, 869, 105058, 128, 1469, 12266, 130, 130, 869, 69981, 128, 1469, 9059, 130, 13124, 130, 130, 236, 49322, 851, 37210, 33424, 3590, 132, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [9]:
def align(dict):
    dataset = []
    for i in range(len(dict)):
        toks = tokenizer(dict[i]['sent'], truncation=True, is_split_into_words=True)
        
        token_ner = []
        for idx in toks.word_ids():
            if idx is None: token_ner.append(-100)
            else:
                token_ner.append(dict[i]['ner_tags'][idx])
        toks['labels'] = token_ner
        dataset.append(toks)
    return dataset

In [10]:
dict = align(dict)

In [11]:
## Для динамического пэддинга предложений
collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
seqeval = evaluate.load("seqeval")

In [12]:
def metrics(p):
    preds, labels = p
    preds = np.argmax(preds, axis=2)

    true_preds = [[ner_idx_to_tag[p] for (p, l) in zip(pred, label) if l != -100] for pred, label in zip(preds, labels)]
    true_labels = [[ner_idx_to_tag[l] for (p, l) in zip(pred, label) if l != -100] for pred, label in zip(preds, labels)]

    res = seqeval.compute(predictions=true_preds, references=true_labels)
    return {"precision": res["overall_precision"],
            "recall": res["overall_recall"],
            "f1": res["overall_f1"],
            "accuracy": res["overall_accuracy"],}

In [13]:
train_args = TrainingArguments(
    output_dir = 'russian_ner_test',
    learning_rate = 3e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    weight_decay=0.02,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
)

In [14]:
model = AutoModelForTokenClassification.from_pretrained('DeepPavlov/rubert-base-cased', num_labels=len(ner_idx_to_tag))
model.config.id2label = ner_idx_to_tag
model.config.label2id = ner_tag_to_idx

Some weights of BertForTokenClassification were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
dtrain, dtest = train_test_split(dict, test_size=0.2, random_state=42)

ner_dataset = DatasetDict({
    "train": Dataset.from_list(dtrain),
    "test": Dataset.from_list(dtest)
})

In [16]:
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=ner_dataset["train"],
    eval_dataset=ner_dataset["test"],
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=metrics,
)

In [None]:
trainer.train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss


In [None]:
from datasets import load_dataset
wnut = load_dataset("wnut_17")
example = wnut['train'][0]
label_list = wnut["train"].features[f"ner_tags"].feature.names
labels = [label_list[i] for i in example[f"ner_tags"]]

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def inference(text):
    inputs = tokenizer(text, return_tensors='pt').to(device)
    logits = model(**inputs).logits
    preds = torch.argmax(logits, dim=2)
    class_list = [model.config.id2label[t.item()] for t in preds[0]]
    print(text.split(' '))
    print(class_list)

In [None]:
inference("Мой Китик самый лучший на планете Земля и совсем скоро выздоровеет)")