In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import datasets
conll = datasets.load_dataset("conll2003")
CONLL_NER_TAGS = conll['train'].features['ner_tags'].feature.names
print(CONLL_NER_TAGS)
conll["test"][2]

Found cached dataset conll2003 (/Users/azatsultanov/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']


{'id': '2',
 'tokens': ['AL-AIN', ',', 'United', 'Arab', 'Emirates', '1996-12-06'],
 'pos_tags': [22, 6, 22, 22, 23, 11],
 'chunk_tags': [11, 0, 11, 12, 12, 12],
 'ner_tags': [5, 0, 5, 6, 6, 0]}

In [3]:
from transformers import (pipeline, 
        AutoModelForTokenClassification, AutoTokenizer, 
        BertForTokenClassification, BertTokenizer)

# Load pretrained model and tokenizer for English NER task (dslim/bert-base-NER)
model_name = "dslim/bert-base-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForTokenClassification.from_pretrained(model_name)

In [20]:
import numpy as np
import torch

class ModelWrapper():
    def __init__(self, model, tokenizer):
        self.model = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy=None)

    def _unite_entities(self, entities):
        if len(entities) <= 1:
            return entities

        united_result = []
        cur_entity = {key: entities[0][key] for key in ['entity', 'word', 'start', 'end']}
        for entity in entities[1:]:
            if entity['word'].startswith('##'):
                cur_entity['word'] += entity['word'].lstrip('#')
                cur_entity['end'] = entity['end']
            else:
                united_result.append(cur_entity)
                cur_entity = {key: entity[key] for key in ['entity', 'word', 'start', 'end']}
        united_result.append(cur_entity)
        return united_result

    def _convert_entities_to_bio(self, tokens, entities):
        bio_tags = []
        cur_entity_idx = 0
        for token in tokens:
            if (cur_entity_idx < len(entities))\
                    and (token == entities[cur_entity_idx]['word']):
                bio_tags.append(entities[cur_entity_idx]['entity'])
                cur_entity_idx += 1
            else:
                bio_tags.append('O')
        return bio_tags

    def _postprocessing(self, tokens, model_output):
        entities = self._unite_entities(model_output)
        bio_tags = self._convert_entities_to_bio(tokens, entities)
        return bio_tags

    def predict(self, X):
        with torch.no_grad():
            ner_entitites = self.model(X)
            tags = []
            for text, entities in zip(X, ner_entitites):
                tags.append(self._postprocessing(text.split(), entities))
            return tags

wrapped_model = ModelWrapper(model, tokenizer)

In [21]:
def preprocessing_dataset(dataset):
    result = {'X': [], 'y_true': []}
    for sample in dataset:
        result['X'].append(' '.join(sample['tokens']))
        result['y_true'].append([CONLL_NER_TAGS[tag] for tag in sample['ner_tags']])
    return result

In [25]:
# from itertools import islice


# train_data = preprocessing_dataset(conll['train'])
# train_data['y_pred'] = wrapped_model.predict(train_data['X'])

In [26]:

# oos_data = preprocessing_dataset(conll['test'])
# oos_data['y_pred'] = wrapped_model.predict(oos_data['X'])

In [29]:
# import pickle

# with open('./train_data.pkl', 'wb') as f:
#     pickle.dump(train_data, f)
# with open('./oos_data.pkl', 'wb') as f:
#     pickle.dump(oos_data, f)


In [30]:
import pickle

with open('./train_data.pkl', 'rb') as f:
    train_data = pickle.load(f)

with open('./oos_data.pkl', 'rb') as f:
    oos_data = pickle.load(f)

In [32]:
all([len(i) == len(j) for i, j in zip(train_data['y_true'], train_data['y_pred'])])


True

In [33]:
all([len(i) == len(j) for i, j in zip(oos_data['y_true'], oos_data['y_pred'])])


True

In [37]:
from sbe_vallib.validation.sampler.ner_sampler import NerSampler

sampler = NerSampler(train=train_data, oos=oos_data)
sampler.set_seed(42, bootstrap=True)

In [39]:
'o'.split('-')

['o']

In [44]:
list(set([tag.split('-')[-1] for tag in CONLL_NER_TAGS]) - set(['O']))

['PER', 'MISC', 'LOC', 'ORG']

In [45]:
from nervaluate import Evaluator

evaluator = Evaluator(oos_data['y_true'], oos_data['y_pred'], tags=['PER', 'MISC', 'LOC', 'ORG'], loader='list')
ner_metrics, ner_metrics_by_tag = evaluator.evaluate()

In [47]:
ner_metrics_by_tag

{'PER': {'ent_type': {'correct': 1117,
   'incorrect': 33,
   'partial': 0,
   'missed': 467,
   'spurious': 15,
   'possible': 1617,
   'actual': 1165,
   'precision': 0.9587982832618026,
   'recall': 0.6907854050711194,
   'f1': 0.8030194104960461},
  'partial': {'correct': 1097,
   'incorrect': 0,
   'partial': 53,
   'missed': 467,
   'spurious': 15,
   'possible': 1617,
   'actual': 1165,
   'precision': 0.9643776824034335,
   'recall': 0.6948051948051948,
   'f1': 0.8076923076923076},
  'strict': {'correct': 1069,
   'incorrect': 81,
   'partial': 0,
   'missed': 467,
   'spurious': 15,
   'possible': 1617,
   'actual': 1165,
   'precision': 0.9175965665236051,
   'recall': 0.6611008039579468,
   'f1': 0.7685118619698059},
  'exact': {'correct': 1097,
   'incorrect': 53,
   'partial': 0,
   'missed': 467,
   'spurious': 15,
   'possible': 1617,
   'actual': 1165,
   'precision': 0.9416309012875537,
   'recall': 0.6784168212739641,
   'f1': 0.7886412652767792}},
 'MISC': {'ent_typ

In [53]:
from sbe_vallib.validation.scorer.base import BaseScorer

class NerScorer(BaseScorer):
    def __init__(self, metrics: dict,
                 custom_metrics={},
                 is_calc_ner_metrics=True,
                 tags=['PER', 'MISC', 'LOC', 'ORG'],
                 **kwargs, ):
        super().__init__(metrics, custom_metrics)
        self.tags = tags
        self.is_calc_ner_metrics = is_calc_ner_metrics

    def ner_metrics(self, y_true, y_pred):
        answer = {}
        evaluator = Evaluator(y_true, y_pred, tags=self.tags, loader='list')
        ner_metrics, ner_metrics_by_tag = evaluator.evaluate()

        for schema in ['ent_type', 'partial', 'strict', 'exact']:
            for metric in ['f1', 'recall', 'precision']:
                answer[f"{schema}_{metric}"] = ner_metrics[schema][metric]
                for tag in self.tags:
                    answer[f"{schema}_{metric}_by_{tag}"] = ner_metrics_by_tag[tag][schema][metric]
        return answer

    def score(self, y_true, y_pred, **kwargs):
        answer = {}
        for metric_name in self.metrics:
            answer[metric_name] = self.metrics[metric_name]["callable"](
                y_true, y_pred
            )
        if self.is_calc_ner_metrics:
            answer.update(self.ner_metrics(y_true, y_pred))    
        return answer


scorer = NerScorer({}, {})
scores = scorer.score(oos_data['y_true'], oos_data['y_pred'])

In [54]:
scores

{'ent_type_f1': 0.8238541153277475,
 'ent_type_f1_by_PER': 0.8030194104960461,
 'ent_type_f1_by_MISC': 0.7612179487179488,
 'ent_type_f1_by_LOC': 0.8062648961525366,
 'ent_type_f1_by_ORG': 0.8829452485840151,
 'ent_type_recall': 0.7399079320113314,
 'ent_type_recall_by_PER': 0.6907854050711194,
 'ent_type_recall_by_MISC': 0.6766381766381766,
 'ent_type_recall_by_LOC': 0.709832134292566,
 'ent_type_recall_by_ORG': 0.8446718844069837,
 'ent_type_precision': 0.9292861907938625,
 'ent_type_precision_by_PER': 0.9587982832618026,
 'ent_type_precision_by_MISC': 0.86996336996337,
 'ent_type_precision_by_LOC': 0.9330181245074862,
 'ent_type_precision_by_ORG': 0.924851680949242,
 'partial_f1': 0.8498767865943814,
 'partial_f1_by_PER': 0.8076923076923076,
 'partial_f1_by_MISC': 0.7908653846153845,
 'partial_f1_by_LOC': 0.8478038815117467,
 'partial_f1_by_ORG': 0.9118942731277533,
 'partial_recall': 0.7632790368271954,
 'partial_recall_by_PER': 0.6948051948051948,
 'partial_recall_by_MISC': 0.7029