In [None]:
%load_ext autoreload
%autoreload 2

## Dataset

In [None]:
import datasets
conll = datasets.load_dataset("conll2003")
CONLL_NER_TAGS = conll['train'].features['ner_tags'].feature.names
print(CONLL_NER_TAGS)
conll["test"][2]

In [None]:
sorted(list(set([tag.split('-')[-1] for tag in CONLL_NER_TAGS]) - set(['O'])))

In [None]:
def preprocessing_dataset(dataset):
    result = {'X': [], 'y_true': []}
    for sample in dataset:
        result['X'].append(' '.join(sample['tokens']))
        result['y_true'].append([CONLL_NER_TAGS[tag] for tag in sample['ner_tags']])
    return result

In [None]:
from itertools import islice

train_data = preprocessing_dataset(conll['train'])
oos_data = preprocessing_dataset(conll['test'])

In [None]:
import numpy as np
import torch
from tqdm import tqdm
from transformers import (pipeline, 
        AutoModelForTokenClassification, AutoTokenizer, 
        BertForTokenClassification, BertTokenizer)


class ModelWrapper():
    def __init__(self, model, tokenizer, classes):
        self.model = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy=None)
        self.classes_ = classes

    def _unite_entities(self, entities):
        if len(entities) <= 1:
            return entities

        united_result = []
        cur_entity = {key: entities[0][key] for key in ['entity', 'word', 'start', 'end']}
        for entity in entities[1:]:
            if entity['word'].startswith('##'):
                cur_entity['word'] += entity['word'].lstrip('#')
                cur_entity['end'] = entity['end']
            else:
                united_result.append(cur_entity)
                cur_entity = {key: entity[key] for key in ['entity', 'word', 'start', 'end']}
        united_result.append(cur_entity)
        return united_result

    def _convert_entities_to_bio(self, tokens, entities):
        bio_tags = []
        cur_entity_idx = 0
        for token in tokens:
            if (cur_entity_idx < len(entities))\
                    and (token == entities[cur_entity_idx]['word']):
                bio_tags.append(entities[cur_entity_idx]['entity'])
                cur_entity_idx += 1
            else:
                bio_tags.append('O')
        return bio_tags

    def _postprocessing(self, tokens, model_output):
        entities = self._unite_entities(model_output)
        bio_tags = self._convert_entities_to_bio(tokens, entities)
        return bio_tags

    def predict(self, X):
        with torch.no_grad():
            ner_entitites = self.model(X)
            tags = []
            for text, entities in tqdm(zip(X, ner_entitites)):
                tags.append(self._postprocessing(text.split(), entities))
            return tags


# Load pretrained model and tokenizer for English NER task (dslim/bert-base-NER)
model_name = "dslim/bert-base-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForTokenClassification.from_pretrained(model_name)
wrapped_model = ModelWrapper(model, tokenizer, ['LOC', 'MISC', 'ORG', 'PER'])

In [None]:
train_data['y_pred'] = wrapped_model.predict(train_data['X'])
oos_data['y_pred'] = wrapped_model.predict(oos_data['X'])

In [None]:
# import pickle

# with open('./train_data.pkl', 'wb') as f:
#     pickle.dump(train_data, f)
# with open('./oos_data.pkl', 'wb') as f:
#     pickle.dump(oos_data, f)

In [None]:
import pickle

with open('./train_data.pkl', 'rb') as f:
    train_data = pickle.load(f)

with open('./oos_data.pkl', 'rb') as f:
    oos_data = pickle.load(f)

In [None]:
from sbe_vallib.validation.sampler.ner_sampler import NerSampler
from sbe_vallib.validation.scorer.ner_scorer import NerScorer
from sbe_vallib.validation.utils.metrics import NER_IOB_METRICS


sampler = NerSampler(train=train_data, oos=oos_data)


scorer = NerScorer(metrics=NER_IOB_METRICS)
scores = scorer.score(oos_data['y_true'], oos_data['y_pred'])
scores

In [None]:
from sbe_vallib import Validation

ner_validation = Validation(model = wrapped_model, scorer=scorer, sampler=sampler, pipeline='sbe_vallib/validation/nlp/pipelines/Config_45.xlsx')
res = ner_validation.validate()

In [None]:
res.keys()