In [None]:
cd ..

In [None]:
import torch
import tensorflow as tf
from utils.visualize_bert import keras2torch
from models.bert_utils import load_bert
import numpy as np
from tokenizers import BertWordPieceTokenizer
from models.bert_utils import get_token_dict
from preprocessing.generate_data import seq2tokens, ALPHABET, seq2kmers

# Preparation

In [None]:
# keras bert model
m = load_bert('resources/bert_nc_C2_final.h5')

In [4]:
# converted 🤗 transformers model
tm = keras2torch(m)
_ = tm.eval()                   # toggle evaluation mode, no output

In [5]:
# tokenizers
tokend = get_token_dict()
t = BertWordPieceTokenizer('resources/vocab.txt')

In [6]:
# tokenizer evaluation (Archaea sequence)
test_string = ("CTGACGCACCCGGGTGCCATTCTTGCAAAAAGCCTTACAATTCCGCTGATTGATGCTACGCATTATGCAACTGAACTTCC"
               "GGGACTGTATCGTTTGCGAGATTTAATCGCTTCCTTTGGAGTCGAGTCAGCGGTATTTGATACTTCTGTTCCATGGAGAA"
               "TGAAAACCTATTATGAAAATTACTGAGTTAGAACAAAAAAAAGTACCGCACGGTGAAGTTGTCCTCATTGGTCTTGGCCG"
               "TCTTGGTCTGAGAACAGCCCTAAATCTCATGCATGTCAATCGGGGCGGACCAGTTCGGATAACTGTGTATGACGGACAAA"
               "AAATATCTGCCGATGATCTGATATTCCGCATGTGGGGTGGAGAAATTGGCGAATATAAAACAGATTTCCTCAAACGGCTT"
               "GCAGGCCCCGGATACAGCAGGGAAATAATATCAGTTCCAGAGTATATTTCTGAGGAAAATCTGTCTCTGATTACCGGAGG"
               "GGATGTTGCGTGTGTCGAGATTGCAGGCGGTGATACATTGCCTACTACCGGGGCTATTATCCGGCATGCCCAGTCTTTGG"
               "GCATGAAGACTATCAGTACGATGGGTGTATTTGGTATTTCCGGCGATAATGTTTATGCCGTTCCTCTGGAAGAAGCAAAT"
               "ACAGATAATCCAATTGTTGCCGCAATGCTTGAATACGGGATTTCCCATCATATGCTTGTCGGGACTGGAAAACTGATTCG"
               "TGACTGGGAACCTGTTACTCCGTATATCATGGATAAAATTGCAGAAGTGATGTCGTCAGAAATACTGCGTCTGACCGAGG"
               "GGAAATAATGCCGACGATATCGACTGCCGAATGCTTTACCCACGGAAAAGTTGCAAATGAGCTCCATGCATTTGCCCGCG"
               "GGTATCCGCATGAATATCTCTTTTCTATAGATAGGAAAAAAGTTGATATTTCCGTTGTGGCCGGGATGTTTATTCCAACA"
               "CTTACAGGTGTCAGAACTCTTCTGCATTTTGAGCCGCTGGAACCGCGGTTGGTTATAGACACGGTGAAAGTTTATGAACA"
               "GGATCAGGATTGTATTATGGCATGCCGGATGGCGGAGGCCGTTATGCGGGTGACCGGGGCAGATATTGGTATAGGAACTA"
               "CTGCAGGCATCGGGAAAGGCGCAGTGGCAATAGCCTCTCAGGATAAAATCTATTCCAAAGTCACAAGAATTGATGCAGAT"
               "TTCAGGACTTCAGATGCAAAAAAACTGATGCAGCGTGAAAAGTCAGGTGTTTTTACTGCACTGCGTTTGTTTGAGGAATT"
               "TTTGTTGGAGGGGGAGTTCCCCGATAGTTATAATAAATACATATAATTAGTAACACAAATTGCTATTAATATTAATATTA"
               "TAACTACATTAATCATATTGATTTTAACATATTTAGAAAGATTTATTACGAATATTATTAAATACACTATTGTTGTCACA"
               "TATTGATGGCAGTACAAACTGGAGATTACATACATGAAAGTAGCAATTTTAGGAGCAGGA")
((ins:=seq2tokens(test_string, tokend, 502))[0]
       == np.array(t.encode(' '.join(seq2kmers(test_string))).ids)
       )[:-1].all()
# ignore last position

True

check whether the model can correctly classify the sequence

In [None]:
list(zip(['Viruses', 'Archaea', 'Bacteria', 'Eukaryota'], *m([np.array([ins[0]]), np.array([ins[1]])])))

In [None]:
list(zip(['Viruses', 'Archaea', 'Bacteria', 'Eukaryota'], *m([np.array([ins[0][:150]]), np.array([ins[1][:150]])])))

In [None]:
# output attention data for external JS visualization script (-> doesn't work yet)
from bertviz.bertviz.util import format_attention
import json
def js_data(tmodel, in_ids):
    out = tmodel(input_ids=torch.tensor(np.array([in_ids]), dtype=torch.long), output_attentions=True)
    attn = format_attention(out[-1]).tolist()
    tokens = list(map({v: k for k, v in tokend.items()}.__getitem__, in_ids))
    return {'attn': attn, 'left_text': tokens, 'right_text': tokens}
json.dump(js_data(tm, ins[0]), open('test_vis.json', 'w'), indent=2)

In [7]:
tm_all = tm(input_ids=torch.tensor(np.array([ins[0]]), dtype=torch.long), output_attentions=True)

In [8]:
m_nsp = tf.keras.Model(inputs=m.input, outputs= [m.get_layer(name='NSP-Dense').output])(
    [np.array([ins[0]]), np.array([ins[1]])])

In [None]:
# does the converted model give similar results? tolerance of 0.01
# should be mostly Trues
np.isclose(m_nsp.numpy(), tm_all[1].detach().numpy(), atol=1e-2)

# Visualization
NOTE: only one view can be displayed at a time -> always clear output of other view cells

In [12]:
from bertviz.bertviz import head_view, model_view

For reasons, head＿view and model＿view require different versions (first, second respectively) of d3.
Also, when switching, the browser tab has to be closed and reopened as the loaded version is cached.

In [10]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [None]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',
    jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});



In [13]:
# first 10 tokens
ins_10 = np.concatenate((ins[0][:9], ins[0][-1:]))
tokens_10 = list(map({v: k for k, v in tokend.items()}.__getitem__, ins_10))
head_view(tm(input_ids=torch.tensor(np.array([ins_10]), dtype=torch.long), output_attentions=True)[-1],
          tokens_10, None)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# first 10 tokens
ins_10 = np.concatenate((ins[0][:9], ins[0][-1:]))
tokens_10 = list(map({v: k for k, v in tokend.items()}.__getitem__, ins_10))
model_view(tm(input_ids=torch.tensor(np.array([ins_10]), dtype=torch.long), output_attentions=True)[-1],
          tokens_10, None)

In [None]:
# NOTE: too big, browser is likely gonna crash!!
# first 150 tokens (already knows it's archaea)
ins_150 = np.concatenate((ins[0][:149], ins[0][-1:]))
tokens_150 = list(map({v: k for k, v in tokend.items()}.__getitem__, ins_150))
head_view(tm(input_ids=torch.tensor(np.array([ins_150]), dtype=torch.long), output_attentions=True)[-1],
          tokens_150, None)

In [None]:
tokens = list(map({v: k for k, v in tokend.items()}.__getitem__, ins[0]))

In [None]:
# NOTE: too big, browser is likely gonna crash!!
# all tokens
tokens = list(map({v: k for k, v in tokend.items()}.__getitem__, ins[0]))
head_view(tm(input_ids=torch.tensor(np.array([ins[0]]), dtype=torch.long), output_attentions=True)[-1],
          tokens, None)

In [None]:
# NOTE: too big, browser is likely gonna crash!!
# model view
model_view(tm(input_ids=torch.tensor(np.array([ins[0]]), dtype=torch.long), output_attentions=True)[-1],
           tokens, None)