In [2]:
import torch
from src.KEMCE.knowledge_bert import DescTokenizer, EntityTokenizer, SeqsTokenizer, BertConfig
from pytorch_pretrained_bert import BertModel
import pickle
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
import numpy as np
logging.basicConfig(level=logging.INFO)

In [5]:
ent_seq = '[UNK] D_5641 D_5363 [UNK] D_56782 D_34690 D_53642'
code2desc_file = 'outputs/kemce/KG/code2desc.pickle'
ent_embd_file = 'outputs/kemce/KG/embeddings/CCS_TransR_entity.npy'
ent_vocab_file = 'outputs/kemce/KG/entity2id'
seqs_vocab_file = 'outputs/kemce/data/raw/mimic_vocab.txt'
seqs_file = 'outputs/kemce/data/raw/mimic.seqs'
ent_file = 'outputs/kemce/data/raw/mimic.entity'
config_json = 'src/KEMCE/kemce_config.json'

In [6]:
seqs= pickle.load(open(seqs_file, 'rb'))
ents= pickle.load(open(ent_file, 'rb'))

In [7]:
vist = seqs[0]
ent = ents[0]

In [8]:
vist

'[CLS] D_60000 D_4241 D_3899 D_4111 D_2724 D_V4582 D_4019 D_41401 [SEP] D_78039 D_2720 D_V4581 D_V1582 D_4241 D_V4579 D_2724 D_2252 D_3485 D_4019 [SEP]'

In [9]:
seqs_tokenizer = SeqsTokenizer(seqs_vocab_file)
ent_tokenize = EntityTokenizer(ent_vocab_file)
desc_tokenize = DescTokenizer(code2desc_file)

INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/xpeng/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


In [10]:
seq_tokens, seq_input = seqs_tokenizer.tokenize(seqs[0])
ent_tokens, ent_input = ent_tokenize.tokenize(ents[0])
desc_tokens, desc_input = desc_tokenize.tokenize(ents[0])

In [19]:
masked_index = 8
seq_tokens[masked_index] = '[MASK]'

In [22]:
seq_input = seqs_tokenizer.convert_tokens_to_ids(seq_tokens)

In [23]:
seq_input 

[2,
 10,
 543,
 99,
 41,
 62,
 14,
 16,
 4,
 9,
 457,
 65,
 10,
 26,
 616,
 217,
 62,
 45,
 24,
 16,
 3]

In [8]:
ent_mask = []
for ent in ent_tokens:
    if ent != "[UNK]":
        ent_mask.append(1)
    else:
        ent_mask.append(0)
ent_mask[0] = 1
input_mask = [1] * len(seq_tokens)

In [9]:
type_mask = np.zeros(len(seq_tokens))
index = 0
for i, token in enumerate(seq_tokens):
    if token.startswith('[SEP'):
        index = i
        break
type_mask[index+1:] = 1

In [33]:
from src.ERNIE.code.knowledge_bert import BertForMaskedLM

# config = BertConfig.from_json_file(config_json)
model_bert = BertModel.from_pretrained('bert-base-uncased')

INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /home/xpeng/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
INFO:pytorch_pretrained_bert.modeling:extracting archive file /home/xpeng/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /home/xpeng/research/projects/medicalAI_torch/src/KG_medical/tmpi7eu088a
INFO:pytorch_pretrained_bert.modeling:Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}



In [34]:
model_bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Lin

In [24]:
seq_input_tensor = torch.tensor([seq_input])
ent_input_tensor = torch.tensor([ent_input])
desc_input_tensor = torch.tensor([desc_input])
ent_mask_tensor = torch.tensor([ent_mask])
input_mask_tensor = torch.tensor([input_mask])
type_mask_tensor = torch.tensor([type_mask]).long()

In [15]:
kemce = KemceModel(config, model_bert.embeddings.word_embeddings.weight, ent_embd_file)

In [25]:
prediction = kemce(seq_input_tensor, type_mask_tensor, ent_input_tensor, desc_input_tensor, ent_mask_tensor, input_mask_tensor)

torch.Size([1, 21, 100]) torch.Size([1, 21, 1])


In [26]:
prediction.shape

torch.Size([1, 21, 4890])

In [29]:
predicted_index = torch.argmax(prediction[0, masked_index]).item()
predicted_token = seqs_tokenizer.convert_ids_to_tokens([predicted_index])[0]

In [42]:
prediction.shape

torch.Size([1, 21, 4890])

In [41]:
prediction.view(-1).shape

torch.Size([102690])

In [None]:
ent_embd = np.load(ent_embd_file)

In [None]:
ent_embd = torch.tensor(ent_embd)
pad_embed = torch.zeros(1,ent_embd.shape[1])
ent_embd = torch.cat([pad_embed, ent_embd])

In [None]:
ent_embd[0]

In [None]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('ernie_base')

In [None]:
code_01195 = 'Pulmonary tuberculosis, unspecified, tubercle bacilli not found by bacteriological examination, but tuberculosis confirmed histologically'
code_tokens = tokenizer.tokenize(code_01195)

In [None]:
'Surgical operation with anastomosis, bypass, or graft, with natural or artificial tissues used as implant causing abnormal patient reaction, or later complication, without mention of misadventure at time of operation'

In [None]:
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(code_tokens)

In [38]:
# Load pre-trained model (weights)
from src.ERNIE.code.knowledge_bert import BertForMaskedLM, BertModel

model, _ = BertModel.from_pretrained('ernie_base')
model.eval()

INFO:src.ERNIE.code.knowledge_bert.modeling:loading archive file src/ERNIE/data/ernie_base.tar.gz
INFO:src.ERNIE.code.knowledge_bert.modeling:extracting archive file src/ERNIE/data/ernie_base.tar.gz to temp dir /home/xpeng/research/projects/medicalAI_torch/src/KG_medical/tmpajl_1idh
INFO:src.ERNIE.code.knowledge_bert.modeling:Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_types": [
    "sim",
    "sim",
    "sim",
    "sim",
    "sim",
    "mix",
    "norm",
    "norm",
    "norm",
    "norm",
    "norm",
    "norm"
  ],
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}



BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer_simple(
        (attention): BertAttention_simple(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate_simple(
          (dense

In [None]:
pre_trained_embed = model.bert.embeddings.word_embeddings.weight

In [None]:
kernel_num = 128
kernel_sizes = [3,4,5]
size_output = 100
dropout = 0.5

In [None]:
cnn = CNN_Text(pre_trained_embed, kernel_num, kernel_sizes, size_output, dropout)

In [None]:
cnn

In [None]:
code_tokens_tensor = torch.tensor([indexed_tokens])

In [None]:
outs = cnn(code_tokens_tensor)

In [48]:
import torch.nn as nn
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

In [49]:
output,target, input

(tensor(2.4676, grad_fn=<NllLossBackward>),
 tensor([4, 1, 4]),
 tensor([[ 1.0092,  0.4476,  0.0469,  1.5486, -0.4610],
         [-1.3207, -1.1270, -0.6534, -0.0093,  0.4398],
         [-0.0490,  1.6708,  0.4965, -0.4875,  0.1177]], requires_grad=True))