In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
from allennlp.models.archival import load_archive
from chemdataextractor.data import find_data
import json, os, appdirs
import pprint

In [2]:
overrides = {"model.text_field_embedder.token_embedders.bert.pretrained_model": find_data("models/scibert_cased_weights-1.0.tar.gz")}
cde_bert_archive = load_archive(find_data('models/bert_finetuned_crf_model-1.0a'), overrides=json.dumps(overrides))
cde_bertcrf_model = cde_bert_archive.model
cde_bertcrf_model_state_dict = cde_bertcrf_model.state_dict()

In [3]:
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, AutoConfig, PretrainedConfig
from torch.utils.data import Dataset, DataLoader
from chemdataextractor.nlp.crf import ConditionalRandomField, allowed_transitions
from chemdataextractor.nlp.allennlp_modules import TimeDistributed
from chemdataextractor.errors import ConfigurationError
from typing import Dict, Optional, List, Tuple
from overrides import overrides
import numpy as np
from chemdataextractor.nlp.util import (combine_initial_dims, get_device_of,
                                        get_range_vector,
                                        uncombine_initial_dims)

class BertCrfConfig(PretrainedConfig):
    model_type = 'bert'

    def __init__(
        self,
        num_tags: int = 3,
        dropout=0.1,
        label_namespace: str = "labels",
        label_encoding: Optional[str] = None,
        index_and_label: List[Tuple[int, str]] = None,
        constrain_crf_decoding: bool = True,
        include_start_end_transitions: bool = True,
        model_name_or_path: str = None,
        **kwargs
    ):
        self.num_tags = num_tags
        self.dropout = dropout
        self.label_namespace = label_namespace
        self.label_encoding = label_encoding
        self.index_and_label = index_and_label
        self.constrain_crf_decoding = constrain_crf_decoding
        self.include_start_end_transitions = include_start_end_transitions
        self.model_name_or_path = model_name_or_path
        super().__init__(**kwargs)


class BertCrfModel(PreTrainedModel):
    config_class = BertCrfConfig  # Required for saving/loading

    def __init__(self, config):

        super().__init__(config)
        self.bert_model = AutoModel.from_config(
            AutoConfig.from_pretrained(config.model_name_or_path))
        self.num_tags = config.num_tags
        self.dropout = nn.Dropout(config.dropout)
        self.tag_projection_layer = TimeDistributed(
            nn.Linear(self.bert_model.config.hidden_size, self.num_tags, bias=True)
        )

        self.label_encoding = config.label_encoding
        self.index_and_label = config.index_and_label
        self.index_to_label = self._index_to_label()
        self.label_to_index = self._label_to_index()

        if config.constrain_crf_decoding:
            if not config.label_encoding:
                raise ConfigurationError("constrain_crf_decoding is True, but "
                                         "no label_encoding was specified.")
            labels = self.index_to_label
            constraints = allowed_transitions(config.label_encoding, labels)
        else:
            constraints = None

        self.include_start_end_transitions = config.include_start_end_transitions
        self.crf = ConditionalRandomField(
            self.num_tags, constraints,
            include_start_end_transitions=config.include_start_end_transitions
        )


    def _index_to_label(self):
        return {index: label for index, label in self.index_and_label}

    def _label_to_index(self):
        return {label: index for index, label in self.index_and_label}

    def forward(self, input_ids, offsets, crf_mask, token_type_ids=None):
        # BERT embeddings
        # print(input_ids.size())

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        outputs = self.bert_model(input_ids=combine_initial_dims(input_ids),
                                  token_type_ids=combine_initial_dims(
                                      token_type_ids),
                                  attention_mask=combine_initial_dims(input_mask))
        # all_encoder_layers = torch.stack(outputs.last_hidden_state)
        last_hidden_state = outputs.last_hidden_state
        last_hidden_state = self.dropout(last_hidden_state)
        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)
        # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
        offsets2d = combine_initial_dims(offsets)
        # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
        range_vector = get_range_vector(offsets2d.size(0),
                                        device=get_device_of(last_hidden_state)).unsqueeze(1)
        # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
        selected_embeddings = last_hidden_state[range_vector, offsets2d]

        output_embeddings = uncombine_initial_dims(
            selected_embeddings, offsets.size())
        print("HF output_embeddings", output_embeddings)

        # TODO: Sperate the function into two parts: one for the BERT embeddings and the other for the CRF
        # print(sequence_output.size())
        # Project onto tag space
        logits = self.tag_projection_layer(output_embeddings)
        best_paths = self.crf.viterbi_tags(logits, crf_mask)

        predicted_tags = [x for x, y in best_paths]

        output = {"logits": logits, "mask": crf_mask, "tags": predicted_tags}

        return output
    
    def forward_on_instances(self, instances: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]:
        """
        Takes a list of  :class:`~allennlp.data.instance.Instance`s, converts that text into
        arrays using this model's :class:`Vocabulary`, passes those arrays through
        :func:`self.forward()` and :func:`self.decode()` (which by default does nothing)
        and returns the result.  Before returning the result, we convert any
        ``torch.Tensors`` into numpy arrays and separate the
        batched output into a list of individual dicts per instance. Note that typically
        this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to
        :func:`forward_on_instance`.

        Parameters
        ----------
        instances : Dict[str, torch.Tensor], required
            The instances to run the model on.

        Returns
        -------
        A list of the models output for each instance.
        """
        batch_size = instances['input_ids'].size(0)
        with torch.no_grad():
            instances = {k: v.to(self.device) for k, v in instances.items()}
            outputs = self.decode(self(**instances))

            instance_separated_output: List[Dict[str, np.ndarray]] = [{} for _ in range(batch_size)]
            for name, output in list(outputs.items()):
                if isinstance(output, torch.Tensor):
                    # NOTE(markn): This is a hack because 0-dim pytorch tensors are not iterable.
                    # This occurs with batch size 1, because we still want to include the loss in that case.
                    # if output.dim() == 0:
                    #     output = output.unsqueeze(0)

                    output = output.detach().cpu().numpy()
                for instance_output, batch_element in zip(instance_separated_output, output):
                    instance_output[name] = batch_element
            return instance_separated_output

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Converts the tag ids to the actual tags.
        ``output_dict["tags"]`` is a list of lists of tag_ids,
        so we use an ugly nested list comprehension.
        """
        output_dict["tags"] = [
            [self.index_to_label[tag]
             for tag in instance_tags]
            for instance_tags in output_dict["tags"]
        ]
        return output_dict


In [6]:
bertcrf_config = BertCrfConfig(
    num_tags=3,
    label_namespace="labels",
    label_encoding="BIO",
    index_and_label=[(0, "O"), (1, "I-CEM"), (2, "B-CEM")],
    constrain_crf_decoding=True,
    include_start_end_transitions=False,
    dropout=0.1,
    model_name_or_path="allenai/scibert_scivocab_cased"
)

In [2]:
save_dir = os.path.join(appdirs.user_data_dir('ChemDataExtractor'), 'models/hf_bert_crf_tagger')
print(save_dir)

/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger


In [7]:
bertcrf_config.save_pretrained(save_dir)

In [8]:
bertcrf_tagger = BertCrfModel(bertcrf_config)

In [9]:
state_dict = {}
for k, v in cde_bertcrf_model_state_dict.items():
    if k.startswith("text_field_embedder.token_embedder_bert"):
        state_dict[k[40:]] = v
    else:
        state_dict[k] = v
# print(state_dict.keys())

In [10]:
bertcrf_tagger.load_state_dict(state_dict=state_dict)

<All keys matched successfully>

In [11]:
bertcrf_tagger.save_pretrained(save_dir)

In [3]:
from transformers import BertTokenizer
hf_tokenizer = BertTokenizer(vocab_file=find_data('models/scibert_cased_vocab-1.0.txt'), do_lower_case=False)
hf_tokenizer.save_pretrained(save_dir)

('/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger/tokenizer_config.json',
 '/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger/special_tokens_map.json',
 '/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger/vocab.txt',
 '/home/dh582/.local/share/ChemDataExtractor/models/hf_bert_crf_tagger/added_tokens.json')

In [14]:
from allennlp.data.token_indexers import PretrainedBertIndexer
allen_indexer = PretrainedBertIndexer(do_lowercase=False, use_starting_offsets=True, truncate_long_sequences=False, pretrained_model=find_data("models/scibert_cased_vocab-1.0.txt"))

In [20]:
from chemdataextractor.doc import Sentence
test_s = Sentence('2-(4-Chloro-2-fluoro-3-difluoromethylphenyl)-[1,3,2]-dioxaborinane 1H NMR (CDCl3):')

In [21]:
test_s.ner_tagged_tokens

input_ids:
 tensor([[  101,   957,   152, 30171, 30118,  1359,   578,   143,   957,   152,
         30171, 30118,  1359,   578, 19732, 30110,   578,   957,   152, 30171,
         30118,  1359,   578,  8225, 30110,   578,   957,   152, 30171, 30118,
          1359,   578, 14972,  8086, 21532, 13981,   551,   578,   268,   957,
           152, 30171, 30118,  1359,  1914,   578,   432,   783,  2321,  1923,
           647,   155, 30155,  6052,   143, 15918, 30141,   551,   864,   102]])


[('2', 'B-CM'),
 ('-', 'I-CM'),
 ('(', 'I-CM'),
 ('4', 'I-CM'),
 ('-', 'I-CM'),
 ('Chloro', 'I-CM'),
 ('-', 'I-CM'),
 ('2', 'I-CM'),
 ('-', 'I-CM'),
 ('fluoro', 'I-CM'),
 ('-', 'I-CM'),
 ('3', 'I-CM'),
 ('-', 'I-CM'),
 ('difluoromethylphenyl', 'I-CM'),
 (')', 'I-CM'),
 ('-', 'I-CM'),
 ('[', 'I-CM'),
 ('1,3,2', 'I-CM'),
 (']', 'I-CM'),
 ('-', 'I-CM'),
 ('dioxaborinane', 'I-CM'),
 ('1H', 'O'),
 ('NMR', 'O'),
 ('(', None),
 ('CDCl3', None),
 (')', None),
 (':', 'O')]