In [6]:
import os, sys

In [7]:
sys.path.append(os.path.join(os.getcwd(), 'ner'))

In [2]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [4]:
tokenizer_kwargs = {
        "max_length": 512,
        "truncation": "only_second",
        "padding": True,
        "return_tensors": "pt",
        "return_offsets_mapping": True,
    }

In [11]:
PROMPT_MAPPER = {
    "LOC": "location",
    "PER": "person",
    "ORG": "organization",
    "MISC": "miscellaneous entity"
}

In [16]:
import torch
from torch import load

In [35]:
checkpoint = load('./qa_ner_0.pkl', map_location=torch.device('cpu'))

In [42]:
model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [43]:
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [77]:
from ner.qa_types import QAInstance, QASpan

In [45]:
from ner.metrics.qa import get_top_valid_spans



In [107]:
def check_questions(instances, model):
    with torch.no_grad():
        context_list = []
        question_list = []

        for instance in instances:
            context_list.append(instance.context)
            question_list.append(instance.question)

        tokenized_batch = tokenizer(
                question_list, context_list, **tokenizer_kwargs
        )

        offset_mapping_batch = tokenized_batch.pop("offset_mapping")
        outputs = model.forward(**tokenized_batch)
        spans_pred_batch_top_1 = get_top_valid_spans(
                context_list=context_list,
                question_list=question_list,
                prompt_mapper=PROMPT_MAPPER,
                inputs=tokenized_batch,
                outputs=outputs,
                offset_mapping_batch=offset_mapping_batch,
                n_best_size=1,
                max_answer_length=100,
            )
        
        for idx in range(len(spans_pred_batch_top_1)):
            if not spans_pred_batch_top_1[idx]:
                empty_span = QASpan(
                    token="",
                    label="O",
                    start_context_char_pos=0,
                    end_context_char_pos=0,
                )
                spans_pred_batch_top_1[idx] = [empty_span]
        
        return spans_pred_batch_top_1

In [112]:
instances = [
    QAInstance(
        context='Sergey Sobyanin started reforms in Moscow in september 2022', 
        question='What is the location?',
        answer=None
    ),
    QAInstance(
        context='Sergey Sobyanin started reforms in Moscow in september 2022', 
        question='What is the person?',
        answer=None
    ),
    QAInstance(
        context='UN chief delays his next trip to focus on Russia\'s suspension of the Black Sea grain deal', 
        question='What is the organization?',
        answer=None
    ),
    QAInstance(
        context='UN chief delays his next trip to focus on Russia\'s suspension of the Black Sea grain deal', 
        question='What is the location?',
        answer=None
    ),
]

In [113]:
answers = check_questions(instances, model)

In [117]:
for instance, answers_ in zip(instances, answers):
    print(f'Context: {instance.context} | Question: {instance.question}')
    for answer in answers_:
        if answer.label != 'O':
            print(f'\t{answer.token} is {PROMPT_MAPPER[answer.label]}')

Context: Sergey Sobyanin started reforms in Moscow in september 2022 Question: What is the location?
	Moscow is location
Context: Sergey Sobyanin started reforms in Moscow in september 2022 Question: What is the person?
	Sergey is person
Context: UN chief delays his next trip to focus on Russia's suspension of the Black Sea grain deal Question: What is the organization?
	UN is organization
Context: UN chief delays his next trip to focus on Russia's suspension of the Black Sea grain deal Question: What is the location?
