In [None]:
%%javascript
require.config({
  paths: {
      d3: 'https://cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: 'https://ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

In [None]:
import json
import logging
from IPython.core.display import display, HTML, Javascript
import os
import numpy as np
from collections import OrderedDict
from collections import defaultdict

import torch
import tokenization
from modeling import BertConfig, BertForNER
from run_bioner import convert_examples_to_features, InputExample, AICupProcessor

procr = AICupProcessor()
label_list = procr.get_labels()
logger = logging.getLogger('run_bioner')
logger.setLevel(logging.WARNING)
bert_config_file = "pretrained_model/bert_config_bioner.json"
bert_config = BertConfig.from_json_file(bert_config_file)
vocab_file = "pretrained_model/vocab.txt"
processor = AICupProcessor()
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=False)

max_seq_length = 400
device = 'cpu' # 'cuda' or 'cpu'
init_checkpoint = "model_step_2564.pt"

model = BertForNER(bert_config, len(label_list))
model_params_dict = model.state_dict()
pretrained_dict = torch.load(init_checkpoint, map_location='cpu')
model_params_dict.update(pretrained_dict)
model.load_state_dict(model_params_dict)
model.eval()
model.to(device)
print("Loaded model")

In [None]:
def show():
    vis_html = """
      <span style="user-select:none">
        Layer: <select id="layer"></select>
      </span>
      <div id='vis'></div> 
    """
    display(HTML(vis_html))
    vis_js = open("bertviz/head_view.js").read()
    params = {
        'attention': attention_data,
        'default_filter': "all"
    }
    display(Javascript('window.params = %s' % json.dumps(params)))
    display(Javascript(vis_js))

In [None]:
def get_model_predictions_restored(model, sentence, label_id_to_name, device='cpu'):

    sentence = sentence.split()
    sent_toks, _ = tokenizer.tokenize_with_map(sentence)
    max_seq_length = len(sent_toks) + 5
    input_example = InputExample(1, sentence, label=['O'] * len(sentence))

    tmp_feats = convert_examples_to_features([input_example, ], label_list, max_seq_length, tokenizer)
    tmp_input_ids = torch.tensor([f.input_ids for f in tmp_feats], dtype=torch.long).to(device)
    tmp_input_mask = torch.tensor([f.input_mask for f in tmp_feats], dtype=torch.long).to(device)
    tmp_segment_ids = torch.tensor([f.segment_ids for f in tmp_feats], dtype=torch.long).to(device)
    
    with torch.no_grad():
        logits = model(tmp_input_ids, tmp_segment_ids, tmp_input_mask)
    logits = logits[0]
    pred_labels_id = logits.argmax(dim=-1)
    
    if pred_labels_id.device != 'cpu':
        pred_labels_id = pred_labels_id.cpu()
    pred_labels_id = pred_labels_id.numpy().tolist()
    pred_labels = [label_id_to_name[l] for l in pred_labels_id]
    pred_labels = pred_labels[1:]
    sent_toks = input_example.text_a
    sent_toks_map = input_example.text_a_map

    restored_tags = []
    for i_prd, pred_ne_tag in enumerate(pred_labels):
        if i_prd >= len(sent_toks_map):
            break
        if (i_prd > 0) and (sent_toks_map[i_prd - 1] == sent_toks_map[i_prd]):
            continue
        restored_tags.append(pred_ne_tag)
    return list(zip(sent_toks, restored_tags))

def get_model_predictions(model, sentence, label_id_to_name, device='cpu'):
    sentence = sentence.split()
    sent_toks, _ = tokenizer.tokenize_with_map(sentence)
    max_seq_length = len(sent_toks) + 5
    input_example = InputExample(1, sentence, label=['O'] * len(sentence))
    tmp_feats = convert_examples_to_features([input_example, ], label_list, max_seq_length, tokenizer)
    tmp_input_ids = torch.tensor([f.input_ids for f in tmp_feats], dtype=torch.long).to(device)
    tmp_input_mask = torch.tensor([f.input_mask for f in tmp_feats], dtype=torch.long).to(device)
    tmp_segment_ids = torch.tensor([f.segment_ids for f in tmp_feats], dtype=torch.long).to(device)
    
    with torch.no_grad():
        logits = model(tmp_input_ids, tmp_segment_ids, tmp_input_mask)
    logits = logits[0]
    pred_labels_id = logits.argmax(dim=-1)
    
    if pred_labels_id.device != 'cpu':
        pred_labels_id = pred_labels_id.cpu()
    pred_labels_id = pred_labels_id.numpy().tolist()
    pred_labels = [label_id_to_name[l] for l in pred_labels_id]
    pred_labels = pred_labels[1:]
    return list(zip(sent_toks, pred_labels))

def get_attention(attention_probs, max_len):

    """Compute representation of attention to pass to the d3 visualization

    Returns:
      Dictionary of attn representations with the structure:
      {
        'all': Sentence attention
      }
      where each value is a dictionary:
      {
        'left_text': list of source tokens, to be displayed on the left of the vis
        'right_text': list of target tokens, to be displayed on the right of the vis
        'attn': list of attention matrices, one for each layer. Each has shape [num_heads, source_seq_len, target_seq_len]
      }
    """

    # Populate map with attn data
    attn_dict = defaultdict(list)
    attn_dict['all'].append(attention_probs[:, 1:max_len, 1:max_len].tolist())

    results = {
        'all': {
            'attn': attn_dict['all'],
        }
    }

    return results

In [None]:
sentence = "FHL2 interacts with EGFR to promote glioblastoma growth."

with torch.no_grad():
    predictions =  get_model_predictions(model, sentence, label_list, device)
    attention_probs = model.bert.encoder.layer[-1].attention.self.attention_probs[0]
    # 12 heads, max_len, max_len
    max_len = len(predictions) + 1
    attention_data = get_attention(attention_probs, max_len)
    attention_data['all']['left_text'] = list([p[0] for p in predictions])
    attention_data['all']['right_text'] = list([p[1] for p in predictions])

show()


In [None]:
# sentence = "FHL2 interacts with EGFR to promote glioblastoma growth."
# sentence = "Stress - induced activation of Mst1 in cardiomyocytes promoted accumulation of p62 and aggresome formation, accompanied by the disappearance of autophagosomes ."
# sentence = "Unexpectedly, mixed-lineage kinase domain-like (MLKL) is also required for the induction of aerobic respiration, and we further show that it is required for RIP3 translocation to meet mitochondria - localized PDC ."
# sentence = "IGF - I strongly induced migration of the four cell lines through IGF - IR."
