# BioMedBERT BigQuery Data Analysis/ Pre-training

In [28]:
# imports
import os
import json
import numpy as np
import pandas as pd
import textwrap
import tensorflow as tf
from google.cloud import bigquery

In [29]:
project_id = 'ai-vs-covid19'
client = bigquery.Client(project=project_id)

## Query Analysis

In [3]:
# Get number of rows
row_count = client.query('''
  SELECT 
    COUNT(*) as total
  FROM `ai-vs-covid19.BigBioMedBERT2.ncbi_comm_use`''').to_dataframe().total
row_count

0    1526206
Name: total, dtype: int64

In [5]:
# get column names
col_names = client.query('''
  SELECT column_name
  FROM `ai-vs-covid19.BigBioMedBERT2`.INFORMATION_SCHEMA.COLUMNS
  WHERE table_name = 'ncbi_comm_use'
''').to_dataframe()
col_names

Unnamed: 0,column_name
0,Refs
1,Body
2,Front
3,Meta
4,Filename


In [6]:
# get general table information schema
table_schema = client.query('''
  SELECT *
  FROM `ai-vs-covid19.BigBioMedBERT2`.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS
  WHERE table_name = 'ncbi_comm_use'
''').to_dataframe()
table_schema

Unnamed: 0,table_catalog,table_schema,table_name,column_name,field_path,data_type,description
0,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Refs,Refs,STRING,
1,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Body,Body,STRING,
2,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Front,Front,STRING,
3,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Meta,Meta,STRING,
4,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Filename,Filename,STRING,


In [32]:
# select first 10 rows
first_10_rows = client.query('''
  SELECT *
  FROM `ai-vs-covid19.BigBioMedBERT2.ncbi_comm_use`
  LIMIT 10
''').to_dataframe()
first_10_rows

Unnamed: 0,Refs,Body,Front,Meta,Filename
0,,Noorizan Abd. Aziz\nMalaysia\nKwame Adjei\nGha...,Infect Dis PovertyInfect Dis PovertyInfectious...,,comm_use.I-N.txt.tar.gz-unpacked/Infect_Dis_Po...
1,,Isabella Aboderin\nKenya\nRabiul Ahasan\nMalay...,Infect Dis PovertyInfect Dis PovertyInfectious...,,comm_use.I-N.txt.tar.gz-unpacked/Infect_Dis_Po...
2,,Hanin Abdel-Haq\nItaly\nSuneth Agampodi\nSri L...,Infect Dis PovertyInfect Dis PovertyInfectious...,,comm_use.I-N.txt.tar.gz-unpacked/Infect_Dis_Po...
3,,Dear Editor\nWe appreciate the comments. Curre...,Int J Surg Case RepInt J Surg Case RepInternat...,,comm_use.I-N.txt.tar.gz-unpacked/Int_J_Surg_Ca...
4,,Daijiro Abe\nJapan\nFerihan Ahmed-Popova\nBulg...,J Physiol AnthropolJ Physiol AnthropolJournal ...,,comm_use.I-N.txt.tar.gz-unpacked/J_Physiol_Ant...
5,,Chris Abbiss\nAustralia\nDaijiro Abe\nJapan\nV...,J Physiol AnthropolJ Physiol AnthropolJournal ...,,comm_use.I-N.txt.tar.gz-unpacked/J_Physiol_Ant...
6,,Tatsuro Amano\nJapan\nSumiko Anno\nJapan\nKiyo...,J Physiol AnthropolJ Physiol AnthropolJournal ...,,comm_use.I-N.txt.tar.gz-unpacked/J_Physiol_Ant...
7,,Daijiro Abe\nJapan\nYasuyo Abe\nJapan\nVicent ...,J Physiol AnthropolJ Physiol AnthropolJournal ...,,comm_use.I-N.txt.tar.gz-unpacked/J_Physiol_Ant...
8,,"CITATION\nTreacy RB, McBryde CW, Shears E, Pyn...",Indian J OrthopIJOrthoIndian Journal of Orthop...,,comm_use.I-N.txt.tar.gz-unpacked/Indian_J_Orth...
9,,"Prof. A. K. Gupta, M.S, F.R.C.S, M.Ch.Orth 191...",Indian J OrthopIJOIndian Journal of Orthopaedi...,,comm_use.I-N.txt.tar.gz-unpacked/Indian_J_Orth...


In [25]:
# print body text
def print_body(body_series: pd.Series) -> str:
  print(textwrap.fill(body_series))

In [39]:
print_body(first_10_rows.Body[3])

Dear Editor We appreciate the comments. Currently the erection in our
patient is complete without restrictions. However, we are aware of the
possibility of problems with growth. Hence, we have informed the
parents regarding these possible problems in the future. Intra-
operatively, we considered using scrotal flaps for the coverage of
penile shaft. However. The diameter of the penile shaft was so small
and the scrotal flap was thought to be too thick and hence would hide
the definition of the shaft at this age. The excellent myo-cutaneous
flap described by the authors still remains an option if needed in the
future. We promise the readers to write a letter to the Editor in the
future to document if such problems would arise. Conflict of interest
None. Funding None. Ethical approval The original study was approved
by the research committee, National Hospital (Care), Riyadh, Saudi
Arabia. The current reply does not require this. Consent Not
applicable Author contribution Reply is from th

## Data Preprocessing

In [2]:
!pip install sentencepiece
# !git clone https://github.com/google-research/bert



In [3]:
# base imports
import os
import sys
import nltk
import sentencepiece as spm

In [32]:
#import bert modules
sys.path.append("bert")
from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder

### Preprocess text
Remove punсtuation, uppercase letters and non-utf symbols.

In [6]:
regex_tokenizer = nltk.RegexpTokenizer("\w+")

def normalize_text(text):
  # lowercase text
  text = str(text).lower()
  # remove non-UTF
  text = text.encode("utf-8", "ignore").decode()
  # remove punctuation symbols
  text = " ".join(regex_tokenizer.tokenize(text))
  return text

In [28]:
# example to normalize text from ``Body`` in BioMedBERT dataset
print_body(normalize_text(first_10_rows.Body[1]))

## Create Expanded csv dataset

In [33]:
import os
import glob
import pandas as pd
from google.cloud import storage

In [10]:
storage_client = storage.Client(project=project_id)

In [11]:
bucket=storage_client.get_bucket('big_bio_med_bert_dump_csv')
# List all objects that satisfy the filter.
blobs=bucket.list_blobs(prefix='ncbi_comm_use')

In [12]:
blob = [blob for blob in blobs]

In [13]:
print(len(blob))
print(len(blob) // 5)
print((len(blob) // 5)*5)

250
50
250


In [14]:
# Create a function called "chunks" with two arguments, l and n:
def split_list(data, chunk):
    # For item i in a range that is a length of data (l),
    for i in range(0, len(data), chunk):
        # Create an index range for data of chunk (e.g. 5) items:
        yield data[i:i+chunk]

In [15]:
# list of length in which we have to split 
blob_split_1, blob_split_2, blob_split_3, blob_split_4, blob_split_5 = list(split_list(blob, 50))

In [16]:
print(len(blob))
print(len(blob_split_1))
print(len(blob_split_2))
print(len(blob_split_3))
print(len(blob_split_4))
print(len(blob_split_5))

250
50
50
50
50
50


In [17]:
blob_split_1[0:5]

[<Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000000.csv, 1585586809935848>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000001.csv, 1585586811418366>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000002.csv, 1585586808669575>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000003.csv, 1585586810202260>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000004.csv, 1585586810903214>]

In [18]:
blob_split_2[0:5]

[<Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000050.csv, 1585586816240877>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000051.csv, 1585586814357784>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000052.csv, 1585586814721827>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000053.csv, 1585586817661717>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000054.csv, 1585586811616956>]

In [19]:
def download_to_local(folder, blob_lst):
    print('File download Started…. Wait for the job to complete.')
    # Create this folder locally if not exists
    if not os.path.exists(folder):
        os.makedirs(folder)
    # Iterating through for loop one by one using API call
    for blob in blob_lst:
#         print('Blobs: {}'.format(blob.name))
        destination_uri = '{}/{}'.format(folder, (blob.name).split('/')[-1])
        blob.download_to_filename(destination_uri)
        print('Exported {} to {}'.format(blob.name, destination_uri))

In [1]:
# !rm -rf data #data_1 data_2 data_3 data_4 #data_5
# !rm ncbi_comm_use_csv_A.csv

In [38]:
# download first part of csv's
# download_to_local('data', blob)
# download_to_local('data_1', blob_split_1)
# download_to_local('data_2', blob_split_2)
# download_to_local('data_3', blob_split_3)
# download_to_local('data_4', blob_split_4)
# download_to_local('data_5', blob_split_5)

In [39]:
# make combined csv
def combined_csv(data_folder):
    extension = 'csv'
    all_filenames = [i for i in glob.glob('{}/*.{}'.format(data_folder, extension))]
    #combine all files in the list
    combined_csv = pd.concat([pd.read_csv(f) for f in all_filenames ])
    return combined_csv

In [40]:
# blob_csv_A = combined_csv('data')
# blob_csv = combined_csv('data')

In [21]:
# len(blob_csv)

In [19]:
# blob_csv.to_csv('gs://ekaba-assets/ncbi_comm_use.csv')

In [43]:
# blob_csv_A.to_csv( "ncbi_comm_use_csv_A.csv", index=False, encoding='utf-8-sig')
# blob_csv.to_csv( "ncbi_comm_use.csv", index=False, encoding='utf-8-sig')

In [6]:
# copy files from gcs bucket
# !gsutil -m cp gs://ekaba-assets/ncbi_comm_use_BODY.csv .

Copying gs://ekaba-assets/ncbi_comm_use_BODY.csv...
- [1/1 files][ 42.8 GiB/ 42.8 GiB] 100% Done  96.5 MiB/s ETA 00:00:00           
Operation completed over 1 objects/42.8 GiB.                                     


In [None]:
# body = pd.read_csv('ncbi_comm_use.csv')

In [27]:
# body_sel = body[['Body']]

In [29]:
# body_sel.to_csv('gs://ekaba-assets/ncbi_comm_use_BODY.csv')

In [7]:
# body_sel.to_csv( "ncbi_comm_use_BODY.csv", index=False, encoding='utf-8-sig')

In [9]:
# remove FULL ncbi_comm_use.csv
# !rm ncbi_comm_use.csv

In [10]:
# convert csv to txt
import csv
import sys
maxInt = sys.maxsize
csv.field_size_limit(maxInt)

csv_file = 'ncbi_comm_use_BODY.csv'
txt_file = 'ncbi_comm_use_BODY.txt'
with open(txt_file, "w") as my_output_file:
    with open(csv_file, "r") as my_input_file:
        [ my_output_file.write(" ".join(row)+'\n') for row in csv.reader(my_input_file)]
    my_output_file.close()

In [12]:
# move text file to GCS
# !gsutil -m cp ncbi_comm_use_BODY.txt gs://ekaba-assets/

In [13]:
# remove csv file
# !rm ncbi_comm_use_BODY.csv

In [14]:
from tensorflow.keras.utils import Progbar
def count_lines(filename):
  count = 0
  with open(filename) as fi:
    for line in fi:
      count += 1
  return count

In [30]:
# Apply normalization to entire dataset
RAW_DATA_FPATH = "ncbi_comm_use_BODY.txt"
PRC_DATA_FPATH = "processed_ncbi_comm_use_BODY.txt"

# apply normalization to the dataset

total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)

with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
  with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
    for l in fi:
      fo.write(normalize_text(l)+"\n")
      bar.add(1)



In [32]:
# move processed text file to GCS
# !gsutil -m cp processed_ncbi_comm_use_BODY.txt gs://ekaba-assets/

In [34]:
# remove intermediate files
# !rm ncbi_comm_use_BODY.csv ncbi_comm_use_BODY.txt #processed_ncbi_comm_use_BODY.txt

## Building the vocabulary

In [11]:
PRC_DATA_FPATH = "processed_ncbi_comm_use_BODY.txt"
MODEL_PREFIX = "biomedbert" #@param {type: "string"}
VOC_SIZE = 32000 #@param {type:"integer"}
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}
NUM_PLACEHOLDERS = 256 #@param {type:"integer"}

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} --input_sentence_size={} '
               '--shuffle_input_sentence=true ' 
               '--bos_id=-1 --eos_id=-1').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)

In [None]:
spm.SentencePieceTrainer.Train(SPM_COMMAND)

In [1]:
!ls

bert  ekaba_biomedbert.ipynb  processed_ncbi_comm_use_BODY.txt


In [None]:
!head -n 30 tokenizer.vocab

<unk>	0
▁the	-3.06425
▁of	-3.55653
▁and	-3.71063
▁in	-3.91177
▁to	-4.20013
▁a	-4.22434
s	-4.23072
▁	-4.25891
ed	-4.50333
▁0	-4.74595
▁for	-4.76638
▁with	-4.79447
ing	-4.82313
▁1	-4.87834
▁was	-5.08902
▁were	-5.12589
▁is	-5.12706
▁that	-5.16514
d	-5.18435
▁2	-5.26383
▁as	-5.27425
▁by	-5.29105
ly	-5.37886
▁be	-5.4833
▁3	-5.50984
▁on	-5.54573
▁are	-5.65021
▁from	-5.65672
▁5	-5.66851


In [None]:
def read_sentencepiece_vocab(filepath):
  voc = []
  with open(filepath, encoding='utf-8') as fi:
    for line in fi:
      voc.append(line.split("\t")[0])
  # skip the first <unk> token
  voc = voc[1:]
  return voc

In [None]:
import random

snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

Learnt vocab size: 31743
Sample tokens: ['▁burns', '▁buds', 'kis', 'imoto', '▁consent', 'm', 'cgcagcacc', 'oa', '▁unf', '▁responsibility']


In [None]:
def parse_sentencepiece_token(token):
    if token.startswith("▁"):
        return token[1:]
    else:
        return "##" + token

In [None]:
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))

In [None]:
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab

In [None]:
bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))

32000


In [13]:
# write vocabulary to file
VOC_FNAME = "vocab.txt"

with open(VOC_FNAME, "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

In [None]:
bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
bert_tokenizer.tokenize(first_10_rows.Body[0])[0:20]




['background',
 'screening',
 'for',
 'increased',
 'waist',
 'circumference',
 'and',
 'hyper',
 '##triglyceride',
 '##mia',
 '[UNK]',
 'the',
 'hyper',
 '##triglyceride',
 '##mic',
 '[UNK]',
 'waist',
 'phenotype',
 '[UNK]',
 'is']

## Generating pre-trained data

In [None]:
# sharding the dataset
!mkdir ./shards
!split -a 4 -l 5560 -d $PRC_DATA_FPATH ./shards/shard_
!ls ./shards/

shard_0000  shard_0002	shard_0004  shard_0006
shard_0001  shard_0003	shard_0005


In [None]:
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}
PROCESSES = 2 #@param {type:"integer"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}

For each shard we need to call `create_pretraining_data.py` script

In [None]:
XARGS_CMD = ("ls ./shards/ | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=./shards/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', 
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)

In [None]:
tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD

In [7]:
# !gsutil -m cp -r gs://ekaba-assets/pre_trained_data .

Save model assets and checkpoints to GCS

In [9]:
BUCKET_NAME = "ekaba-assets"
MODEL_DIR = "bert_model"
tf.io.gfile.mkdir(MODEL_DIR)

Hyparameter configuration for BERT BASE

In [15]:
VOC_SIZE = 32000
VOC_FNAME = "biomedbert-8M.txt"

In [16]:
# use this for BERT-base

bert_base_config = {
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": VOC_SIZE
}

with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo:
    json.dump(bert_base_config, fo, indent=2)
  
# with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo:
#   for token in bert_vocab:
#     fo.write(token+"\n")

In [17]:
!gsutil -m cp -r $MODEL_DIR gs://$BUCKET_NAME

Copying file://bert_model/bert_config.json [Content-Type=application/json]...
Copying file://bert_model/biomedbert-8M.txt [Content-Type=text/plain]...
Copying file://bert_model/.ipynb_checkpoints/bert_config-checkpoint.json [Content-Type=application/json]...
Copying file://bert_model/.ipynb_checkpoints/biomedbert-8M-checkpoint.txt [Content-Type=text/plain]...
/ [4/4 files][505.2 KiB/505.2 KiB] 100% Done                                    
Operation completed over 4 objects/505.2 KiB.                                    


In [None]:
# if BUCKET_NAME:
#   !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME

# Train the BioMedBERT model

In [1]:
import os
import sys
import tensorflow as tf

In [2]:
#import bert modules
sys.path.append("bert")
from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder

In [3]:
import logging
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

In [4]:
BUCKET_NAME = "ekaba-assets" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pre_trained_data" #@param {type:"string"}
VOC_FNAME = "biomedbert-8M.txt" #@param {type:"string"}

# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param

# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}
NUM_TPU_CORES = 8

if BUCKET_NAME:
    BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
    BUCKET_PATH = "."

BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)

VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")

INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)

bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.io.gfile.glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))

log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))

INFO:tensorflow:Using checkpoint: None
INFO:tensorflow:Using 10000 data shards


**Train on TPUs**

In [5]:
if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False



In [None]:
model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=INIT_CHECKPOINT,
      learning_rate=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=10,
      use_tpu=USE_TPU,
      use_one_hot_embeddings=True)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=BERT_GCS_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)
  
train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

INFO:tensorflow:Using config: {'_model_dir': 'gs://ebisong-covid-19-temp/bert_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 2500, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.105.86.2:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f90a6fbe6d8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.105.86.2:8470', '_evaluation_master': 'grpc://10.105.86.2:8470', '_is_chief': True, '_num_ps_repli

In [None]:
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)

INFO:tensorflow:Querying Tensorflow master (grpc://10.105.86.2:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 4843677424759378058)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 5239469049172511286)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 16512040750734129682)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 15874515171332526121)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 482179580865349395)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worke