# Configure settings

In [None]:
#@markdown ### General Config
MAX_SEQ_LENGTH =  1024 #@param {type:"integer"}
PROCESSES = 2 #@param {type:"integer"}
NUM_TPU_CORES = 8 #@param {type:"integer"}
BUCKET_NAME = "theodore_jiang" #@param {type:"string"}
#@markdown ###### For if multiple models need to be evaluated at the same time: xxx is the placeholder for the individual model identifier (if only one is being evaluated xxx will only placehold for that single model)
MODEL_NAME_FORMAT = "bert_model_xxx" #@param {type:"string"}
LOGGING_DIR_NAME_FORMAT = "bert_model_xxx_loss_spam" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data_1024" #@param {type:"string"}
EVAL_DIR = "eval_data_1024" #@param {type:"string"}
TESTING_DIR = "testing_data_1024" #@param {type:"string"}
RUN_NAME = "human_pretraining" #@param {type:"string"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}

#@markdown ### Evaluation procedure config
EVAL_TEST_BATCH_SIZE = 64 #@param {type:"integer"}


#Clone the repo

In [None]:
#@markdown ######where to clone the repo into (only value that it can't be is "mutformer"):
REPO_DESTINATION_PATH = "code/mutformer" #@param {type:"string"}
import os,shutil
if not os.path.exists(REPO_DESTINATION_PATH):
  os.makedirs(REPO_DESTINATION_PATH)
else:
  shutil.rmtree(REPO_DESTINATION_PATH)
  os.makedirs(REPO_DESTINATION_PATH)
cmd = "git clone https://github.com/WGLab/mutformer.git \"" + REPO_DESTINATION_PATH + "\""
!{cmd}

#Imports/Authenticate for GCP

In [None]:
%tensorflow_version 1.x

import sys
import json
import random
import logging
import tensorflow as tf
import time
import os
import shutil
from google.colab import auth

print("Authorize for GCS:")
auth.authenticate_user()
print("Authorize done")

print("current date/time:",time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))

from glob import glob
from google.colab import auth, drive
from tensorflow.keras.utils import Progbar

if not os.path.exists("mutformer"):
  shutil.copytree(REPO_DESTINATION_PATH+"/mutformer_model_code","mutformer")
else:
  shutil.rmtree("mutformer")
  shutil.copytree(REPO_DESTINATION_PATH+"/mutformer_model_code","mutformer")
if "mutformer" in sys.path:
  sys.path.remove("mutformer")
sys.path.append("mutformer")

from mutformer import modeling, optimization, tokenization
from mutformer.modeling import BertModel,BertModelModified
from mutformer.run_pretraining import input_fn_builder, model_fn_builder

  
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

log.handlers = []
#@markdown ###### Whether or not to write logs to a file
DO_FILE_LOGGING = False #@param {type:"boolean"}
if DO_FILE_LOGGING:
  #@markdown ###### If using file logging, what path to write logs to
  FILE_LOGGING_PATH = '/content/drive/My Drive/spam.log' #@param {type:"string"}
  fh = logging.FileHandler(FILE_LOGGING_PATH)
  fh.setLevel(logging.INFO)
  fh.setFormatter(formatter)
  log.addHandler(fh)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
log.addHandler(ch)

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  raise Exception('Not connected to TPU runtime, TPU required to run mutformer (evaluation will also run really slow without TPU)')

#Auto Detect amount of train steps per epoch in the source data/Mount Drive if needed

In [None]:
#@markdown ###### if data was stored in drive, folder where the original data was stored (if data was stored in GCS, leave this item blank)
data_folder = "/content/drive/My Drive/BERT pretraining/mutformer_pretraining_data" #@param {type: "string"}
if "/content/drive" in data_folder:
  from google.colab import drive
  !fusermount -u /content/drive
  drive.flush_and_unmount()
  drive.mount('/content/drive', force_remount=True)
  DRIVE_PATH = "/content/drive/My Drive"

  data_path_train = drive_data_folder+"/train.txt" 

  lines = open(data_path_train).read().split("\n")
  SEQUENCES_PER_EPOCH = len(lines)
  STEPS_PER_EPOCH = int(SEQUENCES_PER_EPOCH/TRAIN_BATCH_SIZE)

  print("sequences per epoch:",SEQUENCES_PER_EPOCH, "steps per epoch:",STEPS_PER_EPOCH)
else:
  from tqdm import tqdm
  def steps_getter(input_files):
    tot_sequences = 0
    for input_file in input_files:
      print("reading:",input_file)

      d = tf.data.TFRecordDataset(input_file)

      with tf.Session() as sess:
        tot_sequences+=sess.run(d.reduce(0, lambda x,_: x+1))

    return tot_sequences

  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
  got_data = False
  while not got_data: ##will keep trying to access the data until available
    for f in range(0,DATA_COPIES):
        DATA_GCS_DIR_train = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR+"/"+str(f))
        train_input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR_train,'*tfrecord'))
        print("Using:",train_input_files)
        if len(train_input_files)>0:
          got_data = True
          try:
            SEQUENCES_PER_EPOCH = steps_getter(train_input_files)
            STEPS_PER_EPOCH = int(SEQUENCES_PER_EPOCH/TRAIN_BATCH_SIZE)
            print("sequences per epoch:",SEQUENCES_PER_EPOCH, "steps per epoch:",STEPS_PER_EPOCH)
            break
          except:
            got_data=False
    if got_data:
      break
    print("Could not find data, waiting for data generation...trying again in another "+str(1200)+" seconds.")
    time.sleep(1200)

#Evaluation

###Setting up Evaluation operation

In [None]:
def reload_ckpt(model_dir,logging_dir,current_ckpt,model,data_dir):
  BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, model_dir)

  CONFIG_FILE = os.path.join(BERT_GCS_DIR, "config.json")

  INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)
  print("init chkpt:",INIT_CHECKPOINT)
  print("current chkpt:",current_ckpt)
  if INIT_CHECKPOINT != current_ckpt:
    config = modeling.BertConfig.from_json_file(CONFIG_FILE)
    test_input_files = tf.gfile.Glob(os.path.join(data_dir,'*tfrecord'))
    log.info("Using {} data shards for testing".format(len(test_input_files)))
    model_fn = model_fn_builder(
          bert_config=config,
          logging_dir=logging_dir,
          init_checkpoint=INIT_CHECKPOINT,
          init_learning_rate=1,
          decay_per_step=1,
          num_warmup_steps=10,
          use_tpu=True,
          use_one_hot_embeddings=True,
          bert=model)

    
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=BERT_GCS_DIR,
        save_checkpoints_steps=1,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=1,
            num_shards=NUM_TPU_CORES,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=True,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=1,
        eval_batch_size=EVAL_TEST_BATCH_SIZE)
    
    test_input_fn = input_fn_builder(
        input_files=test_input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=False)
    return INIT_CHECKPOINT,estimator,test_input_fn,True
  else:
    return None,None,None,False

###Run Eval

In [None]:
import time
BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
#@markdown ###### whether to evaluate on the test set or the dev set ("test" or "dev") (test set will only run once, dev set will run continuously)
dataset = "test" #@param{type:"string"}

if dataset=="test":
  DATA_DIR = TESTING_DIR
elif dataset=="dev":
  DATA_DIR = EVAL_DIR
else:
  raise Exception("only datasets supported are dev and test")

models_to_evaluate = ["modified","orig","large","modified_medium","modified_large"] #@param #list of models to evaluate

name2model = {      ##dictionary mapping model architecture to each model name
    "modified":BertModelModified,
    "modified_medium":BertModelModified,
    "modified_large":BertModelModified,
    "orig":BertModel,
    "large":BertModel
}

def write_metrics(metrics,dir): ##evaluation metrics will be written into google drive to minimize interations with GCS
  gs = metrics["global_step"]
  print("global step",gs)

  tf.compat.v1.disable_eager_execution()
  tf.reset_default_graph()  
  for key,value in metrics.items():
    print(key,value)
    x_scalar = tf.constant(value)
    first_summary = tf.summary.scalar(name=key, tensor=x_scalar)

    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        writer = tf.summary.FileWriter(dir)
        sess.run(init)
        summary = sess.run(first_summary)
        writer.add_summary(summary, gs)
        writer.flush()
        print('Done with writing the scalar summary')
    time.sleep(1)
  if not os.path.exists(DRIVE_PATH+"/"+dir):
    os.makedirs(DRIVE_PATH+"/"+dir)
  cmd = "cp -r \""+dir+"/.\" \""+DRIVE_PATH+"/"+dir+"\""
  !{cmd}
current_ckpts = ["N/A" for i in range(len(models_to_evaluate))]

total_metrics = {}

while True:
  for n,model in enumerate(models_to_evaluate):
    MODEL_DIR = MODEL_NAME_FORMAT.replace("xxx",model)
    LOCAL_LOGGING_DIR = "{}/{}".format(LOGGING_DIR_NAME_FORMAT.replace("xxx",model),RUN_NAME)
    current_ckpt = current_ckpts[n]
    current_ckpt,estimator,test_input_fn,new = reload_ckpt(MODEL_DIR,GCS_LOGGING_DIR,current_ckpt,name2model[model],"{}/{}".format(BUCKET_PATH, DATA_DIR))
    current_ckpts[n] = current_ckpt
    if new:
      print("\n\nEVALUATING "+model+" MODEL\n\n")
      log.info("Using checkpoint: {}".format(current_ckpt))
      metrics = estimator.evaluate(input_fn=test_input_fn, steps=(TEST_STEPS if dataset=="test" else EVAL_STEPS))
      if dataset == "dev":
        write_metrics(metrics,LOCAL_LOGGING_DIR)
      else:
        total_metrics[LOCAL_LOGGING_DIR] = metrics

  print("finished 1 eval loop")
  if dataset=="test":
    break
  time.sleep(600)
if dataset == "test":
  for logging_dir,metrics in total_metrics.items():
    print("Printing metrics for:",logging_dir,"\n")
    for key,metric in metrics.items():
      print(key+":",metric)
    print("\n")