# BioMedBERT BigQuery Data Analysis/ Pre-training

In [0]:
# imports
import os
import json
import numpy as np
import pandas as pd
import textwrap
import tensorflow as tf
from google.colab import auth
from google.cloud import bigquery

In [0]:
# authenticate user
auth.authenticate_user()
print('Authenticated')

Authenticated


In [0]:
project_id = 'ai-vs-covid19'
client = bigquery.Client(project=project_id)

## Query Analysis

In [0]:
# Get number of rows
row_count = client.query('''
  SELECT 
    COUNT(*) as total
  FROM `ai-vs-covid19.BigBioMedBERT2.ncbi_comm_use`''').to_dataframe().total
row_count

0    1526206
Name: total, dtype: int64

In [0]:
# get column names
col_names = client.query('''
  SELECT column_name
  FROM `ai-vs-covid19.BigBioMedBERT2`.INFORMATION_SCHEMA.COLUMNS
  WHERE table_name = 'ncbi_comm_use'
''').to_dataframe()
col_names

Unnamed: 0,column_name
0,Refs
1,Body
2,Front
3,Meta
4,Filename


In [0]:
# get general table information schema
table_schema = client.query('''
  SELECT *
  FROM `ai-vs-covid19.BigBioMedBERT2`.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS
  WHERE table_name = 'ncbi_comm_use'
''').to_dataframe()
table_schema

Unnamed: 0,table_catalog,table_schema,table_name,column_name,field_path,data_type,description
0,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Refs,Refs,STRING,
1,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Body,Body,STRING,
2,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Front,Front,STRING,
3,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Meta,Meta,STRING,
4,ai-vs-covid19,BigBioMedBERT2,ncbi_comm_use,Filename,Filename,STRING,


In [0]:
# select first 10 rows
first_10_rows = client.query('''
  SELECT *
  FROM `ai-vs-covid19.BigBioMedBERT2.ncbi_comm_use`
  LIMIT 10
''').to_dataframe()
first_10_rows

Unnamed: 0,Refs,Body,Front,Meta,Filename
0,,Background\nScreening for increased waist circ...,J Int AIDS SocJ Int AIDS SocJournal of the Int...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
1,,Introduction\nTesting for HIV tropism is recom...,J Int AIDS SocJ Int AIDS SocJournal of the Int...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
2,,Background\nThe goal of highly active antiretr...,J Int AIDS SocJ Int AIDS SocJournal of the Int...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
3,,Purpose\nAbacavir use has been associated with...,J Int AIDS SocJ Int AIDS SocJournal of the Int...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
4,,"Co-infections with HCV and HIV are common, bec...",J Int AIDS SocJ Int AIDS SocJournal of the Int...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
5,,Background\nGuidelines recommend starting trea...,J Int AIDS SocJ Int AIDS SocJournal of the Int...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
6,,"Table 1 Rates of death from suicide, according...",J Int AIDS SocJ Int AIDS SocJIASJournal of the...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
7,,We are living in an extraordinary moment in th...,J Int AIDS SocJ Int AIDS SocJIASJournal of the...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
8,,In the past 3-5 years cognitive impairment hav...,J Int AIDS SocJ Int AIDS SocJournal of the Int...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...
9,,Table 1 Number of patients with treatment emer...,J Int AIDS SocJ Int AIDS SocJIASJournal of the...,,comm_use.I-N.txt.tar.gz-unpacked/J_Int_AIDS_So...


In [0]:
# print body text
def print_body(body_series: pd.Series) -> str:
  print(textwrap.fill(body_series))

In [0]:
print_body(first_10_rows.Body[1])

Introduction Testing for HIV tropism is recommended before prescribing
a chemokine receptor blocker. To date, in most European countries HIV
tropism is determined using a phenotypic test. Recently, new data have
emerged supporting the use of a genotypic HIV V3-loop sequence
analysis as the basis for tropism determination. The European
guidelines group on clinical management of HIV-1 tropism testing was
established to make recommendations to clinicians and virologists.
Methods We searched online databases for articles from Jan 2006 until
March 2010 with the terms: tropism or CCR5-antagonist or CCR5
antagonist or maraviroc or vicriviroc. Additional articles and/or
conference abstracts were identified by hand searching. This strategy
identified 712 potential articles and 1240 abstracts. All were
reviewed and finally 57 papers and 42 abstracts were included and used
by the panel to reach a consensus statement. Results The panel
recommends HIV-tropism testing for the following indications: 

## Data Preprocessing

In [0]:
!pip install sentencepiece
!git clone https://github.com/google-research/bert

Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |▎                               | 10kB 19.1MB/s eta 0:00:01[K     |▋                               | 20kB 2.2MB/s eta 0:00:01[K     |█                               | 30kB 3.2MB/s eta 0:00:01[K     |█▎                              | 40kB 2.1MB/s eta 0:00:01[K     |█▋                              | 51kB 2.6MB/s eta 0:00:01[K     |██                              | 61kB 3.1MB/s eta 0:00:01[K     |██▏                             | 71kB 3.6MB/s eta 0:00:01[K     |██▌                             | 81kB 4.1MB/s eta 0:00:01[K     |██▉                             | 92kB 4.5MB/s eta 0:00:01[K     |███▏                            | 102kB 3.5MB/s eta 0:00:01[K     |███▌                            | 112kB 3.5MB/s eta 0:00:01[K     |███▉                     

In [0]:
# base imports
import os
import sys
import nltk
import sentencepiece as spm

In [0]:
#import bert modules
sys.path.append("bert")
from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder




### Preprocess text
Remove punсtuation, uppercase letters and non-utf symbols.

In [0]:
regex_tokenizer = nltk.RegexpTokenizer("\w+")

def normalize_text(text):
  # lowercase text
  text = str(text).lower()
  # remove non-UTF
  text = text.encode("utf-8", "ignore").decode()
  # remove punctuation symbols
  text = " ".join(regex_tokenizer.tokenize(text))
  return text

In [0]:
# example to normalize text from ``Body`` in BioMedBERT dataset
print_body(normalize_text(first_10_rows.Body[0]))

background screening for increased waist circumference and
hypertriglyceridemia the hypertriglyceridemic waist phenotype is an
inexpensive approach to identify patients at risk of coronary artery
disease in apparently healthy individuals who may be at increased risk
of type 2 diabetes and coronary heart disease because of an excess of
intra abdominal visceral fat we examined the relationship between the
hypertriglyceridemic waist and selected cardiometabolic risk factors
in hiv individuals methods the hw phenotype was defined as a waist
circumference of 90 cm or more and a triglyceride level of 2 0 mmol l
or more in men and a waist circumference of 85 cm or more and a
triglyceride level of 1 5 mmol l or more in women using these
threshold values a total of 2322 patients 841 women and 1481 men with
hiv aged 18 75 years were divided into 4 groups low tg low wc high tg
low wc low tg high wc high tg high wc continuous variables were
analyzed using anova or kruskal wallis test where appropr

In [0]:
# get all document blobs into textfile as a new line. --- FAIL -- Too Large
# body_blob = client.query('''
#   SELECT Body
#   FROM `ai-vs-covid19.BigBioMedBERT2.ncbi_comm_use`
# ''')

In [0]:
# download queried data from GCS
# ! wget https://storage.googleapis.com/ebisong-covid-19-temp/results-20200327-140409.csv

## Create Expanded csv dataset

In [0]:
import os
import glob
import pandas as pd
from google.cloud import storage

In [0]:
storage_client = storage.Client(project=project_id)

In [0]:
bucket=storage_client.get_bucket('big_bio_med_bert_dump_csv')
# List all objects that satisfy the filter.
blobs=bucket.list_blobs(prefix='ncbi_comm_use')

In [0]:
blob = [blob for blob in blobs]

In [0]:
print(len(blob))
print(len(blob) // 10)
print((len(blob) // 10)*10)

250
25
250


In [0]:
# Create a function called "chunks" with two arguments, l and n:
def split_list(data, chunk):
    # For item i in a range that is a length of data (l),
    for i in range(0, len(data), chunk):
        # Create an index range for data of chunk (e.g. 5) items:
        yield data[i:i+chunk]

In [0]:
# list of length in which we have to split 
blob_split_A, blob_split_B, blob_split_C, blob_split_D, blob_split_E = list(split_list(blob, 50))

In [0]:
print(len(blob))
print(len(blob_split_A))
print(len(blob_split_B))
print(len(blob_split_C))
print(len(blob_split_D))
print(len(blob_split_E))

250
50
50
50
50
50


In [0]:
blob_split_A[0:5]

[<Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000000.csv, 1585586809935848>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000001.csv, 1585586811418366>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000002.csv, 1585586808669575>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000003.csv, 1585586810202260>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000004.csv, 1585586810903214>]

In [0]:
blob_split_B[0:5]

[<Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000050.csv, 1585586816240877>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000051.csv, 1585586814357784>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000052.csv, 1585586814721827>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000053.csv, 1585586817661717>,
 <Blob: big_bio_med_bert_dump_csv, ncbi_comm_use/000000000054.csv, 1585586811616956>]

In [0]:
def download_to_local(folder, blob_lst):
    print('File download Started…. Wait for the job to complete.')
    # Create this folder locally if not exists
    if not os.path.exists(folder):
        os.makedirs(folder)
    # Iterating through for loop one by one using API call
    for blob in blob_lst:
        print('Blobs: {}'.format(blob.name))
        destination_uri = '{}/{}'.format(folder, (blob.name).split('/')[-1])
        blob.download_to_filename(destination_uri)
        print('Exported {} to {}'.format(blob.name, destination_uri))

In [0]:
!rm -rf data
# !rm ncbi_comm_use_csv_A.csv

In [0]:
# download first part of csv's
# download_to_local('data', blob_split_A)
download_to_local('data', blob_split_B)

File download Started…. Wait for the job to complete.
Blobs: ncbi_comm_use/000000000050.csv
Exported ncbi_comm_use/000000000050.csv to data/000000000050.csv
Blobs: ncbi_comm_use/000000000051.csv
Exported ncbi_comm_use/000000000051.csv to data/000000000051.csv
Blobs: ncbi_comm_use/000000000052.csv
Exported ncbi_comm_use/000000000052.csv to data/000000000052.csv
Blobs: ncbi_comm_use/000000000053.csv
Exported ncbi_comm_use/000000000053.csv to data/000000000053.csv
Blobs: ncbi_comm_use/000000000054.csv
Exported ncbi_comm_use/000000000054.csv to data/000000000054.csv
Blobs: ncbi_comm_use/000000000055.csv
Exported ncbi_comm_use/000000000055.csv to data/000000000055.csv
Blobs: ncbi_comm_use/000000000056.csv
Exported ncbi_comm_use/000000000056.csv to data/000000000056.csv
Blobs: ncbi_comm_use/000000000057.csv
Exported ncbi_comm_use/000000000057.csv to data/000000000057.csv
Blobs: ncbi_comm_use/000000000058.csv
Exported ncbi_comm_use/000000000058.csv to data/000000000058.csv
Blobs: ncbi_comm_us

In [0]:
# make combined csv
def combined_csv(data_folder):
    extension = 'csv'
    all_filenames = [i for i in glob.glob('{}/*.{}'.format(data_folder, extension))]
    #combine all files in the list
    combined_csv = pd.concat([pd.read_csv(f) for f in all_filenames ])
    return combined_csv

In [0]:
# blob_csv_A = combined_csv('data')
blob_csv_B = combined_csv('data')

In [0]:
len(blob_csv_B)

317187

In [0]:
# blob_csv_A.to_csv('gs://big_bio_med_bert_dump_csv/ncbi_comm_use')

In [0]:
# blob_csv_A.to_csv( "ncbi_comm_use_csv_A.csv", index=False, encoding='utf-8-sig')
blob_csv_B.to_csv( "ncbi_comm_use_csv_B.csv", index=False, encoding='utf-8-sig')

In [0]:
# copy files from gcs bucket
# !gsutil -m cp -r gs://big_bio_med_bert_dump_csv/ncbi_comm_use data/

In [0]:
# convert csv to txt
import csv
import sys
maxInt = sys.maxsize
csv.field_size_limit(maxInt)

csv_file = 'results-20200327-140409.csv'
txt_file = 'dataset.txt'
with open(txt_file, "w") as my_output_file:
    with open(csv_file, "r") as my_input_file:
        [ my_output_file.write(" ".join(row)+'\n') for row in csv.reader(my_input_file)]
    my_output_file.close()

In [0]:
from tensorflow.keras.utils import Progbar
def count_lines(filename):
  count = 0
  with open(filename) as fi:
    for line in fi:
      count += 1
  return count

In [0]:
# Apply normalization to entire dataset
RAW_DATA_FPATH = "dataset.txt"
PRC_DATA_FPATH = "processed_dataset.txt"

# apply normalization to the dataset
# this will take a minute or two

total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)

with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
  with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
    for l in fi:
      fo.write(normalize_text(l)+"\n")
      bar.add(1)



## Building the vocabulary

In [0]:
MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE =  32000#@param {type:"integer"} #3200
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"} #12800000
NUM_PLACEHOLDERS = 256 #@param {type:"integer"} # 256

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} --input_sentence_size={} '
               '--shuffle_input_sentence=true ' 
               '--bos_id=-1 --eos_id=-1 ').format(
              #  '--hard_vocab_limit=false ').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)

In [0]:
spm.SentencePieceTrainer.Train(SPM_COMMAND)

True

In [0]:
!ls

adc.json  dataset.txt		 results-20200327-140409.csv  tokenizer.model
bert	  processed_dataset.txt  sample_data		      tokenizer.vocab


In [0]:
!head -n 30 tokenizer.vocab

<unk>	0
▁the	-3.06425
▁of	-3.55653
▁and	-3.71063
▁in	-3.91177
▁to	-4.20013
▁a	-4.22434
s	-4.23072
▁	-4.25891
ed	-4.50333
▁0	-4.74595
▁for	-4.76638
▁with	-4.79447
ing	-4.82313
▁1	-4.87834
▁was	-5.08902
▁were	-5.12589
▁is	-5.12706
▁that	-5.16514
d	-5.18435
▁2	-5.26383
▁as	-5.27425
▁by	-5.29105
ly	-5.37886
▁be	-5.4833
▁3	-5.50984
▁on	-5.54573
▁are	-5.65021
▁from	-5.65672
▁5	-5.66851


In [0]:
def read_sentencepiece_vocab(filepath):
  voc = []
  with open(filepath, encoding='utf-8') as fi:
    for line in fi:
      voc.append(line.split("\t")[0])
  # skip the first <unk> token
  voc = voc[1:]
  return voc

In [0]:
import random

snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

Learnt vocab size: 31743
Sample tokens: ['▁burns', '▁buds', 'kis', 'imoto', '▁consent', 'm', 'cgcagcacc', 'oa', '▁unf', '▁responsibility']


In [0]:
def parse_sentencepiece_token(token):
    if token.startswith("▁"):
        return token[1:]
    else:
        return "##" + token

In [0]:
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))

In [0]:
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab

In [0]:
bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))

32000


In [0]:
# write vocabulary to file
VOC_FNAME = "vocab.txt"

with open(VOC_FNAME, "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

In [0]:
bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
bert_tokenizer.tokenize(first_10_rows.Body[0])[0:20]




['background',
 'screening',
 'for',
 'increased',
 'waist',
 'circumference',
 'and',
 'hyper',
 '##triglyceride',
 '##mia',
 '[UNK]',
 'the',
 'hyper',
 '##triglyceride',
 '##mic',
 '[UNK]',
 'waist',
 'phenotype',
 '[UNK]',
 'is']

## Generating pre-trained data

In [0]:
# sharding the dataset
!mkdir ./shards
!split -a 4 -l 5560 -d $PRC_DATA_FPATH ./shards/shard_
!ls ./shards/

shard_0000  shard_0002	shard_0004  shard_0006
shard_0001  shard_0003	shard_0005


In [0]:
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}
PROCESSES = 2 #@param {type:"integer"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}

For each shard we need to call `create_pretraining_data.py` script

In [0]:
XARGS_CMD = ("ls ./shards/ | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=./shards/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', 
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)

In [0]:
tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD



W0328 16:13:16.199445 140097396144000 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.


W0328 16:13:16.199707 140097396144000 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W0328 16:13:16.199911 140097396144000 module_wrapper.py:139] From /content/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.



W0328 16:13:16.207443 139782562772864 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.


W0328 16:13:16.207705 139782562772864 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W0328 16:13:16.20

Save model assets and checkpoints to GCS

In [0]:
BUCKET_NAME = "ebisong-covid-19-temp"
MODEL_DIR = "bert_model"
tf.gfile.MkDir(MODEL_DIR)

Hyparameter configuration for BERT BASE

In [0]:
# use this for BERT-base

bert_base_config = {
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": VOC_SIZE
}

with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo:
  json.dump(bert_base_config, fo, indent=2)
  
with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

In [0]:
if BUCKET_NAME:
  !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME

Copying file://bert_model/bert_config.json [Content-Type=application/json]...
Copying file://bert_model/vocab.txt [Content-Type=text/plain]...
Copying file://pretraining_data/shard_0003.tfrecord [Content-Type=application/octet-stream]...
Copying file://pretraining_data/shard_0000.tfrecord [Content-Type=application/octet-stream]...
Copying file://pretraining_data/shard_0002.tfrecord [Content-Type=application/octet-stream]...
/ [0/9 files][    0.0 B/ 50.7 MiB]   0% Done                                    / [0/9 files][    0.0 B/ 50.7 MiB]   0% Done                                    / [0/9 files][    0.0 B/ 50.7 MiB]   0% Done                                    / [0/9 files][    0.0 B/ 50.7 MiB]   0% Done                                    / [0/9 files][    0.0 B/ 50.7 MiB]   0% Done                                    Copying file://pretraining_data/shard_0005.tfrecord [Content-Type=application/octet-stream]...
/ [0/9 files][    0.0 B/ 50.7 MiB]   0% Done                            

# Train the BioMedBERT model

In [0]:
import logging
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

In [0]:
BUCKET_NAME = "ebisong-covid-19-temp" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
VOC_FNAME = "vocab.txt" #@param {type:"string"}

# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param

# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}
NUM_TPU_CORES = 8

if BUCKET_NAME:
  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
  BUCKET_PATH = "."

BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)

VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")

INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)

bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))

log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))


INFO:tensorflow:Using checkpoint: gs://ebisong-covid-19-temp/bert_model/model.ckpt-397500
INFO:tensorflow:Using 7 data shards


**Train on TPUs**

In [0]:
if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False

INFO:tensorflow:Using TPU runtime
INFO:tensorflow:TPU address is grpc://10.105.86.2:8470
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [0]:
model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=INIT_CHECKPOINT,
      learning_rate=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=10,
      use_tpu=USE_TPU,
      use_one_hot_embeddings=True)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=BERT_GCS_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)
  
train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

INFO:tensorflow:Using config: {'_model_dir': 'gs://ebisong-covid-19-temp/bert_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 2500, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.105.86.2:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f90a6fbe6d8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.105.86.2:8470', '_evaluation_master': 'grpc://10.105.86.2:8470', '_is_chief': True, '_num_ps_repli

In [0]:
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)

INFO:tensorflow:Querying Tensorflow master (grpc://10.105.86.2:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 4843677424759378058)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 5239469049172511286)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 16512040750734129682)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 15874515171332526121)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 482179580865349395)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worke