In [1]:
from allennlp.models.archival import load_archive
from chemdataextractor.data import find_data
import json
import pprint

In [65]:
import torch
import torch.nn as nn
from transformers import AutoModel
from torch.utils.data import Dataset, DataLoader
from chemdataextractor.nlp.crf import ConditionalRandomField, allowed_transitions
from chemdataextractor.nlp.allennlp_modules import TimeDistributed
from typing import Dict, Optional
from overrides import overrides

class BertCrfTagger(nn.Module):
    def __init__(self, vocab,
                 model_name="allenai/scibert_scivocab_cased",
                 dropout=0.1,
                 label_namespace: str = "labels",
                 label_encoding: Optional[str] = None,
                 constrain_crf_decoding: bool = None,
                 include_start_end_transitions: bool = True):

        super(BertCrfTagger, self).__init__()
        self.vocab = vocab
        self.bert_model = AutoModel.from_pretrained(model_name)
        self.num_tags = vocab.get_vocab_size(label_namespace) 
        self.tag_projection_layer = TimeDistributed(
            nn.Linear(self.bert_model.config.hidden_size, self.num_tags)
        )
        
        self.label_encoding = label_encoding
        if constrain_crf_decoding:
            if not label_encoding:
                raise ConfigurationError("constrain_crf_decoding is True, but "
                                         "no label_encoding was specified.")
            labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
            constraints = allowed_transitions(label_encoding, labels)
        else:
            constraints = None

        self.include_start_end_transitions = include_start_end_transitions
        self.crf = ConditionalRandomField(
                self.num_tags, constraints,
                include_start_end_transitions=include_start_end_transitions
        )

        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask, labels=None):
        # BERT embeddings
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        
        # Project onto tag space
        logits = self.tag_projection_layer(embedded_text_input)
        best_paths = self.crf.viterbi_tags(logits, mask)

        predicted_tags = [x for x, y in best_paths]

        output = {"logits": logits, "mask": mask, "tags": predicted_tags}
        
        return output

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Converts the tag ids to the actual tags.
        ``output_dict["tags"]`` is a list of lists of tag_ids,
        so we use an ugly nested list comprehension.
        """
        output_dict["tags"] = [
                [self.vocab.get_token_from_index(tag, namespace=self.label_namespace)
                 for tag in instance_tags]
                for instance_tags in output_dict["tags"]
        ]
        return output_dict


In [40]:
overrides = {"model.text_field_embedder.token_embedders.bert.pretrained_model": find_data("models/scibert_cased_weights-1.0.tar.gz")}
cde_bert_archive = load_archive(find_data('models/bert_finetuned_crf_model-1.0a'), overrides=json.dumps(overrides))

In [56]:
cde_bertcrf_model = cde_bert_archive.model
pprint.pprint(cde_bertcrf_model)

_BertCrfTagger(
  (text_field_embedder): BasicTextFieldEmbedder(
    (token_embedder_bert): PretrainedBertEmbedder(
      (bert_model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(31116, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): BertLayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
          

In [50]:
pprint.pprint(cde_bert_model.vocab.save_to_files('.'))

None


In [57]:
cde_bertcrf_model_state_dict = cde_bertcrf_model.state_dict()

In [58]:
from transformers import AutoModel, AutoTokenizer
scibert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_cased")
scibert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_cased")
# scibert_model.load_state_dict(cde_bert_model_state_dict, strict=False)

Some weights of the model checkpoint at allenai/scibert_scivocab_cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [71]:
hf_tagger = BertCrfTagger(vocab=cde_bert_model.vocab, model_name="allenai/scibert_scivocab_cased", label_encoding="BIO", constrain_crf_decoding=True, include_start_end_transitions=False)

Some weights of the model checkpoint at allenai/scibert_scivocab_cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [72]:
state_dict = {}
for k, v in cde_bertcrf_model_state_dict.items():
    if k.startswith("text_field_embedder.token_embedder_bert"):
        state_dict[k[40:]] = v
    else:
        state_dict[k] = v
print(state_dict.keys())
        

dict_keys(['bert_model.embeddings.word_embeddings.weight', 'bert_model.embeddings.position_embeddings.weight', 'bert_model.embeddings.token_type_embeddings.weight', 'bert_model.embeddings.LayerNorm.weight', 'bert_model.embeddings.LayerNorm.bias', 'bert_model.encoder.layer.0.attention.self.query.weight', 'bert_model.encoder.layer.0.attention.self.query.bias', 'bert_model.encoder.layer.0.attention.self.key.weight', 'bert_model.encoder.layer.0.attention.self.key.bias', 'bert_model.encoder.layer.0.attention.self.value.weight', 'bert_model.encoder.layer.0.attention.self.value.bias', 'bert_model.encoder.layer.0.attention.output.dense.weight', 'bert_model.encoder.layer.0.attention.output.dense.bias', 'bert_model.encoder.layer.0.attention.output.LayerNorm.weight', 'bert_model.encoder.layer.0.attention.output.LayerNorm.bias', 'bert_model.encoder.layer.0.intermediate.dense.weight', 'bert_model.encoder.layer.0.intermediate.dense.bias', 'bert_model.encoder.layer.0.output.dense.weight', 'bert_model

In [73]:
hf_tagger.load_state_dict(state_dict=state_dict, strict=False)

_IncompatibleKeys(missing_keys=['bert_model.embeddings.position_ids'], unexpected_keys=[])