Note: If using a TPU from Google Cloud (not the Colab TPU), make sure to run this notebook on a VM with access to all GCP APIs, and make sure TPUs are enabled for the GCP project

Note: Run multiple copies of this notebook in multiple VMs to train multiple models in parallel

# Configure settings

In [None]:
#@markdown ## General Config
USE_GCP_TPU = False #@param {type:"boolean"}
MAX_SEQ_LENGTH =  1024#@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = False #@param {type:"boolean"}
PROCESSES = 2 #@param {type:"integer"}
BUCKET_NAME = "theodore_jiang" #@param {type:"string"}
MODEL_DIR = "bert_model_modified_large" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data_1024_modified_large" #@param {type:"string"}
LOGGING_DIR = "bert_model_pretraining_loss_spam" #@param {type:"string"}
#@markdown ######for miscellaneous temporary storage
TEMP_DIR = "modified_large_temp" #@param {type:"string"}
RUN_NAME = "bert_model_modified_large" #@param {type:"string"}

#@markdown ## Input data pipeline config
DATA_COPIES = 20 #@param {type:"integer"}
TRAIN_BATCH_SIZE =  32 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
#@markdown ######When checking data, how long to wait between each check (to minimize interaction with GCS, should be around the same time it takes for the data generation script to generate 1 epoch worth of data)
CHECK_DATA_EVERY_N_SECS = 1200 #@param {type:"integer"}

#@markdown ### Training procedure config
EVAL_BATCH_SIZE = 64
INIT_LEARNING_RATE =  2e-5#@param {type:"number"}
END_LEARNING_RATE = 1e-9
SAVE_CHECKPOINTS_STEPS =  1000#@param {type:"integer"}
NUM_TPU_CORES = 8
PLANNED_TOTAL_SEQUENCES_SEEN =  1e9 #@param {type:"number"}
#@markdown ###### (PLANNED_TOTAL_STEPS will override PLANNED_TOTAL_SEQUENCES_SEEN; if you wish to use PLANNED_TOTAL_SEQUENCES_SEEN, set PLANNED_TOTAL_STEPS to -1)
PLANNED_TOTAL_STEPS =  2e6#@param {type:"number"}
PLANNED_TOTAL_STEPS = PLANNED_TOTAL_SEQUENCES_SEEN/TRAIN_BATCH_SIZE if PLANNED_TOTAL_STEPS==-1 else PLANNED_TOTAL_STEPS
DECAY_PER_STEP = (END_LEARNING_RATE-INIT_LEARNING_RATE)/PLANNED_TOTAL_STEPS
#@markdown ## Model Config:
#@markdown ######Possible values for MODEL_TO_USE: orig, withConv:
MODEL_TO_USE = "withConv" #@param {type:"string"}
HIDDEN_SIZE =   768#@param {type:"integer"}
HIDDEN_LAYERS =  12#@param {type:"integer"}


CUSTOM_MODEL = None ##change this to a model_fn style function if you wish to use a custom model

bert_config = {
  "hidden_size": HIDDEN_SIZE, 
  "hidden_act": "gelu", 
  "initializer_range": 0.02, 
  "hidden_dropout_prob": 0.1, 
  "num_attention_heads": HIDDEN_LAYERS, 
  "type_vocab_size": 2, 
  "max_position_embeddings": MAX_SEQ_LENGTH, 
  "num_hidden_layers": HIDDEN_LAYERS, 
  "intermediate_size": 3072, 
  "attention_probs_dropout_prob": 0.1
}


#If running on a GCP TPU, use these commands prior to running this notebook

To ssh into the VM:

```
gcloud beta compute ssh --zone <COMPUTE ZONE> <VM NAME> --project <PROJECT NAME> -- -L 8888:localhost:8888
```

Make sure the port above matches the port below (in this case it's 8888)

```
sudo apt-get update
sudo apt-get -y install python3 python3-pip
sudo apt-get install pkg-config
sudo apt-get install libhdf5-serial-dev
sudo apt-get install libffi6 libffi-dev
sudo -H pip3 install jupyter tensorflow==1.14 google-api-python-client tqdm
sudo -H pip3 install jupyter_http_over_ws
jupyter serverextension enable --py jupyter_http_over_ws
jupyter notebook   --NotebookApp.allow_origin='https://colab.research.google.com'   --port=8888   --NotebookApp.port_retries=0   --no-browser

(one command):sudo apt-get update ; sudo apt-get -y install python3 python3-pip ; sudo apt-get install pkg-config ; sudo apt-get -y install libhdf5-serial-dev ; sudo apt-get install libffi6 libffi-dev; sudo -H pip3 install jupyter tensorflow==1.14 google-api-python-client tqdm ; sudo -H pip3 install jupyter_http_over_ws ; jupyter serverextension enable --py jupyter_http_over_ws ; jupyter notebook   --NotebookApp.allow_origin='https://colab.research.google.com'   --port=8888   --NotebookApp.port_retries=0   --no-browser
```
And then copy and paste the outputted link with "locahost: ..." into the colab connect to local runtime option


###Also run this code segment, which creates a TPU

In [None]:
GCE_PROJECT_NAME = "genome-project-319100" #@param {type:"string"}
TPU_ZONE = "us-central1-f" #@param {type:"string"}
TPU_NAME = "mutformer-tpu" #@param {type:"string"}

!gcloud alpha compute tpus create $TPU_NAME --accelerator-type=tpu-v2 --version=1.15.5 --zone=$TPU_ZONE ##create new TPU

!gsutil iam ch serviceAccount:`gcloud alpha compute tpus describe $TPU_NAME | grep serviceAccount | cut -d' ' -f2`:admin gs://theodore_jiang && echo 'Successfully set permissions!' ##give TPU access to GCS

#Clone the repo

In [None]:
if USE_GCP_TPU:
  !sudo apt-get -y install git
#@markdown ######where to clone the repo into (only value that it can't be is "mutformer"):
REPO_DESTINATION_PATH = "code/mutformer" #@param {type:"string"}
import os,shutil
if not os.path.exists(REPO_DESTINATION_PATH):
  os.makedirs(REPO_DESTINATION_PATH)
else:
  shutil.rmtree(REPO_DESTINATION_PATH)
  os.makedirs(REPO_DESTINATION_PATH)
cmd = "git clone https://github.com/WGLab/mutformer.git \"" + REPO_DESTINATION_PATH + "\""
!{cmd}

#Imports/Authenticate for GCP

In [None]:
if not USE_GCP_TPU:
  %tensorflow_version 1.x
  from google.colab import auth
  print("Authorize for GCS:")
  auth.authenticate_user()
  print("Authorize done")
import sys
import json
import random
import logging
import tensorflow as tf
import time
import os
import shutil

if not os.path.exists("mutformer"):
  shutil.copytree(REPO_DESTINATION_PATH+"/mutformer_model_code","mutformer")
else:
  shutil.rmtree("mutformer")
  shutil.copytree(REPO_DESTINATION_PATH+"/mutformer_model_code","mutformer")
if "mutformer" in sys.path:
  sys.path.remove("mutformer")
sys.path.append("mutformer")

from mutformer import modeling, optimization, tokenization
from mutformer.modeling import BertModel,BertModelModified
from mutformer.run_pretraining import input_fn_builder, model_fn_builder

if MODEL_TO_USE=="orig":
  MODEL = BertModel
  print("Using model: orig")
elif MODEL_TO_USE == "withConv":
  MODEL = BertModelModified
  print("Using model: withConv")
else:
  raise Exception("The model specified was not one of the available models: [\"orig\", \"withConv\"].")

  
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

log.handlers = []

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

#@markdown ###### Whether or not to write logs to a file
DO_FILE_LOGGING = True #@param {type:"boolean"}
if DO_FILE_LOGGING:
  #@markdown ###### If using file logging, what path to write logs to
  FILE_LOGGING_PATH = 'file_logging/spam.log' #@param {type:"string"}
  if not os.path.exists("/".join(FILE_LOGGING_PATH.split("/")[:-1])):
    os.makedirs("/".join(FILE_LOGGING_PATH.split("/")[:-1]))
  fh = logging.FileHandler(FILE_LOGGING_PATH)
  fh.setLevel(logging.INFO)
  fh.setFormatter(formatter)
  log.addHandler(fh)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
log.addHandler(ch)

if USE_GCP_TPU:
  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_NAME, zone=TPU_ZONE, project=GCE_PROJECT_NAME)
  TPU_ADDRESS = tpu_cluster_resolver.get_master()
  with tf.Session(TPU_ADDRESS) as session:
      log.info('TPU address is ' + TPU_ADDRESS)
      # Upload credentials to TPU.
      tf.contrib.cloud.configure_gcs(session)
else:
  if 'COLAB_TPU_ADDR' in os.environ:
    log.info("Using TPU runtime")
    TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

    with tf.Session(TPU_ADDRESS) as session:
      log.info('TPU address is ' + TPU_ADDRESS)
      # Upload credentials to TPU.
      with tf.gfile.Open('/content/adc.json', 'r') as f:
        auth_info = json.load(f)
      tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
      
  else:
    raise Exception('Not connected to TPU runtime, TPU required to run mutformer')


#Auto Detect amount of train steps per epoch in the source data/Mount Drive if needed

In [None]:
#@markdown ###### if not USE_GCP_TPU and data was stored in drive, folder where the original data was stored (if data was stored in GCS or USE_GCP_TPU is true, leave this item blank)
data_folder = "/content/drive/My Drive/BERT pretraining/mutformer_pretraining_data" #@param {type: "string"}

if not USE_GCP_TPU and "/content/drive" in data_folder:
  from google.colab import drive
  !fusermount -u /content/drive
  drive.flush_and_unmount()
  drive.mount('/content/drive', force_remount=True)
  DRIVE_PATH = "/content/drive/My Drive"

  data_path_train = drive_data_folder+"/train.txt" 

  lines = tf.gfile.Open(data_path_train).read().split("\n")
  SEQUENCES_PER_EPOCH = len(lines)
  STEPS_PER_EPOCH = int(SEQUENCES_PER_EPOCH/TRAIN_BATCH_SIZE)

  print("sequences per epoch:",SEQUENCES_PER_EPOCH, "steps per epoch:",STEPS_PER_EPOCH)
else:
  from tqdm import tqdm
  def steps_getter(input_files):
    tot_sequences = 0
    for input_file in input_files:
      print("reading:",input_file)

      d = tf.data.TFRecordDataset(input_file)

      with tf.Session() as sess:
        tot_sequences+=sess.run(d.reduce(0, lambda x,_: x+1))

    return tot_sequences

  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
  got_data = False
  while not got_data: ##will keep trying to access the data until available
    for f in range(0,DATA_COPIES):
        DATA_GCS_DIR_train = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR+"/"+str(f))
        train_input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR_train,'*tfrecord'))
        print("Using:",train_input_files)
        if len(train_input_files)>0:
          got_data = True
          try:
            SEQUENCES_PER_EPOCH = steps_getter(train_input_files)
            STEPS_PER_EPOCH = int(SEQUENCES_PER_EPOCH/TRAIN_BATCH_SIZE)
            print("sequences per epoch:",SEQUENCES_PER_EPOCH, "steps per epoch:",STEPS_PER_EPOCH)
            break
          except:
            got_data=False
    if got_data:
      break
    print("Could not find data, waiting for data generation...trying again in another "+str(CHECK_DATA_EVERY_N_SECS)+" seconds.")
    time.sleep(CHECK_DATA_EVERY_N_SECS)

#Upload config to GCS

In [None]:
bert_config["vocab_size"] = len(vocab.split("\n"))

with tf.gfile.Open("{}/config.json".format(MODEL_DIR), "w") as fo:
  json.dump(bert_config, fo, indent=2)

!gsutil -m cp -r $MODEL_DIR gs://$BUCKET_NAME

# Run Training

In [None]:
import time

operating_files = ["available_indexes","epoch"]

def download_tmp_files(operating_files): ##for downloading tmp files from drive or GCS
  for op_file in operating_files:
    if USE_GCP_TPU: ##If using GCP TPU, drive isn't available, so we need to store temporary files in GCS
      cmd = "gsutil -m cp -r gs://"+BUCKET_NAME+"/"+TEMP_DIR+"/"+op_file+".txt "+TEMP_DIR+"/"+op_file+".txt"
      !{cmd}
    else:
      shutil.copy(DRIVE_PATH+"/"+TEMP_DIR+"/"+op_file+".txt",TEMP_DIR+"/"+op_file+".txt")

def upload_tmp_files(operating_files): ##for uploading tmp files to drive or GCS
  for op_file in operating_files:
    if USE_GCP_TPU: ##doing the same thing as above^^
      cmd = "gsutil -m cp -r "+TEMP_DIR+"/"+op_file+".txt gs://"+BUCKET_NAME+"/"+TEMP_DIR+"/"+op_file+".txt"
      !{cmd}
    else:
      shutil.copy(TEMP_DIR+"/"+op_file+".txt",DRIVE_PATH+"/"+TEMP_DIR+"/"+op_file+".txt")

download_tmp_files(operating_files)

if os.path.exists(TEMP_DIR+"/epoch.txt"): ##detect the current epoch
  current_epoch = int(tf.gfile.Open(TEMP_DIR+"/epoch.txt").read())
else:
  current_epoch=0

BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
GCS_LOGGING_DIR = "{}/{}".format(BUCKET_PATH, LOGGING_DIR+"/"+RUN_NAME)

CONFIG_FILE = os.path.join(BERT_GCS_DIR, "config.json")

while True: ##training loop
  print("\n\n\n\n\nEPOCH:"+str(current_epoch)+"\n\n\n\n\n\n")
  
  got_data = False
  while not got_data:
    for f in range(0,DATA_COPIES): ##try to access any of the data bins
      print("trying to access training data from saved sector number "+str(f))
      DATA_GCS_DIR_train = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR+"/"+str(f))
      train_input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR_train,'*tfrecord'))
      print("train_input_files:",train_input_files)
      if len(train_input_files)>0:
        got_data = True
        break
      else:
        current_available_indexes = tf.gfile.Open(TEMP_DIR+"/available_indexes.txt").read().split("\n")[:-1]
        print("current:",current_available_indexes)

        new_inds = ""
        for ind in current_available_indexes:
          if int(ind) != f:
            new_inds += ind +"\n"
        print("new_inds",new_inds)
        tf.gfile.Open(TEMP_DIR+"/available_indexes.txt","w+").write(new_inds)
    upload_tmp_files(["available_indexes"])
    if not got_data:
      time.sleep(CHECK_DATA_EVERY_N_SECS)
        

  INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)
  try:
    INIT_CHECKPOINT_STEP = INIT_CHECKPOINT.split("-")[-1]
    print("CURRENT STEP:",INIT_CHECKPOINT_STEP)
    if int(INIT_CHECKPOINT_STEP)>=PLANNED_TOTAL_STEPS: ##if reached planed total steps, stop
      break
  except:
    pass

  config = modeling.BertConfig.from_json_file(CONFIG_FILE)

  log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
  log.info("Using {} data shards for training".format(len(train_input_files)))
  model_fn = model_fn_builder(
      bert_config=config,
      logging_dir=GCS_LOGGING_DIR,
      init_checkpoint=INIT_CHECKPOINT,
      init_learning_rate=INIT_LEARNING_RATE,
      decay_per_step=DECAY_PER_STEP,
      num_warmup_steps=10,
      use_tpu=True,
      use_one_hot_embeddings=True,
      bert=MODEL)

  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=BERT_GCS_DIR,
      save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
          num_shards=NUM_TPU_CORES,
          per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=True,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=TRAIN_BATCH_SIZE,
      eval_batch_size=EVAL_BATCH_SIZE)
    
  train_input_fn = input_fn_builder(
          input_files=train_input_files,
          max_seq_length=MAX_SEQ_LENGTH,
          max_predictions_per_seq=MAX_PREDICTIONS,
          is_training=True)
  try:
    estimator.train(input_fn=train_input_fn, steps=STEPS_PER_EPOCH)
    current_epoch+=1
  except:
    pass

  # For dynamic masking, a parallel data generation is used. This portion deletes the current data and 
  # updates the list of available data via a txt (to minimize interaction with GCS) so that the data 
  # generation algortihm can generate the data with different masking positions 
  cmd = "gsutil -m rm -r "+DATA_GCS_DIR_train
  !{cmd}
  current_available_indexes = tf.gfile.Open(TEMP_DIR+"/available_indexes.txt").read().split("\n")[:-1]
  print("current:",current_available_indexes)

  new_inds = ""
  for ind in current_available_indexes:
    if int(ind) != f:
      new_inds += ind +"\n"
  print("new_inds",new_inds)
  tf.gfile.Open(TEMP_DIR+"/available_indexes.txt","w+").write(new_inds)
  tf.gfile.Open(TEMP_DIR+"/epoch.txt","w+").write(str(current_epoch))
  upload_tmp_files(operating_files)
