In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
from allennlp.models.archival import load_archive
from chemdataextractor.data import find_data
import json, os, appdirs
import pprint

In [2]:
overrides = {"model.text_field_embedder.token_embedders.bert.pretrained_model": find_data("models/scibert_cased_weights-1.0.tar.gz")}
cde_bert_archive = load_archive(find_data('models/bert_finetuned_crf_model-1.0a'), overrides=json.dumps(overrides))
cde_bertcrf_model = cde_bert_archive.model
cde_bertcrf_model_state_dict = cde_bertcrf_model.state_dict()

In [3]:
cde_bertcrf_model.vocab._index_to_token

_IndexToTokenDefaultDict(None, {'labels': {0: 'O', 1: 'I-CEM', 2: 'B-CEM'}})

In [4]:
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, AutoConfig, PretrainedConfig
from torch.utils.data import Dataset, DataLoader
from chemdataextractor.nlp.crf import ConditionalRandomField, allowed_transitions
from chemdataextractor.nlp.allennlp_modules import TimeDistributed
from chemdataextractor.errors import ConfigurationError
from typing import Dict, Optional, List, Tuple
from overrides import overrides

class BertCrfConfig(PretrainedConfig):
    model_type = 'bert'
    
    def __init__(
        self,
        num_tags: int = 3,
        dropout=0.1,
        label_namespace: str = "labels",
        label_encoding: Optional[str] = None,
        index_and_label: List[Tuple[int, str]] = None,
        constrain_crf_decoding: bool = True,
        include_start_end_transitions: bool = True,
        model_name_or_path: str = None,
        **kwargs
    ):
        self.num_tags = num_tags
        self.dropout = dropout
        self.label_namespace = label_namespace
        self.label_encoding = label_encoding
        self.index_and_label = index_and_label
        self.constrain_crf_decoding = constrain_crf_decoding
        self.include_start_end_transitions = include_start_end_transitions
        self.model_name_or_path = model_name_or_path
        super().__init__(**kwargs)


class BertCrfTagger(PreTrainedModel):
    config_class = BertCrfConfig  # Required for saving/loading
    
    def __init__(self, config):

        super().__init__(config)
        self.bert_model = AutoModel.from_config(AutoConfig.from_pretrained(config.model_name_or_path))
        self.num_tags = config.num_tags
        self.tag_projection_layer = TimeDistributed(
            nn.Linear(self.bert_model.config.hidden_size, self.num_tags)
        )

        self.label_encoding = config.label_encoding
        self.index_and_label = config.index_and_label
        self.index_to_label = self._index_to_label()
        self.label_to_index = self._label_to_index()
    
        if config.constrain_crf_decoding:
            if not config.label_encoding:
                raise ConfigurationError("constrain_crf_decoding is True, but "
                                         "no label_encoding was specified.")
            labels = self.index_to_label
            constraints = allowed_transitions(config.label_encoding, labels)
        else:
            constraints = None

        self.include_start_end_transitions = config.include_start_end_transitions
        self.crf = ConditionalRandomField(
                self.num_tags, constraints,
                include_start_end_transitions=config.include_start_end_transitions
        )

        
        # Dropout for regularization
        self.dropout = nn.Dropout(config.dropout)
    
    def _index_to_label(self):
        return {index: label for index, label in self.index_and_label}
    
    def _label_to_index(self):
        return {label: index for index, label in self.index_and_label}

    def forward(self, input_ids, attention_mask, labels=None):
        # BERT embeddings
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        
        # Project onto tag space
        logits = self.tag_projection_layer(sequence_output)
        best_paths = self.crf.viterbi_tags(logits, attention_mask)

        predicted_tags = [x for x, y in best_paths]

        output = {"logits": logits, "mask": attention_mask, "tags": predicted_tags}
        
        return output

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Converts the tag ids to the actual tags.
        ``output_dict["tags"]`` is a list of lists of tag_ids,
        so we use an ugly nested list comprehension.
        """
        output_dict["tags"] = [
                [self.index_to_label[tag]
                 for tag in instance_tags]
                for instance_tags in output_dict["tags"]
        ]
        return output_dict


In [5]:
bertcrf_config = BertCrfConfig(
    num_tags=3,
    label_namespace="labels",
    label_encoding="BIO",
    index_and_label=[(0, "O"), (1, "I-CEM"), (2, "B-CEM")],
    constrain_crf_decoding=True,
    include_start_end_transitions=False,
    dropout=0.1,
    model_name_or_path="allenai/scibert_scivocab_cased"
)

In [6]:
bertcrf_config.constrain_crf_decoding

True

In [7]:
save_dir = os.path.join(appdirs.user_data_dir('ChemDataExtractor'), 'models/hf_bert_crf_tagger')
print(save_dir)

/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger


In [8]:
bertcrf_config.save_pretrained(save_dir)

In [9]:
bertcrf_tagger = BertCrfTagger(bertcrf_config)

In [10]:
state_dict = {}
for k, v in cde_bertcrf_model_state_dict.items():
    if k.startswith("text_field_embedder.token_embedder_bert"):
        state_dict[k[40:]] = v
    else:
        state_dict[k] = v
# print(state_dict.keys())

In [11]:
bertcrf_tagger.load_state_dict(state_dict=state_dict, strict=False)

<All keys matched successfully>

In [12]:
bertcrf_tagger.save_pretrained(save_dir)

In [13]:
from transformers import BertTokenizer
tokenizer = BertTokenizer(vocab_file=find_data('models/scibert_cased_vocab-1.0.txt'), do_lower_case=False)
input_dict = tokenizer("The chemical formula of water is H2O.")

In [14]:
tokenizer.save_pretrained(save_dir)

('/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger/tokenizer_config.json',
 '/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger/special_tokens_map.json',
 '/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger/vocab.txt',
 '/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger/added_tokens.json')

In [22]:
bertcrf_tagger.eval()
with torch.no_grad():
    output = bertcrf_tagger(torch.tensor([input_dict["input_ids"]]), torch.tensor([input_dict["attention_mask"]]))
    tag = bertcrf_tagger.decode(output)

In [23]:
tag

{'logits': tensor([[[ 9.2465, -4.4532, -5.7068],
          [10.0801, -4.6235, -6.3135],
          [10.1462, -4.6416, -6.3580],
          [10.0347, -4.6187, -6.2793],
          [10.1487, -4.6174, -6.3801],
          [10.1001, -4.6115, -6.3406],
          [10.0827, -4.5815, -6.3484],
          [ 1.8490, -4.8789,  2.1975],
          [-0.2525,  4.1726, -2.9152],
          [-0.0103,  4.3760, -3.4589],
          [ 9.5042, -4.4805, -5.9208],
          [ 9.1265, -4.4352, -5.6125]]]),
 'mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'tags': [['O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'I-CEM',
   'B-CEM',
   'B-CEM',
   'O',
   'O']]}

In [26]:
from chemdataextractor.doc import Sentence
s = Sentence("The chemical formula of water is H2O.")
tagged_tokens = s.ner_tagged_tokens

tokens: {'bert': tensor([[  101,   186,  3556,  4841,   125,  1583,   163,   233, 30130, 30159,
           211,   102]]), 'bert-offsets': tensor([[ 1,  2,  3,  4,  5,  6,  7, 10]]), 'bert-type-ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
embedded_text_input: tensor([[[ 1.4035, -0.6533, -0.3828,  ..., -0.8605,  0.0689,  0.2380],
         [ 1.4013, -0.6504, -0.3632,  ..., -0.8577,  0.0815,  0.2464],
         [ 1.4011, -0.6624, -0.3886,  ..., -0.8685,  0.0624,  0.2448],
         ...,
         [ 1.4000, -0.6552, -0.3846,  ..., -0.8585,  0.0710,  0.2238],
         [-0.0509, -0.5476, -0.2018,  ..., -0.2917,  0.4364,  0.9864],
         [ 1.3849, -0.6988, -0.4754,  ..., -0.9258, -0.0135,  0.2528]]]) 
of shape: torch.Size([1, 8, 768])
mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1]])


In [27]:
tagged_tokens

[('The', 'O'),
 ('chemical', 'O'),
 ('formula', 'O'),
 ('of', 'O'),
 ('water', 'O'),
 ('is', 'O'),
 ('H2O', 'B-CM'),
 ('.', 'O')]