In [0]:
# @title Preparation
!pip install -q keras-bert
!wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
!unzip -o uncased_L-12_H-768_A-12.zip

Archive:  uncased_L-12_H-768_A-12.zip
  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.meta  
  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001  
  inflating: uncased_L-12_H-768_A-12/vocab.txt  
  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.index  
  inflating: uncased_L-12_H-768_A-12/bert_config.json  


In [0]:
# @title Environment
import os

pretrained_path = 'uncased_L-12_H-768_A-12'
config_path = os.path.join(pretrained_path, 'bert_config.json')
checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt')
vocab_path = os.path.join(pretrained_path, 'vocab.txt')

# TF_KERAS must be added to environment variables in order to use TPU
os.environ['TF_KERAS'] = '1'

In [0]:
# @title Initialize TPU Strategy

import tensorflow as tf
from keras_bert import get_custom_objects

TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)

In [0]:
# @title Load Basic Model
import codecs
from keras_bert import load_trained_model_from_checkpoint

token_dict = {}
with codecs.open(vocab_path, 'r', 'utf8') as reader:
    for line in reader:
        token = line.strip()
        token_dict[token] = len(token_dict)

with strategy.scope():
    model = load_trained_model_from_checkpoint(config_path, checkpoint_path)

In [0]:
# @title Extraction
import numpy as np
from keras_bert import Tokenizer

tokenizer = Tokenizer(token_dict)
text = 'From that day forth... my arm changed... and a voice echoed'
tokens = tokenizer.tokenize(text)
indices, segments = tokenizer.encode(first=text, max_len=512)
print(tokens)

predicts = model.predict([np.array([indices] * 8), np.array([segments] * 8)])[0]
    
for i, token in enumerate(tokens):
    print(token, predicts[i].tolist()[:19])

['[CLS]', 'from', 'that', 'day', 'forth', '.', '.', '.', 'my', 'arm', 'changed', '.', '.', '.', 'and', 'a', 'voice', 'echoed', '[SEP]']
[CLS] [0.24250675737857819, 0.04605229198932648, -0.24484458565711975, -0.5553151369094849, -0.16091349720954895, -0.046603765338659286, 0.2648216784000397, 0.4632352590560913, -0.15888850390911102, -0.5054463148117065, 0.1872796267271042, -0.0844820961356163, -0.10551590472459793, 1.1445354223251343, 0.35309019684791565, 0.4587448537349701, 0.22255975008010864, 0.23128646612167358, 0.17905741930007935]
from [0.2858668565750122, 0.12927496433258057, 0.08937370777130127, -0.06506256759166718, -0.18307062983512878, 0.4357893466949463, -0.4666714668273926, 1.1149680614471436, 0.26170825958251953, -1.0477269887924194, -0.7197380661964417, -0.30874621868133545, 0.3589649498462677, 0.43190720677375793, 0.510287880897522, 0.4445205330848694, 0.6695327162742615, 0.11726009100675583, 0.34817394614219666]
that [-0.7514970302581787, 0.14548861980438232, 0.3224588