In [33]:
import torch
from src.ERNIE.code.knowledge_bert import BertTokenizer, BertModel, BertForMaskedLM
# Use TAGME
import tagme
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging

logging.basicConfig(level=logging.INFO)

In [2]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('ernie_base')
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text_a = "Who was Jim Henson ? "
text_b = "Jim Henson was a puppeteer ."

INFO:src.ERNIE.code.knowledge_bert.tokenization:loading vocabulary file src/ERNIE/data/ernie_base/vocab.txt


In [3]:
# Set the authorization token for subsequent calls.
tagme.GCUBE_TOKEN = "24f422ad-2187-45cc-b77e-8fbe4c1ce78c-843339462"
text_a_ann = tagme.annotate(text_a)
text_b_ann = tagme.annotate(text_b)

In [4]:
code = 'Tuberculous pneumonia [any form], tubercle bacilli not found by bacteriological or histological examination, but tuberculosis confirmed by other methods [inoculation of animals]'
code_ann = tagme.annotate(code)

In [5]:
for a in code_ann.get_annotations(0.1):
    print(a)

Tuberculous pneumonia -> Tuberculosis (score: 0.6046293377876282)
form -> Morphology (biology) (score: 0.21826446056365967)
tubercle -> Tubercle (score: 0.4407349228858948)
bacilli -> Bacillus (shape) (score: 0.3832319974899292)
bacteriological -> Bacteriology (score: 0.3213954567909241)
histological -> Histology (score: 0.4029560685157776)
tuberculosis -> Tuberculosis (score: 0.5171748399734497)
confirmed -> Confirmation (score: 0.16835619509220123)
methods -> Methodology (score: 0.1495819091796875)
inoculation -> Inoculation (score: 0.17543242871761322)


In [6]:
# Read entity map
ent_map = {}
with open("src/ERNIE/data/kg_embed/entity_map.txt") as fin:
    for line in fin:
        name, qid = line.strip().split("\t")
        ent_map[name] = qid

In [7]:
def get_ents(ann):
    ents = []
    # Keep annotations with a score higher than 0.3
    for a in ann.get_annotations(0.1):
        if a.entity_title not in ent_map:
            continue
        ents.append([ent_map[a.entity_title], a.begin, a.end, a.score])
    return ents

In [8]:
ents_a = get_ents(text_a_ann)
ents_b = get_ents(text_b_ann)

In [9]:
ents_code = get_ents(code_ann)

In [10]:
ents_code

[['Q12204', 0, 21, 0.6046293377876282],
 ['Q183252', 27, 31, 0.21826446056365967],
 ['Q243748', 64, 79, 0.3213954567909241],
 ['Q7168', 83, 95, 0.4029560685157776],
 ['Q12204', 113, 125, 0.5171748399734497],
 ['Q188613', 126, 135, 0.16835619509220123],
 ['Q185698', 145, 152, 0.1495819091796875]]

In [11]:
# Tokenize
tokens_a, entities_a = tokenizer.tokenize(text_a, ents_a)
tokens_b, entities_b = tokenizer.tokenize(text_b, ents_b)

In [12]:
tokens_code, entities_code = tokenizer.tokenize(code, ents_code)

In [13]:
# tokens_code, entities_code

In [14]:
tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
ents = ["UNK"] + entities_a + ["UNK"] + entities_b + ["UNK"]
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
input_mask = [1] * len(tokens)

In [40]:
tokens

['[CLS]',
 'who',
 'was',
 'jim',
 'henson',
 '?',
 '[SEP]',
 'jim',
 '[MASK]',
 'was',
 'a',
 'puppet',
 '##eer',
 '.',
 '[SEP]']

In [15]:
# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 8
tokens[masked_index] = '[MASK]'

In [16]:
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokens)

In [17]:
# Convert ents
entity2id = {}
with open("src/ERNIE/data/kg_embed/entity2id.txt") as fin:
    fin.readline()
    for line in fin:
        qid, eid = line.strip().split('\t')
        entity2id[qid] = int(eid)

In [18]:
indexed_ents = []
ent_mask = []
for ent in ents:
    if ent != "UNK" and ent in entity2id:
        indexed_ents.append(entity2id[ent])
        ent_mask.append(1)
    else:
        indexed_ents.append(-1)
        ent_mask.append(0)
ent_mask[0] = 1

In [38]:
indexed_tokens, indexed_ents

([101,
  2040,
  2001,
  3958,
  27227,
  1029,
  102,
  3958,
  103,
  2001,
  1037,
  13997,
  11510,
  1012,
  102],
 [-1, 125869, -1, 145996, -1, -1, -1, 145996, -1, -1, -1, 77021, -1, -1, -1])

In [20]:
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
ents_tensor = torch.tensor([indexed_ents])
segments_tensors = torch.tensor([segments_ids])
ent_mask = torch.tensor([ent_mask])

In [37]:
ents_tensor.shape, tokens_tensor.shape

(torch.Size([1, 15, 100]), torch.Size([1, 15]))

In [39]:
ents_tensor

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0086,  0.0453, -0.0619,  ..., -0.0537,  0.0081, -0.0052],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       device='cuda:0')

In [22]:
# Load pre-trained model (weights)
model, _ = BertForMaskedLM.from_pretrained('ernie_base')
# model, _ = BertForMaskedLM.from_pretrained('bert-large-uncased')
model.eval()

INFO:src.ERNIE.code.knowledge_bert.modeling:loading archive file src/ERNIE/data/ernie_base.tar.gz
INFO:src.ERNIE.code.knowledge_bert.modeling:extracting archive file src/ERNIE/data/ernie_base.tar.gz to temp dir /home/xpeng/research/projects/medicalAI_torch/src/KG_medical/tmpma7gxgfs
INFO:src.ERNIE.code.knowledge_bert.modeling:Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_types": [
    "sim",
    "sim",
    "sim",
    "sim",
    "sim",
    "mix",
    "norm",
    "norm",
    "norm",
    "norm",
    "norm",
    "norm"
  ],
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

INFO:src.ERNIE.code.knowledge_bert.modeling:Weights from pretrained model not used in BertForMaskedLM: ['cls.predictions_ent.transform.dense.weight', 'cls.predictions_ent.transform.den

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer_simple(
          (attention): BertAttention_simple(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
    

In [32]:
word_embedding = model.bert.embeddings.word_embeddings.weight

Parameter containing:
tensor([[-0.0201, -0.0753, -0.0225,  ..., -0.0219, -0.0424, -0.0157],
        [-0.0083, -0.0599, -0.0288,  ..., -0.0109, -0.0359, -0.0184],
        [-0.0145, -0.0638, -0.0294,  ..., -0.0111, -0.0447, -0.0137],
        ...,
        [-0.0187, -0.0525, -0.0118,  ...,  0.0039, -0.0224, -0.0287],
        [-0.0478, -0.0471,  0.0090,  ...,  0.0142, -0.0075, -0.0089],
        [-0.0020, -0.0877, -0.0065,  ...,  0.0041, -0.0331,  0.0725]],
       device='cuda:0', requires_grad=True)

In [23]:
vecs = []
vecs.append([0]*100)
with open("src/ERNIE/data/kg_embed/entity2vec.vec", 'r') as fin:
    for line in fin:
        vec = line.strip().split('\t')
        vec = [float(x) for x in vec]
        vecs.append(vec)
embed = torch.FloatTensor(vecs)
embed = torch.nn.Embedding.from_pretrained(embed)

In [24]:
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
ents_tensor = embed(ents_tensor+1).to('cuda')
ent_mask = ent_mask.to("cuda")
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer_simple(
          (attention): BertAttention_simple(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
    

In [25]:
tokens_tensor,ent_mask,segments_tensors

(tensor([[  101,  2040,  2001,  3958, 27227,  1029,   102,  3958,   103,  2001,
           1037, 13997, 11510,  1012,   102]], device='cuda:0'),
 tensor([[1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]], device='cuda:0'),
 tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0'))

In [26]:
# Predict all tokens
with torch.no_grad():
    predictions = model(tokens_tensor, ents_tensor, ent_mask, segments_tensors)
    # confirm we were able to predict 'henson'
    predicted_index = torch.argmax(predictions[0, masked_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    assert predicted_token == 'henson'

In [27]:
sorted([[3,[12,12]], [2,[2,4]]])

[[2, [2, 4]], [3, [12, 12]]]

In [28]:
predicted_token

'henson'